Showing preview only (323K chars total). Download the full file or copy to clipboard to get everything.
Repository: NVlabs/stylegan
Branch: master
Commit: 1e0d5c781384
Files: 31
Total size: 310.7 KB
Directory structure:
gitextract_gssg1h6e/
├── LICENSE.txt
├── README.md
├── config.py
├── dataset_tool.py
├── dnnlib/
│ ├── __init__.py
│ ├── submission/
│ │ ├── __init__.py
│ │ ├── _internal/
│ │ │ └── run.py
│ │ ├── run_context.py
│ │ └── submit.py
│ ├── tflib/
│ │ ├── __init__.py
│ │ ├── autosummary.py
│ │ ├── network.py
│ │ ├── optimizer.py
│ │ └── tfutil.py
│ └── util.py
├── generate_figures.py
├── metrics/
│ ├── __init__.py
│ ├── frechet_inception_distance.py
│ ├── linear_separability.py
│ ├── metric_base.py
│ └── perceptual_path_length.py
├── pretrained_example.py
├── run_metrics.py
├── train.py
└── training/
├── __init__.py
├── dataset.py
├── loss.py
├── misc.py
├── networks_progan.py
├── networks_stylegan.py
└── training_loop.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE.txt
================================================
Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
Attribution-NonCommercial 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial 4.0 International Public
License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
j. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
k. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
l. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the "Licensor." The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.
================================================
FILE: README.md
================================================
## StyleGAN — Official TensorFlow Implementation





**Picture:** *These people are not real – they were produced by our generator that allows control over different aspects of the image.*
This repository contains the official TensorFlow implementation of the following paper:
> **A Style-Based Generator Architecture for Generative Adversarial Networks**<br>
> Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)<br>
> https://arxiv.org/abs/1812.04948
>
> **Abstract:** *We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.*
For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)
**★★★ NEW: [StyleGAN2-ADA-PyTorch](https://github.com/NVlabs/stylegan2-ada-pytorch) is now available; see the full list of versions [here](https://nvlabs.github.io/stylegan2/versions.html) ★★★**
## Resources
Material related to our paper is available via the following links:
- Paper: https://arxiv.org/abs/1812.04948
- Video: https://youtu.be/kSLJriaOumA
- Code: https://github.com/NVlabs/stylegan
- FFHQ: https://github.com/NVlabs/ffhq-dataset
Additional material can be found on Google Drive:
| Path | Description
| :--- | :----------
| [StyleGAN](https://drive.google.com/open?id=1uka3a1noXHAydRPRbknqwKVGODvnmUBX) | Main folder.
| ├ [stylegan-paper.pdf](https://drive.google.com/open?id=1v-HkF3Ehrpon7wVIx4r5DLcko_U_V6Lt) | High-quality version of the paper PDF.
| ├ [stylegan-video.mp4](https://drive.google.com/open?id=1uzwkZHQX_9pYg1i0d1Nbe3D9xPO8-qBf) | High-quality version of the result video.
| ├ [images](https://drive.google.com/open?id=1-l46akONUWF6LCpDoeq63H53rD7MeiTd) | Example images produced using our generator.
| │ ├ [representative-images](https://drive.google.com/open?id=1ToY5P4Vvf5_c3TyUizQ8fckFFoFtBvD8) | High-quality images to be used in articles, blog posts, etc.
| │ └ [100k-generated-images](https://drive.google.com/open?id=100DJ0QXyG89HZzB4w2Cbyf4xjNK54cQ1) | 100,000 generated images for different amounts of truncation.
| │    ├ [ffhq-1024x1024](https://drive.google.com/open?id=14lm8VRN1pr4g_KVe6_LvyDX1PObst6d4) | Generated using Flickr-Faces-HQ dataset at 1024×1024.
| │    ├ [bedrooms-256x256](https://drive.google.com/open?id=1Vxz9fksw4kgjiHrvHkX4Hze4dyThFW6t) | Generated using LSUN Bedroom dataset at 256×256.
| │    ├ [cars-512x384](https://drive.google.com/open?id=1MFCvOMdLE2_mpeLPTiDw5dxc2CRuKkzS) | Generated using LSUN Car dataset at 512×384.
| │    └ [cats-256x256](https://drive.google.com/open?id=1gq-Gj3GRFiyghTPKhp8uDMA9HV_0ZFWQ) | Generated using LSUN Cat dataset at 256×256.
| ├ [videos](https://drive.google.com/open?id=1N8pOd_Bf8v89NGUaROdbD8-ayLPgyRRo) | Example videos produced using our generator.
| │ └ [high-quality-video-clips](https://drive.google.com/open?id=1NFO7_vH0t98J13ckJYFd7kuaTkyeRJ86) | Individual segments of the result video as high-quality MP4.
| ├ [ffhq-dataset](https://drive.google.com/open?id=1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP) | Raw data for the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset).
| └ [networks](https://drive.google.com/open?id=1MASQyN5m0voPcx7-9K0r5gObhvvPups7) | Pre-trained networks as pickled instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py).
|    ├ [stylegan-ffhq-1024x1024.pkl](https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ) | StyleGAN trained with Flickr-Faces-HQ dataset at 1024×1024.
|    ├ [stylegan-celebahq-1024x1024.pkl](https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf) | StyleGAN trained with CelebA-HQ dataset at 1024×1024.
|    ├ [stylegan-bedrooms-256x256.pkl](https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF) | StyleGAN trained with LSUN Bedroom dataset at 256×256.
|    ├ [stylegan-cars-512x384.pkl](https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3) | StyleGAN trained with LSUN Car dataset at 512×384.
|    ├ [stylegan-cats-256x256.pkl](https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ) | StyleGAN trained with LSUN Cat dataset at 256×256.
|    └ [metrics](https://drive.google.com/open?id=1MvYdWCBuMfnoYGptRH-AgKLbPTsIQLhl) | Auxiliary networks for the quality and disentanglement metrics.
|       ├ [inception_v3_features.pkl](https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn) | Standard [Inception-v3](https://arxiv.org/abs/1512.00567) classifier that outputs a raw feature vector.
|       ├ [vgg16_zhang_perceptual.pkl](https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2) | Standard [LPIPS](https://arxiv.org/abs/1801.03924) metric to estimate perceptual similarity.
|       ├ [celebahq-classifier-00-male.pkl](https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX) | Binary classifier trained to detect a single attribute of CelebA-HQ.
|       └ ⋯ | Please see the file listing for remaining networks.
## Licenses
All material, excluding the Flickr-Faces-HQ dataset, is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made.
For license information regarding the FFHQ dataset, please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).
`inception_v3_features.pkl` and `inception_v3_softmax.pkl` are derived from the pre-trained [Inception-v3](https://arxiv.org/abs/1512.00567) network by Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna. The network was originally shared under [Apache 2.0](https://github.com/tensorflow/models/blob/master/LICENSE) license on the [TensorFlow Models](https://github.com/tensorflow/models) repository.
`vgg16.pkl` and `vgg16_zhang_perceptual.pkl` are derived from the pre-trained [VGG-16](https://arxiv.org/abs/1409.1556) network by Karen Simonyan and Andrew Zisserman. The network was originally shared under [Creative Commons BY 4.0](https://creativecommons.org/licenses/by/4.0/) license on the [Very Deep Convolutional Networks for Large-Scale Visual Recognition](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) project page.
`vgg16_zhang_perceptual.pkl` is further derived from the pre-trained [LPIPS](https://arxiv.org/abs/1801.03924) weights by Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, and Oliver Wang. The weights were originally shared under [BSD 2-Clause "Simplified" License](https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE) on the [PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity) repository.
## System requirements
* Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.
* 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer.
* TensorFlow 1.10.0 or newer with GPU support.
* One or more high-end NVIDIA GPUs with at least 11GB of DRAM. We recommend NVIDIA DGX-1 with 8 Tesla V100 GPUs.
* NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer.
## Using pre-trained networks
A minimal example of using a pre-trained StyleGAN generator is given in [pretrained_example.py](./pretrained_example.py). When executed, the script downloads a pre-trained StyleGAN generator from Google Drive and uses it to generate an image:
```
> python pretrained_example.py
Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done
Gs Params OutputShape WeightShape
--- --- --- ---
latents_in - (?, 512) -
...
images_out - (?, 3, 1024, 1024) -
--- --- --- ---
Total 26219627
> ls results
example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP
```
A more advanced example is given in [generate_figures.py](./generate_figures.py). The script reproduces the figures from our paper in order to illustrate style mixing, noise inputs, and truncation:
```
> python generate_figures.py
results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu
results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6
results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG
results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_
results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v
results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr
results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke
results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W
```
The pre-trained networks are stored as standard pickle files on Google Drive:
```
# Load pre-trained network.
url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
```
The above code downloads the file and unpickles it to yield 3 instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py). To generate images, you will typically want to use `Gs` – the other two networks are provided for completeness. In order for `pickle.load()` to work, you will need to have the `dnnlib` source directory in your PYTHONPATH and a `tf.Session` set as default. The session can initialized by calling `dnnlib.tflib.init_tf()`.
There are three ways to use the pre-trained generator:
1. Use `Gs.run()` for immediate-mode operation where the inputs and outputs are numpy arrays:
```
# Pick latent vector.
rnd = np.random.RandomState(5)
latents = rnd.randn(1, Gs.input_shape[1])
# Generate image.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
```
The first argument is a batch of latent vectors of shape `[num, 512]`. The second argument is reserved for class labels (not used by StyleGAN). The remaining keyword arguments are optional and can be used to further modify the operation (see below). The output is a batch of images, whose format is dictated by the `output_transform` argument.
2. Use `Gs.get_output_for()` to incorporate the generator as a part of a larger TensorFlow expression:
```
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
images = tflib.convert_images_to_uint8(images)
result_expr.append(inception_clone.get_output_for(images))
```
The above code is from [metrics/frechet_inception_distance.py](./metrics/frechet_inception_distance.py). It generates a batch of random images and feeds them directly to the [Inception-v3](https://arxiv.org/abs/1512.00567) network without having to convert the data to numpy arrays in between.
3. Look up `Gs.components.mapping` and `Gs.components.synthesis` to access individual sub-networks of the generator. Similar to `Gs`, the sub-networks are represented as independent instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py):
```
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
```
The above code is from [generate_figures.py](./generate_figures.py). It first transforms a batch of latent vectors into the intermediate *W* space using the mapping network and then turns these vectors into a batch of images using the synthesis network. The `dlatents` array stores a separate copy of the same *w* vector for each layer of the synthesis network to facilitate style mixing.
The exact details of the generator are defined in [training/networks_stylegan.py](./training/networks_stylegan.py) (see `G_style`, `G_mapping`, and `G_synthesis`). The following keyword arguments can be specified to modify the behavior when calling `run()` and `get_output_for()`:
* `truncation_psi` and `truncation_cutoff` control the truncation trick that that is performed by default when using `Gs` (ψ=0.7, cutoff=8). It can be disabled by setting `truncation_psi=1` or `is_validation=True`, and the image quality can be further improved at the cost of variation by setting e.g. `truncation_psi=0.5`. Note that truncation is always disabled when using the sub-networks directly. The average *w* needed to manually perform the truncation trick can be looked up using `Gs.get_var('dlatent_avg')`.
* `randomize_noise` determines whether to use re-randomize the noise inputs for each generated image (`True`, default) or whether to use specific noise values for the entire minibatch (`False`). The specific values can be accessed via the `tf.Variable` instances that are found using `[var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]`.
* When using the mapping network directly, you can specify `dlatent_broadcast=None` to disable the automatic duplication of `dlatents` over the layers of the synthesis network.
* Runtime performance can be fine-tuned via `structure='fixed'` and `dtype='float16'`. The former disables support for progressive growing, which is not needed for a fully-trained generator, and the latter performs all computation using half-precision floating point arithmetic.
## Preparing datasets for training
The training and evaluation scripts operate on datasets stored as multi-resolution TFRecords. Each dataset is represented by a directory containing the same image data in several resolutions to enable efficient streaming. There is a separate *.tfrecords file for each resolution, and if the dataset contains labels, they are stored in a separate file as well. By default, the scripts expect to find the datasets at `datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords`. The directory can be changed by editing [config.py](./config.py):
```
result_dir = 'results'
data_dir = 'datasets'
cache_dir = 'cache'
```
To obtain the FFHQ dataset (`datasets/ffhq`), please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).
To obtain the CelebA-HQ dataset (`datasets/celebahq`), please refer to the [Progressive GAN repository](https://github.com/tkarras/progressive_growing_of_gans).
To obtain other datasets, including LSUN, please consult their corresponding project pages. The datasets can be converted to multi-resolution TFRecords using the provided [dataset_tool.py](./dataset_tool.py):
```
> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256
> python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384
> python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256
> python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10
> python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images
```
## Training networks
Once the datasets are set up, you can train your own StyleGAN networks as follows:
1. Edit [train.py](./train.py) to specify the dataset and training configuration by uncommenting or editing specific lines.
2. Run the training script with `python train.py`.
3. The results are written to a newly created directory `results/<ID>-<DESCRIPTION>`.
4. The training may take several days (or weeks) to complete, depending on the configuration.
By default, `train.py` is configured to train the highest-quality StyleGAN (configuration F in Table 1) for the FFHQ dataset at 1024×1024 resolution using 8 GPUs. Please note that we have used 8 GPUs in all of our experiments. Training with fewer GPUs may not produce identical results – if you wish to compare against our technique, we strongly recommend using the same number of GPUs.
Expected training times for the default configuration using Tesla V100 GPUs:
| GPUs | 1024×1024 | 512×512 | 256×256 |
| :--- | :-------------- | :------------ | :------------ |
| 1 | 41 days 4 hours | 24 days 21 hours | 14 days 22 hours |
| 2 | 21 days 22 hours | 13 days 7 hours | 9 days 5 hours |
| 4 | 11 days 8 hours | 7 days 0 hours | 4 days 21 hours |
| 8 | 6 days 14 hours | 4 days 10 hours | 3 days 8 hours |
## Evaluating quality and disentanglement
The quality and disentanglement metrics used in our paper can be evaluated using [run_metrics.py](./run_metrics.py). By default, the script will evaluate the Fréchet Inception Distance (`fid50k`) for the pre-trained FFHQ generator and write the results into a newly created directory under `results`. The exact behavior can be changed by uncommenting or editing specific lines in [run_metrics.py](./run_metrics.py).
Expected evaluation time and results for the pre-trained FFHQ generator using one Tesla V100 GPU:
| Metric | Time | Result | Description
| :----- | :--- | :----- | :----------
| fid50k | 16 min | 4.4159 | Fréchet Inception Distance using 50,000 images.
| ppl_zfull | 55 min | 664.8854 | Perceptual Path Length for full paths in *Z*.
| ppl_wfull | 55 min | 233.3059 | Perceptual Path Length for full paths in *W*.
| ppl_zend | 55 min | 666.1057 | Perceptual Path Length for path endpoints in *Z*.
| ppl_wend | 55 min | 197.2266 | Perceptual Path Length for path endpoints in *W*.
| ls | 10 hours | z: 165.0106<br>w: 3.7447 | Linear Separability in *Z* and *W*.
Please note that the exact results may vary from run to run due to the non-deterministic nature of TensorFlow.
## Acknowledgements
We thank Jaakko Lehtinen, David Luebke, and Tuomas Kynkäänniemi for in-depth discussions and helpful comments; Janne Hellsten, Tero Kuosmanen, and Pekka Jänis for compute infrastructure and help with the code release.
================================================
FILE: config.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Global configuration."""
#----------------------------------------------------------------------------
# Paths.
result_dir = 'results'
data_dir = 'datasets'
cache_dir = 'cache'
run_dir_ignore = ['results', 'datasets', 'cache']
#----------------------------------------------------------------------------
================================================
FILE: dataset_tool.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN."""
# pylint: disable=too-many-lines
import os
import sys
import glob
import argparse
import threading
import six.moves.queue as Queue # pylint: disable=import-error
import traceback
import numpy as np
import tensorflow as tf
import PIL.Image
import dnnlib.tflib as tflib
from training import dataset
#----------------------------------------------------------------------------
def error(msg):
print('Error: ' + msg)
exit(1)
#----------------------------------------------------------------------------
class TFRecordExporter:
def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10):
self.tfrecord_dir = tfrecord_dir
self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
self.expected_images = expected_images
self.cur_images = 0
self.shape = None
self.resolution_log2 = None
self.tfr_writers = []
self.print_progress = print_progress
self.progress_interval = progress_interval
if self.print_progress:
print('Creating dataset "%s"' % tfrecord_dir)
if not os.path.isdir(self.tfrecord_dir):
os.makedirs(self.tfrecord_dir)
assert os.path.isdir(self.tfrecord_dir)
def close(self):
if self.print_progress:
print('%-40s\r' % 'Flushing data...', end='', flush=True)
for tfr_writer in self.tfr_writers:
tfr_writer.close()
self.tfr_writers = []
if self.print_progress:
print('%-40s\r' % '', end='', flush=True)
print('Added %d images.' % self.cur_images)
def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order.
order = np.arange(self.expected_images)
np.random.RandomState(123).shuffle(order)
return order
def add_image(self, img):
if self.print_progress and self.cur_images % self.progress_interval == 0:
print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
if self.shape is None:
self.shape = img.shape
self.resolution_log2 = int(np.log2(self.shape[1]))
assert self.shape[0] in [1, 3]
assert self.shape[1] == self.shape[2]
assert self.shape[1] == 2**self.resolution_log2
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
for lod in range(self.resolution_log2 - 1):
tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod)
self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
assert img.shape == self.shape
for lod, tfr_writer in enumerate(self.tfr_writers):
if lod:
img = img.astype(np.float32)
img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25
quant = np.rint(img).clip(0, 255).astype(np.uint8)
ex = tf.train.Example(features=tf.train.Features(feature={
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))
tfr_writer.write(ex.SerializeToString())
self.cur_images += 1
def add_labels(self, labels):
if self.print_progress:
print('%-40s\r' % 'Saving labels...', end='', flush=True)
assert labels.shape[0] == self.cur_images
with open(self.tfr_prefix + '-rxx.labels', 'wb') as f:
np.save(f, labels.astype(np.float32))
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
#----------------------------------------------------------------------------
class ExceptionInfo(object):
def __init__(self):
self.value = sys.exc_info()[1]
self.traceback = traceback.format_exc()
#----------------------------------------------------------------------------
class WorkerThread(threading.Thread):
def __init__(self, task_queue):
threading.Thread.__init__(self)
self.task_queue = task_queue
def run(self):
while True:
func, args, result_queue = self.task_queue.get()
if func is None:
break
try:
result = func(*args)
except:
result = ExceptionInfo()
result_queue.put((result, args))
#----------------------------------------------------------------------------
class ThreadPool(object):
def __init__(self, num_threads):
assert num_threads >= 1
self.task_queue = Queue.Queue()
self.result_queues = dict()
self.num_threads = num_threads
for _idx in range(self.num_threads):
thread = WorkerThread(self.task_queue)
thread.daemon = True
thread.start()
def add_task(self, func, args=()):
assert hasattr(func, '__call__') # must be a function
if func not in self.result_queues:
self.result_queues[func] = Queue.Queue()
self.task_queue.put((func, args, self.result_queues[func]))
def get_result(self, func): # returns (result, args)
result, args = self.result_queues[func].get()
if isinstance(result, ExceptionInfo):
print('\n\nWorker thread caught an exception:\n' + result.traceback)
raise result.value
return result, args
def finish(self):
for _idx in range(self.num_threads):
self.task_queue.put((None, (), None))
def __enter__(self): # for 'with' statement
return self
def __exit__(self, *excinfo):
self.finish()
def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None):
if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4
assert max_items_in_flight >= 1
results = []
retire_idx = [0]
def task_func(prepared, _idx):
return process_func(prepared)
def retire_result():
processed, (_prepared, idx) = self.get_result(task_func)
results[idx] = processed
while retire_idx[0] < len(results) and results[retire_idx[0]] is not None:
yield post_func(results[retire_idx[0]])
results[retire_idx[0]] = None
retire_idx[0] += 1
for idx, item in enumerate(item_iterator):
prepared = pre_func(item)
results.append(None)
self.add_task(func=task_func, args=(prepared, idx))
while retire_idx[0] < idx - max_items_in_flight + 2:
for res in retire_result(): yield res
while retire_idx[0] < len(results):
for res in retire_result(): yield res
#----------------------------------------------------------------------------
def display(tfrecord_dir):
print('Loading dataset "%s"' % tfrecord_dir)
tflib.init_tf({'gpu_options.allow_growth': True})
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0)
tflib.init_uninitialized_vars()
import cv2 # pip install opencv-python
idx = 0
while True:
try:
images, labels = dset.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
break
if idx == 0:
print('Displaying images')
cv2.namedWindow('dataset_tool')
print('Press SPACE or ENTER to advance, ESC to exit')
print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist()))
cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR
idx += 1
if cv2.waitKey() == 27:
break
print('\nDisplayed %d images.' % idx)
#----------------------------------------------------------------------------
def extract(tfrecord_dir, output_dir):
print('Loading dataset "%s"' % tfrecord_dir)
tflib.init_tf({'gpu_options.allow_growth': True})
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0)
tflib.init_uninitialized_vars()
print('Extracting images to "%s"' % output_dir)
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
idx = 0
while True:
if idx % 10 == 0:
print('%d\r' % idx, end='', flush=True)
try:
images, _labels = dset.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
break
if images.shape[1] == 1:
img = PIL.Image.fromarray(images[0][0], 'L')
else:
img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB')
img.save(os.path.join(output_dir, 'img%08d.png' % idx))
idx += 1
print('Extracted %d images.' % idx)
#----------------------------------------------------------------------------
def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
max_label_size = 0 if ignore_labels else 'full'
print('Loading dataset "%s"' % tfrecord_dir_a)
tflib.init_tf({'gpu_options.allow_growth': True})
dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
print('Loading dataset "%s"' % tfrecord_dir_b)
dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
tflib.init_uninitialized_vars()
print('Comparing datasets')
idx = 0
identical_images = 0
identical_labels = 0
while True:
if idx % 100 == 0:
print('%d\r' % idx, end='', flush=True)
try:
images_a, labels_a = dset_a.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
images_a, labels_a = None, None
try:
images_b, labels_b = dset_b.get_minibatch_np(1)
except tf.errors.OutOfRangeError:
images_b, labels_b = None, None
if images_a is None or images_b is None:
if images_a is not None or images_b is not None:
print('Datasets contain different number of images')
break
if images_a.shape == images_b.shape and np.all(images_a == images_b):
identical_images += 1
else:
print('Image %d is different' % idx)
if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b):
identical_labels += 1
else:
print('Label %d is different' % idx)
idx += 1
print('Identical images: %d / %d' % (identical_images, idx))
if not ignore_labels:
print('Identical labels: %d / %d' % (identical_labels, idx))
#----------------------------------------------------------------------------
def create_mnist(tfrecord_dir, mnist_dir):
print('Loading MNIST from "%s"' % mnist_dir)
import gzip
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
images = np.frombuffer(file.read(), np.uint8, offset=16)
with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file:
labels = np.frombuffer(file.read(), np.uint8, offset=8)
images = images.reshape(-1, 1, 28, 28)
images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0)
assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8
assert labels.shape == (60000,) and labels.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123):
print('Loading MNIST from "%s"' % mnist_dir)
import gzip
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
images = np.frombuffer(file.read(), np.uint8, offset=16)
images = images.reshape(-1, 28, 28)
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
with TFRecordExporter(tfrecord_dir, num_images) as tfr:
rnd = np.random.RandomState(random_seed)
for _idx in range(num_images):
tfr.add_image(images[rnd.randint(images.shape[0], size=3)])
#----------------------------------------------------------------------------
def create_cifar10(tfrecord_dir, cifar10_dir):
print('Loading CIFAR-10 from "%s"' % cifar10_dir)
import pickle
images = []
labels = []
for batch in range(1, 6):
with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file:
data = pickle.load(file, encoding='latin1')
images.append(data['data'].reshape(-1, 3, 32, 32))
labels.append(data['labels'])
images = np.concatenate(images)
labels = np.concatenate(labels)
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
assert labels.shape == (50000,) and labels.dtype == np.int32
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_cifar100(tfrecord_dir, cifar100_dir):
print('Loading CIFAR-100 from "%s"' % cifar100_dir)
import pickle
with open(os.path.join(cifar100_dir, 'train'), 'rb') as file:
data = pickle.load(file, encoding='latin1')
images = data['data'].reshape(-1, 3, 32, 32)
labels = np.array(data['fine_labels'])
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
assert labels.shape == (50000,) and labels.dtype == np.int32
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 99
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_svhn(tfrecord_dir, svhn_dir):
print('Loading SVHN from "%s"' % svhn_dir)
import pickle
images = []
labels = []
for batch in range(1, 4):
with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file:
data = pickle.load(file, encoding='latin1')
images.append(data[0])
labels.append(data[1])
images = np.concatenate(images)
labels = np.concatenate(labels)
assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8
assert labels.shape == (73257,) and labels.dtype == np.uint8
assert np.min(images) == 0 and np.max(images) == 255
assert np.min(labels) == 0 and np.max(labels) == 9
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
onehot[np.arange(labels.size), labels] = 1.0
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
tfr.add_image(images[order[idx]])
tfr.add_labels(onehot[order])
#----------------------------------------------------------------------------
def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None):
print('Loading LSUN dataset from "%s"' % lmdb_dir)
import lmdb # pip install lmdb # pylint: disable=import-error
import cv2 # pip install opencv-python
import io
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter
if max_images is None:
max_images = total_images
with TFRecordExporter(tfrecord_dir, max_images) as tfr:
for _idx, (_key, value) in enumerate(txn.cursor()):
try:
try:
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
if img is None:
raise IOError('cv2.imdecode failed')
img = img[:, :, ::-1] # BGR => RGB
except IOError:
img = np.asarray(PIL.Image.open(io.BytesIO(value)))
crop = np.min(img.shape[:2])
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS)
img = np.asarray(img)
img = img.transpose([2, 0, 1]) # HWC => CHW
tfr.add_image(img)
except:
print(sys.exc_info()[1])
if tfr.cur_images == max_images:
break
#----------------------------------------------------------------------------
def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None):
assert width == 2 ** int(np.round(np.log2(width)))
assert height <= width
print('Loading LSUN dataset from "%s"' % lmdb_dir)
import lmdb # pip install lmdb # pylint: disable=import-error
import cv2 # pip install opencv-python
import io
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter
if max_images is None:
max_images = total_images
with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr:
for idx, (_key, value) in enumerate(txn.cursor()):
try:
try:
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
if img is None:
raise IOError('cv2.imdecode failed')
img = img[:, :, ::-1] # BGR => RGB
except IOError:
img = np.asarray(PIL.Image.open(io.BytesIO(value)))
ch = int(np.round(width * img.shape[0] / img.shape[1]))
if img.shape[1] < width or ch < height:
continue
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
img = PIL.Image.fromarray(img, 'RGB')
img = img.resize((width, height), PIL.Image.ANTIALIAS)
img = np.asarray(img)
img = img.transpose([2, 0, 1]) # HWC => CHW
canvas = np.zeros([3, width, width], dtype=np.uint8)
canvas[:, (width - height) // 2 : (width + height) // 2] = img
tfr.add_image(canvas)
print('\r%d / %d => %d ' % (idx + 1, total_images, tfr.cur_images), end='')
except:
print(sys.exc_info()[1])
if tfr.cur_images == max_images:
break
print()
#----------------------------------------------------------------------------
def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121):
print('Loading CelebA from "%s"' % celeba_dir)
glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png')
image_filenames = sorted(glob.glob(glob_pattern))
expected_images = 202599
if len(image_filenames) != expected_images:
error('Expected to find %d images' % expected_images)
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
order = tfr.choose_shuffled_order()
for idx in range(order.size):
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
assert img.shape == (218, 178, 3)
img = img[cy - 64 : cy + 64, cx - 64 : cx + 64]
img = img.transpose(2, 0, 1) # HWC => CHW
tfr.add_image(img)
#----------------------------------------------------------------------------
def create_from_images(tfrecord_dir, image_dir, shuffle):
print('Loading images from "%s"' % image_dir)
image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
if len(image_filenames) == 0:
error('No input images found')
img = np.asarray(PIL.Image.open(image_filenames[0]))
resolution = img.shape[0]
channels = img.shape[2] if img.ndim == 3 else 1
if img.shape[1] != resolution:
error('Input images must have the same width and height')
if resolution != 2 ** int(np.floor(np.log2(resolution))):
error('Input image resolution must be a power-of-two')
if channels not in [1, 3]:
error('Input images must be stored as RGB or grayscale')
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
for idx in range(order.size):
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
if channels == 1:
img = img[np.newaxis, :, :] # HW => CHW
else:
img = img.transpose([2, 0, 1]) # HWC => CHW
tfr.add_image(img)
#----------------------------------------------------------------------------
def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle):
print('Loading HDF5 archive from "%s"' % hdf5_filename)
import h5py # conda install h5py
with h5py.File(hdf5_filename, 'r') as hdf5_file:
hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3])
with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr:
order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0])
for idx in range(order.size):
tfr.add_image(hdf5_data[order[idx]])
npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy'
if os.path.isfile(npy_filename):
tfr.add_labels(np.load(npy_filename)[order])
#----------------------------------------------------------------------------
def execute_cmdline(argv):
prog = argv[0]
parser = argparse.ArgumentParser(
prog = prog,
description = 'Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.',
epilog = 'Type "%s <command> -h" for more information.' % prog)
subparsers = parser.add_subparsers(dest='command')
subparsers.required = True
def add_command(cmd, desc, example=None):
epilog = 'Example: %s %s' % (prog, example) if example is not None else None
return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)
p = add_command( 'display', 'Display images in dataset.',
'display datasets/mnist')
p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
p = add_command( 'extract', 'Extract images from dataset.',
'extract datasets/mnist mnist-images')
p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
p.add_argument( 'output_dir', help='Directory to extract the images into')
p = add_command( 'compare', 'Compare two datasets.',
'compare datasets/mydataset datasets/mnist')
p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset')
p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset')
p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0)
p = add_command( 'create_mnist', 'Create dataset for MNIST.',
'create_mnist datasets/mnist ~/downloads/mnist')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'mnist_dir', help='Directory containing MNIST')
p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.',
'create_mnistrgb datasets/mnistrgb ~/downloads/mnist')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'mnist_dir', help='Directory containing MNIST')
p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000)
p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123)
p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.',
'create_cifar10 datasets/cifar10 ~/downloads/cifar10')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10')
p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.',
'create_cifar100 datasets/cifar100 ~/downloads/cifar100')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100')
p = add_command( 'create_svhn', 'Create dataset for SVHN.',
'create_svhn datasets/svhn ~/downloads/svhn')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'svhn_dir', help='Directory containing SVHN')
p = add_command( 'create_lsun', 'Create dataset for single LSUN category.',
'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database')
p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256)
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None)
p = add_command( 'create_lsun_wide', 'Create LSUN dataset with non-square aspect ratio.',
'create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database')
p.add_argument( '--width', help='Output width (default: 512)', type=int, default=512)
p.add_argument( '--height', help='Output height (default: 384)', type=int, default=384)
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None)
p = add_command( 'create_celeba', 'Create dataset for CelebA.',
'create_celeba datasets/celeba ~/downloads/celeba')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'celeba_dir', help='Directory containing CelebA')
p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89)
p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121)
p = add_command( 'create_from_images', 'Create dataset from a directory full of images.',
'create_from_images datasets/mydataset myimagedir')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'image_dir', help='Directory containing the images')
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.',
'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5')
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images')
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h'])
func = globals()[args.command]
del args.command
func(**vars(args))
#----------------------------------------------------------------------------
if __name__ == "__main__":
execute_cmdline(sys.argv)
#----------------------------------------------------------------------------
================================================
FILE: dnnlib/__init__.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
from . import submission
from .submission.run_context import RunContext
from .submission.submit import SubmitTarget
from .submission.submit import PathType
from .submission.submit import SubmitConfig
from .submission.submit import get_path_from_template
from .submission.submit import submit_run
from .util import EasyDict
submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
================================================
FILE: dnnlib/submission/__init__.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
from . import run_context
from . import submit
================================================
FILE: dnnlib/submission/_internal/run.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Helper for launching run functions in computing clusters.
During the submit process, this file is copied to the appropriate run dir.
When the job is launched in the cluster, this module is the first thing that
is run inside the docker container.
"""
import os
import pickle
import sys
# PYTHONPATH should have been set so that the run_dir/src is in it
import dnnlib
def main():
if not len(sys.argv) >= 4:
raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
run_dir = str(sys.argv[1])
task_name = str(sys.argv[2])
host_name = str(sys.argv[3])
submit_config_path = os.path.join(run_dir, "submit_config.pkl")
# SubmitConfig should have been pickled to the run dir
if not os.path.exists(submit_config_path):
raise RuntimeError("SubmitConfig pickle file does not exist!")
submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
submit_config.task_name = task_name
submit_config.host_name = host_name
dnnlib.submission.submit.run_wrapper(submit_config)
if __name__ == "__main__":
main()
================================================
FILE: dnnlib/submission/run_context.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Helpers for managing the run/training loop."""
import datetime
import json
import os
import pprint
import time
import types
from typing import Any
from . import submit
class RunContext(object):
"""Helper class for managing the run/training loop.
The context will hide the implementation details of a basic run/training loop.
It will set things up properly, tell if run should be stopped, and then cleans up.
User should call update periodically and use should_stop to determine if run should be stopped.
Args:
submit_config: The SubmitConfig that is used for the current run.
config_module: The whole config module that is used for the current run.
max_epoch: Optional cached value for the max_epoch variable used in update.
"""
def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
self.submit_config = submit_config
self.should_stop_flag = False
self.has_closed = False
self.start_time = time.time()
self.last_update_time = time.time()
self.last_update_interval = 0.0
self.max_epoch = max_epoch
# pretty print the all the relevant content of the config module to a text file
if config_module is not None:
with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
# write out details about the run to a text file
self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
def __enter__(self) -> "RunContext":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
"""Do general housekeeping and keep the state of the context up-to-date.
Should be called often enough but not in a tight loop."""
assert not self.has_closed
self.last_update_interval = time.time() - self.last_update_time
self.last_update_time = time.time()
if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
self.should_stop_flag = True
max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
def should_stop(self) -> bool:
"""Tell whether a stopping condition has been triggered one way or another."""
return self.should_stop_flag
def get_time_since_start(self) -> float:
"""How much time has passed since the creation of the context."""
return time.time() - self.start_time
def get_time_since_last_update(self) -> float:
"""How much time has passed since the last call to update."""
return time.time() - self.last_update_time
def get_last_update_interval(self) -> float:
"""How much time passed between the previous two calls to update."""
return self.last_update_interval
def close(self) -> None:
"""Close the context and clean up.
Should only be called once."""
if not self.has_closed:
# update the run.txt with stopping time
self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
self.has_closed = True
================================================
FILE: dnnlib/submission/submit.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Submit a function to be run either locally or in a computing cluster."""
import copy
import io
import os
import pathlib
import pickle
import platform
import pprint
import re
import shutil
import time
import traceback
import zipfile
from enum import Enum
from .. import util
from ..util import EasyDict
class SubmitTarget(Enum):
"""The target where the function should be run.
LOCAL: Run it locally.
"""
LOCAL = 1
class PathType(Enum):
"""Determines in which format should a path be formatted.
WINDOWS: Format with Windows style.
LINUX: Format with Linux/Posix style.
AUTO: Use current OS type to select either WINDOWS or LINUX.
"""
WINDOWS = 1
LINUX = 2
AUTO = 3
_user_name_override = None
class SubmitConfig(util.EasyDict):
"""Strongly typed config dict needed to submit runs.
Attributes:
run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
run_desc: Description of the run. Will be used in the run dir and task name.
run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
submit_target: Submit target enum value. Used to select where the run is actually launched.
num_gpus: Number of GPUs used/requested for the run.
print_info: Whether to print debug information when submitting.
ask_confirmation: Whether to ask a confirmation before submitting.
run_id: Automatically populated value during submit.
run_name: Automatically populated value during submit.
run_dir: Automatically populated value during submit.
run_func_name: Automatically populated value during submit.
run_func_kwargs: Automatically populated value during submit.
user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
task_name: Automatically populated value during submit.
host_name: Automatically populated value during submit.
"""
def __init__(self):
super().__init__()
# run (set these)
self.run_dir_root = "" # should always be passed through get_path_from_template
self.run_desc = ""
self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
self.run_dir_extra_files = None
# submit (set these)
self.submit_target = SubmitTarget.LOCAL
self.num_gpus = 1
self.print_info = False
self.ask_confirmation = False
# (automatically populated)
self.run_id = None
self.run_name = None
self.run_dir = None
self.run_func_name = None
self.run_func_kwargs = None
self.user_name = None
self.task_name = None
self.host_name = "localhost"
def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
"""Replace tags in the given path template and return either Windows or Linux formatted path."""
# automatically select path type depending on running OS
if path_type == PathType.AUTO:
if platform.system() == "Windows":
path_type = PathType.WINDOWS
elif platform.system() == "Linux":
path_type = PathType.LINUX
else:
raise RuntimeError("Unknown platform")
path_template = path_template.replace("<USERNAME>", get_user_name())
# return correctly formatted path
if path_type == PathType.WINDOWS:
return str(pathlib.PureWindowsPath(path_template))
elif path_type == PathType.LINUX:
return str(pathlib.PurePosixPath(path_template))
else:
raise RuntimeError("Unknown platform")
def get_template_from_path(path: str) -> str:
"""Convert a normal path back to its template representation."""
# replace all path parts with the template tags
path = path.replace("\\", "/")
return path
def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
"""Convert a normal path to template and the convert it back to a normal path with given path type."""
path_template = get_template_from_path(path)
path = get_path_from_template(path_template, path_type)
return path
def set_user_name_override(name: str) -> None:
"""Set the global username override value."""
global _user_name_override
_user_name_override = name
def get_user_name():
"""Get the current user name."""
if _user_name_override is not None:
return _user_name_override
elif platform.system() == "Windows":
return os.getlogin()
elif platform.system() == "Linux":
try:
import pwd # pylint: disable=import-error
return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
except:
return "unknown"
else:
raise RuntimeError("Unknown platform")
def _create_run_dir_local(submit_config: SubmitConfig) -> str:
"""Create a new run dir with increasing ID number at the start."""
run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
if not os.path.exists(run_dir_root):
print("Creating the run dir root: {}".format(run_dir_root))
os.makedirs(run_dir_root)
submit_config.run_id = _get_next_run_id_local(run_dir_root)
submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
run_dir = os.path.join(run_dir_root, submit_config.run_name)
if os.path.exists(run_dir):
raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
print("Creating the run dir: {}".format(run_dir))
os.makedirs(run_dir)
return run_dir
def _get_next_run_id_local(run_dir_root: str) -> int:
"""Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
r = re.compile("^\\d+") # match one or more digits at the start of the string
run_id = 0
for dir_name in dir_names:
m = r.match(dir_name)
if m is not None:
i = int(m.group())
run_id = max(run_id, i + 1)
return run_id
def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
"""Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
print("Copying files to the run dir")
files = []
run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
assert '.' in submit_config.run_func_name
for _idx in range(submit_config.run_func_name.count('.') - 1):
run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
if submit_config.run_dir_extra_files is not None:
files += submit_config.run_dir_extra_files
files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
util.copy_files_and_create_dirs(files)
pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
def run_wrapper(submit_config: SubmitConfig) -> None:
"""Wrap the actual run function call for handling logging, exceptions, typing, etc."""
is_local = submit_config.submit_target == SubmitTarget.LOCAL
checker = None
# when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
if is_local:
logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
logger = util.Logger(file_name=None, should_flush=True)
import dnnlib
dnnlib.submit_config = submit_config
try:
print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
start_time = time.time()
util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
except:
if is_local:
raise
else:
traceback.print_exc()
log_src = os.path.join(submit_config.run_dir, "log.txt")
log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
shutil.copyfile(log_src, log_dst)
finally:
open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
dnnlib.submit_config = None
logger.close()
if checker is not None:
checker.stop()
def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
"""Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
submit_config = copy.copy(submit_config)
if submit_config.user_name is None:
submit_config.user_name = get_user_name()
submit_config.run_func_name = run_func_name
submit_config.run_func_kwargs = run_func_kwargs
assert submit_config.submit_target == SubmitTarget.LOCAL
if submit_config.submit_target in {SubmitTarget.LOCAL}:
run_dir = _create_run_dir_local(submit_config)
submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
submit_config.run_dir = run_dir
_populate_run_dir(run_dir, submit_config)
if submit_config.print_info:
print("\nSubmit config:\n")
pprint.pprint(submit_config, indent=4, width=200, compact=False)
print()
if submit_config.ask_confirmation:
if not util.ask_yes_no("Continue submitting the job?"):
return
run_wrapper(submit_config)
================================================
FILE: dnnlib/tflib/__init__.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
from . import autosummary
from . import network
from . import optimizer
from . import tfutil
from .tfutil import *
from .network import Network
from .optimizer import Optimizer
================================================
FILE: dnnlib/tflib/autosummary.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Helper for adding automatically tracked values to Tensorboard.
Autosummary creates an identity op that internally keeps track of the input
values and automatically shows up in TensorBoard. The reported value
represents an average over input components. The average is accumulated
constantly over time and flushed when save_summaries() is called.
Notes:
- The output tensor must be used as an input for something else in the
graph. Otherwise, the autosummary op will not get executed, and the average
value will not get accumulated.
- It is perfectly fine to include autosummaries with the same name in
several places throughout the graph, even if they are executed concurrently.
- It is ok to also pass in a python scalar or numpy array. In this case, it
is added to the average immediately.
"""
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from tensorboard import summary as summary_lib
from tensorboard.plugins.custom_scalar import layout_pb2
from . import tfutil
from .tfutil import TfExpression
from .tfutil import TfExpressionEx
_dtype = tf.float64
_vars = OrderedDict() # name => [var, ...]
_immediate = OrderedDict() # name => update_op, update_value
_finalized = False
_merge_op = None
def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
"""Internal helper for creating autosummary accumulators."""
assert not _finalized
name_id = name.replace("/", "_")
v = tf.cast(value_expr, _dtype)
if v.shape.is_fully_defined():
size = np.prod(tfutil.shape_to_list(v.shape))
size_expr = tf.constant(size, dtype=_dtype)
else:
size = None
size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
if size == 1:
if v.shape.ndims != 0:
v = tf.reshape(v, [])
v = [size_expr, v, tf.square(v)]
else:
v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
if name in _vars:
_vars[name].append(var)
else:
_vars[name] = [var]
return update_op
def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
"""Create a new autosummary.
Args:
name: Name to use in TensorBoard
value: TensorFlow expression or python value to track
passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
Example use of the passthru mechanism:
n = autosummary('l2loss', loss, passthru=n)
This is a shorthand for the following code:
with tf.control_dependencies([autosummary('l2loss', loss)]):
n = tf.identity(n)
"""
tfutil.assert_tf_initialized()
name_id = name.replace("/", "_")
if tfutil.is_tf_expression(value):
with tf.name_scope("summary_" + name_id), tf.device(value.device):
update_op = _create_var(name, value)
with tf.control_dependencies([update_op]):
return tf.identity(value if passthru is None else passthru)
else: # python scalar or numpy array
if name not in _immediate:
with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
update_value = tf.placeholder(_dtype)
update_op = _create_var(name, update_value)
_immediate[name] = update_op, update_value
update_op, update_value = _immediate[name]
tfutil.run(update_op, {update_value: value})
return value if passthru is None else passthru
def finalize_autosummaries() -> None:
"""Create the necessary ops to include autosummaries in TensorBoard report.
Note: This should be done only once per graph.
"""
global _finalized
tfutil.assert_tf_initialized()
if _finalized:
return None
_finalized = True
tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
# Create summary ops.
with tf.device(None), tf.control_dependencies(None):
for name, vars_list in _vars.items():
name_id = name.replace("/", "_")
with tfutil.absolute_name_scope("Autosummary/" + name_id):
moments = tf.add_n(vars_list)
moments /= moments[0]
with tf.control_dependencies([moments]): # read before resetting
reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
mean = moments[1]
std = tf.sqrt(moments[2] - tf.square(moments[1]))
tf.summary.scalar(name, mean)
tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
# Group by category and chart name.
cat_dict = OrderedDict()
for series_name in sorted(_vars.keys()):
p = series_name.split("/")
cat = p[0] if len(p) >= 2 else ""
chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
if cat not in cat_dict:
cat_dict[cat] = OrderedDict()
if chart not in cat_dict[cat]:
cat_dict[cat][chart] = []
cat_dict[cat][chart].append(series_name)
# Setup custom_scalar layout.
categories = []
for cat_name, chart_dict in cat_dict.items():
charts = []
for chart_name, series_names in chart_dict.items():
series = []
for series_name in series_names:
series.append(layout_pb2.MarginChartContent.Series(
value=series_name,
lower="xCustomScalars/" + series_name + "/margin_lo",
upper="xCustomScalars/" + series_name + "/margin_hi"))
margin = layout_pb2.MarginChartContent(series=series)
charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
categories.append(layout_pb2.Category(title=cat_name, chart=charts))
layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
return layout
def save_summaries(file_writer, global_step=None):
"""Call FileWriter.add_summary() with all summaries in the default graph,
automatically finalizing and merging them on the first call.
"""
global _merge_op
tfutil.assert_tf_initialized()
if _merge_op is None:
layout = finalize_autosummaries()
if layout is not None:
file_writer.add_summary(layout)
with tf.device(None), tf.control_dependencies(None):
_merge_op = tf.summary.merge_all()
file_writer.add_summary(_merge_op.eval(), global_step)
================================================
FILE: dnnlib/tflib/network.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Helper for managing networks."""
import types
import inspect
import re
import uuid
import sys
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from typing import Any, List, Tuple, Union
from . import tfutil
from .. import util
from .tfutil import TfExpression, TfExpressionEx
_import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
_import_module_src = dict() # Source code for temporary modules created during pickle import.
def import_handler(handler_func):
"""Function decorator for declaring custom import handlers."""
_import_handlers.append(handler_func)
return handler_func
class Network:
"""Generic network abstraction.
Acts as a convenience wrapper for a parameterized network construction
function, providing several utility methods and convenient access to
the inputs/outputs/weights.
Network objects can be safely pickled and unpickled for long-term
archival purposes. The pickling works reliably as long as the underlying
network construction function is defined in a standalone Python module
that has no side effects or application-specific imports.
Args:
name: Network name. Used to select TensorFlow name and variable scopes.
func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
static_kwargs: Keyword arguments to be passed in to the network construction function.
Attributes:
name: User-specified name, defaults to build func name if None.
scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
static_kwargs: Arguments passed to the user-supplied build func.
components: Container for sub-networks. Passed to the build func, and retained between calls.
num_inputs: Number of input tensors.
num_outputs: Number of output tensors.
input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
input_shape: Short-hand for input_shapes[0].
output_shape: Short-hand for output_shapes[0].
input_templates: Input placeholders in the template graph.
output_templates: Output tensors in the template graph.
input_names: Name string for each input.
output_names: Name string for each output.
own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
vars: All variables (local_name => var).
trainables: All trainable variables (local_name => var).
var_global_to_local: Mapping from variable global names to local names.
"""
def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
tfutil.assert_tf_initialized()
assert isinstance(name, str) or name is None
assert func_name is not None
assert isinstance(func_name, str) or util.is_top_level_function(func_name)
assert util.is_pickleable(static_kwargs)
self._init_fields()
self.name = name
self.static_kwargs = util.EasyDict(static_kwargs)
# Locate the user-specified network build function.
if util.is_top_level_function(func_name):
func_name = util.get_top_level_function_name(func_name)
module, self._build_func_name = util.get_module_from_obj_name(func_name)
self._build_func = util.get_obj_from_module(module, self._build_func_name)
assert callable(self._build_func)
# Dig up source code for the module containing the build function.
self._build_module_src = _import_module_src.get(module, None)
if self._build_module_src is None:
self._build_module_src = inspect.getsource(module)
# Init TensorFlow graph.
self._init_graph()
self.reset_own_vars()
def _init_fields(self) -> None:
self.name = None
self.scope = None
self.static_kwargs = util.EasyDict()
self.components = util.EasyDict()
self.num_inputs = 0
self.num_outputs = 0
self.input_shapes = [[]]
self.output_shapes = [[]]
self.input_shape = []
self.output_shape = []
self.input_templates = []
self.output_templates = []
self.input_names = []
self.output_names = []
self.own_vars = OrderedDict()
self.vars = OrderedDict()
self.trainables = OrderedDict()
self.var_global_to_local = OrderedDict()
self._build_func = None # User-supplied build function that constructs the network.
self._build_func_name = None # Name of the build function.
self._build_module_src = None # Full source code of the module containing the build function.
self._run_cache = dict() # Cached graph data for Network.run().
def _init_graph(self) -> None:
# Collect inputs.
self.input_names = []
for param in inspect.signature(self._build_func).parameters.values():
if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
self.input_names.append(param.name)
self.num_inputs = len(self.input_names)
assert self.num_inputs >= 1
# Choose name and scope.
if self.name is None:
self.name = self._build_func_name
assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
with tf.name_scope(None):
self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
# Finalize build func kwargs.
build_kwargs = dict(self.static_kwargs)
build_kwargs["is_template_graph"] = True
build_kwargs["components"] = self.components
# Build template graph.
with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
assert tf.get_variable_scope().name == self.scope
assert tf.get_default_graph().get_name_scope() == self.scope
with tf.control_dependencies(None): # ignore surrounding control dependencies
self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
out_expr = self._build_func(*self.input_templates, **build_kwargs)
# Collect outputs.
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
self.num_outputs = len(self.output_templates)
assert self.num_outputs >= 1
assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
# Perform sanity checks.
if any(t.shape.ndims is None for t in self.input_templates):
raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
if any(t.shape.ndims is None for t in self.output_templates):
raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
if any(not isinstance(comp, Network) for comp in self.components.values()):
raise ValueError("Components of a Network must be Networks themselves.")
if len(self.components) != len(set(comp.name for comp in self.components.values())):
raise ValueError("Components of a Network must have unique names.")
# List inputs and outputs.
self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
self.input_shape = self.input_shapes[0]
self.output_shape = self.output_shapes[0]
self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
# List variables.
self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
self.vars = OrderedDict(self.own_vars)
self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
def reset_own_vars(self) -> None:
"""Re-initialize all variables of this network, excluding sub-networks."""
tfutil.run([var.initializer for var in self.own_vars.values()])
def reset_vars(self) -> None:
"""Re-initialize all variables of this network, including sub-networks."""
tfutil.run([var.initializer for var in self.vars.values()])
def reset_trainables(self) -> None:
"""Re-initialize all trainable variables of this network, including sub-networks."""
tfutil.run([var.initializer for var in self.trainables.values()])
def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
"""Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
assert len(in_expr) == self.num_inputs
assert not all(expr is None for expr in in_expr)
# Finalize build func kwargs.
build_kwargs = dict(self.static_kwargs)
build_kwargs.update(dynamic_kwargs)
build_kwargs["is_template_graph"] = False
build_kwargs["components"] = self.components
# Build TensorFlow graph to evaluate the network.
with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
assert tf.get_variable_scope().name == self.scope
valid_inputs = [expr for expr in in_expr if expr is not None]
final_inputs = []
for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
if expr is not None:
expr = tf.identity(expr, name=name)
else:
expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
final_inputs.append(expr)
out_expr = self._build_func(*final_inputs, **build_kwargs)
# Propagate input shapes back to the user-specified expressions.
for expr, final in zip(in_expr, final_inputs):
if isinstance(expr, tf.Tensor):
expr.set_shape(final.shape)
# Express outputs in the desired format.
assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
if return_as_list:
out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
return out_expr
def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
"""Get the local name of a given variable, without any surrounding name scopes."""
assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
return self.var_global_to_local[global_name]
def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
"""Find variable by local or global name."""
assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
"""Get the value of a given variable as NumPy array.
Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
return self.find_var(var_or_local_name).eval()
def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
"""Set the value of a given variable based on the given NumPy array.
Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
tfutil.set_vars({self.find_var(var_or_local_name): new_value})
def __getstate__(self) -> dict:
"""Pickle export."""
state = dict()
state["version"] = 3
state["name"] = self.name
state["static_kwargs"] = dict(self.static_kwargs)
state["components"] = dict(self.components)
state["build_module_src"] = self._build_module_src
state["build_func_name"] = self._build_func_name
state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
return state
def __setstate__(self, state: dict) -> None:
"""Pickle import."""
# pylint: disable=attribute-defined-outside-init
tfutil.assert_tf_initialized()
self._init_fields()
# Execute custom import handlers.
for handler in _import_handlers:
state = handler(state)
# Set basic fields.
assert state["version"] in [2, 3]
self.name = state["name"]
self.static_kwargs = util.EasyDict(state["static_kwargs"])
self.components = util.EasyDict(state.get("components", {}))
self._build_module_src = state["build_module_src"]
self._build_func_name = state["build_func_name"]
# Create temporary module from the imported source code.
module_name = "_tflib_network_import_" + uuid.uuid4().hex
module = types.ModuleType(module_name)
sys.modules[module_name] = module
_import_module_src[module] = self._build_module_src
exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
# Locate network build function in the temporary module.
self._build_func = util.get_obj_from_module(module, self._build_func_name)
assert callable(self._build_func)
# Init TensorFlow graph.
self._init_graph()
self.reset_own_vars()
tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
def clone(self, name: str = None, **new_static_kwargs) -> "Network":
"""Create a clone of this network with its own copy of the variables."""
# pylint: disable=protected-access
net = object.__new__(Network)
net._init_fields()
net.name = name if name is not None else self.name
net.static_kwargs = util.EasyDict(self.static_kwargs)
net.static_kwargs.update(new_static_kwargs)
net._build_module_src = self._build_module_src
net._build_func_name = self._build_func_name
net._build_func = self._build_func
net._init_graph()
net.copy_vars_from(self)
return net
def copy_own_vars_from(self, src_net: "Network") -> None:
"""Copy the values of all variables from the given network, excluding sub-networks."""
names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def copy_vars_from(self, src_net: "Network") -> None:
"""Copy the values of all variables from the given network, including sub-networks."""
names = [name for name in self.vars.keys() if name in src_net.vars]
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def copy_trainables_from(self, src_net: "Network") -> None:
"""Copy the values of all trainable variables from the given network, including sub-networks."""
names = [name for name in self.trainables.keys() if name in src_net.trainables]
tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
"""Create new network with the given parameters, and copy all variables from this network."""
if new_name is None:
new_name = self.name
static_kwargs = dict(self.static_kwargs)
static_kwargs.update(new_static_kwargs)
net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
net.copy_vars_from(self)
return net
def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
"""Construct a TensorFlow op that updates the variables of this network
to be slightly closer to those of the given network."""
with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
ops = []
for name, var in self.vars.items():
if name in src_net.vars:
cur_beta = beta if name in self.trainables else beta_nontrainable
new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
ops.append(var.assign(new_value))
return tf.group(*ops)
def run(self,
*in_arrays: Tuple[Union[np.ndarray, None], ...],
input_transform: dict = None,
output_transform: dict = None,
return_as_list: bool = False,
print_progress: bool = False,
minibatch_size: int = None,
num_gpus: int = 1,
assume_frozen: bool = False,
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
"""Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
Args:
input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
The dict must contain a 'func' field that points to a top-level function. The function is called with the input
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
The dict must contain a 'func' field that points to a top-level function. The function is called with the output
TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
print_progress: Print progress to the console? Useful for very large input arrays.
minibatch_size: Maximum minibatch size to use, None = disable batching.
num_gpus: Number of GPUs to use.
assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
"""
assert len(in_arrays) == self.num_inputs
assert not all(arr is None for arr in in_arrays)
assert input_transform is None or util.is_top_level_function(input_transform["func"])
assert output_transform is None or util.is_top_level_function(output_transform["func"])
output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
num_items = in_arrays[0].shape[0]
if minibatch_size is None:
minibatch_size = num_items
# Construct unique hash key from all arguments that affect the TensorFlow graph.
key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
def unwind_key(obj):
if isinstance(obj, dict):
return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
if callable(obj):
return util.get_top_level_function_name(obj)
return obj
key = repr(unwind_key(key))
# Build graph.
if key not in self._run_cache:
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
with tf.device("/cpu:0"):
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
out_split = []
for gpu in range(num_gpus):
with tf.device("/gpu:%d" % gpu):
net_gpu = self.clone() if assume_frozen else self
in_gpu = in_split[gpu]
if input_transform is not None:
in_kwargs = dict(input_transform)
in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
assert len(in_gpu) == self.num_inputs
out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
if output_transform is not None:
out_kwargs = dict(output_transform)
out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
assert len(out_gpu) == self.num_outputs
out_split.append(out_gpu)
with tf.device("/cpu:0"):
out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
self._run_cache[key] = in_expr, out_expr
# Run minibatches.
in_expr, out_expr = self._run_cache[key]
out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
for mb_begin in range(0, num_items, minibatch_size):
if print_progress:
print("\r%d / %d" % (mb_begin, num_items), end="")
mb_end = min(mb_begin + minibatch_size, num_items)
mb_num = mb_end - mb_begin
mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
for dst, src in zip(out_arrays, mb_out):
dst[mb_begin: mb_end] = src
# Done.
if print_progress:
print("\r%d / %d" % (num_items, num_items))
if not return_as_list:
out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
return out_arrays
def list_ops(self) -> List[TfExpression]:
include_prefix = self.scope + "/"
exclude_prefix = include_prefix + "_"
ops = tf.get_default_graph().get_operations()
ops = [op for op in ops if op.name.startswith(include_prefix)]
ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
return ops
def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
"""Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
individual layers of the network. Mainly intended to be used for reporting."""
layers = []
def recurse(scope, parent_ops, parent_vars, level):
# Ignore specific patterns.
if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
return
# Filter ops and vars by scope.
global_prefix = scope + "/"
local_prefix = global_prefix[len(self.scope) + 1:]
cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
if not cur_ops and not cur_vars:
return
# Filter out all ops related to variables.
for var in [op for op in cur_ops if op.type.startswith("Variable")]:
var_prefix = var.name + "/"
cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
# Scope does not contain ops as immediate children => recurse deeper.
contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
visited = set()
for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
token = rel_name.split("/")[0]
if token not in visited:
recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
visited.add(token)
return
# Report layer.
layer_name = scope[len(self.scope) + 1:]
layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
layer_trainables = [var for _name, var in cur_vars if var.trainable]
layers.append((layer_name, layer_output, layer_trainables))
recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
return layers
def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
"""Print a summary table of the network structure."""
rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
rows += [["---"] * 4]
total_params = 0
for layer_name, layer_output, layer_trainables in self.list_layers():
num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
weights.sort(key=lambda x: len(x.name))
if len(weights) == 0 and len(layer_trainables) == 1:
weights = layer_trainables
total_params += num_params
if not hide_layers_with_no_params or num_params != 0:
num_params_str = str(num_params) if num_params > 0 else "-"
output_shape_str = str(layer_output.shape)
weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
rows += [["---"] * 4]
rows += [["Total", str(total_params), "", ""]]
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
print()
for row in rows:
print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
print()
def setup_weight_histograms(self, title: str = None) -> None:
"""Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
if title is None:
title = self.name
with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
for local_name, var in self.trainables.items():
if "/" in local_name:
p = local_name.split("/")
name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
else:
name = title + "_toplevel/" + local_name
tf.summary.histogram(name, var)
#----------------------------------------------------------------------------
# Backwards-compatible emulation of legacy output transformation in Network.run().
_print_legacy_warning = True
def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
global _print_legacy_warning
legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
return output_transform, dynamic_kwargs
if _print_legacy_warning:
_print_legacy_warning = False
print()
print("WARNING: Old-style output transformations in Network.run() are deprecated.")
print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
print()
assert output_transform is None
new_kwargs = dict(dynamic_kwargs)
new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
new_transform["func"] = _legacy_output_transform_func
return new_transform, new_kwargs
def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
if out_mul != 1.0:
expr = [x * out_mul for x in expr]
if out_add != 0.0:
expr = [x + out_add for x in expr]
if out_shrink > 1:
ksize = [1, 1, out_shrink, out_shrink]
expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
if out_dtype is not None:
if tf.as_dtype(out_dtype).is_integer:
expr = [tf.round(x) for x in expr]
expr = [tf.saturate_cast(x, out_dtype) for x in expr]
return expr
================================================
FILE: dnnlib/tflib/optimizer.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Helper wrapper for a Tensorflow optimizer."""
import numpy as np
import tensorflow as tf
from collections import OrderedDict
from typing import List, Union
from . import autosummary
from . import tfutil
from .. import util
from .tfutil import TfExpression, TfExpressionEx
try:
# TensorFlow 1.13
from tensorflow.python.ops import nccl_ops
except:
# Older TensorFlow versions
import tensorflow.contrib.nccl as nccl_ops
class Optimizer:
"""A Wrapper for tf.train.Optimizer.
Automatically takes care of:
- Gradient averaging for multi-GPU training.
- Dynamic loss scaling and typecasts for FP16 training.
- Ignoring corrupted gradients that contain NaNs/Infs.
- Reporting statistics.
- Well-chosen default settings.
"""
def __init__(self,
name: str = "Train",
tf_optimizer: str = "tf.train.AdamOptimizer",
learning_rate: TfExpressionEx = 0.001,
use_loss_scaling: bool = False,
loss_scaling_init: float = 64.0,
loss_scaling_inc: float = 0.0005,
loss_scaling_dec: float = 1.0,
**kwargs):
# Init fields.
self.name = name
self.learning_rate = tf.convert_to_tensor(learning_rate)
self.id = self.name.replace("/", ".")
self.scope = tf.get_default_graph().unique_name(self.id)
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
self.optimizer_kwargs = dict(kwargs)
self.use_loss_scaling = use_loss_scaling
self.loss_scaling_init = loss_scaling_init
self.loss_scaling_inc = loss_scaling_inc
self.loss_scaling_dec = loss_scaling_dec
self._grad_shapes = None # [shape, ...]
self._dev_opt = OrderedDict() # device => optimizer
self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
self._updates_applied = False
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
"""Register the gradients of the given loss function with respect to the given variables.
Intended to be called once per GPU."""
assert not self._updates_applied
# Validate arguments.
if isinstance(trainable_vars, dict):
trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
if self._grad_shapes is None:
self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
assert len(trainable_vars) == len(self._grad_shapes)
assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
dev = loss.device
assert all(var.device == dev for var in trainable_vars)
# Register device and compute gradients.
with tf.name_scope(self.id + "_grad"), tf.device(dev):
if dev not in self._dev_opt:
opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
assert callable(self.optimizer_class)
self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
self._dev_grads[dev] = []
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
self._dev_grads[dev].append(grads)
def apply_updates(self) -> tf.Operation:
"""Construct training op to update the registered variables based on their gradients."""
tfutil.assert_tf_initialized()
assert not self._updates_applied
self._updates_applied = True
devices = list(self._dev_grads.keys())
total_grads = sum(len(grads) for grads in self._dev_grads.values())
assert len(devices) >= 1 and total_grads >= 1
ops = []
with tfutil.absolute_name_scope(self.scope):
# Cast gradients to FP32 and calculate partial sum within each device.
dev_grads = OrderedDict() # device => [(grad, var), ...]
for dev_idx, dev in enumerate(devices):
with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
sums = []
for gv in zip(*self._dev_grads[dev]):
assert all(v is gv[0][1] for g, v in gv)
g = [tf.cast(g, tf.float32) for g, v in gv]
g = g[0] if len(g) == 1 else tf.add_n(g)
sums.append((g, gv[0][1]))
dev_grads[dev] = sums
# Sum gradients across devices.
if len(devices) > 1:
with tf.name_scope("SumAcrossGPUs"), tf.device(None):
for var_idx, grad_shape in enumerate(self._grad_shapes):
g = [dev_grads[dev][var_idx][0] for dev in devices]
if np.prod(grad_shape): # nccl does not support zero-sized tensors
g = nccl_ops.all_sum(g)
for dev, gg in zip(devices, g):
dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
# Apply updates separately on each device.
for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
# Scale gradients as needed.
if self.use_loss_scaling or total_grads > 1:
with tf.name_scope("Scale"):
coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
coef = self.undo_loss_scaling(coef)
grads = [(g * coef, v) for g, v in grads]
# Check for overflows.
with tf.name_scope("CheckOverflow"):
grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
# Update weights and adjust loss scaling.
with tf.name_scope("UpdateWeights"):
# pylint: disable=cell-var-from-loop
opt = self._dev_opt[dev]
ls_var = self.get_loss_scaling_var(dev)
if not self.use_loss_scaling:
ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
else:
ops.append(tf.cond(grad_ok,
lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
# Report statistics on the last device.
if dev == devices[-1]:
with tf.name_scope("Statistics"):
ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
if self.use_loss_scaling:
ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
# Initialize variables and group everything into a single op.
self.reset_optimizer_state()
tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
return tf.group(*ops, name="TrainingOp")
def reset_optimizer_state(self) -> None:
"""Reset internal state of the underlying optimizer."""
tfutil.assert_tf_initialized()
tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
if not self.use_loss_scaling:
return None
if device not in self._dev_ls_var:
with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
return self._dev_ls_var[device]
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Apply dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
"""Undo the effect of dynamic loss scaling for the given expression."""
assert tfutil.is_tf_expression(value)
if not self.use_loss_scaling:
return value
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
================================================
FILE: dnnlib/tflib/tfutil.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Miscellaneous helper utils for Tensorflow."""
import os
import numpy as np
import tensorflow as tf
from typing import Any, Iterable, List, Union
TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
"""A type that represents a valid Tensorflow expression."""
TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
"""A type that can be converted to a valid Tensorflow expression."""
def run(*args, **kwargs) -> Any:
"""Run the specified ops in the default session."""
assert_tf_initialized()
return tf.get_default_session().run(*args, **kwargs)
def is_tf_expression(x: Any) -> bool:
"""Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
"""Convert a Tensorflow shape to a list of ints."""
return [dim.value for dim in shape]
def flatten(x: TfExpressionEx) -> TfExpression:
"""Shortcut function for flattening a tensor."""
with tf.name_scope("Flatten"):
return tf.reshape(x, [-1])
def log2(x: TfExpressionEx) -> TfExpression:
"""Logarithm in base 2."""
with tf.name_scope("Log2"):
return tf.log(x) * np.float32(1.0 / np.log(2.0))
def exp2(x: TfExpressionEx) -> TfExpression:
"""Exponent in base 2."""
with tf.name_scope("Exp2"):
return tf.exp(x * np.float32(np.log(2.0)))
def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
"""Linear interpolation."""
with tf.name_scope("Lerp"):
return a + (b - a) * t
def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
"""Linear interpolation with clip."""
with tf.name_scope("LerpClip"):
return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
def absolute_name_scope(scope: str) -> tf.name_scope:
"""Forcefully enter the specified name scope, ignoring any surrounding scopes."""
return tf.name_scope(scope + "/")
def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
"""Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
def _sanitize_tf_config(config_dict: dict = None) -> dict:
# Defaults.
cfg = dict()
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
# User overrides.
if config_dict is not None:
cfg.update(config_dict)
return cfg
def init_tf(config_dict: dict = None) -> None:
"""Initialize TensorFlow session using good default settings."""
# Skip if already initialized.
if tf.get_default_session() is not None:
return
# Setup config dict and random seeds.
cfg = _sanitize_tf_config(config_dict)
np_random_seed = cfg["rnd.np_random_seed"]
if np_random_seed is not None:
np.random.seed(np_random_seed)
tf_random_seed = cfg["rnd.tf_random_seed"]
if tf_random_seed == "auto":
tf_random_seed = np.random.randint(1 << 31)
if tf_random_seed is not None:
tf.set_random_seed(tf_random_seed)
# Setup environment variables.
for key, value in list(cfg.items()):
fields = key.split(".")
if fields[0] == "env":
assert len(fields) == 2
os.environ[fields[1]] = str(value)
# Create default TensorFlow session.
create_session(cfg, force_as_default=True)
def assert_tf_initialized():
"""Check that TensorFlow session has been initialized."""
if tf.get_default_session() is None:
raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
"""Create tf.Session based on config dict."""
# Setup TensorFlow config proto.
cfg = _sanitize_tf_config(config_dict)
config_proto = tf.ConfigProto()
for key, value in cfg.items():
fields = key.split(".")
if fields[0] not in ["rnd", "env"]:
obj = config_proto
for field in fields[:-1]:
obj = getattr(obj, field)
setattr(obj, fields[-1], value)
# Create session.
session = tf.Session(config=config_proto)
if force_as_default:
# pylint: disable=protected-access
session._default_session = session.as_default()
session._default_session.enforce_nesting = False
session._default_session.__enter__() # pylint: disable=no-member
return session
def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
"""Initialize all tf.Variables that have not already been initialized.
Equivalent to the following, but more efficient and does not bloat the tf graph:
tf.variables_initializer(tf.report_uninitialized_variables()).run()
"""
assert_tf_initialized()
if target_vars is None:
target_vars = tf.global_variables()
test_vars = []
test_ops = []
with tf.control_dependencies(None): # ignore surrounding control_dependencies
for var in target_vars:
assert is_tf_expression(var)
try:
tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
except KeyError:
# Op does not exist => variable may be uninitialized.
test_vars.append(var)
with absolute_name_scope(var.name.split(":")[0]):
test_ops.append(tf.is_variable_initialized(var))
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
run([var.initializer for var in init_vars])
def set_vars(var_to_value_dict: dict) -> None:
"""Set the values of given tf.Variables.
Equivalent to the following, but more efficient and does not bloat the tf graph:
tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
"""
assert_tf_initialized()
ops = []
feed_dict = {}
for var, value in var_to_value_dict.items():
assert is_tf_expression(var)
try:
setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
except KeyError:
with absolute_name_scope(var.name.split(":")[0]):
with tf.control_dependencies(None): # ignore surrounding control_dependencies
setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
ops.append(setter)
feed_dict[setter.op.inputs[1]] = value
run(ops, feed_dict)
def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
"""Create tf.Variable with large initial value without bloating the tf graph."""
assert_tf_initialized()
assert isinstance(initial_value, np.ndarray)
zeros = tf.zeros(initial_value.shape, initial_value.dtype)
var = tf.Variable(zeros, *args, **kwargs)
set_vars({var: initial_value})
return var
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
Can be used as an input transformation for Network.run().
"""
images = tf.cast(images, tf.float32)
if nhwc_to_nchw:
images = tf.transpose(images, [0, 3, 1, 2])
return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
Can be used as an output transformation for Network.run().
"""
images = tf.cast(images, tf.float32)
if shrink > 1:
ksize = [1, 1, shrink, shrink]
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
if nchw_to_nhwc:
images = tf.transpose(images, [0, 2, 3, 1])
scale = 255 / (drange[1] - drange[0])
images = images * scale + (0.5 - drange[0] * scale)
return tf.saturate_cast(images, tf.uint8)
================================================
FILE: dnnlib/util.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Miscellaneous utility classes and functions."""
import ctypes
import fnmatch
import importlib
import inspect
import numpy as np
import os
import shutil
import sys
import types
import io
import pickle
import re
import requests
import html
import hashlib
import glob
import uuid
from distutils.util import strtobool
from typing import Any, List, Tuple, Union
# Util classes
# ------------------------------------------------------------------------------------------
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
class Logger(object):
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
self.file = None
if file_name is not None:
self.file = open(file_name, file_mode)
self.should_flush = should_flush
self.stdout = sys.stdout
self.stderr = sys.stderr
sys.stdout = self
sys.stderr = self
def __enter__(self) -> "Logger":
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.close()
def write(self, text: str) -> None:
"""Write text to stdout (and a file) and optionally flush."""
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
return
if self.file is not None:
self.file.write(text)
self.stdout.write(text)
if self.should_flush:
self.flush()
def flush(self) -> None:
"""Flush written text to both stdout and a file, if open."""
if self.file is not None:
self.file.flush()
self.stdout.flush()
def close(self) -> None:
"""Flush, close possible files, and remove stdout/stderr mirroring."""
self.flush()
# if using multiple loggers, prevent closing in wrong order
if sys.stdout is self:
sys.stdout = self.stdout
if sys.stderr is self:
sys.stderr = self.stderr
if self.file is not None:
self.file.close()
# Small util functions
# ------------------------------------------------------------------------------------------
def format_time(seconds: Union[int, float]) -> str:
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
s = int(np.rint(seconds))
if s < 60:
return "{0}s".format(s)
elif s < 60 * 60:
return "{0}m {1:02}s".format(s // 60, s % 60)
elif s < 24 * 60 * 60:
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
else:
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
def ask_yes_no(question: str) -> bool:
"""Ask the user the question until the user inputs a valid answer."""
while True:
try:
print("{0} [y/n]".format(question))
return strtobool(input().lower())
except ValueError:
pass
def tuple_product(t: Tuple) -> Any:
"""Calculate the product of the tuple elements."""
result = 1
for v in t:
result *= v
return result
_str_to_ctype = {
"uint8": ctypes.c_ubyte,
"uint16": ctypes.c_uint16,
"uint32": ctypes.c_uint32,
"uint64": ctypes.c_uint64,
"int8": ctypes.c_byte,
"int16": ctypes.c_int16,
"int32": ctypes.c_int32,
"int64": ctypes.c_int64,
"float32": ctypes.c_float,
"float64": ctypes.c_double
}
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
type_str = None
if isinstance(type_obj, str):
type_str = type_obj
elif hasattr(type_obj, "__name__"):
type_str = type_obj.__name__
elif hasattr(type_obj, "name"):
type_str = type_obj.name
else:
raise RuntimeError("Cannot infer type name from input")
assert type_str in _str_to_ctype.keys()
my_dtype = np.dtype(type_str)
my_ctype = _str_to_ctype[type_str]
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
return my_dtype, my_ctype
def is_pickleable(obj: Any) -> bool:
try:
with io.BytesIO() as stream:
pickle.dump(obj, stream)
return True
except:
return False
# Functionality to import modules/objects by name, and call functions by name
# ------------------------------------------------------------------------------------------
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
"""Searches for the underlying module behind the name to some python object.
Returns the module and the object name (original name with module part removed)."""
# allow convenience shorthands, substitute them by full names
obj_name = re.sub("^np.", "numpy.", obj_name)
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
# list alternatives for (module_name, local_obj_name)
parts = obj_name.split(".")
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
# try each alternative in turn
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
return module, local_obj_name
except:
pass
# maybe some of the modules themselves contain errors?
for module_name, _local_obj_name in name_pairs:
try:
importlib.import_module(module_name) # may raise ImportError
except ImportError:
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
raise
# maybe the requested attribute is missing?
for module_name, local_obj_name in name_pairs:
try:
module = importlib.import_module(module_name) # may raise ImportError
get_obj_from_module(module, local_obj_name) # may raise AttributeError
except ImportError:
pass
# we are out of luck, but we have no idea why
raise ImportError(obj_name)
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
"""Traverses the object name and returns the last (rightmost) python object."""
if obj_name == '':
return module
obj = module
for part in obj_name.split("."):
obj = getattr(obj, part)
return obj
def get_obj_by_name(name: str) -> Any:
"""Finds the python object with the given name."""
module, obj_name = get_module_from_obj_name(name)
return get_obj_from_module(module, obj_name)
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
"""Finds the python object with the given name and calls it as a function."""
assert func_name is not None
func_obj = get_obj_by_name(func_name)
assert callable(func_obj)
return func_obj(*args, **kwargs)
def get_module_dir_by_obj_name(obj_name: str) -> str:
"""Get the directory path of the module containing the given object name."""
module, _ = get_module_from_obj_name(obj_name)
return os.path.dirname(inspect.getfile(module))
def is_top_level_function(obj: Any) -> bool:
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
def get_top_level_function_name(obj: Any) -> str:
"""Return the fully-qualified name of a top-level function."""
assert is_top_level_function(obj)
return obj.__module__ + "." + obj.__name__
# File system helpers
# ------------------------------------------------------------------------------------------
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
"""List all files recursively in a given directory while ignoring given file and directory names.
Returns list of tuples containing both absolute and relative paths."""
assert os.path.isdir(dir_path)
base_name = os.path.basename(os.path.normpath(dir_path))
if ignores is None:
ignores = []
result = []
for root, dirs, files in os.walk(dir_path, topdown=True):
for ignore_ in ignores:
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
# dirs need to be edited in-place
for d in dirs_to_remove:
dirs.remove(d)
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
absolute_paths = [os.path.join(root, f) for f in files]
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
if add_base_to_relative:
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
assert len(absolute_paths) == len(relative_paths)
result += zip(absolute_paths, relative_paths)
return result
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
"""Takes in a list of tuples of (src, dst) paths and copies files.
Will create all necessary directories."""
for file in files:
target_dir_name = os.path.dirname(file[1])
# will create all intermediate-level directories
if not os.path.exists(target_dir_name):
os.makedirs(target_dir_name)
shutil.copyfile(file[0], file[1])
# URL helpers
# ------------------------------------------------------------------------------------------
def is_url(obj: Any) -> bool:
"""Determine whether the given object is a valid URL string."""
if not isinstance(obj, str) or not "://" in obj:
return False
try:
res = requests.compat.urlparse(obj)
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
if not res.scheme or not res.netloc or not "." in res.netloc:
return False
except:
return False
return True
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert is_url(url)
assert num_attempts >= 1
# Lookup from cache.
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
if cache_dir is not None:
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
if len(cache_files) == 1:
return open(cache_files[0], "rb")
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError("Google Drive quota exceeded")
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Save to cache.
if cache_dir is not None:
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
os.makedirs(cache_dir, exist_ok=True)
with open(temp_file, "wb") as f:
f.write(url_data)
os.replace(temp_file, cache_file) # atomic
# Return data as file object.
return io.BytesIO(url_data)
================================================
FILE: generate_figures.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators."""
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config
#----------------------------------------------------------------------------
# Helpers for loading and using pre-trained generators.
url_ffhq = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
url_celebahq = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl
url_bedrooms = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl
url_cars = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl
url_cats = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl
synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)
_Gs_cache = dict()
def load_Gs(url):
if url not in _Gs_cache:
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
_Gs_cache[url] = Gs
return _Gs_cache[url]
#----------------------------------------------------------------------------
# Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images.
def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed):
print(png)
latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1])
images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb]
canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')
image_iter = iter(list(images))
for col, lod in enumerate(lods):
for row in range(rows * 2**lod):
image = PIL.Image.fromarray(next(image_iter), 'RGB')
image = image.crop((cx, cy, cx + cw, cy + ch))
image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)
canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 3: Style mixing.
def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):
print(png)
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)
src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)
canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')
for col, src_image in enumerate(list(src_images)):
canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
for row, dst_image in enumerate(list(dst_images)):
canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
for col, image in enumerate(list(row_images)):
canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 4: Noise detail.
def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds):
print(png)
canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white')
for row, seed in enumerate(seeds):
latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples)
images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs)
canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h))
for i in range(4):
crop = PIL.Image.fromarray(images[i + 1], 'RGB')
crop = crop.crop((650, 180, 906, 436))
crop = crop.resize((w//2, h//2), PIL.Image.NEAREST)
canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2))
diff = np.std(np.mean(images, axis=3), axis=0) * 4
diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8)
canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 5: Noise components.
def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips):
print(png)
Gsc = Gs.clone()
noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')]
noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...]
latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
all_images = []
for noise_range in noise_ranges:
tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)})
range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs)
range_images[flips, :, :] = range_images[flips, :, ::-1]
all_images.append(list(range_images))
canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white')
for col, col_images in enumerate(zip(*all_images)):
canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0))
canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0))
canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h))
canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h))
canvas.save(png)
#----------------------------------------------------------------------------
# Figure 8: Truncation trick.
def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis):
print(png)
latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)
dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component]
dlatent_avg = Gs.get_var('dlatent_avg') # [component]
canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white')
for row, dlatent in enumerate(list(dlatents)):
row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg
row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
for col, image in enumerate(list(row_images)):
canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h))
canvas.save(png)
#----------------------------------------------------------------------------
# Main program.
def main():
tflib.init_tf()
os.makedirs(config.result_dir, exist_ok=True)
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5)
draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)])
draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012])
draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1])
draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1])
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0)
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2)
draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1)
#----------------------------------------------------------------------------
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------
================================================
FILE: metrics/__init__.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
# empty
================================================
FILE: metrics/frechet_inception_distance.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Frechet Inception Distance (FID)."""
import os
import numpy as np
import scipy
import tensorflow as tf
import dnnlib.tflib as tflib
from metrics import metric_base
from training import misc
#----------------------------------------------------------------------------
class FID(metric_base.MetricBase):
def __init__(self, num_images, minibatch_per_gpu, **kwargs):
super().__init__(**kwargs)
self.num_images = num_images
self.minibatch_per_gpu = minibatch_per_gpu
def _evaluate(self, Gs, num_gpus):
minibatch_size = num_gpus * self.minibatch_per_gpu
inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl
activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)
# Calculate statistics for reals.
cache_file = self._get_cache_file_for_reals(num_images=self.num_images)
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
if os.path.isfile(cache_file):
mu_real, sigma_real = misc.load_pkl(cache_file)
else:
for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)):
begin = idx * minibatch_size
end = min(begin + minibatch_size, self.num_images)
activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)
if end == self.num_images:
break
mu_real = np.mean(activations, axis=0)
sigma_real = np.cov(activations, rowvar=False)
misc.save_pkl((mu_real, sigma_real), cache_file)
# Construct TensorFlow graph.
result_expr = []
for gpu_idx in range(num_gpus):
with tf.device('/gpu:%d' % gpu_idx):
Gs_clone = Gs.clone()
inception_clone = inception.clone()
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
images = tflib.convert_images_to_uint8(images)
result_expr.append(inception_clone.get_output_for(images))
# Calculate statistics for fakes.
for begin in range(0, self.num_images, minibatch_size):
end = min(begin + minibatch_size, self.num_images)
activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]
mu_fake = np.mean(activations, axis=0)
sigma_fake = np.cov(activations, rowvar=False)
# Calculate FID.
m = np.square(mu_fake - mu_real).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member
dist = m + np.trace(sigma_fake + sigma_real - 2*s)
self._report_result(np.real(dist))
#----------------------------------------------------------------------------
================================================
FILE: metrics/linear_separability.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Linear Separability (LS)."""
from collections import defaultdict
import numpy as np
import sklearn.svm
import tensorflow as tf
import dnnlib.tflib as tflib
from metrics import metric_base
from training import misc
#----------------------------------------------------------------------------
classifier_urls = [
'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl
'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl
'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl
'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl
'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl
'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl
'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl
'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl
'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl
'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl
'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl
'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl
'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl
'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl
'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl
'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl
'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl
'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl
'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl
'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl
'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl
'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl
'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl
'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl
'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl
'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl
'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl
'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl
'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl
'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl
'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl
'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl
'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl
'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl
'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl
'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl
'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl
'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl
'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl
'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl
]
#----------------------------------------------------------------------------
def prob_normalize(p):
p = np.asarray(p).astype(np.float32)
assert len(p.shape) == 2
return p / np.sum(p)
def mutual_information(p):
p = prob_normalize(p)
px = np.sum(p, axis=1)
py = np.sum(p, axis=0)
result = 0.0
for x in range(p.shape[0]):
p_x = px[x]
for y in range(p.shape[1]):
p_xy = p[x][y]
p_y = py[y]
if p_xy > 0.0:
result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output
return result
def entropy(p):
p = prob_normalize(p)
result = 0.0
for x in range(p.shape[0]):
for y in range(p.shape[1]):
p_xy = p[x][y]
if p_xy > 0.0:
result -= p_xy * np.log2(p_xy)
return result
def conditional_entropy(p):
# H(Y|X) where X corresponds to axis 0, Y to axis 1
# i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0?
p = prob_normalize(p)
y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y)
return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up.
#----------------------------------------------------------------------------
class LS(metric_base.MetricBase):
def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs):
assert num_keep <= num_samples
super().__init__(**kwargs)
self.num_samples = num_samples
self.num_keep = num_keep
self.attrib_indices = attrib_indices
self.minibatch_per_gpu = minibatch_per_gpu
def _evaluate(self, Gs, num_gpus):
minibatch_size = num_gpus * self.minibatch_per_gpu
# Construct TensorFlow graph for each GPU.
result_expr = []
for gpu_idx in range(num_gpus):
with tf.device('/gpu:%d' % gpu_idx):
Gs_clone = Gs.clone()
# Generate images.
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True)
images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True)
# Downsample to 256x256. The attribute classifiers were built for 256x256.
if images.shape[2] > 256:
factor = images.shape[2] // 256
images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
images = tf.reduce_mean(images, axis=[3, 5])
# Run classifier for each attribute.
result_dict = dict(latents=latents, dlatents=dlatents[:,-1])
for attrib_idx in self.attrib_indices:
classifier = misc.load_pkl(classifier_urls[attrib_idx])
logits = classifier.get_output_for(images, None)
predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1))
result_dict[attrib_idx] = predictions
result_expr.append(result_dict)
# Sampling loop.
results = []
for _ in range(0, self.num_samples, minibatch_size):
results += tflib.run(result_expr)
results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()}
# Calculate conditional entropy for each attribute.
conditional_entropies = defaultdict(list)
for attrib_idx in self.attrib_indices:
# Prune the least confident samples.
pruned_indices = list(range(self.num_samples))
pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i]))
pruned_indices = pruned_indices[:self.num_keep]
# Fit SVM to the remaining samples.
svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1)
for space in ['latents', 'dlatents']:
svm_inputs = results[space][pruned_indices]
try:
svm = sklearn.svm.LinearSVC()
svm.fit(svm_inputs, svm_targets)
svm.score(svm_inputs, svm_targets)
svm_outputs = svm.predict(svm_inputs)
except:
svm_outputs = svm_targets # assume perfect prediction
# Calculate conditional entropy.
p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)]
conditional_entropies[space].append(conditional_entropy(p))
# Calculate separability scores.
scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()}
self._report_result(scores['latents'], suffix='_z')
self._report_result(scores['dlatents'], suffix='_w')
#----------------------------------------------------------------------------
================================================
FILE: metrics/metric_base.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Common definitions for GAN metrics."""
import os
import time
import hashlib
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib
import config
from training import misc
from training import dataset
#----------------------------------------------------------------------------
# Standard metrics.
fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8)
ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16)
ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16)
ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16)
ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16)
ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4)
dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging
#----------------------------------------------------------------------------
# Base class for metrics.
class MetricBase:
def __init__(self, name):
self.name = name
self._network_pkl = None
self._dataset_args = None
self._mirror_augment = None
self._results = []
self._eval_time = None
def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True):
self._network_pkl = network_pkl
self._dataset_args = dataset_args
self._mirror_augment = mirror_augment
self._results = []
if (dataset_args is None or mirror_augment is None) and run_dir is not None:
run_config = misc.parse_config_for_previous_run(run_dir)
self._dataset_args = dict(run_config['dataset'])
self._dataset_args['shuffle_mb'] = 0
self._mirror_augment = run_config['train'].get('mirror_augment', False)
time_begin = time.time()
with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager
_G, _D, Gs = misc.load_pkl(self._network_pkl)
self._evaluate(Gs, num_gpus=num_gpus)
self._eval_time = time.time() - time_begin
if log_results:
result_str = self.get_result_str()
if run_dir is not None:
log = os.path.join(run_dir, 'metric-%s.txt' % self.name)
with dnnlib.util.Logger(log, 'a'):
print(result_str)
else:
print(result_str)
def get_result_str(self):
network_name = os.path.splitext(os.path.basename(self._network_pkl))[0]
if len(network_name) > 29:
network_name = '...' + network_name[-26:]
result_str = '%-30s' % network_name
result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time)
for res in self._results:
result_str += ' ' + self.name + res.suffix + ' '
result_str += res.fmt % res.value
return result_str
def update_autosummaries(self):
for res in self._results:
tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value)
def _evaluate(self, Gs, num_gpus):
raise NotImplementedError # to be overridden by subclasses
def _report_result(self, value, suffix='', fmt='%-10.4f'):
self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment)
all_args.update(self._dataset_args)
all_args.update(kwargs)
md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))
dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1]
return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension))
def _iterate_reals(self, minibatch_size):
dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args)
while True:
images, _labels = dataset_obj.get_minibatch_np(minibatch_size)
if self._mirror_augment:
images = misc.apply_mirror_augment(images)
yield images
def _iterate_fakes(self, Gs, minibatch_size, num_gpus):
while True:
latents = np.random.randn(minibatch_size, *Gs.input_shape[1:])
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True)
yield images
#----------------------------------------------------------------------------
# Group of multiple metrics.
class MetricGroup:
def __init__(self, metric_kwarg_list):
self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list]
def run(self, *args, **kwargs):
for metric in self.metrics:
metric.run(*args, **kwargs)
def get_result_str(self):
return ' '.join(metric.get_result_str() for metric in self.metrics)
def update_autosummaries(self):
for metric in self.metrics:
metric.update_autosummaries()
#----------------------------------------------------------------------------
# Dummy metric for debugging purposes.
class DummyMetric(MetricBase):
def _evaluate(self, Gs, num_gpus):
_ = Gs, num_gpus
self._report_result(0.0)
#----------------------------------------------------------------------------
================================================
FILE: metrics/perceptual_path_length.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Perceptual Path Length (PPL)."""
import numpy as np
import tensorflow as tf
import dnnlib.tflib as tflib
from metrics import metric_base
from training import misc
#----------------------------------------------------------------------------
# Normalize batch of vectors.
def normalize(v):
return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
# Spherical interpolation of a batch of vectors.
def slerp(a, b, t):
a = normalize(a)
b = normalize(b)
d = tf.reduce_sum(a * b, axis=-1, keepdims=True)
p = t * tf.math.acos(d)
c = normalize(b - d * a)
d = a * tf.math.cos(p) + c * tf.math.sin(p)
return normalize(d)
#----------------------------------------------------------------------------
class PPL(metric_base.MetricBase):
def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs):
assert space in ['z', 'w']
assert sampling in ['full', 'end']
super().__init__(**kwargs)
self.num_samples = num_samples
self.epsilon = epsilon
self.space = space
self.sampling = sampling
self.minibatch_per_gpu = minibatch_per_gpu
def _evaluate(self, Gs, num_gpus):
minibatch_size = num_gpus * self.minibatch_per_gpu
# Construct TensorFlow graph.
distance_expr = []
for gpu_idx in range(num_gpus):
with tf.device('/gpu:%d' % gpu_idx):
Gs_clone = Gs.clone()
noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')]
# Generate random latents and interpolation t-values.
lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:])
lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0)
# Interpolate in W or Z.
if self.space == 'w':
dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True)
dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2]
dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis])
dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon)
dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape)
else: # space == 'z'
lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2]
lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis])
lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon)
lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape)
dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True)
# Synthesize images.
with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch
images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False)
# Crop only the face region.
c = int(images.shape[2] // 8)
images = images[:, :, c*3 : c*7, c*2 : c*6]
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
if images.shape[2] > 256:
factor = images.shape[2] // 256
images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])
images = tf.reduce_mean(images, axis=[3,5])
# Scale dynamic range from [-1,1] to [0,255] for VGG.
images = (images + 1) * (255 / 2)
# Evaluate perceptual distance.
img_e0, img_e1 = images[0::2], images[1::2]
distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl
distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2))
# Sampling loop.
all_distances = []
for _ in range(0, self.num_samples, minibatch_size):
all_distances += tflib.run(distance_expr)
all_distances = np.concatenate(all_distances, axis=0)
# Reject outliers.
lo = np.percentile(all_distances, 1, interpolation='lower')
hi = np.percentile(all_distances, 99, interpolation='higher')
filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances)
self._report_result(np.mean(filtered_distances))
#----------------------------------------------------------------------------
================================================
FILE: pretrained_example.py
================================================
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Minimal script for generating an image using pre-trained StyleGAN generator."""
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config
def main():
# Initialize TensorFlow.
tflib.init_tf()
# Load pre-trained network.
url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
# Print network details.
Gs.prin
gitextract_gssg1h6e/
├── LICENSE.txt
├── README.md
├── config.py
├── dataset_tool.py
├── dnnlib/
│ ├── __init__.py
│ ├── submission/
│ │ ├── __init__.py
│ │ ├── _internal/
│ │ │ └── run.py
│ │ ├── run_context.py
│ │ └── submit.py
│ ├── tflib/
│ │ ├── __init__.py
│ │ ├── autosummary.py
│ │ ├── network.py
│ │ ├── optimizer.py
│ │ └── tfutil.py
│ └── util.py
├── generate_figures.py
├── metrics/
│ ├── __init__.py
│ ├── frechet_inception_distance.py
│ ├── linear_separability.py
│ ├── metric_base.py
│ └── perceptual_path_length.py
├── pretrained_example.py
├── run_metrics.py
├── train.py
└── training/
├── __init__.py
├── dataset.py
├── loss.py
├── misc.py
├── networks_progan.py
├── networks_stylegan.py
└── training_loop.py
SYMBOL INDEX (279 symbols across 23 files)
FILE: dataset_tool.py
function error (line 27) | def error(msg):
class TFRecordExporter (line 33) | class TFRecordExporter:
method __init__ (line 34) | def __init__(self, tfrecord_dir, expected_images, print_progress=True,...
method close (line 51) | def close(self):
method choose_shuffled_order (line 61) | def choose_shuffled_order(self): # Note: Images and labels must be add...
method add_image (line 66) | def add_image(self, img):
method add_labels (line 91) | def add_labels(self, labels):
method __enter__ (line 98) | def __enter__(self):
method __exit__ (line 101) | def __exit__(self, *args):
class ExceptionInfo (line 106) | class ExceptionInfo(object):
method __init__ (line 107) | def __init__(self):
class WorkerThread (line 113) | class WorkerThread(threading.Thread):
method __init__ (line 114) | def __init__(self, task_queue):
method run (line 118) | def run(self):
class ThreadPool (line 131) | class ThreadPool(object):
method __init__ (line 132) | def __init__(self, num_threads):
method add_task (line 142) | def add_task(self, func, args=()):
method get_result (line 148) | def get_result(self, func): # returns (result, args)
method finish (line 155) | def finish(self):
method __enter__ (line 159) | def __enter__(self): # for 'with' statement
method __exit__ (line 162) | def __exit__(self, *excinfo):
method process_items_concurrently (line 165) | def process_items_concurrently(self, item_iterator, process_func=lambd...
function display (line 193) | def display(tfrecord_dir):
function extract (line 219) | def extract(tfrecord_dir, output_dir):
function compare (line 246) | def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
function create_mnist (line 289) | def create_mnist(tfrecord_dir, mnist_dir):
function create_mnistrgb (line 313) | def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_...
function create_cifar10 (line 330) | def create_cifar10(tfrecord_dir, cifar10_dir):
function create_cifar100 (line 357) | def create_cifar100(tfrecord_dir, cifar100_dir):
function create_svhn (line 379) | def create_svhn(tfrecord_dir, svhn_dir):
function create_lsun (line 406) | def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None):
function create_lsun_wide (line 439) | def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_...
function create_celeba (line 484) | def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121):
function create_from_images (line 503) | def create_from_images(tfrecord_dir, image_dir, shuffle):
function create_from_hdf5 (line 531) | def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle):
function execute_cmdline (line 546) | def execute_cmdline(argv):
FILE: dnnlib/submission/_internal/run.py
function main (line 22) | def main():
FILE: dnnlib/submission/run_context.py
class RunContext (line 22) | class RunContext(object):
method __init__ (line 35) | def __init__(self, submit_config: submit.SubmitConfig, config_module: ...
method __enter__ (line 55) | def __enter__(self) -> "RunContext":
method __exit__ (line 58) | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> N...
method update (line 61) | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = N...
method should_stop (line 74) | def should_stop(self) -> bool:
method get_time_since_start (line 78) | def get_time_since_start(self) -> float:
method get_time_since_last_update (line 82) | def get_time_since_last_update(self) -> float:
method get_last_update_interval (line 86) | def get_last_update_interval(self) -> float:
method close (line 90) | def close(self) -> None:
FILE: dnnlib/submission/submit.py
class SubmitTarget (line 30) | class SubmitTarget(Enum):
class PathType (line 38) | class PathType(Enum):
class SubmitConfig (line 53) | class SubmitConfig(util.EasyDict):
method __init__ (line 75) | def __init__(self):
function get_path_from_template (line 101) | def get_path_from_template(path_template: str, path_type: PathType = Pat...
function get_template_from_path (line 123) | def get_template_from_path(path: str) -> str:
function convert_path (line 130) | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
function set_user_name_override (line 137) | def set_user_name_override(name: str) -> None:
function get_user_name (line 143) | def get_user_name():
function _create_run_dir_local (line 159) | def _create_run_dir_local(submit_config: SubmitConfig) -> str:
function _get_next_run_id_local (line 180) | def _get_next_run_id_local(run_dir_root: str) -> int:
function _populate_run_dir (line 196) | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
function run_wrapper (line 224) | def run_wrapper(submit_config: SubmitConfig) -> None:
function submit_run (line 263) | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_fu...
FILE: dnnlib/tflib/autosummary.py
function _create_var (line 42) | def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
function autosummary (line 74) | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpression...
function finalize_autosummaries (line 112) | def finalize_autosummaries() -> None:
function save_summaries (line 170) | def save_summaries(file_writer, global_step=None):
FILE: dnnlib/tflib/network.py
function import_handler (line 30) | def import_handler(handler_func):
class Network (line 36) | class Network:
method __init__ (line 74) | def __init__(self, name: str = None, func_name: Any = None, **static_k...
method _init_fields (line 101) | def _init_fields(self) -> None:
method _init_graph (line 126) | def _init_graph(self) -> None:
method reset_own_vars (line 188) | def reset_own_vars(self) -> None:
method reset_vars (line 192) | def reset_vars(self) -> None:
method reset_trainables (line 196) | def reset_trainables(self) -> None:
method get_output_for (line 200) | def get_output_for(self, *in_expr: TfExpression, return_as_list: bool ...
method get_var_local_name (line 235) | def get_var_local_name(self, var_or_global_name: Union[TfExpression, s...
method find_var (line 241) | def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfE...
method get_var (line 246) | def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.n...
method set_var (line 251) | def set_var(self, var_or_local_name: Union[TfExpression, str], new_val...
method __getstate__ (line 256) | def __getstate__(self) -> dict:
method __setstate__ (line 268) | def __setstate__(self, state: dict) -> None:
method clone (line 302) | def clone(self, name: str = None, **new_static_kwargs) -> "Network":
method copy_own_vars_from (line 317) | def copy_own_vars_from(self, src_net: "Network") -> None:
method copy_vars_from (line 322) | def copy_vars_from(self, src_net: "Network") -> None:
method copy_trainables_from (line 327) | def copy_trainables_from(self, src_net: "Network") -> None:
method convert (line 332) | def convert(self, new_func_name: str, new_name: str = None, **new_stat...
method setup_as_moving_average_of (line 342) | def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpre...
method run (line 354) | def run(self,
method list_ops (line 456) | def list_ops(self) -> List[TfExpression]:
method list_layers (line 464) | def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpressi...
method print_layers (line 507) | def print_layers(self, title: str = None, hide_layers_with_no_params: ...
method setup_weight_histograms (line 536) | def setup_weight_histograms(self, title: str = None) -> None:
function _handle_legacy_output_transforms (line 556) | def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
function _legacy_output_transform_func (line 576) | def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_s...
FILE: dnnlib/tflib/optimizer.py
class Optimizer (line 29) | class Optimizer:
method __init__ (line 40) | def __init__(self,
method register_gradients (line 67) | def register_gradients(self, loss: TfExpression, trainable_vars: Union...
method apply_updates (line 102) | def apply_updates(self) -> tf.Operation:
method reset_optimizer_state (line 182) | def reset_optimizer_state(self) -> None:
method get_loss_scaling_var (line 187) | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
method apply_loss_scaling (line 198) | def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
method undo_loss_scaling (line 207) | def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
FILE: dnnlib/tflib/tfutil.py
function run (line 23) | def run(*args, **kwargs) -> Any:
function is_tf_expression (line 29) | def is_tf_expression(x: Any) -> bool:
function shape_to_list (line 34) | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
function flatten (line 39) | def flatten(x: TfExpressionEx) -> TfExpression:
function log2 (line 45) | def log2(x: TfExpressionEx) -> TfExpression:
function exp2 (line 51) | def exp2(x: TfExpressionEx) -> TfExpression:
function lerp (line 57) | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfE...
function lerp_clip (line 63) | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -...
function absolute_name_scope (line 69) | def absolute_name_scope(scope: str) -> tf.name_scope:
function absolute_variable_scope (line 74) | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
function _sanitize_tf_config (line 79) | def _sanitize_tf_config(config_dict: dict = None) -> dict:
function init_tf (line 94) | def init_tf(config_dict: dict = None) -> None:
function assert_tf_initialized (line 122) | def assert_tf_initialized():
function create_session (line 128) | def create_session(config_dict: dict = None, force_as_default: bool = Fa...
function init_uninitialized_vars (line 152) | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
function set_vars (line 182) | def set_vars(var_to_value_dict: dict) -> None:
function create_var_with_large_initial_value (line 208) | def create_var_with_large_initial_value(initial_value: np.ndarray, *args...
function convert_images_from_uint8 (line 218) | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
function convert_images_to_uint8 (line 228) | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, s...
FILE: dnnlib/util.py
class EasyDict (line 36) | class EasyDict(dict):
method __getattr__ (line 39) | def __getattr__(self, name: str) -> Any:
method __setattr__ (line 45) | def __setattr__(self, name: str, value: Any) -> None:
method __delattr__ (line 48) | def __delattr__(self, name: str) -> None:
class Logger (line 52) | class Logger(object):
method __init__ (line 55) | def __init__(self, file_name: str = None, file_mode: str = "w", should...
method __enter__ (line 68) | def __enter__(self) -> "Logger":
method __exit__ (line 71) | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> N...
method write (line 74) | def write(self, text: str) -> None:
method flush (line 87) | def flush(self) -> None:
method close (line 94) | def close(self) -> None:
function format_time (line 112) | def format_time(seconds: Union[int, float]) -> str:
function ask_yes_no (line 126) | def ask_yes_no(question: str) -> bool:
function tuple_product (line 136) | def tuple_product(t: Tuple) -> Any:
function get_dtype_and_ctype (line 160) | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
function is_pickleable (line 183) | def is_pickleable(obj: Any) -> bool:
function get_module_from_obj_name (line 195) | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, s...
function get_obj_from_module (line 236) | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
function get_obj_by_name (line 246) | def get_obj_by_name(name: str) -> Any:
function call_func_by_name (line 252) | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
function get_module_dir_by_obj_name (line 260) | def get_module_dir_by_obj_name(obj_name: str) -> str:
function is_top_level_function (line 266) | def is_top_level_function(obj: Any) -> bool:
function get_top_level_function_name (line 271) | def get_top_level_function_name(obj: Any) -> str:
function list_dir_recursively_with_ignore (line 280) | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] =...
function copy_files_and_create_dirs (line 313) | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
function is_url (line 329) | def is_url(obj: Any) -> bool:
function open_url (line 345) | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, ve...
FILE: generate_figures.py
function load_Gs (line 31) | def load_Gs(url):
function draw_uncurated_result_figure (line 41) | def draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, se...
function draw_style_mixing_figure (line 59) | def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_...
function draw_noise_detail_figure (line 83) | def draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds):
function draw_noise_components_figure (line 103) | def draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, fli...
function draw_truncation_trick_figure (line 127) | def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis):
function main (line 144) | def main():
FILE: metrics/frechet_inception_distance.py
class FID (line 21) | class FID(metric_base.MetricBase):
method __init__ (line 22) | def __init__(self, num_images, minibatch_per_gpu, **kwargs):
method _evaluate (line 27) | def _evaluate(self, Gs, num_gpus):
FILE: metrics/linear_separability.py
function prob_normalize (line 66) | def prob_normalize(p):
function mutual_information (line 71) | def mutual_information(p):
function entropy (line 85) | def entropy(p):
function conditional_entropy (line 95) | def conditional_entropy(p):
class LS (line 104) | class LS(metric_base.MetricBase):
method __init__ (line 105) | def __init__(self, num_samples, num_keep, attrib_indices, minibatch_pe...
method _evaluate (line 113) | def _evaluate(self, Gs, num_gpus):
FILE: metrics/metric_base.py
class MetricBase (line 36) | class MetricBase:
method __init__ (line 37) | def __init__(self, name):
method run (line 45) | def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_aug...
method get_result_str (line 72) | def get_result_str(self):
method update_autosummaries (line 83) | def update_autosummaries(self):
method _evaluate (line 87) | def _evaluate(self, Gs, num_gpus):
method _report_result (line 90) | def _report_result(self, value, suffix='', fmt='%-10.4f'):
method _get_cache_file_for_reals (line 93) | def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
method _iterate_reals (line 101) | def _iterate_reals(self, minibatch_size):
method _iterate_fakes (line 109) | def _iterate_fakes(self, Gs, minibatch_size, num_gpus):
class MetricGroup (line 119) | class MetricGroup:
method __init__ (line 120) | def __init__(self, metric_kwarg_list):
method run (line 123) | def run(self, *args, **kwargs):
method get_result_str (line 127) | def get_result_str(self):
method update_autosummaries (line 130) | def update_autosummaries(self):
class DummyMetric (line 137) | class DummyMetric(MetricBase):
method _evaluate (line 138) | def _evaluate(self, Gs, num_gpus):
FILE: metrics/perceptual_path_length.py
function normalize (line 20) | def normalize(v):
function slerp (line 24) | def slerp(a, b, t):
class PPL (line 35) | class PPL(metric_base.MetricBase):
method __init__ (line 36) | def __init__(self, num_samples, epsilon, space, sampling, minibatch_pe...
method _evaluate (line 46) | def _evaluate(self, Gs, num_gpus):
FILE: pretrained_example.py
function main (line 18) | def main():
FILE: run_metrics.py
function run_pickle (line 20) | def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mi...
function run_snapshot (line 32) | def run_snapshot(submit_config, metric_args, run_id, snapshot):
function run_all_snapshots (line 46) | def run_all_snapshots(submit_config, metric_args, run_id):
function main (line 62) | def main():
FILE: train.py
function main (line 177) | def main():
FILE: training/dataset.py
function parse_tfrecord_tf (line 20) | def parse_tfrecord_tf(record):
function parse_tfrecord_np (line 27) | def parse_tfrecord_np(record):
class TFRecordDataset (line 37) | class TFRecordDataset:
method __init__ (line 38) | def __init__(self,
method configure (line 136) | def configure(self, minibatch_size, lod=0):
method get_minibatch_tf (line 145) | def get_minibatch_tf(self): # => images, labels
method get_minibatch_np (line 149) | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels
method get_random_labels_tf (line 156) | def get_random_labels_tf(self, minibatch_size): # => labels
method get_random_labels_np (line 163) | def get_random_labels_np(self, minibatch_size): # => labels
class SyntheticDataset (line 171) | class SyntheticDataset:
method __init__ (line 172) | def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dyn...
method configure (line 190) | def configure(self, minibatch_size, lod=0):
method get_minibatch_tf (line 195) | def get_minibatch_tf(self): # => images, labels
method get_minibatch_np (line 203) | def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels
method get_random_labels_tf (line 209) | def get_random_labels_tf(self, minibatch_size): # => labels
method get_random_labels_np (line 213) | def get_random_labels_np(self, minibatch_size): # => labels
method _generate_images (line 219) | def _generate_images(self, minibatch, lod, shape): # to be overridden ...
method _generate_labels (line 222) | def _generate_labels(self, minibatch): # to be overridden by subclasses
function load_dataset (line 228) | def load_dataset(class_name='training.dataset.TFRecordDataset', data_dir...
FILE: training/loss.py
function fp32 (line 17) | def fp32(*values):
function G_wgan (line 26) | def G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=u...
function D_wgan (line 34) | def D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pyl...
function D_wgan_gp (line 50) | def D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # ...
function D_hinge (line 83) | def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # p...
function D_hinge_gp (line 93) | def D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, #...
function G_logistic_saturating (line 123) | def G_logistic_saturating(G, D, opt, training_set, minibatch_size): # py...
function G_logistic_nonsaturating (line 131) | def G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): #...
function D_logistic (line 139) | def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): ...
function D_logistic_simplegp (line 150) | def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, ...
FILE: training/misc.py
function open_file_or_url (line 26) | def open_file_or_url(file_or_url):
function load_pkl (line 31) | def load_pkl(file_or_url):
function save_pkl (line 35) | def save_pkl(obj, filename):
function adjust_dynamic_range (line 42) | def adjust_dynamic_range(data, drange_in, drange_out):
function create_image_grid (line 49) | def create_image_grid(images, grid_size=None):
function convert_to_pil_image (line 66) | def convert_to_pil_image(image, drange=[0,1]):
function save_image (line 79) | def save_image(image, filename, drange=[0,1], quality=95):
function save_image_grid (line 86) | def save_image_grid(images, filename, drange=[0,1], grid_size=None):
function locate_run_dir (line 92) | def locate_run_dir(run_id_or_run_dir):
function list_network_pkls (line 113) | def list_network_pkls(run_id_or_run_dir, include_final=True):
function locate_network_pkl (line 122) | def locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_net...
function get_id_string_for_network_pkl (line 145) | def get_id_string_for_network_pkl(network_pkl):
function load_network_pkl (line 152) | def load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_netwo...
function parse_config_for_previous_run (line 155) | def parse_config_for_previous_run(run_id):
function load_dataset_for_previous_run (line 180) | def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, m...
function apply_mirror_augment (line 187) | def apply_mirror_augment(minibatch):
function setup_snapshot_image_grid (line 197) | def setup_snapshot_image_grid(G, training_set,
FILE: training/networks_progan.py
function lerp (line 18) | def lerp(a, b, t): return a + (b - a) * t
function lerp_clip (line 19) | def lerp_clip(a, b, t): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
function cset (line 20) | def cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_c...
function get_weight (line 25) | def get_weight(shape, gain=np.sqrt(2), use_wscale=False):
function dense (line 38) | def dense(x, fmaps, gain=np.sqrt(2), use_wscale=False):
function conv2d (line 48) | def conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):
function apply_bias (line 57) | def apply_bias(x):
function leaky_relu (line 67) | def leaky_relu(x, alpha=0.2):
function upscale2d (line 75) | def upscale2d(x, factor=2):
function upscale2d_conv2d (line 89) | def upscale2d_conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):
function downscale2d (line 102) | def downscale2d(x, factor=2):
function conv2d_downscale2d (line 113) | def conv2d_downscale2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=Fal...
function pixel_norm (line 124) | def pixel_norm(x, epsilon=1e-8):
function minibatch_stddev_layer (line 131) | def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
function G_paper (line 149) | def G_paper(
function D_paper (line 238) | def D_paper(
FILE: training/networks_stylegan.py
function _blur2d (line 22) | def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):
function _upscale2d (line 51) | def _upscale2d(x, factor=2, gain=1):
function _downscale2d (line 70) | def _downscale2d(x, factor=2, gain=1):
function blur2d (line 96) | def blur2d(x, f=[1,2,1], normalize=True):
function upscale2d (line 108) | def upscale2d(x, factor=2):
function downscale2d (line 120) | def downscale2d(x, factor=2):
function get_weight (line 135) | def get_weight(shape, gain=np.sqrt(2), use_wscale=False, lrmul=1):
function dense (line 154) | def dense(x, fmaps, **kwargs):
function conv2d (line 164) | def conv2d(x, fmaps, kernel, **kwargs):
function upscale2d_conv2d (line 174) | def upscale2d_conv2d(x, fmaps, kernel, fused_scale='auto', **kwargs):
function conv2d_downscale2d (line 193) | def conv2d_downscale2d(x, fmaps, kernel, fused_scale='auto', **kwargs):
function apply_bias (line 213) | def apply_bias(x, lrmul=1):
function leaky_relu (line 223) | def leaky_relu(x, alpha=0.2):
function pixel_norm (line 239) | def pixel_norm(x, epsilon=1e-8):
function instance_norm (line 247) | def instance_norm(x, epsilon=1e-8):
function style_mod (line 261) | def style_mod(x, dlatent, **kwargs):
function apply_noise (line 270) | def apply_noise(x, noise_var=None, randomize_noise=True):
function minibatch_stddev_layer (line 283) | def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
function G_style (line 302) | def G_style(
function G_mapping (line 384) | def G_mapping(
function G_synthesis (line 440) | def G_synthesis(
function D_basic (line 564) | def D_basic(
FILE: training/training_loop.py
function process_reals (line 26) | def process_reals(x, lod, mirror_augment, drange_data, drange_net):
function training_schedule (line 55) | def training_schedule(
function training_loop (line 112) | def training_loop(
Condensed preview — 31 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (327K chars).
[
{
"path": "LICENSE.txt",
"chars": 19406,
"preview": "Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n\n\nAttribution-NonCommercial 4.0 International\n\n============"
},
{
"path": "README.md",
"chars": 20710,
"preview": "## StyleGAN — Official TensorFlow Implementation\n 2019, NVIDIA CORPORATION. All rights reserved.\r\n#\r\n# This work is licensed under the Creative Commons At"
},
{
"path": "dataset_tool.py",
"chars": 30421,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "dnnlib/__init__.py",
"chars": 797,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/submission/__init__.py",
"chars": 390,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/submission/_internal/run.py",
"chars": 1543,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/submission/run_context.py",
"chars": 4347,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/submission/submit.py",
"chars": 11185,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "dnnlib/tflib/__init__.py",
"chars": 522,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/tflib/autosummary.py",
"chars": 7535,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/tflib/network.py",
"chars": 30119,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/tflib/optimizer.py",
"chars": 9981,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "dnnlib/tflib/tfutil.py",
"chars": 9265,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "dnnlib/util.py",
"chars": 13756,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "generate_figures.py",
"chars": 9563,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "metrics/__init__.py",
"chars": 350,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "metrics/frechet_inception_distance.py",
"chars": 3335,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "metrics/linear_separability.py",
"chars": 10313,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "metrics/metric_base.py",
"chars": 6479,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "metrics/perceptual_path_length.py",
"chars": 5225,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "pretrained_example.py",
"chars": 1781,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "run_metrics.py",
"chars": 4374,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "train.py",
"chars": 16580,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "training/__init__.py",
"chars": 350,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "training/dataset.py",
"chars": 12220,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "training/loss.py",
"chars": 10416,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "training/misc.py",
"chars": 9998,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "training/networks_progan.py",
"chars": 17575,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
},
{
"path": "training/networks_stylegan.py",
"chars": 33389,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Att"
},
{
"path": "training/training_loop.py",
"chars": 15522,
"preview": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attr"
}
]
About this extraction
This page contains the full source code of the NVlabs/stylegan GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 31 files (310.7 KB), approximately 79.8k tokens, and a symbol index with 279 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.