[
  {
    "path": "LICENSE.txt",
    "content": "Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n\n\nAttribution-NonCommercial 4.0 International\n\n=======================================================================\n\nCreative Commons Corporation (\"Creative Commons\") is not a law firm and\ndoes not provide legal services or legal advice. Distribution of\nCreative Commons public licenses does not create a lawyer-client or\nother relationship. Creative Commons makes its licenses and related\ninformation available on an \"as-is\" basis. Creative Commons gives no\nwarranties regarding its licenses, any material licensed under their\nterms and conditions, or any related information. Creative Commons\ndisclaims all liability for damages resulting from their use to the\nfullest extent possible.\n\nUsing Creative Commons Public Licenses\n\nCreative Commons public licenses provide a standard set of terms and\nconditions that creators and other rights holders may use to share\noriginal works of authorship and other material subject to copyright\nand certain other rights specified in the public license below. The\nfollowing considerations are for informational purposes only, are not\nexhaustive, and do not form part of our licenses.\n\n     Considerations for licensors: Our public licenses are\n     intended for use by those authorized to give the public\n     permission to use material in ways otherwise restricted by\n     copyright and certain other rights. Our licenses are\n     irrevocable. Licensors should read and understand the terms\n     and conditions of the license they choose before applying it.\n     Licensors should also secure all rights necessary before\n     applying our licenses so that the public can reuse the\n     material as expected. Licensors should clearly mark any\n     material not subject to the license. This includes other CC-\n     licensed material, or material used under an exception or\n     limitation to copyright. More considerations for licensors:\n    wiki.creativecommons.org/Considerations_for_licensors\n\n     Considerations for the public: By using one of our public\n     licenses, a licensor grants the public permission to use the\n     licensed material under specified terms and conditions. If\n     the licensor's permission is not necessary for any reason--for\n     example, because of any applicable exception or limitation to\n     copyright--then that use is not regulated by the license. Our\n     licenses grant only permissions under copyright and certain\n     other rights that a licensor has authority to grant. Use of\n     the licensed material may still be restricted for other\n     reasons, including because others have copyright or other\n     rights in the material. A licensor may make special requests,\n     such as asking that all changes be marked or described.\n     Although not required by our licenses, you are encouraged to\n     respect those requests where reasonable. More_considerations\n     for the public: \n    wiki.creativecommons.org/Considerations_for_licensees\n\n=======================================================================\n\nCreative Commons Attribution-NonCommercial 4.0 International Public\nLicense\n\nBy exercising the Licensed Rights (defined below), You accept and agree\nto be bound by the terms and conditions of this Creative Commons\nAttribution-NonCommercial 4.0 International Public License (\"Public\nLicense\"). To the extent this Public License may be interpreted as a\ncontract, You are granted the Licensed Rights in consideration of Your\nacceptance of these terms and conditions, and the Licensor grants You\nsuch rights in consideration of benefits the Licensor receives from\nmaking the Licensed Material available under these terms and\nconditions.\n\n\nSection 1 -- Definitions.\n\n  a. Adapted Material means material subject to Copyright and Similar\n     Rights that is derived from or based upon the Licensed Material\n     and in which the Licensed Material is translated, altered,\n     arranged, transformed, or otherwise modified in a manner requiring\n     permission under the Copyright and Similar Rights held by the\n     Licensor. For purposes of this Public License, where the Licensed\n     Material is a musical work, performance, or sound recording,\n     Adapted Material is always produced where the Licensed Material is\n     synched in timed relation with a moving image.\n\n  b. Adapter's License means the license You apply to Your Copyright\n     and Similar Rights in Your contributions to Adapted Material in\n     accordance with the terms and conditions of this Public License.\n\n  c. Copyright and Similar Rights means copyright and/or similar rights\n     closely related to copyright including, without limitation,\n     performance, broadcast, sound recording, and Sui Generis Database\n     Rights, without regard to how the rights are labeled or\n     categorized. For purposes of this Public License, the rights\n     specified in Section 2(b)(1)-(2) are not Copyright and Similar\n     Rights.\n  d. Effective Technological Measures means those measures that, in the\n     absence of proper authority, may not be circumvented under laws\n     fulfilling obligations under Article 11 of the WIPO Copyright\n     Treaty adopted on December 20, 1996, and/or similar international\n     agreements.\n\n  e. Exceptions and Limitations means fair use, fair dealing, and/or\n     any other exception or limitation to Copyright and Similar Rights\n     that applies to Your use of the Licensed Material.\n\n  f. Licensed Material means the artistic or literary work, database,\n     or other material to which the Licensor applied this Public\n     License.\n\n  g. Licensed Rights means the rights granted to You subject to the\n     terms and conditions of this Public License, which are limited to\n     all Copyright and Similar Rights that apply to Your use of the\n     Licensed Material and that the Licensor has authority to license.\n\n  h. Licensor means the individual(s) or entity(ies) granting rights\n     under this Public License.\n\n  i. NonCommercial means not primarily intended for or directed towards\n     commercial advantage or monetary compensation. For purposes of\n     this Public License, the exchange of the Licensed Material for\n     other material subject to Copyright and Similar Rights by digital\n     file-sharing or similar means is NonCommercial provided there is\n     no payment of monetary compensation in connection with the\n     exchange.\n\n  j. Share means to provide material to the public by any means or\n     process that requires permission under the Licensed Rights, such\n     as reproduction, public display, public performance, distribution,\n     dissemination, communication, or importation, and to make material\n     available to the public including in ways that members of the\n     public may access the material from a place and at a time\n     individually chosen by them.\n\n  k. Sui Generis Database Rights means rights other than copyright\n     resulting from Directive 96/9/EC of the European Parliament and of\n     the Council of 11 March 1996 on the legal protection of databases,\n     as amended and/or succeeded, as well as other essentially\n     equivalent rights anywhere in the world.\n\n  l. You means the individual or entity exercising the Licensed Rights\n     under this Public License. Your has a corresponding meaning.\n\n\nSection 2 -- Scope.\n\n  a. License grant.\n\n       1. Subject to the terms and conditions of this Public License,\n          the Licensor hereby grants You a worldwide, royalty-free,\n          non-sublicensable, non-exclusive, irrevocable license to\n          exercise the Licensed Rights in the Licensed Material to:\n\n            a. reproduce and Share the Licensed Material, in whole or\n               in part, for NonCommercial purposes only; and\n\n            b. produce, reproduce, and Share Adapted Material for\n               NonCommercial purposes only.\n\n       2. Exceptions and Limitations. For the avoidance of doubt, where\n          Exceptions and Limitations apply to Your use, this Public\n          License does not apply, and You do not need to comply with\n          its terms and conditions.\n\n       3. Term. The term of this Public License is specified in Section\n          6(a).\n\n       4. Media and formats; technical modifications allowed. The\n          Licensor authorizes You to exercise the Licensed Rights in\n          all media and formats whether now known or hereafter created,\n          and to make technical modifications necessary to do so. The\n          Licensor waives and/or agrees not to assert any right or\n          authority to forbid You from making technical modifications\n          necessary to exercise the Licensed Rights, including\n          technical modifications necessary to circumvent Effective\n          Technological Measures. For purposes of this Public License,\n          simply making modifications authorized by this Section 2(a)\n          (4) never produces Adapted Material.\n\n       5. Downstream recipients.\n\n            a. Offer from the Licensor -- Licensed Material. Every\n               recipient of the Licensed Material automatically\n               receives an offer from the Licensor to exercise the\n               Licensed Rights under the terms and conditions of this\n               Public License.\n\n            b. No downstream restrictions. You may not offer or impose\n               any additional or different terms or conditions on, or\n               apply any Effective Technological Measures to, the\n               Licensed Material if doing so restricts exercise of the\n               Licensed Rights by any recipient of the Licensed\n               Material.\n\n       6. No endorsement. Nothing in this Public License constitutes or\n          may be construed as permission to assert or imply that You\n          are, or that Your use of the Licensed Material is, connected\n          with, or sponsored, endorsed, or granted official status by,\n          the Licensor or others designated to receive attribution as\n          provided in Section 3(a)(1)(A)(i).\n\n  b. Other rights.\n\n       1. Moral rights, such as the right of integrity, are not\n          licensed under this Public License, nor are publicity,\n          privacy, and/or other similar personality rights; however, to\n          the extent possible, the Licensor waives and/or agrees not to\n          assert any such rights held by the Licensor to the limited\n          extent necessary to allow You to exercise the Licensed\n          Rights, but not otherwise.\n\n       2. Patent and trademark rights are not licensed under this\n          Public License.\n\n       3. To the extent possible, the Licensor waives any right to\n          collect royalties from You for the exercise of the Licensed\n          Rights, whether directly or through a collecting society\n          under any voluntary or waivable statutory or compulsory\n          licensing scheme. In all other cases the Licensor expressly\n          reserves any right to collect such royalties, including when\n          the Licensed Material is used other than for NonCommercial\n          purposes.\n\n\nSection 3 -- License Conditions.\n\nYour exercise of the Licensed Rights is expressly made subject to the\nfollowing conditions.\n\n  a. Attribution.\n\n       1. If You Share the Licensed Material (including in modified\n          form), You must:\n\n            a. retain the following if it is supplied by the Licensor\n               with the Licensed Material:\n\n                 i. identification of the creator(s) of the Licensed\n                    Material and any others designated to receive\n                    attribution, in any reasonable manner requested by\n                    the Licensor (including by pseudonym if\n                    designated);\n\n                ii. a copyright notice;\n\n               iii. a notice that refers to this Public License;\n\n                iv. a notice that refers to the disclaimer of\n                    warranties;\n\n                 v. a URI or hyperlink to the Licensed Material to the\n                    extent reasonably practicable;\n\n            b. indicate if You modified the Licensed Material and\n               retain an indication of any previous modifications; and\n\n            c. indicate the Licensed Material is licensed under this\n               Public License, and include the text of, or the URI or\n               hyperlink to, this Public License.\n\n       2. You may satisfy the conditions in Section 3(a)(1) in any\n          reasonable manner based on the medium, means, and context in\n          which You Share the Licensed Material. For example, it may be\n          reasonable to satisfy the conditions by providing a URI or\n          hyperlink to a resource that includes the required\n          information.\n\n       3. If requested by the Licensor, You must remove any of the\n          information required by Section 3(a)(1)(A) to the extent\n          reasonably practicable.\n\n       4. If You Share Adapted Material You produce, the Adapter's\n          License You apply must not prevent recipients of the Adapted\n          Material from complying with this Public License.\n\n\nSection 4 -- Sui Generis Database Rights.\n\nWhere the Licensed Rights include Sui Generis Database Rights that\napply to Your use of the Licensed Material:\n\n  a. for the avoidance of doubt, Section 2(a)(1) grants You the right\n     to extract, reuse, reproduce, and Share all or a substantial\n     portion of the contents of the database for NonCommercial purposes\n     only;\n\n  b. if You include all or a substantial portion of the database\n     contents in a database in which You have Sui Generis Database\n     Rights, then the database in which You have Sui Generis Database\n     Rights (but not its individual contents) is Adapted Material; and\n\n  c. You must comply with the conditions in Section 3(a) if You Share\n     all or a substantial portion of the contents of the database.\n\nFor the avoidance of doubt, this Section 4 supplements and does not\nreplace Your obligations under this Public License where the Licensed\nRights include other Copyright and Similar Rights.\n\n\nSection 5 -- Disclaimer of Warranties and Limitation of Liability.\n\n  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE\n     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS\n     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF\n     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,\n     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,\n     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR\n     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,\n     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT\n     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT\n     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.\n\n  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE\n     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,\n     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,\n     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,\n     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR\n     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN\n     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR\n     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR\n     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.\n\n  c. The disclaimer of warranties and limitation of liability provided\n     above shall be interpreted in a manner that, to the extent\n     possible, most closely approximates an absolute disclaimer and\n     waiver of all liability.\n\n\nSection 6 -- Term and Termination.\n\n  a. This Public License applies for the term of the Copyright and\n     Similar Rights licensed here. However, if You fail to comply with\n     this Public License, then Your rights under this Public License\n     terminate automatically.\n\n  b. Where Your right to use the Licensed Material has terminated under\n     Section 6(a), it reinstates:\n\n       1. automatically as of the date the violation is cured, provided\n          it is cured within 30 days of Your discovery of the\n          violation; or\n\n       2. upon express reinstatement by the Licensor.\n\n     For the avoidance of doubt, this Section 6(b) does not affect any\n     right the Licensor may have to seek remedies for Your violations\n     of this Public License.\n\n  c. For the avoidance of doubt, the Licensor may also offer the\n     Licensed Material under separate terms or conditions or stop\n     distributing the Licensed Material at any time; however, doing so\n     will not terminate this Public License.\n\n  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public\n     License.\n\n\nSection 7 -- Other Terms and Conditions.\n\n  a. The Licensor shall not be bound by any additional or different\n     terms or conditions communicated by You unless expressly agreed.\n\n  b. Any arrangements, understandings, or agreements regarding the\n     Licensed Material not stated herein are separate from and\n     independent of the terms and conditions of this Public License.\n\n\nSection 8 -- Interpretation.\n\n  a. For the avoidance of doubt, this Public License does not, and\n     shall not be interpreted to, reduce, limit, restrict, or impose\n     conditions on any use of the Licensed Material that could lawfully\n     be made without permission under this Public License.\n\n  b. To the extent possible, if any provision of this Public License is\n     deemed unenforceable, it shall be automatically reformed to the\n     minimum extent necessary to make it enforceable. If the provision\n     cannot be reformed, it shall be severed from this Public License\n     without affecting the enforceability of the remaining terms and\n     conditions.\n\n  c. No term or condition of this Public License will be waived and no\n     failure to comply consented to unless expressly agreed to by the\n     Licensor.\n\n  d. Nothing in this Public License constitutes or may be interpreted\n     as a limitation upon, or waiver of, any privileges and immunities\n     that apply to the Licensor or You, including from the legal\n     processes of any jurisdiction or authority.\n\n=======================================================================\n\nCreative Commons is not a party to its public\nlicenses. Notwithstanding, Creative Commons may elect to apply one of\nits public licenses to material it publishes and in those instances\nwill be considered the \"Licensor.\" The text of the Creative Commons\npublic licenses is dedicated to the public domain under the CC0 Public\nDomain Dedication. Except for the limited purpose of indicating that\nmaterial is shared under a Creative Commons public license or as\notherwise permitted by the Creative Commons policies published at\ncreativecommons.org/policies, Creative Commons does not authorize the\nuse of the trademark \"Creative Commons\" or any other trademark or logo\nof Creative Commons without its prior written consent including,\nwithout limitation, in connection with any unauthorized modifications\nto any of its public licenses or any other arrangements,\nunderstandings, or agreements concerning use of licensed material. For\nthe avoidance of doubt, this paragraph does not form part of the\npublic licenses.\n\nCreative Commons may be contacted at creativecommons.org.\n"
  },
  {
    "path": "README.md",
    "content": "## StyleGAN &mdash; Official TensorFlow Implementation\n![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg?style=plastic)\n![TensorFlow 1.10](https://img.shields.io/badge/tensorflow-1.10-green.svg?style=plastic)\n![cuDNN 7.3.1](https://img.shields.io/badge/cudnn-7.3.1-green.svg?style=plastic)\n![License CC BY-NC](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=plastic)\n\n![Teaser image](./stylegan-teaser.png)\n**Picture:** *These people are not real &ndash; they were produced by our generator that allows control over different aspects of the image.*\n\nThis repository contains the official TensorFlow implementation of the following paper:\n\n> **A Style-Based Generator Architecture for Generative Adversarial Networks**<br>\n> Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)<br>\n> https://arxiv.org/abs/1812.04948\n>\n> **Abstract:** *We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.*\n\nFor business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)\n\n**&#9733;&#9733;&#9733; NEW: [StyleGAN2-ADA-PyTorch](https://github.com/NVlabs/stylegan2-ada-pytorch) is now available; see the full list of versions [here](https://nvlabs.github.io/stylegan2/versions.html) &#9733;&#9733;&#9733;**\n\n## Resources\n\nMaterial related to our paper is available via the following links:\n\n- Paper: https://arxiv.org/abs/1812.04948\n- Video: https://youtu.be/kSLJriaOumA\n- Code: https://github.com/NVlabs/stylegan\n- FFHQ: https://github.com/NVlabs/ffhq-dataset\n\nAdditional material can be found on Google Drive:\n\n| Path | Description\n| :--- | :----------\n| [StyleGAN](https://drive.google.com/open?id=1uka3a1noXHAydRPRbknqwKVGODvnmUBX) | Main folder.\n| &boxvr;&nbsp; [stylegan-paper.pdf](https://drive.google.com/open?id=1v-HkF3Ehrpon7wVIx4r5DLcko_U_V6Lt) | High-quality version of the paper PDF.\n| &boxvr;&nbsp; [stylegan-video.mp4](https://drive.google.com/open?id=1uzwkZHQX_9pYg1i0d1Nbe3D9xPO8-qBf) | High-quality version of the result video.\n| &boxvr;&nbsp; [images](https://drive.google.com/open?id=1-l46akONUWF6LCpDoeq63H53rD7MeiTd) | Example images produced using our generator.\n| &boxv;&nbsp; &boxvr;&nbsp; [representative-images](https://drive.google.com/open?id=1ToY5P4Vvf5_c3TyUizQ8fckFFoFtBvD8) | High-quality images to be used in articles, blog posts, etc.\n| &boxv;&nbsp; &boxur;&nbsp; [100k-generated-images](https://drive.google.com/open?id=100DJ0QXyG89HZzB4w2Cbyf4xjNK54cQ1) | 100,000 generated images for different amounts of truncation.\n| &boxv;&nbsp; &ensp;&ensp; &boxvr;&nbsp; [ffhq-1024x1024](https://drive.google.com/open?id=14lm8VRN1pr4g_KVe6_LvyDX1PObst6d4) | Generated using Flickr-Faces-HQ dataset at 1024&times;1024.\n| &boxv;&nbsp; &ensp;&ensp; &boxvr;&nbsp; [bedrooms-256x256](https://drive.google.com/open?id=1Vxz9fksw4kgjiHrvHkX4Hze4dyThFW6t) | Generated using LSUN Bedroom dataset at 256&times;256.\n| &boxv;&nbsp; &ensp;&ensp; &boxvr;&nbsp; [cars-512x384](https://drive.google.com/open?id=1MFCvOMdLE2_mpeLPTiDw5dxc2CRuKkzS) | Generated using LSUN Car dataset at 512&times;384.\n| &boxv;&nbsp; &ensp;&ensp; &boxur;&nbsp; [cats-256x256](https://drive.google.com/open?id=1gq-Gj3GRFiyghTPKhp8uDMA9HV_0ZFWQ) | Generated using LSUN Cat dataset at 256&times;256.\n| &boxvr;&nbsp; [videos](https://drive.google.com/open?id=1N8pOd_Bf8v89NGUaROdbD8-ayLPgyRRo) | Example videos produced using our generator.\n| &boxv;&nbsp; &boxur;&nbsp; [high-quality-video-clips](https://drive.google.com/open?id=1NFO7_vH0t98J13ckJYFd7kuaTkyeRJ86) | Individual segments of the result video as high-quality MP4.\n| &boxvr;&nbsp; [ffhq-dataset](https://drive.google.com/open?id=1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP) | Raw data for the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset).\n| &boxur;&nbsp; [networks](https://drive.google.com/open?id=1MASQyN5m0voPcx7-9K0r5gObhvvPups7) | Pre-trained networks as pickled instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py).\n| &ensp;&ensp; &boxvr;&nbsp; [stylegan-ffhq-1024x1024.pkl](https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ) | StyleGAN trained with Flickr-Faces-HQ dataset at 1024&times;1024.\n| &ensp;&ensp; &boxvr;&nbsp; [stylegan-celebahq-1024x1024.pkl](https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf) | StyleGAN trained with CelebA-HQ dataset at 1024&times;1024.\n| &ensp;&ensp; &boxvr;&nbsp; [stylegan-bedrooms-256x256.pkl](https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF) | StyleGAN trained with LSUN Bedroom dataset at 256&times;256.\n| &ensp;&ensp; &boxvr;&nbsp; [stylegan-cars-512x384.pkl](https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3) | StyleGAN trained with LSUN Car dataset at 512&times;384.\n| &ensp;&ensp; &boxvr;&nbsp; [stylegan-cats-256x256.pkl](https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ) | StyleGAN trained with LSUN Cat dataset at 256&times;256.\n| &ensp;&ensp; &boxur;&nbsp; [metrics](https://drive.google.com/open?id=1MvYdWCBuMfnoYGptRH-AgKLbPTsIQLhl) | Auxiliary networks for the quality and disentanglement metrics.\n| &ensp;&ensp; &ensp;&ensp; &boxvr;&nbsp; [inception_v3_features.pkl](https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn) | Standard [Inception-v3](https://arxiv.org/abs/1512.00567) classifier that outputs a raw feature vector.\n| &ensp;&ensp; &ensp;&ensp; &boxvr;&nbsp; [vgg16_zhang_perceptual.pkl](https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2) | Standard [LPIPS](https://arxiv.org/abs/1801.03924) metric to estimate perceptual similarity.\n| &ensp;&ensp; &ensp;&ensp; &boxvr;&nbsp; [celebahq-classifier-00-male.pkl](https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX) | Binary classifier trained to detect a single attribute of CelebA-HQ.\n| &ensp;&ensp; &ensp;&ensp; &boxur;&nbsp;&#x22ef; | Please see the file listing for remaining networks.\n\n## Licenses\n\nAll material, excluding the Flickr-Faces-HQ dataset, is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made.\n\nFor license information regarding the FFHQ dataset, please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).\n\n`inception_v3_features.pkl` and `inception_v3_softmax.pkl` are derived from the pre-trained [Inception-v3](https://arxiv.org/abs/1512.00567) network by Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna. The network was originally shared under [Apache 2.0](https://github.com/tensorflow/models/blob/master/LICENSE) license on the [TensorFlow Models](https://github.com/tensorflow/models) repository.\n\n`vgg16.pkl` and `vgg16_zhang_perceptual.pkl` are derived from the pre-trained [VGG-16](https://arxiv.org/abs/1409.1556) network by Karen Simonyan and Andrew Zisserman. The network was originally shared under [Creative Commons BY 4.0](https://creativecommons.org/licenses/by/4.0/) license on the [Very Deep Convolutional Networks for Large-Scale Visual Recognition](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) project page.\n\n`vgg16_zhang_perceptual.pkl` is further derived from the pre-trained [LPIPS](https://arxiv.org/abs/1801.03924) weights by Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, and Oliver Wang. The weights were originally shared under [BSD 2-Clause \"Simplified\" License](https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE) on the [PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity) repository.\n\n## System requirements\n\n* Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.\n* 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer.\n* TensorFlow 1.10.0 or newer with GPU support.\n* One or more high-end NVIDIA GPUs with at least 11GB of DRAM. We recommend NVIDIA DGX-1 with 8 Tesla V100 GPUs.\n* NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer.\n\n## Using pre-trained networks\n\nA minimal example of using a pre-trained StyleGAN generator is given in [pretrained_example.py](./pretrained_example.py). When executed, the script downloads a pre-trained StyleGAN generator from Google Drive and uses it to generate an image:\n\n```\n> python pretrained_example.py\nDownloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done\n\nGs                              Params    OutputShape          WeightShape\n---                             ---       ---                  ---\nlatents_in                      -         (?, 512)             -\n...\nimages_out                      -         (?, 3, 1024, 1024)   -\n---                             ---       ---                  ---\nTotal                           26219627\n\n> ls results\nexample.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP\n```\n\nA more advanced example is given in [generate_figures.py](./generate_figures.py). The script reproduces the figures from our paper in order to illustrate style mixing, noise inputs, and truncation:\n```\n> python generate_figures.py\nresults/figure02-uncurated-ffhq.png     # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu\nresults/figure03-style-mixing.png       # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6\nresults/figure04-noise-detail.png       # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG\nresults/figure05-noise-components.png   # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_\nresults/figure08-truncation-trick.png   # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v\nresults/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr\nresults/figure11-uncurated-cars.png     # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke\nresults/figure12-uncurated-cats.png     # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W\n```\n\nThe pre-trained networks are stored as standard pickle files on Google Drive:\n\n```\n# Load pre-trained network.\nurl = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl\nwith dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:\n    _G, _D, Gs = pickle.load(f)\n    # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.\n    # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.\n    # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.\n```\n\nThe above code downloads the file and unpickles it to yield 3 instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py). To generate images, you will typically want to use `Gs` &ndash; the other two networks are provided for completeness. In order for `pickle.load()` to work, you will need to have the `dnnlib` source directory in your PYTHONPATH and a `tf.Session` set as default. The session can initialized by calling `dnnlib.tflib.init_tf()`.\n\nThere are three ways to use the pre-trained generator:\n\n1. Use `Gs.run()` for immediate-mode operation where the inputs and outputs are numpy arrays:\n   ```\n   # Pick latent vector.\n   rnd = np.random.RandomState(5)\n   latents = rnd.randn(1, Gs.input_shape[1])\n\n   # Generate image.\n   fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)\n   images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)\n   ```\n   The first argument is a batch of latent vectors of shape `[num, 512]`. The second argument is reserved for class labels (not used by StyleGAN). The remaining keyword arguments are optional and can be used to further modify the operation (see below). The output is a batch of images, whose format is dictated by the `output_transform` argument.\n\n2. Use `Gs.get_output_for()` to incorporate the generator as a part of a larger TensorFlow expression:\n   ```\n   latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])\n   images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)\n   images = tflib.convert_images_to_uint8(images)\n   result_expr.append(inception_clone.get_output_for(images))\n   ```\n   The above code is from [metrics/frechet_inception_distance.py](./metrics/frechet_inception_distance.py). It generates a batch of random images and feeds them directly to the [Inception-v3](https://arxiv.org/abs/1512.00567) network without having to convert the data to numpy arrays in between.\n\n3. Look up `Gs.components.mapping` and `Gs.components.synthesis` to access individual sub-networks of the generator. Similar to `Gs`, the sub-networks are represented as independent instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py):\n   ```\n   src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)\n   src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]\n   src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)\n   ```\n   The above code is from [generate_figures.py](./generate_figures.py). It first transforms a batch of latent vectors into the intermediate *W* space using the mapping network and then turns these vectors into a batch of images using the synthesis network. The `dlatents` array stores a separate copy of the same *w* vector for each layer of the synthesis network to facilitate style mixing.\n\nThe exact details of the generator are defined in [training/networks_stylegan.py](./training/networks_stylegan.py) (see `G_style`, `G_mapping`, and `G_synthesis`). The following keyword arguments can be specified to modify the behavior when calling `run()` and `get_output_for()`:\n\n* `truncation_psi` and `truncation_cutoff` control the truncation trick that that is performed by default when using `Gs` (&psi;=0.7, cutoff=8). It can be disabled by setting `truncation_psi=1` or `is_validation=True`, and the image quality can be further improved at the cost of variation by setting e.g. `truncation_psi=0.5`. Note that truncation is always disabled when using the sub-networks directly. The average *w* needed to manually perform the truncation trick can be looked up using `Gs.get_var('dlatent_avg')`.\n\n* `randomize_noise` determines whether to use re-randomize the noise inputs for each generated image (`True`, default) or whether to use specific noise values for the entire minibatch (`False`). The specific values can be accessed via the `tf.Variable` instances that are found using `[var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]`.\n\n* When using the mapping network directly, you can specify `dlatent_broadcast=None` to disable the automatic duplication of `dlatents` over the layers of the synthesis network.\n\n* Runtime performance can be fine-tuned via `structure='fixed'` and `dtype='float16'`. The former disables support for progressive growing, which is not needed for a fully-trained generator, and the latter performs all computation using half-precision floating point arithmetic.\n\n## Preparing datasets for training\n\nThe training and evaluation scripts operate on datasets stored as multi-resolution TFRecords. Each dataset is represented by a directory containing the same image data in several resolutions to enable efficient streaming. There is a separate *.tfrecords file for each resolution, and if the dataset contains labels, they are stored in a separate file as well. By default, the scripts expect to find the datasets at `datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords`. The directory can be changed by editing [config.py](./config.py):\n\n```\nresult_dir = 'results'\ndata_dir = 'datasets'\ncache_dir = 'cache'\n```\n\nTo obtain the FFHQ dataset (`datasets/ffhq`), please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).\n\nTo obtain the CelebA-HQ dataset (`datasets/celebahq`), please refer to the [Progressive GAN repository](https://github.com/tkarras/progressive_growing_of_gans).\n\nTo obtain other datasets, including LSUN, please consult their corresponding project pages. The datasets can be converted to multi-resolution TFRecords using the provided [dataset_tool.py](./dataset_tool.py):\n\n```\n> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256\n> python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384\n> python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256\n> python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10\n> python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images\n```\n\n## Training networks\n\nOnce the datasets are set up, you can train your own StyleGAN networks as follows:\n\n1. Edit [train.py](./train.py) to specify the dataset and training configuration by uncommenting or editing specific lines.\n2. Run the training script with `python train.py`.\n3. The results are written to a newly created directory `results/<ID>-<DESCRIPTION>`.\n4. The training may take several days (or weeks) to complete, depending on the configuration.\n\nBy default, `train.py` is configured to train the highest-quality StyleGAN (configuration F in Table 1) for the FFHQ dataset at 1024&times;1024 resolution using 8 GPUs. Please note that we have used 8 GPUs in all of our experiments. Training with fewer GPUs may not produce identical results &ndash; if you wish to compare against our technique, we strongly recommend using the same number of GPUs.\n\nExpected training times for the default configuration using Tesla V100 GPUs:\n\n| GPUs | 1024&times;1024  | 512&times;512    | 256&times;256    |\n| :--- | :--------------  | :------------    | :------------    |\n| 1    | 41 days 4 hours  | 24 days 21 hours | 14 days 22 hours |\n| 2    | 21 days 22 hours | 13 days 7 hours  | 9 days 5 hours   |\n| 4    | 11 days 8 hours  | 7 days 0 hours   | 4 days 21 hours  |\n| 8    | 6 days 14 hours  | 4 days 10 hours  | 3 days 8 hours   |\n\n## Evaluating quality and disentanglement\n\nThe quality and disentanglement metrics used in our paper can be evaluated using [run_metrics.py](./run_metrics.py). By default, the script will evaluate the Fr&eacute;chet Inception Distance (`fid50k`) for the pre-trained FFHQ generator and write the results into a newly created directory under `results`. The exact behavior can be changed by uncommenting or editing specific lines in [run_metrics.py](./run_metrics.py).\n\nExpected evaluation time and results for the pre-trained FFHQ generator using one Tesla V100 GPU:\n\n| Metric    | Time      | Result   | Description\n| :-----    | :---      | :-----   | :----------\n| fid50k    | 16 min    | 4.4159   | Fr&eacute;chet Inception Distance using 50,000 images.\n| ppl_zfull | 55 min    | 664.8854 | Perceptual Path Length for full paths in *Z*.\n| ppl_wfull | 55 min    | 233.3059 | Perceptual Path Length for full paths in *W*.\n| ppl_zend  | 55 min    | 666.1057 | Perceptual Path Length for path endpoints in *Z*.\n| ppl_wend  | 55 min    | 197.2266 | Perceptual Path Length for path endpoints in *W*.\n| ls        | 10 hours  | z: 165.0106<br>w: 3.7447 | Linear Separability in *Z* and *W*.\n\nPlease note that the exact results may vary from run to run due to the non-deterministic nature of TensorFlow.\n\n## Acknowledgements\n\nWe thank Jaakko Lehtinen, David Luebke, and Tuomas Kynk&auml;&auml;nniemi for in-depth discussions and helpful comments; Janne Hellsten, Tero Kuosmanen, and Pekka J&auml;nis for compute infrastructure and help with the code release.\n"
  },
  {
    "path": "config.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\r\n#\r\n# This work is licensed under the Creative Commons Attribution-NonCommercial\r\n# 4.0 International License. To view a copy of this license, visit\r\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\r\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\r\n\r\n\"\"\"Global configuration.\"\"\"\r\n\r\n#----------------------------------------------------------------------------\r\n# Paths.\r\n\r\nresult_dir = 'results'\r\ndata_dir = 'datasets'\r\ncache_dir = 'cache'\r\nrun_dir_ignore = ['results', 'datasets', 'cache']\r\n\r\n#----------------------------------------------------------------------------\r\n"
  },
  {
    "path": "dataset_tool.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.\"\"\"\n\n# pylint: disable=too-many-lines\nimport os\nimport sys\nimport glob\nimport argparse\nimport threading\nimport six.moves.queue as Queue # pylint: disable=import-error\nimport traceback\nimport numpy as np\nimport tensorflow as tf\nimport PIL.Image\nimport dnnlib.tflib as tflib\n\nfrom training import dataset\n\n#----------------------------------------------------------------------------\n\ndef error(msg):\n    print('Error: ' + msg)\n    exit(1)\n\n#----------------------------------------------------------------------------\n\nclass TFRecordExporter:\n    def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10):\n        self.tfrecord_dir       = tfrecord_dir\n        self.tfr_prefix         = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))\n        self.expected_images    = expected_images\n        self.cur_images         = 0\n        self.shape              = None\n        self.resolution_log2    = None\n        self.tfr_writers        = []\n        self.print_progress     = print_progress\n        self.progress_interval  = progress_interval\n\n        if self.print_progress:\n            print('Creating dataset \"%s\"' % tfrecord_dir)\n        if not os.path.isdir(self.tfrecord_dir):\n            os.makedirs(self.tfrecord_dir)\n        assert os.path.isdir(self.tfrecord_dir)\n\n    def close(self):\n        if self.print_progress:\n            print('%-40s\\r' % 'Flushing data...', end='', flush=True)\n        for tfr_writer in self.tfr_writers:\n            tfr_writer.close()\n        self.tfr_writers = []\n        if self.print_progress:\n            print('%-40s\\r' % '', end='', flush=True)\n            print('Added %d images.' % self.cur_images)\n\n    def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order.\n        order = np.arange(self.expected_images)\n        np.random.RandomState(123).shuffle(order)\n        return order\n\n    def add_image(self, img):\n        if self.print_progress and self.cur_images % self.progress_interval == 0:\n            print('%d / %d\\r' % (self.cur_images, self.expected_images), end='', flush=True)\n        if self.shape is None:\n            self.shape = img.shape\n            self.resolution_log2 = int(np.log2(self.shape[1]))\n            assert self.shape[0] in [1, 3]\n            assert self.shape[1] == self.shape[2]\n            assert self.shape[1] == 2**self.resolution_log2\n            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)\n            for lod in range(self.resolution_log2 - 1):\n                tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod)\n                self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))\n        assert img.shape == self.shape\n        for lod, tfr_writer in enumerate(self.tfr_writers):\n            if lod:\n                img = img.astype(np.float32)\n                img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25\n            quant = np.rint(img).clip(0, 255).astype(np.uint8)\n            ex = tf.train.Example(features=tf.train.Features(feature={\n                'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),\n                'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))\n            tfr_writer.write(ex.SerializeToString())\n        self.cur_images += 1\n\n    def add_labels(self, labels):\n        if self.print_progress:\n            print('%-40s\\r' % 'Saving labels...', end='', flush=True)\n        assert labels.shape[0] == self.cur_images\n        with open(self.tfr_prefix + '-rxx.labels', 'wb') as f:\n            np.save(f, labels.astype(np.float32))\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, *args):\n        self.close()\n\n#----------------------------------------------------------------------------\n\nclass ExceptionInfo(object):\n    def __init__(self):\n        self.value = sys.exc_info()[1]\n        self.traceback = traceback.format_exc()\n\n#----------------------------------------------------------------------------\n\nclass WorkerThread(threading.Thread):\n    def __init__(self, task_queue):\n        threading.Thread.__init__(self)\n        self.task_queue = task_queue\n\n    def run(self):\n        while True:\n            func, args, result_queue = self.task_queue.get()\n            if func is None:\n                break\n            try:\n                result = func(*args)\n            except:\n                result = ExceptionInfo()\n            result_queue.put((result, args))\n\n#----------------------------------------------------------------------------\n\nclass ThreadPool(object):\n    def __init__(self, num_threads):\n        assert num_threads >= 1\n        self.task_queue = Queue.Queue()\n        self.result_queues = dict()\n        self.num_threads = num_threads\n        for _idx in range(self.num_threads):\n            thread = WorkerThread(self.task_queue)\n            thread.daemon = True\n            thread.start()\n\n    def add_task(self, func, args=()):\n        assert hasattr(func, '__call__') # must be a function\n        if func not in self.result_queues:\n            self.result_queues[func] = Queue.Queue()\n        self.task_queue.put((func, args, self.result_queues[func]))\n\n    def get_result(self, func): # returns (result, args)\n        result, args = self.result_queues[func].get()\n        if isinstance(result, ExceptionInfo):\n            print('\\n\\nWorker thread caught an exception:\\n' + result.traceback)\n            raise result.value\n        return result, args\n\n    def finish(self):\n        for _idx in range(self.num_threads):\n            self.task_queue.put((None, (), None))\n\n    def __enter__(self): # for 'with' statement\n        return self\n\n    def __exit__(self, *excinfo):\n        self.finish()\n\n    def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None):\n        if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4\n        assert max_items_in_flight >= 1\n        results = []\n        retire_idx = [0]\n\n        def task_func(prepared, _idx):\n            return process_func(prepared)\n\n        def retire_result():\n            processed, (_prepared, idx) = self.get_result(task_func)\n            results[idx] = processed\n            while retire_idx[0] < len(results) and results[retire_idx[0]] is not None:\n                yield post_func(results[retire_idx[0]])\n                results[retire_idx[0]] = None\n                retire_idx[0] += 1\n\n        for idx, item in enumerate(item_iterator):\n            prepared = pre_func(item)\n            results.append(None)\n            self.add_task(func=task_func, args=(prepared, idx))\n            while retire_idx[0] < idx - max_items_in_flight + 2:\n                for res in retire_result(): yield res\n        while retire_idx[0] < len(results):\n            for res in retire_result(): yield res\n\n#----------------------------------------------------------------------------\n\ndef display(tfrecord_dir):\n    print('Loading dataset \"%s\"' % tfrecord_dir)\n    tflib.init_tf({'gpu_options.allow_growth': True})\n    dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0)\n    tflib.init_uninitialized_vars()\n    import cv2  # pip install opencv-python\n\n    idx = 0\n    while True:\n        try:\n            images, labels = dset.get_minibatch_np(1)\n        except tf.errors.OutOfRangeError:\n            break\n        if idx == 0:\n            print('Displaying images')\n            cv2.namedWindow('dataset_tool')\n            print('Press SPACE or ENTER to advance, ESC to exit')\n        print('\\nidx = %-8d\\nlabel = %s' % (idx, labels[0].tolist()))\n        cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR\n        idx += 1\n        if cv2.waitKey() == 27:\n            break\n    print('\\nDisplayed %d images.' % idx)\n\n#----------------------------------------------------------------------------\n\ndef extract(tfrecord_dir, output_dir):\n    print('Loading dataset \"%s\"' % tfrecord_dir)\n    tflib.init_tf({'gpu_options.allow_growth': True})\n    dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0)\n    tflib.init_uninitialized_vars()\n\n    print('Extracting images to \"%s\"' % output_dir)\n    if not os.path.isdir(output_dir):\n        os.makedirs(output_dir)\n    idx = 0\n    while True:\n        if idx % 10 == 0:\n            print('%d\\r' % idx, end='', flush=True)\n        try:\n            images, _labels = dset.get_minibatch_np(1)\n        except tf.errors.OutOfRangeError:\n            break\n        if images.shape[1] == 1:\n            img = PIL.Image.fromarray(images[0][0], 'L')\n        else:\n            img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB')\n        img.save(os.path.join(output_dir, 'img%08d.png' % idx))\n        idx += 1\n    print('Extracted %d images.' % idx)\n\n#----------------------------------------------------------------------------\n\ndef compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):\n    max_label_size = 0 if ignore_labels else 'full'\n    print('Loading dataset \"%s\"' % tfrecord_dir_a)\n    tflib.init_tf({'gpu_options.allow_growth': True})\n    dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0)\n    print('Loading dataset \"%s\"' % tfrecord_dir_b)\n    dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0)\n    tflib.init_uninitialized_vars()\n\n    print('Comparing datasets')\n    idx = 0\n    identical_images = 0\n    identical_labels = 0\n    while True:\n        if idx % 100 == 0:\n            print('%d\\r' % idx, end='', flush=True)\n        try:\n            images_a, labels_a = dset_a.get_minibatch_np(1)\n        except tf.errors.OutOfRangeError:\n            images_a, labels_a = None, None\n        try:\n            images_b, labels_b = dset_b.get_minibatch_np(1)\n        except tf.errors.OutOfRangeError:\n            images_b, labels_b = None, None\n        if images_a is None or images_b is None:\n            if images_a is not None or images_b is not None:\n                print('Datasets contain different number of images')\n            break\n        if images_a.shape == images_b.shape and np.all(images_a == images_b):\n            identical_images += 1\n        else:\n            print('Image %d is different' % idx)\n        if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b):\n            identical_labels += 1\n        else:\n            print('Label %d is different' % idx)\n        idx += 1\n    print('Identical images: %d / %d' % (identical_images, idx))\n    if not ignore_labels:\n        print('Identical labels: %d / %d' % (identical_labels, idx))\n\n#----------------------------------------------------------------------------\n\ndef create_mnist(tfrecord_dir, mnist_dir):\n    print('Loading MNIST from \"%s\"' % mnist_dir)\n    import gzip\n    with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:\n        images = np.frombuffer(file.read(), np.uint8, offset=16)\n    with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file:\n        labels = np.frombuffer(file.read(), np.uint8, offset=8)\n    images = images.reshape(-1, 1, 28, 28)\n    images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0)\n    assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8\n    assert labels.shape == (60000,) and labels.dtype == np.uint8\n    assert np.min(images) == 0 and np.max(images) == 255\n    assert np.min(labels) == 0 and np.max(labels) == 9\n    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)\n    onehot[np.arange(labels.size), labels] = 1.0\n\n    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:\n        order = tfr.choose_shuffled_order()\n        for idx in range(order.size):\n            tfr.add_image(images[order[idx]])\n        tfr.add_labels(onehot[order])\n\n#----------------------------------------------------------------------------\n\ndef create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123):\n    print('Loading MNIST from \"%s\"' % mnist_dir)\n    import gzip\n    with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:\n        images = np.frombuffer(file.read(), np.uint8, offset=16)\n    images = images.reshape(-1, 28, 28)\n    images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)\n    assert images.shape == (60000, 32, 32) and images.dtype == np.uint8\n    assert np.min(images) == 0 and np.max(images) == 255\n\n    with TFRecordExporter(tfrecord_dir, num_images) as tfr:\n        rnd = np.random.RandomState(random_seed)\n        for _idx in range(num_images):\n            tfr.add_image(images[rnd.randint(images.shape[0], size=3)])\n\n#----------------------------------------------------------------------------\n\ndef create_cifar10(tfrecord_dir, cifar10_dir):\n    print('Loading CIFAR-10 from \"%s\"' % cifar10_dir)\n    import pickle\n    images = []\n    labels = []\n    for batch in range(1, 6):\n        with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file:\n            data = pickle.load(file, encoding='latin1')\n        images.append(data['data'].reshape(-1, 3, 32, 32))\n        labels.append(data['labels'])\n    images = np.concatenate(images)\n    labels = np.concatenate(labels)\n    assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8\n    assert labels.shape == (50000,) and labels.dtype == np.int32\n    assert np.min(images) == 0 and np.max(images) == 255\n    assert np.min(labels) == 0 and np.max(labels) == 9\n    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)\n    onehot[np.arange(labels.size), labels] = 1.0\n\n    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:\n        order = tfr.choose_shuffled_order()\n        for idx in range(order.size):\n            tfr.add_image(images[order[idx]])\n        tfr.add_labels(onehot[order])\n\n#----------------------------------------------------------------------------\n\ndef create_cifar100(tfrecord_dir, cifar100_dir):\n    print('Loading CIFAR-100 from \"%s\"' % cifar100_dir)\n    import pickle\n    with open(os.path.join(cifar100_dir, 'train'), 'rb') as file:\n        data = pickle.load(file, encoding='latin1')\n    images = data['data'].reshape(-1, 3, 32, 32)\n    labels = np.array(data['fine_labels'])\n    assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8\n    assert labels.shape == (50000,) and labels.dtype == np.int32\n    assert np.min(images) == 0 and np.max(images) == 255\n    assert np.min(labels) == 0 and np.max(labels) == 99\n    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)\n    onehot[np.arange(labels.size), labels] = 1.0\n\n    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:\n        order = tfr.choose_shuffled_order()\n        for idx in range(order.size):\n            tfr.add_image(images[order[idx]])\n        tfr.add_labels(onehot[order])\n\n#----------------------------------------------------------------------------\n\ndef create_svhn(tfrecord_dir, svhn_dir):\n    print('Loading SVHN from \"%s\"' % svhn_dir)\n    import pickle\n    images = []\n    labels = []\n    for batch in range(1, 4):\n        with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file:\n            data = pickle.load(file, encoding='latin1')\n        images.append(data[0])\n        labels.append(data[1])\n    images = np.concatenate(images)\n    labels = np.concatenate(labels)\n    assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8\n    assert labels.shape == (73257,) and labels.dtype == np.uint8\n    assert np.min(images) == 0 and np.max(images) == 255\n    assert np.min(labels) == 0 and np.max(labels) == 9\n    onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)\n    onehot[np.arange(labels.size), labels] = 1.0\n\n    with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:\n        order = tfr.choose_shuffled_order()\n        for idx in range(order.size):\n            tfr.add_image(images[order[idx]])\n        tfr.add_labels(onehot[order])\n\n#----------------------------------------------------------------------------\n\ndef create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None):\n    print('Loading LSUN dataset from \"%s\"' % lmdb_dir)\n    import lmdb # pip install lmdb # pylint: disable=import-error\n    import cv2 # pip install opencv-python\n    import io\n    with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:\n        total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter\n        if max_images is None:\n            max_images = total_images\n        with TFRecordExporter(tfrecord_dir, max_images) as tfr:\n            for _idx, (_key, value) in enumerate(txn.cursor()):\n                try:\n                    try:\n                        img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)\n                        if img is None:\n                            raise IOError('cv2.imdecode failed')\n                        img = img[:, :, ::-1] # BGR => RGB\n                    except IOError:\n                        img = np.asarray(PIL.Image.open(io.BytesIO(value)))\n                    crop = np.min(img.shape[:2])\n                    img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]\n                    img = PIL.Image.fromarray(img, 'RGB')\n                    img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS)\n                    img = np.asarray(img)\n                    img = img.transpose([2, 0, 1]) # HWC => CHW\n                    tfr.add_image(img)\n                except:\n                    print(sys.exc_info()[1])\n                if tfr.cur_images == max_images:\n                    break\n\n#----------------------------------------------------------------------------\n\ndef create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None):\n    assert width == 2 ** int(np.round(np.log2(width)))\n    assert height <= width\n    print('Loading LSUN dataset from \"%s\"' % lmdb_dir)\n    import lmdb # pip install lmdb # pylint: disable=import-error\n    import cv2 # pip install opencv-python\n    import io\n    with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:\n        total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter\n        if max_images is None:\n            max_images = total_images\n        with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr:\n            for idx, (_key, value) in enumerate(txn.cursor()):\n                try:\n                    try:\n                        img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)\n                        if img is None:\n                            raise IOError('cv2.imdecode failed')\n                        img = img[:, :, ::-1] # BGR => RGB\n                    except IOError:\n                        img = np.asarray(PIL.Image.open(io.BytesIO(value)))\n\n                    ch = int(np.round(width * img.shape[0] / img.shape[1]))\n                    if img.shape[1] < width or ch < height:\n                        continue\n\n                    img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]\n                    img = PIL.Image.fromarray(img, 'RGB')\n                    img = img.resize((width, height), PIL.Image.ANTIALIAS)\n                    img = np.asarray(img)\n                    img = img.transpose([2, 0, 1]) # HWC => CHW\n\n                    canvas = np.zeros([3, width, width], dtype=np.uint8)\n                    canvas[:, (width - height) // 2 : (width + height) // 2] = img\n                    tfr.add_image(canvas)\n                    print('\\r%d / %d => %d ' % (idx + 1, total_images, tfr.cur_images), end='')\n\n                except:\n                    print(sys.exc_info()[1])\n                if tfr.cur_images == max_images:\n                    break\n    print()\n\n#----------------------------------------------------------------------------\n\ndef create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121):\n    print('Loading CelebA from \"%s\"' % celeba_dir)\n    glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png')\n    image_filenames = sorted(glob.glob(glob_pattern))\n    expected_images = 202599\n    if len(image_filenames) != expected_images:\n        error('Expected to find %d images' % expected_images)\n\n    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:\n        order = tfr.choose_shuffled_order()\n        for idx in range(order.size):\n            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))\n            assert img.shape == (218, 178, 3)\n            img = img[cy - 64 : cy + 64, cx - 64 : cx + 64]\n            img = img.transpose(2, 0, 1) # HWC => CHW\n            tfr.add_image(img)\n\n#----------------------------------------------------------------------------\n\ndef create_from_images(tfrecord_dir, image_dir, shuffle):\n    print('Loading images from \"%s\"' % image_dir)\n    image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))\n    if len(image_filenames) == 0:\n        error('No input images found')\n\n    img = np.asarray(PIL.Image.open(image_filenames[0]))\n    resolution = img.shape[0]\n    channels = img.shape[2] if img.ndim == 3 else 1\n    if img.shape[1] != resolution:\n        error('Input images must have the same width and height')\n    if resolution != 2 ** int(np.floor(np.log2(resolution))):\n        error('Input image resolution must be a power-of-two')\n    if channels not in [1, 3]:\n        error('Input images must be stored as RGB or grayscale')\n\n    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:\n        order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))\n        for idx in range(order.size):\n            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))\n            if channels == 1:\n                img = img[np.newaxis, :, :] # HW => CHW\n            else:\n                img = img.transpose([2, 0, 1]) # HWC => CHW\n            tfr.add_image(img)\n\n#----------------------------------------------------------------------------\n\ndef create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle):\n    print('Loading HDF5 archive from \"%s\"' % hdf5_filename)\n    import h5py # conda install h5py\n    with h5py.File(hdf5_filename, 'r') as hdf5_file:\n        hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3])\n        with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr:\n            order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0])\n            for idx in range(order.size):\n                tfr.add_image(hdf5_data[order[idx]])\n            npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy'\n            if os.path.isfile(npy_filename):\n                tfr.add_labels(np.load(npy_filename)[order])\n\n#----------------------------------------------------------------------------\n\ndef execute_cmdline(argv):\n    prog = argv[0]\n    parser = argparse.ArgumentParser(\n        prog        = prog,\n        description = 'Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.',\n        epilog      = 'Type \"%s <command> -h\" for more information.' % prog)\n\n    subparsers = parser.add_subparsers(dest='command')\n    subparsers.required = True\n    def add_command(cmd, desc, example=None):\n        epilog = 'Example: %s %s' % (prog, example) if example is not None else None\n        return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)\n\n    p = add_command(    'display',          'Display images in dataset.',\n                                            'display datasets/mnist')\n    p.add_argument(     'tfrecord_dir',     help='Directory containing dataset')\n\n    p = add_command(    'extract',          'Extract images from dataset.',\n                                            'extract datasets/mnist mnist-images')\n    p.add_argument(     'tfrecord_dir',     help='Directory containing dataset')\n    p.add_argument(     'output_dir',       help='Directory to extract the images into')\n\n    p = add_command(    'compare',          'Compare two datasets.',\n                                            'compare datasets/mydataset datasets/mnist')\n    p.add_argument(     'tfrecord_dir_a',   help='Directory containing first dataset')\n    p.add_argument(     'tfrecord_dir_b',   help='Directory containing second dataset')\n    p.add_argument(     '--ignore_labels',  help='Ignore labels (default: 0)', type=int, default=0)\n\n    p = add_command(    'create_mnist',     'Create dataset for MNIST.',\n                                            'create_mnist datasets/mnist ~/downloads/mnist')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'mnist_dir',        help='Directory containing MNIST')\n\n    p = add_command(    'create_mnistrgb',  'Create dataset for MNIST-RGB.',\n                                            'create_mnistrgb datasets/mnistrgb ~/downloads/mnist')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'mnist_dir',        help='Directory containing MNIST')\n    p.add_argument(     '--num_images',     help='Number of composite images to create (default: 1000000)', type=int, default=1000000)\n    p.add_argument(     '--random_seed',    help='Random seed (default: 123)', type=int, default=123)\n\n    p = add_command(    'create_cifar10',   'Create dataset for CIFAR-10.',\n                                            'create_cifar10 datasets/cifar10 ~/downloads/cifar10')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'cifar10_dir',      help='Directory containing CIFAR-10')\n\n    p = add_command(    'create_cifar100',  'Create dataset for CIFAR-100.',\n                                            'create_cifar100 datasets/cifar100 ~/downloads/cifar100')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'cifar100_dir',     help='Directory containing CIFAR-100')\n\n    p = add_command(    'create_svhn',      'Create dataset for SVHN.',\n                                            'create_svhn datasets/svhn ~/downloads/svhn')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'svhn_dir',         help='Directory containing SVHN')\n\n    p = add_command(    'create_lsun',      'Create dataset for single LSUN category.',\n                                            'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'lmdb_dir',         help='Directory containing LMDB database')\n    p.add_argument(     '--resolution',     help='Output resolution (default: 256)', type=int, default=256)\n    p.add_argument(     '--max_images',     help='Maximum number of images (default: none)', type=int, default=None)\n\n    p = add_command(    'create_lsun_wide', 'Create LSUN dataset with non-square aspect ratio.',\n                                            'create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'lmdb_dir',         help='Directory containing LMDB database')\n    p.add_argument(     '--width',          help='Output width (default: 512)', type=int, default=512)\n    p.add_argument(     '--height',         help='Output height (default: 384)', type=int, default=384)\n    p.add_argument(     '--max_images',     help='Maximum number of images (default: none)', type=int, default=None)\n\n    p = add_command(    'create_celeba',    'Create dataset for CelebA.',\n                                            'create_celeba datasets/celeba ~/downloads/celeba')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'celeba_dir',       help='Directory containing CelebA')\n    p.add_argument(     '--cx',             help='Center X coordinate (default: 89)', type=int, default=89)\n    p.add_argument(     '--cy',             help='Center Y coordinate (default: 121)', type=int, default=121)\n\n    p = add_command(    'create_from_images', 'Create dataset from a directory full of images.',\n                                            'create_from_images datasets/mydataset myimagedir')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'image_dir',        help='Directory containing the images')\n    p.add_argument(     '--shuffle',        help='Randomize image order (default: 1)', type=int, default=1)\n\n    p = add_command(    'create_from_hdf5', 'Create dataset from legacy HDF5 archive.',\n                                            'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5')\n    p.add_argument(     'tfrecord_dir',     help='New dataset directory to be created')\n    p.add_argument(     'hdf5_filename',    help='HDF5 archive containing the images')\n    p.add_argument(     '--shuffle',        help='Randomize image order (default: 1)', type=int, default=1)\n\n    args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h'])\n    func = globals()[args.command]\n    del args.command\n    func(**vars(args))\n\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    execute_cmdline(sys.argv)\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "dnnlib/__init__.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\nfrom . import submission\n\nfrom .submission.run_context import RunContext\n\nfrom .submission.submit import SubmitTarget\nfrom .submission.submit import PathType\nfrom .submission.submit import SubmitConfig\nfrom .submission.submit import get_path_from_template\nfrom .submission.submit import submit_run\n\nfrom .util import EasyDict\n\nsubmit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.\n"
  },
  {
    "path": "dnnlib/submission/__init__.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\nfrom . import run_context\nfrom . import submit\n"
  },
  {
    "path": "dnnlib/submission/_internal/run.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Helper for launching run functions in computing clusters.\n\nDuring the submit process, this file is copied to the appropriate run dir.\nWhen the job is launched in the cluster, this module is the first thing that\nis run inside the docker container.\n\"\"\"\n\nimport os\nimport pickle\nimport sys\n\n# PYTHONPATH should have been set so that the run_dir/src is in it\nimport dnnlib\n\ndef main():\n    if not len(sys.argv) >= 4:\n        raise RuntimeError(\"This script needs three arguments: run_dir, task_name and host_name!\")\n\n    run_dir = str(sys.argv[1])\n    task_name = str(sys.argv[2])\n    host_name = str(sys.argv[3])\n\n    submit_config_path = os.path.join(run_dir, \"submit_config.pkl\")\n\n    # SubmitConfig should have been pickled to the run dir\n    if not os.path.exists(submit_config_path):\n        raise RuntimeError(\"SubmitConfig pickle file does not exist!\")\n\n    submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, \"rb\"))\n    dnnlib.submission.submit.set_user_name_override(submit_config.user_name)\n\n    submit_config.task_name = task_name\n    submit_config.host_name = host_name\n\n    dnnlib.submission.submit.run_wrapper(submit_config)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "dnnlib/submission/run_context.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Helpers for managing the run/training loop.\"\"\"\n\nimport datetime\nimport json\nimport os\nimport pprint\nimport time\nimport types\n\nfrom typing import Any\n\nfrom . import submit\n\n\nclass RunContext(object):\n    \"\"\"Helper class for managing the run/training loop.\n\n    The context will hide the implementation details of a basic run/training loop.\n    It will set things up properly, tell if run should be stopped, and then cleans up.\n    User should call update periodically and use should_stop to determine if run should be stopped.\n\n    Args:\n        submit_config: The SubmitConfig that is used for the current run.\n        config_module: The whole config module that is used for the current run.\n        max_epoch: Optional cached value for the max_epoch variable used in update.\n    \"\"\"\n\n    def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):\n        self.submit_config = submit_config\n        self.should_stop_flag = False\n        self.has_closed = False\n        self.start_time = time.time()\n        self.last_update_time = time.time()\n        self.last_update_interval = 0.0\n        self.max_epoch = max_epoch\n\n        # pretty print the all the relevant content of the config module to a text file\n        if config_module is not None:\n            with open(os.path.join(submit_config.run_dir, \"config.txt\"), \"w\") as f:\n                filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith(\"_\") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}\n                pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)\n\n        # write out details about the run to a text file\n        self.run_txt_data = {\"task_name\": submit_config.task_name, \"host_name\": submit_config.host_name, \"start_time\": datetime.datetime.now().isoformat(sep=\" \")}\n        with open(os.path.join(submit_config.run_dir, \"run.txt\"), \"w\") as f:\n            pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)\n\n    def __enter__(self) -> \"RunContext\":\n        return self\n\n    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:\n        self.close()\n\n    def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:\n        \"\"\"Do general housekeeping and keep the state of the context up-to-date.\n        Should be called often enough but not in a tight loop.\"\"\"\n        assert not self.has_closed\n\n        self.last_update_interval = time.time() - self.last_update_time\n        self.last_update_time = time.time()\n\n        if os.path.exists(os.path.join(self.submit_config.run_dir, \"abort.txt\")):\n            self.should_stop_flag = True\n\n        max_epoch_val = self.max_epoch if max_epoch is None else max_epoch\n\n    def should_stop(self) -> bool:\n        \"\"\"Tell whether a stopping condition has been triggered one way or another.\"\"\"\n        return self.should_stop_flag\n\n    def get_time_since_start(self) -> float:\n        \"\"\"How much time has passed since the creation of the context.\"\"\"\n        return time.time() - self.start_time\n\n    def get_time_since_last_update(self) -> float:\n        \"\"\"How much time has passed since the last call to update.\"\"\"\n        return time.time() - self.last_update_time\n\n    def get_last_update_interval(self) -> float:\n        \"\"\"How much time passed between the previous two calls to update.\"\"\"\n        return self.last_update_interval\n\n    def close(self) -> None:\n        \"\"\"Close the context and clean up.\n        Should only be called once.\"\"\"\n        if not self.has_closed:\n            # update the run.txt with stopping time\n            self.run_txt_data[\"stop_time\"] = datetime.datetime.now().isoformat(sep=\" \")\n            with open(os.path.join(self.submit_config.run_dir, \"run.txt\"), \"w\") as f:\n                pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)\n\n            self.has_closed = True\n"
  },
  {
    "path": "dnnlib/submission/submit.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Submit a function to be run either locally or in a computing cluster.\"\"\"\n\nimport copy\nimport io\nimport os\nimport pathlib\nimport pickle\nimport platform\nimport pprint\nimport re\nimport shutil\nimport time\nimport traceback\n\nimport zipfile\n\nfrom enum import Enum\n\nfrom .. import util\nfrom ..util import EasyDict\n\n\nclass SubmitTarget(Enum):\n    \"\"\"The target where the function should be run.\n\n    LOCAL: Run it locally.\n    \"\"\"\n    LOCAL = 1\n\n\nclass PathType(Enum):\n    \"\"\"Determines in which format should a path be formatted.\n\n    WINDOWS: Format with Windows style.\n    LINUX: Format with Linux/Posix style.\n    AUTO: Use current OS type to select either WINDOWS or LINUX.\n    \"\"\"\n    WINDOWS = 1\n    LINUX = 2\n    AUTO = 3\n\n\n_user_name_override = None\n\n\nclass SubmitConfig(util.EasyDict):\n    \"\"\"Strongly typed config dict needed to submit runs.\n\n    Attributes:\n        run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.\n        run_desc: Description of the run. Will be used in the run dir and task name.\n        run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.\n        run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.\n        submit_target: Submit target enum value. Used to select where the run is actually launched.\n        num_gpus: Number of GPUs used/requested for the run.\n        print_info: Whether to print debug information when submitting.\n        ask_confirmation: Whether to ask a confirmation before submitting.\n        run_id: Automatically populated value during submit.\n        run_name: Automatically populated value during submit.\n        run_dir: Automatically populated value during submit.\n        run_func_name: Automatically populated value during submit.\n        run_func_kwargs: Automatically populated value during submit.\n        user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.\n        task_name: Automatically populated value during submit.\n        host_name: Automatically populated value during submit.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n\n        # run (set these)\n        self.run_dir_root = \"\"  # should always be passed through get_path_from_template\n        self.run_desc = \"\"\n        self.run_dir_ignore = [\"__pycache__\", \"*.pyproj\", \"*.sln\", \"*.suo\", \".cache\", \".idea\", \".vs\", \".vscode\"]\n        self.run_dir_extra_files = None\n\n        # submit (set these)\n        self.submit_target = SubmitTarget.LOCAL\n        self.num_gpus = 1\n        self.print_info = False\n        self.ask_confirmation = False\n\n        # (automatically populated)\n        self.run_id = None\n        self.run_name = None\n        self.run_dir = None\n        self.run_func_name = None\n        self.run_func_kwargs = None\n        self.user_name = None\n        self.task_name = None\n        self.host_name = \"localhost\"\n\n\ndef get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:\n    \"\"\"Replace tags in the given path template and return either Windows or Linux formatted path.\"\"\"\n    # automatically select path type depending on running OS\n    if path_type == PathType.AUTO:\n        if platform.system() == \"Windows\":\n            path_type = PathType.WINDOWS\n        elif platform.system() == \"Linux\":\n            path_type = PathType.LINUX\n        else:\n            raise RuntimeError(\"Unknown platform\")\n\n    path_template = path_template.replace(\"<USERNAME>\", get_user_name())\n\n    # return correctly formatted path\n    if path_type == PathType.WINDOWS:\n        return str(pathlib.PureWindowsPath(path_template))\n    elif path_type == PathType.LINUX:\n        return str(pathlib.PurePosixPath(path_template))\n    else:\n        raise RuntimeError(\"Unknown platform\")\n\n\ndef get_template_from_path(path: str) -> str:\n    \"\"\"Convert a normal path back to its template representation.\"\"\"\n    # replace all path parts with the template tags\n    path = path.replace(\"\\\\\", \"/\")\n    return path\n\n\ndef convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:\n    \"\"\"Convert a normal path to template and the convert it back to a normal path with given path type.\"\"\"\n    path_template = get_template_from_path(path)\n    path = get_path_from_template(path_template, path_type)\n    return path\n\n\ndef set_user_name_override(name: str) -> None:\n    \"\"\"Set the global username override value.\"\"\"\n    global _user_name_override\n    _user_name_override = name\n\n\ndef get_user_name():\n    \"\"\"Get the current user name.\"\"\"\n    if _user_name_override is not None:\n        return _user_name_override\n    elif platform.system() == \"Windows\":\n        return os.getlogin()\n    elif platform.system() == \"Linux\":\n        try:\n            import pwd # pylint: disable=import-error\n            return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member\n        except:\n            return \"unknown\"\n    else:\n        raise RuntimeError(\"Unknown platform\")\n\n\ndef _create_run_dir_local(submit_config: SubmitConfig) -> str:\n    \"\"\"Create a new run dir with increasing ID number at the start.\"\"\"\n    run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)\n\n    if not os.path.exists(run_dir_root):\n        print(\"Creating the run dir root: {}\".format(run_dir_root))\n        os.makedirs(run_dir_root)\n\n    submit_config.run_id = _get_next_run_id_local(run_dir_root)\n    submit_config.run_name = \"{0:05d}-{1}\".format(submit_config.run_id, submit_config.run_desc)\n    run_dir = os.path.join(run_dir_root, submit_config.run_name)\n\n    if os.path.exists(run_dir):\n        raise RuntimeError(\"The run dir already exists! ({0})\".format(run_dir))\n\n    print(\"Creating the run dir: {}\".format(run_dir))\n    os.makedirs(run_dir)\n\n    return run_dir\n\n\ndef _get_next_run_id_local(run_dir_root: str) -> int:\n    \"\"\"Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names.\"\"\"\n    dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]\n    r = re.compile(\"^\\\\d+\")  # match one or more digits at the start of the string\n    run_id = 0\n\n    for dir_name in dir_names:\n        m = r.match(dir_name)\n\n        if m is not None:\n            i = int(m.group())\n            run_id = max(run_id, i + 1)\n\n    return run_id\n\n\ndef _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:\n    \"\"\"Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable.\"\"\"\n    print(\"Copying files to the run dir\")\n    files = []\n\n    run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)\n    assert '.' in submit_config.run_func_name\n    for _idx in range(submit_config.run_func_name.count('.') - 1):\n        run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)\n    files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)\n\n    dnnlib_module_dir_path = util.get_module_dir_by_obj_name(\"dnnlib\")\n    files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)\n\n    if submit_config.run_dir_extra_files is not None:\n        files += submit_config.run_dir_extra_files\n\n    files = [(f[0], os.path.join(run_dir, \"src\", f[1])) for f in files]\n    files += [(os.path.join(dnnlib_module_dir_path, \"submission\", \"_internal\", \"run.py\"), os.path.join(run_dir, \"run.py\"))]\n\n    util.copy_files_and_create_dirs(files)\n\n    pickle.dump(submit_config, open(os.path.join(run_dir, \"submit_config.pkl\"), \"wb\"))\n\n    with open(os.path.join(run_dir, \"submit_config.txt\"), \"w\") as f:\n        pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)\n\n\ndef run_wrapper(submit_config: SubmitConfig) -> None:\n    \"\"\"Wrap the actual run function call for handling logging, exceptions, typing, etc.\"\"\"\n    is_local = submit_config.submit_target == SubmitTarget.LOCAL\n\n    checker = None\n\n    # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing\n    if is_local:\n        logger = util.Logger(file_name=os.path.join(submit_config.run_dir, \"log.txt\"), file_mode=\"w\", should_flush=True)\n    else:  # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)\n        logger = util.Logger(file_name=None, should_flush=True)\n\n    import dnnlib\n    dnnlib.submit_config = submit_config\n\n    try:\n        print(\"dnnlib: Running {0}() on {1}...\".format(submit_config.run_func_name, submit_config.host_name))\n        start_time = time.time()\n        util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)\n        print(\"dnnlib: Finished {0}() in {1}.\".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))\n    except:\n        if is_local:\n            raise\n        else:\n            traceback.print_exc()\n\n            log_src = os.path.join(submit_config.run_dir, \"log.txt\")\n            log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), \"{0}-error.txt\".format(submit_config.run_name))\n            shutil.copyfile(log_src, log_dst)\n    finally:\n        open(os.path.join(submit_config.run_dir, \"_finished.txt\"), \"w\").close()\n\n    dnnlib.submit_config = None\n    logger.close()\n\n    if checker is not None:\n        checker.stop()\n\n\ndef submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:\n    \"\"\"Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place.\"\"\"\n    submit_config = copy.copy(submit_config)\n\n    if submit_config.user_name is None:\n        submit_config.user_name = get_user_name()\n\n    submit_config.run_func_name = run_func_name\n    submit_config.run_func_kwargs = run_func_kwargs\n\n    assert submit_config.submit_target == SubmitTarget.LOCAL\n    if submit_config.submit_target in {SubmitTarget.LOCAL}:\n        run_dir = _create_run_dir_local(submit_config)\n\n        submit_config.task_name = \"{0}-{1:05d}-{2}\".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)\n        submit_config.run_dir = run_dir\n        _populate_run_dir(run_dir, submit_config)\n\n    if submit_config.print_info:\n        print(\"\\nSubmit config:\\n\")\n        pprint.pprint(submit_config, indent=4, width=200, compact=False)\n        print()\n\n    if submit_config.ask_confirmation:\n        if not util.ask_yes_no(\"Continue submitting the job?\"):\n            return\n\n    run_wrapper(submit_config)\n"
  },
  {
    "path": "dnnlib/tflib/__init__.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\nfrom . import autosummary\nfrom . import network\nfrom . import optimizer\nfrom . import tfutil\n\nfrom .tfutil import *\nfrom .network import Network\n\nfrom .optimizer import Optimizer\n"
  },
  {
    "path": "dnnlib/tflib/autosummary.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Helper for adding automatically tracked values to Tensorboard.\n\nAutosummary creates an identity op that internally keeps track of the input\nvalues and automatically shows up in TensorBoard. The reported value\nrepresents an average over input components. The average is accumulated\nconstantly over time and flushed when save_summaries() is called.\n\nNotes:\n- The output tensor must be used as an input for something else in the\n  graph. Otherwise, the autosummary op will not get executed, and the average\n  value will not get accumulated.\n- It is perfectly fine to include autosummaries with the same name in\n  several places throughout the graph, even if they are executed concurrently.\n- It is ok to also pass in a python scalar or numpy array. In this case, it\n  is added to the average immediately.\n\"\"\"\n\nfrom collections import OrderedDict\nimport numpy as np\nimport tensorflow as tf\nfrom tensorboard import summary as summary_lib\nfrom tensorboard.plugins.custom_scalar import layout_pb2\n\nfrom . import tfutil\nfrom .tfutil import TfExpression\nfrom .tfutil import TfExpressionEx\n\n_dtype = tf.float64\n_vars = OrderedDict()  # name => [var, ...]\n_immediate = OrderedDict()  # name => update_op, update_value\n_finalized = False\n_merge_op = None\n\n\ndef _create_var(name: str, value_expr: TfExpression) -> TfExpression:\n    \"\"\"Internal helper for creating autosummary accumulators.\"\"\"\n    assert not _finalized\n    name_id = name.replace(\"/\", \"_\")\n    v = tf.cast(value_expr, _dtype)\n\n    if v.shape.is_fully_defined():\n        size = np.prod(tfutil.shape_to_list(v.shape))\n        size_expr = tf.constant(size, dtype=_dtype)\n    else:\n        size = None\n        size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))\n\n    if size == 1:\n        if v.shape.ndims != 0:\n            v = tf.reshape(v, [])\n        v = [size_expr, v, tf.square(v)]\n    else:\n        v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]\n    v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))\n\n    with tfutil.absolute_name_scope(\"Autosummary/\" + name_id), tf.control_dependencies(None):\n        var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False)  # [sum(1), sum(x), sum(x**2)]\n    update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))\n\n    if name in _vars:\n        _vars[name].append(var)\n    else:\n        _vars[name] = [var]\n    return update_op\n\n\ndef autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:\n    \"\"\"Create a new autosummary.\n\n    Args:\n        name:     Name to use in TensorBoard\n        value:    TensorFlow expression or python value to track\n        passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.\n\n    Example use of the passthru mechanism:\n\n    n = autosummary('l2loss', loss, passthru=n)\n\n    This is a shorthand for the following code:\n\n    with tf.control_dependencies([autosummary('l2loss', loss)]):\n        n = tf.identity(n)\n    \"\"\"\n    tfutil.assert_tf_initialized()\n    name_id = name.replace(\"/\", \"_\")\n\n    if tfutil.is_tf_expression(value):\n        with tf.name_scope(\"summary_\" + name_id), tf.device(value.device):\n            update_op = _create_var(name, value)\n            with tf.control_dependencies([update_op]):\n                return tf.identity(value if passthru is None else passthru)\n\n    else:  # python scalar or numpy array\n        if name not in _immediate:\n            with tfutil.absolute_name_scope(\"Autosummary/\" + name_id), tf.device(None), tf.control_dependencies(None):\n                update_value = tf.placeholder(_dtype)\n                update_op = _create_var(name, update_value)\n                _immediate[name] = update_op, update_value\n\n        update_op, update_value = _immediate[name]\n        tfutil.run(update_op, {update_value: value})\n        return value if passthru is None else passthru\n\n\ndef finalize_autosummaries() -> None:\n    \"\"\"Create the necessary ops to include autosummaries in TensorBoard report.\n    Note: This should be done only once per graph.\n    \"\"\"\n    global _finalized\n    tfutil.assert_tf_initialized()\n\n    if _finalized:\n        return None\n\n    _finalized = True\n    tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])\n\n    # Create summary ops.\n    with tf.device(None), tf.control_dependencies(None):\n        for name, vars_list in _vars.items():\n            name_id = name.replace(\"/\", \"_\")\n            with tfutil.absolute_name_scope(\"Autosummary/\" + name_id):\n                moments = tf.add_n(vars_list)\n                moments /= moments[0]\n                with tf.control_dependencies([moments]):  # read before resetting\n                    reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]\n                    with tf.name_scope(None), tf.control_dependencies(reset_ops):  # reset before reporting\n                        mean = moments[1]\n                        std = tf.sqrt(moments[2] - tf.square(moments[1]))\n                        tf.summary.scalar(name, mean)\n                        tf.summary.scalar(\"xCustomScalars/\" + name + \"/margin_lo\", mean - std)\n                        tf.summary.scalar(\"xCustomScalars/\" + name + \"/margin_hi\", mean + std)\n\n    # Group by category and chart name.\n    cat_dict = OrderedDict()\n    for series_name in sorted(_vars.keys()):\n        p = series_name.split(\"/\")\n        cat = p[0] if len(p) >= 2 else \"\"\n        chart = \"/\".join(p[1:-1]) if len(p) >= 3 else p[-1]\n        if cat not in cat_dict:\n            cat_dict[cat] = OrderedDict()\n        if chart not in cat_dict[cat]:\n            cat_dict[cat][chart] = []\n        cat_dict[cat][chart].append(series_name)\n\n    # Setup custom_scalar layout.\n    categories = []\n    for cat_name, chart_dict in cat_dict.items():\n        charts = []\n        for chart_name, series_names in chart_dict.items():\n            series = []\n            for series_name in series_names:\n                series.append(layout_pb2.MarginChartContent.Series(\n                    value=series_name,\n                    lower=\"xCustomScalars/\" + series_name + \"/margin_lo\",\n                    upper=\"xCustomScalars/\" + series_name + \"/margin_hi\"))\n            margin = layout_pb2.MarginChartContent(series=series)\n            charts.append(layout_pb2.Chart(title=chart_name, margin=margin))\n        categories.append(layout_pb2.Category(title=cat_name, chart=charts))\n    layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))\n    return layout\n\ndef save_summaries(file_writer, global_step=None):\n    \"\"\"Call FileWriter.add_summary() with all summaries in the default graph,\n    automatically finalizing and merging them on the first call.\n    \"\"\"\n    global _merge_op\n    tfutil.assert_tf_initialized()\n\n    if _merge_op is None:\n        layout = finalize_autosummaries()\n        if layout is not None:\n            file_writer.add_summary(layout)\n        with tf.device(None), tf.control_dependencies(None):\n            _merge_op = tf.summary.merge_all()\n\n    file_writer.add_summary(_merge_op.eval(), global_step)\n"
  },
  {
    "path": "dnnlib/tflib/network.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Helper for managing networks.\"\"\"\n\nimport types\nimport inspect\nimport re\nimport uuid\nimport sys\nimport numpy as np\nimport tensorflow as tf\n\nfrom collections import OrderedDict\nfrom typing import Any, List, Tuple, Union\n\nfrom . import tfutil\nfrom .. import util\n\nfrom .tfutil import TfExpression, TfExpressionEx\n\n_import_handlers = []  # Custom import handlers for dealing with legacy data in pickle import.\n_import_module_src = dict()  # Source code for temporary modules created during pickle import.\n\n\ndef import_handler(handler_func):\n    \"\"\"Function decorator for declaring custom import handlers.\"\"\"\n    _import_handlers.append(handler_func)\n    return handler_func\n\n\nclass Network:\n    \"\"\"Generic network abstraction.\n\n    Acts as a convenience wrapper for a parameterized network construction\n    function, providing several utility methods and convenient access to\n    the inputs/outputs/weights.\n\n    Network objects can be safely pickled and unpickled for long-term\n    archival purposes. The pickling works reliably as long as the underlying\n    network construction function is defined in a standalone Python module\n    that has no side effects or application-specific imports.\n\n    Args:\n        name: Network name. Used to select TensorFlow name and variable scopes.\n        func_name: Fully qualified name of the underlying network construction function, or a top-level function object.\n        static_kwargs: Keyword arguments to be passed in to the network construction function.\n\n    Attributes:\n        name: User-specified name, defaults to build func name if None.\n        scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.\n        static_kwargs: Arguments passed to the user-supplied build func.\n        components: Container for sub-networks. Passed to the build func, and retained between calls.\n        num_inputs: Number of input tensors.\n        num_outputs: Number of output tensors.\n        input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.\n        output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.\n        input_shape: Short-hand for input_shapes[0].\n        output_shape: Short-hand for output_shapes[0].\n        input_templates: Input placeholders in the template graph.\n        output_templates: Output tensors in the template graph.\n        input_names: Name string for each input.\n        output_names: Name string for each output.\n        own_vars: Variables defined by this network (local_name => var), excluding sub-networks.\n        vars: All variables (local_name => var).\n        trainables: All trainable variables (local_name => var).\n        var_global_to_local: Mapping from variable global names to local names.\n    \"\"\"\n\n    def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):\n        tfutil.assert_tf_initialized()\n        assert isinstance(name, str) or name is None\n        assert func_name is not None\n        assert isinstance(func_name, str) or util.is_top_level_function(func_name)\n        assert util.is_pickleable(static_kwargs)\n\n        self._init_fields()\n        self.name = name\n        self.static_kwargs = util.EasyDict(static_kwargs)\n\n        # Locate the user-specified network build function.\n        if util.is_top_level_function(func_name):\n            func_name = util.get_top_level_function_name(func_name)\n        module, self._build_func_name = util.get_module_from_obj_name(func_name)\n        self._build_func = util.get_obj_from_module(module, self._build_func_name)\n        assert callable(self._build_func)\n\n        # Dig up source code for the module containing the build function.\n        self._build_module_src = _import_module_src.get(module, None)\n        if self._build_module_src is None:\n            self._build_module_src = inspect.getsource(module)\n\n        # Init TensorFlow graph.\n        self._init_graph()\n        self.reset_own_vars()\n\n    def _init_fields(self) -> None:\n        self.name = None\n        self.scope = None\n        self.static_kwargs = util.EasyDict()\n        self.components = util.EasyDict()\n        self.num_inputs = 0\n        self.num_outputs = 0\n        self.input_shapes = [[]]\n        self.output_shapes = [[]]\n        self.input_shape = []\n        self.output_shape = []\n        self.input_templates = []\n        self.output_templates = []\n        self.input_names = []\n        self.output_names = []\n        self.own_vars = OrderedDict()\n        self.vars = OrderedDict()\n        self.trainables = OrderedDict()\n        self.var_global_to_local = OrderedDict()\n\n        self._build_func = None  # User-supplied build function that constructs the network.\n        self._build_func_name = None  # Name of the build function.\n        self._build_module_src = None  # Full source code of the module containing the build function.\n        self._run_cache = dict()  # Cached graph data for Network.run().\n\n    def _init_graph(self) -> None:\n        # Collect inputs.\n        self.input_names = []\n\n        for param in inspect.signature(self._build_func).parameters.values():\n            if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:\n                self.input_names.append(param.name)\n\n        self.num_inputs = len(self.input_names)\n        assert self.num_inputs >= 1\n\n        # Choose name and scope.\n        if self.name is None:\n            self.name = self._build_func_name\n        assert re.match(\"^[A-Za-z0-9_.\\\\-]*$\", self.name)\n        with tf.name_scope(None):\n            self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)\n\n        # Finalize build func kwargs.\n        build_kwargs = dict(self.static_kwargs)\n        build_kwargs[\"is_template_graph\"] = True\n        build_kwargs[\"components\"] = self.components\n\n        # Build template graph.\n        with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope):  # ignore surrounding scopes\n            assert tf.get_variable_scope().name == self.scope\n            assert tf.get_default_graph().get_name_scope() == self.scope\n            with tf.control_dependencies(None):  # ignore surrounding control dependencies\n                self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]\n                out_expr = self._build_func(*self.input_templates, **build_kwargs)\n\n        # Collect outputs.\n        assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)\n        self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)\n        self.num_outputs = len(self.output_templates)\n        assert self.num_outputs >= 1\n        assert all(tfutil.is_tf_expression(t) for t in self.output_templates)\n\n        # Perform sanity checks.\n        if any(t.shape.ndims is None for t in self.input_templates):\n            raise ValueError(\"Network input shapes not defined. Please call x.set_shape() for each input.\")\n        if any(t.shape.ndims is None for t in self.output_templates):\n            raise ValueError(\"Network output shapes not defined. Please call x.set_shape() where applicable.\")\n        if any(not isinstance(comp, Network) for comp in self.components.values()):\n            raise ValueError(\"Components of a Network must be Networks themselves.\")\n        if len(self.components) != len(set(comp.name for comp in self.components.values())):\n            raise ValueError(\"Components of a Network must have unique names.\")\n\n        # List inputs and outputs.\n        self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]\n        self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]\n        self.input_shape = self.input_shapes[0]\n        self.output_shape = self.output_shapes[0]\n        self.output_names = [t.name.split(\"/\")[-1].split(\":\")[0] for t in self.output_templates]\n\n        # List variables.\n        self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(\":\")[0], var) for var in tf.global_variables(self.scope + \"/\"))\n        self.vars = OrderedDict(self.own_vars)\n        self.vars.update((comp.name + \"/\" + name, var) for comp in self.components.values() for name, var in comp.vars.items())\n        self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)\n        self.var_global_to_local = OrderedDict((var.name.split(\":\")[0], name) for name, var in self.vars.items())\n\n    def reset_own_vars(self) -> None:\n        \"\"\"Re-initialize all variables of this network, excluding sub-networks.\"\"\"\n        tfutil.run([var.initializer for var in self.own_vars.values()])\n\n    def reset_vars(self) -> None:\n        \"\"\"Re-initialize all variables of this network, including sub-networks.\"\"\"\n        tfutil.run([var.initializer for var in self.vars.values()])\n\n    def reset_trainables(self) -> None:\n        \"\"\"Re-initialize all trainable variables of this network, including sub-networks.\"\"\"\n        tfutil.run([var.initializer for var in self.trainables.values()])\n\n    def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:\n        \"\"\"Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).\"\"\"\n        assert len(in_expr) == self.num_inputs\n        assert not all(expr is None for expr in in_expr)\n\n        # Finalize build func kwargs.\n        build_kwargs = dict(self.static_kwargs)\n        build_kwargs.update(dynamic_kwargs)\n        build_kwargs[\"is_template_graph\"] = False\n        build_kwargs[\"components\"] = self.components\n\n        # Build TensorFlow graph to evaluate the network.\n        with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):\n            assert tf.get_variable_scope().name == self.scope\n            valid_inputs = [expr for expr in in_expr if expr is not None]\n            final_inputs = []\n            for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):\n                if expr is not None:\n                    expr = tf.identity(expr, name=name)\n                else:\n                    expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)\n                final_inputs.append(expr)\n            out_expr = self._build_func(*final_inputs, **build_kwargs)\n\n        # Propagate input shapes back to the user-specified expressions.\n        for expr, final in zip(in_expr, final_inputs):\n            if isinstance(expr, tf.Tensor):\n                expr.set_shape(final.shape)\n\n        # Express outputs in the desired format.\n        assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)\n        if return_as_list:\n            out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)\n        return out_expr\n\n    def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:\n        \"\"\"Get the local name of a given variable, without any surrounding name scopes.\"\"\"\n        assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)\n        global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name\n        return self.var_global_to_local[global_name]\n\n    def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:\n        \"\"\"Find variable by local or global name.\"\"\"\n        assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)\n        return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name\n\n    def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:\n        \"\"\"Get the value of a given variable as NumPy array.\n        Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible.\"\"\"\n        return self.find_var(var_or_local_name).eval()\n\n    def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:\n        \"\"\"Set the value of a given variable based on the given NumPy array.\n        Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.\"\"\"\n        tfutil.set_vars({self.find_var(var_or_local_name): new_value})\n\n    def __getstate__(self) -> dict:\n        \"\"\"Pickle export.\"\"\"\n        state = dict()\n        state[\"version\"]            = 3\n        state[\"name\"]               = self.name\n        state[\"static_kwargs\"]      = dict(self.static_kwargs)\n        state[\"components\"]         = dict(self.components)\n        state[\"build_module_src\"]   = self._build_module_src\n        state[\"build_func_name\"]    = self._build_func_name\n        state[\"variables\"]          = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))\n        return state\n\n    def __setstate__(self, state: dict) -> None:\n        \"\"\"Pickle import.\"\"\"\n        # pylint: disable=attribute-defined-outside-init\n        tfutil.assert_tf_initialized()\n        self._init_fields()\n\n        # Execute custom import handlers.\n        for handler in _import_handlers:\n            state = handler(state)\n\n        # Set basic fields.\n        assert state[\"version\"] in [2, 3]\n        self.name = state[\"name\"]\n        self.static_kwargs = util.EasyDict(state[\"static_kwargs\"])\n        self.components = util.EasyDict(state.get(\"components\", {}))\n        self._build_module_src = state[\"build_module_src\"]\n        self._build_func_name = state[\"build_func_name\"]\n\n        # Create temporary module from the imported source code.\n        module_name = \"_tflib_network_import_\" + uuid.uuid4().hex\n        module = types.ModuleType(module_name)\n        sys.modules[module_name] = module\n        _import_module_src[module] = self._build_module_src\n        exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used\n\n        # Locate network build function in the temporary module.\n        self._build_func = util.get_obj_from_module(module, self._build_func_name)\n        assert callable(self._build_func)\n\n        # Init TensorFlow graph.\n        self._init_graph()\n        self.reset_own_vars()\n        tfutil.set_vars({self.find_var(name): value for name, value in state[\"variables\"]})\n\n    def clone(self, name: str = None, **new_static_kwargs) -> \"Network\":\n        \"\"\"Create a clone of this network with its own copy of the variables.\"\"\"\n        # pylint: disable=protected-access\n        net = object.__new__(Network)\n        net._init_fields()\n        net.name = name if name is not None else self.name\n        net.static_kwargs = util.EasyDict(self.static_kwargs)\n        net.static_kwargs.update(new_static_kwargs)\n        net._build_module_src = self._build_module_src\n        net._build_func_name = self._build_func_name\n        net._build_func = self._build_func\n        net._init_graph()\n        net.copy_vars_from(self)\n        return net\n\n    def copy_own_vars_from(self, src_net: \"Network\") -> None:\n        \"\"\"Copy the values of all variables from the given network, excluding sub-networks.\"\"\"\n        names = [name for name in self.own_vars.keys() if name in src_net.own_vars]\n        tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))\n\n    def copy_vars_from(self, src_net: \"Network\") -> None:\n        \"\"\"Copy the values of all variables from the given network, including sub-networks.\"\"\"\n        names = [name for name in self.vars.keys() if name in src_net.vars]\n        tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))\n\n    def copy_trainables_from(self, src_net: \"Network\") -> None:\n        \"\"\"Copy the values of all trainable variables from the given network, including sub-networks.\"\"\"\n        names = [name for name in self.trainables.keys() if name in src_net.trainables]\n        tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))\n\n    def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> \"Network\":\n        \"\"\"Create new network with the given parameters, and copy all variables from this network.\"\"\"\n        if new_name is None:\n            new_name = self.name\n        static_kwargs = dict(self.static_kwargs)\n        static_kwargs.update(new_static_kwargs)\n        net = Network(name=new_name, func_name=new_func_name, **static_kwargs)\n        net.copy_vars_from(self)\n        return net\n\n    def setup_as_moving_average_of(self, src_net: \"Network\", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:\n        \"\"\"Construct a TensorFlow op that updates the variables of this network\n        to be slightly closer to those of the given network.\"\"\"\n        with tfutil.absolute_name_scope(self.scope + \"/_MovingAvg\"):\n            ops = []\n            for name, var in self.vars.items():\n                if name in src_net.vars:\n                    cur_beta = beta if name in self.trainables else beta_nontrainable\n                    new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)\n                    ops.append(var.assign(new_value))\n            return tf.group(*ops)\n\n    def run(self,\n            *in_arrays: Tuple[Union[np.ndarray, None], ...],\n            input_transform: dict = None,\n            output_transform: dict = None,\n            return_as_list: bool = False,\n            print_progress: bool = False,\n            minibatch_size: int = None,\n            num_gpus: int = 1,\n            assume_frozen: bool = False,\n            **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:\n        \"\"\"Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).\n\n        Args:\n            input_transform:    A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.\n                                The dict must contain a 'func' field that points to a top-level function. The function is called with the input\n                                TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.\n            output_transform:   A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.\n                                The dict must contain a 'func' field that points to a top-level function. The function is called with the output\n                                TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.\n            return_as_list:     True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.\n            print_progress:     Print progress to the console? Useful for very large input arrays.\n            minibatch_size:     Maximum minibatch size to use, None = disable batching.\n            num_gpus:           Number of GPUs to use.\n            assume_frozen:      Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.\n            dynamic_kwargs:     Additional keyword arguments to be passed into the network build function.\n        \"\"\"\n        assert len(in_arrays) == self.num_inputs\n        assert not all(arr is None for arr in in_arrays)\n        assert input_transform is None or util.is_top_level_function(input_transform[\"func\"])\n        assert output_transform is None or util.is_top_level_function(output_transform[\"func\"])\n        output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)\n        num_items = in_arrays[0].shape[0]\n        if minibatch_size is None:\n            minibatch_size = num_items\n\n        # Construct unique hash key from all arguments that affect the TensorFlow graph.\n        key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)\n        def unwind_key(obj):\n            if isinstance(obj, dict):\n                return [(key, unwind_key(value)) for key, value in sorted(obj.items())]\n            if callable(obj):\n                return util.get_top_level_function_name(obj)\n            return obj\n        key = repr(unwind_key(key))\n\n        # Build graph.\n        if key not in self._run_cache:\n            with tfutil.absolute_name_scope(self.scope + \"/_Run\"), tf.control_dependencies(None):\n                with tf.device(\"/cpu:0\"):\n                    in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]\n                    in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))\n\n                out_split = []\n                for gpu in range(num_gpus):\n                    with tf.device(\"/gpu:%d\" % gpu):\n                        net_gpu = self.clone() if assume_frozen else self\n                        in_gpu = in_split[gpu]\n\n                        if input_transform is not None:\n                            in_kwargs = dict(input_transform)\n                            in_gpu = in_kwargs.pop(\"func\")(*in_gpu, **in_kwargs)\n                            in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)\n\n                        assert len(in_gpu) == self.num_inputs\n                        out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)\n\n                        if output_transform is not None:\n                            out_kwargs = dict(output_transform)\n                            out_gpu = out_kwargs.pop(\"func\")(*out_gpu, **out_kwargs)\n                            out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)\n\n                        assert len(out_gpu) == self.num_outputs\n                        out_split.append(out_gpu)\n\n                with tf.device(\"/cpu:0\"):\n                    out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]\n                    self._run_cache[key] = in_expr, out_expr\n\n        # Run minibatches.\n        in_expr, out_expr = self._run_cache[key]\n        out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]\n\n        for mb_begin in range(0, num_items, minibatch_size):\n            if print_progress:\n                print(\"\\r%d / %d\" % (mb_begin, num_items), end=\"\")\n\n            mb_end = min(mb_begin + minibatch_size, num_items)\n            mb_num = mb_end - mb_begin\n            mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]\n            mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))\n\n            for dst, src in zip(out_arrays, mb_out):\n                dst[mb_begin: mb_end] = src\n\n        # Done.\n        if print_progress:\n            print(\"\\r%d / %d\" % (num_items, num_items))\n\n        if not return_as_list:\n            out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)\n        return out_arrays\n\n    def list_ops(self) -> List[TfExpression]:\n        include_prefix = self.scope + \"/\"\n        exclude_prefix = include_prefix + \"_\"\n        ops = tf.get_default_graph().get_operations()\n        ops = [op for op in ops if op.name.startswith(include_prefix)]\n        ops = [op for op in ops if not op.name.startswith(exclude_prefix)]\n        return ops\n\n    def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:\n        \"\"\"Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to\n        individual layers of the network. Mainly intended to be used for reporting.\"\"\"\n        layers = []\n\n        def recurse(scope, parent_ops, parent_vars, level):\n            # Ignore specific patterns.\n            if any(p in scope for p in [\"/Shape\", \"/strided_slice\", \"/Cast\", \"/concat\", \"/Assign\"]):\n                return\n\n            # Filter ops and vars by scope.\n            global_prefix = scope + \"/\"\n            local_prefix = global_prefix[len(self.scope) + 1:]\n            cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]\n            cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]\n            if not cur_ops and not cur_vars:\n                return\n\n            # Filter out all ops related to variables.\n            for var in [op for op in cur_ops if op.type.startswith(\"Variable\")]:\n                var_prefix = var.name + \"/\"\n                cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]\n\n            # Scope does not contain ops as immediate children => recurse deeper.\n            contains_direct_ops = any(\"/\" not in op.name[len(global_prefix):] and op.type != \"Identity\" for op in cur_ops)\n            if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:\n                visited = set()\n                for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:\n                    token = rel_name.split(\"/\")[0]\n                    if token not in visited:\n                        recurse(global_prefix + token, cur_ops, cur_vars, level + 1)\n                        visited.add(token)\n                return\n\n            # Report layer.\n            layer_name = scope[len(self.scope) + 1:]\n            layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]\n            layer_trainables = [var for _name, var in cur_vars if var.trainable]\n            layers.append((layer_name, layer_output, layer_trainables))\n\n        recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)\n        return layers\n\n    def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:\n        \"\"\"Print a summary table of the network structure.\"\"\"\n        rows = [[title if title is not None else self.name, \"Params\", \"OutputShape\", \"WeightShape\"]]\n        rows += [[\"---\"] * 4]\n        total_params = 0\n\n        for layer_name, layer_output, layer_trainables in self.list_layers():\n            num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)\n            weights = [var for var in layer_trainables if var.name.endswith(\"/weight:0\")]\n            weights.sort(key=lambda x: len(x.name))\n            if len(weights) == 0 and len(layer_trainables) == 1:\n                weights = layer_trainables\n            total_params += num_params\n\n            if not hide_layers_with_no_params or num_params != 0:\n                num_params_str = str(num_params) if num_params > 0 else \"-\"\n                output_shape_str = str(layer_output.shape)\n                weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else \"-\"\n                rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]\n\n        rows += [[\"---\"] * 4]\n        rows += [[\"Total\", str(total_params), \"\", \"\"]]\n\n        widths = [max(len(cell) for cell in column) for column in zip(*rows)]\n        print()\n        for row in rows:\n            print(\"  \".join(cell + \" \" * (width - len(cell)) for cell, width in zip(row, widths)))\n        print()\n\n    def setup_weight_histograms(self, title: str = None) -> None:\n        \"\"\"Construct summary ops to include histograms of all trainable parameters in TensorBoard.\"\"\"\n        if title is None:\n            title = self.name\n\n        with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):\n            for local_name, var in self.trainables.items():\n                if \"/\" in local_name:\n                    p = local_name.split(\"/\")\n                    name = title + \"_\" + p[-1] + \"/\" + \"_\".join(p[:-1])\n                else:\n                    name = title + \"_toplevel/\" + local_name\n\n                tf.summary.histogram(name, var)\n\n#----------------------------------------------------------------------------\n# Backwards-compatible emulation of legacy output transformation in Network.run().\n\n_print_legacy_warning = True\n\ndef _handle_legacy_output_transforms(output_transform, dynamic_kwargs):\n    global _print_legacy_warning\n    legacy_kwargs = [\"out_mul\", \"out_add\", \"out_shrink\", \"out_dtype\"]\n    if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):\n        return output_transform, dynamic_kwargs\n\n    if _print_legacy_warning:\n        _print_legacy_warning = False\n        print()\n        print(\"WARNING: Old-style output transformations in Network.run() are deprecated.\")\n        print(\"Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'\")\n        print(\"instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.\")\n        print()\n    assert output_transform is None\n\n    new_kwargs = dict(dynamic_kwargs)\n    new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}\n    new_transform[\"func\"] = _legacy_output_transform_func\n    return new_transform, new_kwargs\n\ndef _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):\n    if out_mul != 1.0:\n        expr = [x * out_mul for x in expr]\n\n    if out_add != 0.0:\n        expr = [x + out_add for x in expr]\n\n    if out_shrink > 1:\n        ksize = [1, 1, out_shrink, out_shrink]\n        expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding=\"VALID\", data_format=\"NCHW\") for x in expr]\n\n    if out_dtype is not None:\n        if tf.as_dtype(out_dtype).is_integer:\n            expr = [tf.round(x) for x in expr]\n        expr = [tf.saturate_cast(x, out_dtype) for x in expr]\n    return expr\n"
  },
  {
    "path": "dnnlib/tflib/optimizer.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Helper wrapper for a Tensorflow optimizer.\"\"\"\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom collections import OrderedDict\nfrom typing import List, Union\n\nfrom . import autosummary\nfrom . import tfutil\nfrom .. import util\n\nfrom .tfutil import TfExpression, TfExpressionEx\n\ntry:\n    # TensorFlow 1.13\n    from tensorflow.python.ops import nccl_ops\nexcept:\n    # Older TensorFlow versions\n    import tensorflow.contrib.nccl as nccl_ops\n\nclass Optimizer:\n    \"\"\"A Wrapper for tf.train.Optimizer.\n\n    Automatically takes care of:\n    - Gradient averaging for multi-GPU training.\n    - Dynamic loss scaling and typecasts for FP16 training.\n    - Ignoring corrupted gradients that contain NaNs/Infs.\n    - Reporting statistics.\n    - Well-chosen default settings.\n    \"\"\"\n\n    def __init__(self,\n                 name: str = \"Train\",\n                 tf_optimizer: str = \"tf.train.AdamOptimizer\",\n                 learning_rate: TfExpressionEx = 0.001,\n                 use_loss_scaling: bool = False,\n                 loss_scaling_init: float = 64.0,\n                 loss_scaling_inc: float = 0.0005,\n                 loss_scaling_dec: float = 1.0,\n                 **kwargs):\n\n        # Init fields.\n        self.name = name\n        self.learning_rate = tf.convert_to_tensor(learning_rate)\n        self.id = self.name.replace(\"/\", \".\")\n        self.scope = tf.get_default_graph().unique_name(self.id)\n        self.optimizer_class = util.get_obj_by_name(tf_optimizer)\n        self.optimizer_kwargs = dict(kwargs)\n        self.use_loss_scaling = use_loss_scaling\n        self.loss_scaling_init = loss_scaling_init\n        self.loss_scaling_inc = loss_scaling_inc\n        self.loss_scaling_dec = loss_scaling_dec\n        self._grad_shapes = None  # [shape, ...]\n        self._dev_opt = OrderedDict()  # device => optimizer\n        self._dev_grads = OrderedDict()  # device => [[(grad, var), ...], ...]\n        self._dev_ls_var = OrderedDict()  # device => variable (log2 of loss scaling factor)\n        self._updates_applied = False\n\n    def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:\n        \"\"\"Register the gradients of the given loss function with respect to the given variables.\n        Intended to be called once per GPU.\"\"\"\n        assert not self._updates_applied\n\n        # Validate arguments.\n        if isinstance(trainable_vars, dict):\n            trainable_vars = list(trainable_vars.values())  # allow passing in Network.trainables as vars\n\n        assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1\n        assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])\n\n        if self._grad_shapes is None:\n            self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]\n\n        assert len(trainable_vars) == len(self._grad_shapes)\n        assert all(tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))\n\n        dev = loss.device\n\n        assert all(var.device == dev for var in trainable_vars)\n\n        # Register device and compute gradients.\n        with tf.name_scope(self.id + \"_grad\"), tf.device(dev):\n            if dev not in self._dev_opt:\n                opt_name = self.scope.replace(\"/\", \"_\") + \"_opt%d\" % len(self._dev_opt)\n                assert callable(self.optimizer_class)\n                self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)\n                self._dev_grads[dev] = []\n\n            loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))\n            grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE)  # disable gating to reduce memory usage\n            grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads]  # replace disconnected gradients with zeros\n            self._dev_grads[dev].append(grads)\n\n    def apply_updates(self) -> tf.Operation:\n        \"\"\"Construct training op to update the registered variables based on their gradients.\"\"\"\n        tfutil.assert_tf_initialized()\n        assert not self._updates_applied\n        self._updates_applied = True\n        devices = list(self._dev_grads.keys())\n        total_grads = sum(len(grads) for grads in self._dev_grads.values())\n        assert len(devices) >= 1 and total_grads >= 1\n        ops = []\n\n        with tfutil.absolute_name_scope(self.scope):\n            # Cast gradients to FP32 and calculate partial sum within each device.\n            dev_grads = OrderedDict()  # device => [(grad, var), ...]\n\n            for dev_idx, dev in enumerate(devices):\n                with tf.name_scope(\"ProcessGrads%d\" % dev_idx), tf.device(dev):\n                    sums = []\n\n                    for gv in zip(*self._dev_grads[dev]):\n                        assert all(v is gv[0][1] for g, v in gv)\n                        g = [tf.cast(g, tf.float32) for g, v in gv]\n                        g = g[0] if len(g) == 1 else tf.add_n(g)\n                        sums.append((g, gv[0][1]))\n\n                    dev_grads[dev] = sums\n\n            # Sum gradients across devices.\n            if len(devices) > 1:\n                with tf.name_scope(\"SumAcrossGPUs\"), tf.device(None):\n                    for var_idx, grad_shape in enumerate(self._grad_shapes):\n                        g = [dev_grads[dev][var_idx][0] for dev in devices]\n\n                        if np.prod(grad_shape):  # nccl does not support zero-sized tensors\n                            g = nccl_ops.all_sum(g)\n\n                        for dev, gg in zip(devices, g):\n                            dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])\n\n            # Apply updates separately on each device.\n            for dev_idx, (dev, grads) in enumerate(dev_grads.items()):\n                with tf.name_scope(\"ApplyGrads%d\" % dev_idx), tf.device(dev):\n                    # Scale gradients as needed.\n                    if self.use_loss_scaling or total_grads > 1:\n                        with tf.name_scope(\"Scale\"):\n                            coef = tf.constant(np.float32(1.0 / total_grads), name=\"coef\")\n                            coef = self.undo_loss_scaling(coef)\n                            grads = [(g * coef, v) for g, v in grads]\n\n                    # Check for overflows.\n                    with tf.name_scope(\"CheckOverflow\"):\n                        grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))\n\n                    # Update weights and adjust loss scaling.\n                    with tf.name_scope(\"UpdateWeights\"):\n                        # pylint: disable=cell-var-from-loop\n                        opt = self._dev_opt[dev]\n                        ls_var = self.get_loss_scaling_var(dev)\n\n                        if not self.use_loss_scaling:\n                            ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))\n                        else:\n                            ops.append(tf.cond(grad_ok,\n                                               lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),\n                                               lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))\n\n                    # Report statistics on the last device.\n                    if dev == devices[-1]:\n                        with tf.name_scope(\"Statistics\"):\n                            ops.append(autosummary.autosummary(self.id + \"/learning_rate\", self.learning_rate))\n                            ops.append(autosummary.autosummary(self.id + \"/overflow_frequency\", tf.where(grad_ok, 0, 1)))\n\n                            if self.use_loss_scaling:\n                                ops.append(autosummary.autosummary(self.id + \"/loss_scaling_log2\", ls_var))\n\n            # Initialize variables and group everything into a single op.\n            self.reset_optimizer_state()\n            tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))\n\n            return tf.group(*ops, name=\"TrainingOp\")\n\n    def reset_optimizer_state(self) -> None:\n        \"\"\"Reset internal state of the underlying optimizer.\"\"\"\n        tfutil.assert_tf_initialized()\n        tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])\n\n    def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:\n        \"\"\"Get or create variable representing log2 of the current dynamic loss scaling factor.\"\"\"\n        if not self.use_loss_scaling:\n            return None\n\n        if device not in self._dev_ls_var:\n            with tfutil.absolute_name_scope(self.scope + \"/LossScalingVars\"), tf.control_dependencies(None):\n                self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name=\"loss_scaling_var\")\n\n        return self._dev_ls_var[device]\n\n    def apply_loss_scaling(self, value: TfExpression) -> TfExpression:\n        \"\"\"Apply dynamic loss scaling for the given expression.\"\"\"\n        assert tfutil.is_tf_expression(value)\n\n        if not self.use_loss_scaling:\n            return value\n\n        return value * tfutil.exp2(self.get_loss_scaling_var(value.device))\n\n    def undo_loss_scaling(self, value: TfExpression) -> TfExpression:\n        \"\"\"Undo the effect of dynamic loss scaling for the given expression.\"\"\"\n        assert tfutil.is_tf_expression(value)\n\n        if not self.use_loss_scaling:\n            return value\n\n        return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type\n"
  },
  {
    "path": "dnnlib/tflib/tfutil.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Miscellaneous helper utils for Tensorflow.\"\"\"\n\nimport os\nimport numpy as np\nimport tensorflow as tf\n\nfrom typing import Any, Iterable, List, Union\n\nTfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]\n\"\"\"A type that represents a valid Tensorflow expression.\"\"\"\n\nTfExpressionEx = Union[TfExpression, int, float, np.ndarray]\n\"\"\"A type that can be converted to a valid Tensorflow expression.\"\"\"\n\n\ndef run(*args, **kwargs) -> Any:\n    \"\"\"Run the specified ops in the default session.\"\"\"\n    assert_tf_initialized()\n    return tf.get_default_session().run(*args, **kwargs)\n\n\ndef is_tf_expression(x: Any) -> bool:\n    \"\"\"Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation.\"\"\"\n    return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))\n\n\ndef shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:\n    \"\"\"Convert a Tensorflow shape to a list of ints.\"\"\"\n    return [dim.value for dim in shape]\n\n\ndef flatten(x: TfExpressionEx) -> TfExpression:\n    \"\"\"Shortcut function for flattening a tensor.\"\"\"\n    with tf.name_scope(\"Flatten\"):\n        return tf.reshape(x, [-1])\n\n\ndef log2(x: TfExpressionEx) -> TfExpression:\n    \"\"\"Logarithm in base 2.\"\"\"\n    with tf.name_scope(\"Log2\"):\n        return tf.log(x) * np.float32(1.0 / np.log(2.0))\n\n\ndef exp2(x: TfExpressionEx) -> TfExpression:\n    \"\"\"Exponent in base 2.\"\"\"\n    with tf.name_scope(\"Exp2\"):\n        return tf.exp(x * np.float32(np.log(2.0)))\n\n\ndef lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:\n    \"\"\"Linear interpolation.\"\"\"\n    with tf.name_scope(\"Lerp\"):\n        return a + (b - a) * t\n\n\ndef lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:\n    \"\"\"Linear interpolation with clip.\"\"\"\n    with tf.name_scope(\"LerpClip\"):\n        return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)\n\n\ndef absolute_name_scope(scope: str) -> tf.name_scope:\n    \"\"\"Forcefully enter the specified name scope, ignoring any surrounding scopes.\"\"\"\n    return tf.name_scope(scope + \"/\")\n\n\ndef absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:\n    \"\"\"Forcefully enter the specified variable scope, ignoring any surrounding scopes.\"\"\"\n    return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)\n\n\ndef _sanitize_tf_config(config_dict: dict = None) -> dict:\n    # Defaults.\n    cfg = dict()\n    cfg[\"rnd.np_random_seed\"]               = None      # Random seed for NumPy. None = keep as is.\n    cfg[\"rnd.tf_random_seed\"]               = \"auto\"    # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.\n    cfg[\"env.TF_CPP_MIN_LOG_LEVEL\"]         = \"1\"       # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.\n    cfg[\"graph_options.place_pruned_graph\"] = True      # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.\n    cfg[\"gpu_options.allow_growth\"]         = True      # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.\n\n    # User overrides.\n    if config_dict is not None:\n        cfg.update(config_dict)\n    return cfg\n\n\ndef init_tf(config_dict: dict = None) -> None:\n    \"\"\"Initialize TensorFlow session using good default settings.\"\"\"\n    # Skip if already initialized.\n    if tf.get_default_session() is not None:\n        return\n\n    # Setup config dict and random seeds.\n    cfg = _sanitize_tf_config(config_dict)\n    np_random_seed = cfg[\"rnd.np_random_seed\"]\n    if np_random_seed is not None:\n        np.random.seed(np_random_seed)\n    tf_random_seed = cfg[\"rnd.tf_random_seed\"]\n    if tf_random_seed == \"auto\":\n        tf_random_seed = np.random.randint(1 << 31)\n    if tf_random_seed is not None:\n        tf.set_random_seed(tf_random_seed)\n\n    # Setup environment variables.\n    for key, value in list(cfg.items()):\n        fields = key.split(\".\")\n        if fields[0] == \"env\":\n            assert len(fields) == 2\n            os.environ[fields[1]] = str(value)\n\n    # Create default TensorFlow session.\n    create_session(cfg, force_as_default=True)\n\n\ndef assert_tf_initialized():\n    \"\"\"Check that TensorFlow session has been initialized.\"\"\"\n    if tf.get_default_session() is None:\n        raise RuntimeError(\"No default TensorFlow session found. Please call dnnlib.tflib.init_tf().\")\n\n\ndef create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:\n    \"\"\"Create tf.Session based on config dict.\"\"\"\n    # Setup TensorFlow config proto.\n    cfg = _sanitize_tf_config(config_dict)\n    config_proto = tf.ConfigProto()\n    for key, value in cfg.items():\n        fields = key.split(\".\")\n        if fields[0] not in [\"rnd\", \"env\"]:\n            obj = config_proto\n            for field in fields[:-1]:\n                obj = getattr(obj, field)\n            setattr(obj, fields[-1], value)\n\n    # Create session.\n    session = tf.Session(config=config_proto)\n    if force_as_default:\n        # pylint: disable=protected-access\n        session._default_session = session.as_default()\n        session._default_session.enforce_nesting = False\n        session._default_session.__enter__() # pylint: disable=no-member\n\n    return session\n\n\ndef init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:\n    \"\"\"Initialize all tf.Variables that have not already been initialized.\n\n    Equivalent to the following, but more efficient and does not bloat the tf graph:\n    tf.variables_initializer(tf.report_uninitialized_variables()).run()\n    \"\"\"\n    assert_tf_initialized()\n    if target_vars is None:\n        target_vars = tf.global_variables()\n\n    test_vars = []\n    test_ops = []\n\n    with tf.control_dependencies(None):  # ignore surrounding control_dependencies\n        for var in target_vars:\n            assert is_tf_expression(var)\n\n            try:\n                tf.get_default_graph().get_tensor_by_name(var.name.replace(\":0\", \"/IsVariableInitialized:0\"))\n            except KeyError:\n                # Op does not exist => variable may be uninitialized.\n                test_vars.append(var)\n\n                with absolute_name_scope(var.name.split(\":\")[0]):\n                    test_ops.append(tf.is_variable_initialized(var))\n\n    init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]\n    run([var.initializer for var in init_vars])\n\n\ndef set_vars(var_to_value_dict: dict) -> None:\n    \"\"\"Set the values of given tf.Variables.\n\n    Equivalent to the following, but more efficient and does not bloat the tf graph:\n    tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]\n    \"\"\"\n    assert_tf_initialized()\n    ops = []\n    feed_dict = {}\n\n    for var, value in var_to_value_dict.items():\n        assert is_tf_expression(var)\n\n        try:\n            setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(\":0\", \"/setter:0\"))  # look for existing op\n        except KeyError:\n            with absolute_name_scope(var.name.split(\":\")[0]):\n                with tf.control_dependencies(None):  # ignore surrounding control_dependencies\n                    setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, \"new_value\"), name=\"setter\")  # create new setter\n\n        ops.append(setter)\n        feed_dict[setter.op.inputs[1]] = value\n\n    run(ops, feed_dict)\n\n\ndef create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):\n    \"\"\"Create tf.Variable with large initial value without bloating the tf graph.\"\"\"\n    assert_tf_initialized()\n    assert isinstance(initial_value, np.ndarray)\n    zeros = tf.zeros(initial_value.shape, initial_value.dtype)\n    var = tf.Variable(zeros, *args, **kwargs)\n    set_vars({var: initial_value})\n    return var\n\n\ndef convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):\n    \"\"\"Convert a minibatch of images from uint8 to float32 with configurable dynamic range.\n    Can be used as an input transformation for Network.run().\n    \"\"\"\n    images = tf.cast(images, tf.float32)\n    if nhwc_to_nchw:\n        images = tf.transpose(images, [0, 3, 1, 2])\n    return (images - drange[0]) * ((drange[1] - drange[0]) / 255)\n\n\ndef convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):\n    \"\"\"Convert a minibatch of images from float32 to uint8 with configurable dynamic range.\n    Can be used as an output transformation for Network.run().\n    \"\"\"\n    images = tf.cast(images, tf.float32)\n    if shrink > 1:\n        ksize = [1, 1, shrink, shrink]\n        images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding=\"VALID\", data_format=\"NCHW\")\n    if nchw_to_nhwc:\n        images = tf.transpose(images, [0, 2, 3, 1])\n    scale = 255 / (drange[1] - drange[0])\n    images = images * scale + (0.5 - drange[0] * scale)\n    return tf.saturate_cast(images, tf.uint8)\n"
  },
  {
    "path": "dnnlib/util.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Miscellaneous utility classes and functions.\"\"\"\n\nimport ctypes\nimport fnmatch\nimport importlib\nimport inspect\nimport numpy as np\nimport os\nimport shutil\nimport sys\nimport types\nimport io\nimport pickle\nimport re\nimport requests\nimport html\nimport hashlib\nimport glob\nimport uuid\n\nfrom distutils.util import strtobool\nfrom typing import Any, List, Tuple, Union\n\n\n# Util classes\n# ------------------------------------------------------------------------------------------\n\n\nclass EasyDict(dict):\n    \"\"\"Convenience class that behaves like a dict but allows access with the attribute syntax.\"\"\"\n\n    def __getattr__(self, name: str) -> Any:\n        try:\n            return self[name]\n        except KeyError:\n            raise AttributeError(name)\n\n    def __setattr__(self, name: str, value: Any) -> None:\n        self[name] = value\n\n    def __delattr__(self, name: str) -> None:\n        del self[name]\n\n\nclass Logger(object):\n    \"\"\"Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.\"\"\"\n\n    def __init__(self, file_name: str = None, file_mode: str = \"w\", should_flush: bool = True):\n        self.file = None\n\n        if file_name is not None:\n            self.file = open(file_name, file_mode)\n\n        self.should_flush = should_flush\n        self.stdout = sys.stdout\n        self.stderr = sys.stderr\n\n        sys.stdout = self\n        sys.stderr = self\n\n    def __enter__(self) -> \"Logger\":\n        return self\n\n    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:\n        self.close()\n\n    def write(self, text: str) -> None:\n        \"\"\"Write text to stdout (and a file) and optionally flush.\"\"\"\n        if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash\n            return\n\n        if self.file is not None:\n            self.file.write(text)\n\n        self.stdout.write(text)\n\n        if self.should_flush:\n            self.flush()\n\n    def flush(self) -> None:\n        \"\"\"Flush written text to both stdout and a file, if open.\"\"\"\n        if self.file is not None:\n            self.file.flush()\n\n        self.stdout.flush()\n\n    def close(self) -> None:\n        \"\"\"Flush, close possible files, and remove stdout/stderr mirroring.\"\"\"\n        self.flush()\n\n        # if using multiple loggers, prevent closing in wrong order\n        if sys.stdout is self:\n            sys.stdout = self.stdout\n        if sys.stderr is self:\n            sys.stderr = self.stderr\n\n        if self.file is not None:\n            self.file.close()\n\n\n# Small util functions\n# ------------------------------------------------------------------------------------------\n\n\ndef format_time(seconds: Union[int, float]) -> str:\n    \"\"\"Convert the seconds to human readable string with days, hours, minutes and seconds.\"\"\"\n    s = int(np.rint(seconds))\n\n    if s < 60:\n        return \"{0}s\".format(s)\n    elif s < 60 * 60:\n        return \"{0}m {1:02}s\".format(s // 60, s % 60)\n    elif s < 24 * 60 * 60:\n        return \"{0}h {1:02}m {2:02}s\".format(s // (60 * 60), (s // 60) % 60, s % 60)\n    else:\n        return \"{0}d {1:02}h {2:02}m\".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)\n\n\ndef ask_yes_no(question: str) -> bool:\n    \"\"\"Ask the user the question until the user inputs a valid answer.\"\"\"\n    while True:\n        try:\n            print(\"{0} [y/n]\".format(question))\n            return strtobool(input().lower())\n        except ValueError:\n            pass\n\n\ndef tuple_product(t: Tuple) -> Any:\n    \"\"\"Calculate the product of the tuple elements.\"\"\"\n    result = 1\n\n    for v in t:\n        result *= v\n\n    return result\n\n\n_str_to_ctype = {\n    \"uint8\": ctypes.c_ubyte,\n    \"uint16\": ctypes.c_uint16,\n    \"uint32\": ctypes.c_uint32,\n    \"uint64\": ctypes.c_uint64,\n    \"int8\": ctypes.c_byte,\n    \"int16\": ctypes.c_int16,\n    \"int32\": ctypes.c_int32,\n    \"int64\": ctypes.c_int64,\n    \"float32\": ctypes.c_float,\n    \"float64\": ctypes.c_double\n}\n\n\ndef get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:\n    \"\"\"Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.\"\"\"\n    type_str = None\n\n    if isinstance(type_obj, str):\n        type_str = type_obj\n    elif hasattr(type_obj, \"__name__\"):\n        type_str = type_obj.__name__\n    elif hasattr(type_obj, \"name\"):\n        type_str = type_obj.name\n    else:\n        raise RuntimeError(\"Cannot infer type name from input\")\n\n    assert type_str in _str_to_ctype.keys()\n\n    my_dtype = np.dtype(type_str)\n    my_ctype = _str_to_ctype[type_str]\n\n    assert my_dtype.itemsize == ctypes.sizeof(my_ctype)\n\n    return my_dtype, my_ctype\n\n\ndef is_pickleable(obj: Any) -> bool:\n    try:\n        with io.BytesIO() as stream:\n            pickle.dump(obj, stream)\n        return True\n    except:\n        return False\n\n\n# Functionality to import modules/objects by name, and call functions by name\n# ------------------------------------------------------------------------------------------\n\ndef get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:\n    \"\"\"Searches for the underlying module behind the name to some python object.\n    Returns the module and the object name (original name with module part removed).\"\"\"\n\n    # allow convenience shorthands, substitute them by full names\n    obj_name = re.sub(\"^np.\", \"numpy.\", obj_name)\n    obj_name = re.sub(\"^tf.\", \"tensorflow.\", obj_name)\n\n    # list alternatives for (module_name, local_obj_name)\n    parts = obj_name.split(\".\")\n    name_pairs = [(\".\".join(parts[:i]), \".\".join(parts[i:])) for i in range(len(parts), 0, -1)]\n\n    # try each alternative in turn\n    for module_name, local_obj_name in name_pairs:\n        try:\n            module = importlib.import_module(module_name) # may raise ImportError\n            get_obj_from_module(module, local_obj_name) # may raise AttributeError\n            return module, local_obj_name\n        except:\n            pass\n\n    # maybe some of the modules themselves contain errors?\n    for module_name, _local_obj_name in name_pairs:\n        try:\n            importlib.import_module(module_name) # may raise ImportError\n        except ImportError:\n            if not str(sys.exc_info()[1]).startswith(\"No module named '\" + module_name + \"'\"):\n                raise\n\n    # maybe the requested attribute is missing?\n    for module_name, local_obj_name in name_pairs:\n        try:\n            module = importlib.import_module(module_name) # may raise ImportError\n            get_obj_from_module(module, local_obj_name) # may raise AttributeError\n        except ImportError:\n            pass\n\n    # we are out of luck, but we have no idea why\n    raise ImportError(obj_name)\n\n\ndef get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:\n    \"\"\"Traverses the object name and returns the last (rightmost) python object.\"\"\"\n    if obj_name == '':\n        return module\n    obj = module\n    for part in obj_name.split(\".\"):\n        obj = getattr(obj, part)\n    return obj\n\n\ndef get_obj_by_name(name: str) -> Any:\n    \"\"\"Finds the python object with the given name.\"\"\"\n    module, obj_name = get_module_from_obj_name(name)\n    return get_obj_from_module(module, obj_name)\n\n\ndef call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:\n    \"\"\"Finds the python object with the given name and calls it as a function.\"\"\"\n    assert func_name is not None\n    func_obj = get_obj_by_name(func_name)\n    assert callable(func_obj)\n    return func_obj(*args, **kwargs)\n\n\ndef get_module_dir_by_obj_name(obj_name: str) -> str:\n    \"\"\"Get the directory path of the module containing the given object name.\"\"\"\n    module, _ = get_module_from_obj_name(obj_name)\n    return os.path.dirname(inspect.getfile(module))\n\n\ndef is_top_level_function(obj: Any) -> bool:\n    \"\"\"Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.\"\"\"\n    return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__\n\n\ndef get_top_level_function_name(obj: Any) -> str:\n    \"\"\"Return the fully-qualified name of a top-level function.\"\"\"\n    assert is_top_level_function(obj)\n    return obj.__module__ + \".\" + obj.__name__\n\n\n# File system helpers\n# ------------------------------------------------------------------------------------------\n\ndef list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:\n    \"\"\"List all files recursively in a given directory while ignoring given file and directory names.\n    Returns list of tuples containing both absolute and relative paths.\"\"\"\n    assert os.path.isdir(dir_path)\n    base_name = os.path.basename(os.path.normpath(dir_path))\n\n    if ignores is None:\n        ignores = []\n\n    result = []\n\n    for root, dirs, files in os.walk(dir_path, topdown=True):\n        for ignore_ in ignores:\n            dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]\n\n            # dirs need to be edited in-place\n            for d in dirs_to_remove:\n                dirs.remove(d)\n\n            files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]\n\n        absolute_paths = [os.path.join(root, f) for f in files]\n        relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]\n\n        if add_base_to_relative:\n            relative_paths = [os.path.join(base_name, p) for p in relative_paths]\n\n        assert len(absolute_paths) == len(relative_paths)\n        result += zip(absolute_paths, relative_paths)\n\n    return result\n\n\ndef copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:\n    \"\"\"Takes in a list of tuples of (src, dst) paths and copies files.\n    Will create all necessary directories.\"\"\"\n    for file in files:\n        target_dir_name = os.path.dirname(file[1])\n\n        # will create all intermediate-level directories\n        if not os.path.exists(target_dir_name):\n            os.makedirs(target_dir_name)\n\n        shutil.copyfile(file[0], file[1])\n\n\n# URL helpers\n# ------------------------------------------------------------------------------------------\n\ndef is_url(obj: Any) -> bool:\n    \"\"\"Determine whether the given object is a valid URL string.\"\"\"\n    if not isinstance(obj, str) or not \"://\" in obj:\n        return False\n    try:\n        res = requests.compat.urlparse(obj)\n        if not res.scheme or not res.netloc or not \".\" in res.netloc:\n            return False\n        res = requests.compat.urlparse(requests.compat.urljoin(obj, \"/\"))\n        if not res.scheme or not res.netloc or not \".\" in res.netloc:\n            return False\n    except:\n        return False\n    return True\n\n\ndef open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:\n    \"\"\"Download the given URL and return a binary-mode file object to access the data.\"\"\"\n    assert is_url(url)\n    assert num_attempts >= 1\n\n    # Lookup from cache.\n    url_md5 = hashlib.md5(url.encode(\"utf-8\")).hexdigest()\n    if cache_dir is not None:\n        cache_files = glob.glob(os.path.join(cache_dir, url_md5 + \"_*\"))\n        if len(cache_files) == 1:\n            return open(cache_files[0], \"rb\")\n\n    # Download.\n    url_name = None\n    url_data = None\n    with requests.Session() as session:\n        if verbose:\n            print(\"Downloading %s ...\" % url, end=\"\", flush=True)\n        for attempts_left in reversed(range(num_attempts)):\n            try:\n                with session.get(url) as res:\n                    res.raise_for_status()\n                    if len(res.content) == 0:\n                        raise IOError(\"No data received\")\n\n                    if len(res.content) < 8192:\n                        content_str = res.content.decode(\"utf-8\")\n                        if \"download_warning\" in res.headers.get(\"Set-Cookie\", \"\"):\n                            links = [html.unescape(link) for link in content_str.split('\"') if \"export=download\" in link]\n                            if len(links) == 1:\n                                url = requests.compat.urljoin(url, links[0])\n                                raise IOError(\"Google Drive virus checker nag\")\n                        if \"Google Drive - Quota exceeded\" in content_str:\n                            raise IOError(\"Google Drive quota exceeded\")\n\n                    match = re.search(r'filename=\"([^\"]*)\"', res.headers.get(\"Content-Disposition\", \"\"))\n                    url_name = match[1] if match else url\n                    url_data = res.content\n                    if verbose:\n                        print(\" done\")\n                    break\n            except:\n                if not attempts_left:\n                    if verbose:\n                        print(\" failed\")\n                    raise\n                if verbose:\n                    print(\".\", end=\"\", flush=True)\n\n    # Save to cache.\n    if cache_dir is not None:\n        safe_name = re.sub(r\"[^0-9a-zA-Z-._]\", \"_\", url_name)\n        cache_file = os.path.join(cache_dir, url_md5 + \"_\" + safe_name)\n        temp_file = os.path.join(cache_dir, \"tmp_\" + uuid.uuid4().hex + \"_\" + url_md5 + \"_\" + safe_name)\n        os.makedirs(cache_dir, exist_ok=True)\n        with open(temp_file, \"wb\") as f:\n            f.write(url_data)\n        os.replace(temp_file, cache_file) # atomic\n\n    # Return data as file object.\n    return io.BytesIO(url_data)\n"
  },
  {
    "path": "generate_figures.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Minimal script for reproducing the figures of the StyleGAN paper using pre-trained generators.\"\"\"\n\nimport os\nimport pickle\nimport numpy as np\nimport PIL.Image\nimport dnnlib\nimport dnnlib.tflib as tflib\nimport config\n\n#----------------------------------------------------------------------------\n# Helpers for loading and using pre-trained generators.\n\nurl_ffhq        = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl\nurl_celebahq    = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl\nurl_bedrooms    = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl\nurl_cars        = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl\nurl_cats        = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl\n\nsynthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)\n\n_Gs_cache = dict()\n\ndef load_Gs(url):\n    if url not in _Gs_cache:\n        with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:\n            _G, _D, Gs = pickle.load(f)\n        _Gs_cache[url] = Gs\n    return _Gs_cache[url]\n\n#----------------------------------------------------------------------------\n# Figures 2, 3, 10, 11, 12: Multi-resolution grid of uncurated result images.\n\ndef draw_uncurated_result_figure(png, Gs, cx, cy, cw, ch, rows, lods, seed):\n    print(png)\n    latents = np.random.RandomState(seed).randn(sum(rows * 2**lod for lod in lods), Gs.input_shape[1])\n    images = Gs.run(latents, None, **synthesis_kwargs) # [seed, y, x, rgb]\n\n    canvas = PIL.Image.new('RGB', (sum(cw // 2**lod for lod in lods), ch * rows), 'white')\n    image_iter = iter(list(images))\n    for col, lod in enumerate(lods):\n        for row in range(rows * 2**lod):\n            image = PIL.Image.fromarray(next(image_iter), 'RGB')\n            image = image.crop((cx, cy, cx + cw, cy + ch))\n            image = image.resize((cw // 2**lod, ch // 2**lod), PIL.Image.ANTIALIAS)\n            canvas.paste(image, (sum(cw // 2**lod for lod in lods[:col]), row * ch // 2**lod))\n    canvas.save(png)\n\n#----------------------------------------------------------------------------\n# Figure 3: Style mixing.\n\ndef draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):\n    print(png)\n    src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)\n    dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)\n    src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]\n    dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]\n    src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)\n    dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)\n\n    canvas = PIL.Image.new('RGB', (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), 'white')\n    for col, src_image in enumerate(list(src_images)):\n        canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))\n    for row, dst_image in enumerate(list(dst_images)):\n        canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))\n        row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))\n        row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]\n        row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)\n        for col, image in enumerate(list(row_images)):\n            canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))\n    canvas.save(png)\n\n#----------------------------------------------------------------------------\n# Figure 4: Noise detail.\n\ndef draw_noise_detail_figure(png, Gs, w, h, num_samples, seeds):\n    print(png)\n    canvas = PIL.Image.new('RGB', (w * 3, h * len(seeds)), 'white')\n    for row, seed in enumerate(seeds):\n        latents = np.stack([np.random.RandomState(seed).randn(Gs.input_shape[1])] * num_samples)\n        images = Gs.run(latents, None, truncation_psi=1, **synthesis_kwargs)\n        canvas.paste(PIL.Image.fromarray(images[0], 'RGB'), (0, row * h))\n        for i in range(4):\n            crop = PIL.Image.fromarray(images[i + 1], 'RGB')\n            crop = crop.crop((650, 180, 906, 436))\n            crop = crop.resize((w//2, h//2), PIL.Image.NEAREST)\n            canvas.paste(crop, (w + (i%2) * w//2, row * h + (i//2) * h//2))\n        diff = np.std(np.mean(images, axis=3), axis=0) * 4\n        diff = np.clip(diff + 0.5, 0, 255).astype(np.uint8)\n        canvas.paste(PIL.Image.fromarray(diff, 'L'), (w * 2, row * h))\n    canvas.save(png)\n\n#----------------------------------------------------------------------------\n# Figure 5: Noise components.\n\ndef draw_noise_components_figure(png, Gs, w, h, seeds, noise_ranges, flips):\n    print(png)\n    Gsc = Gs.clone()\n    noise_vars = [var for name, var in Gsc.components.synthesis.vars.items() if name.startswith('noise')]\n    noise_pairs = list(zip(noise_vars, tflib.run(noise_vars))) # [(var, val), ...]\n    latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)\n    all_images = []\n    for noise_range in noise_ranges:\n        tflib.set_vars({var: val * (1 if i in noise_range else 0) for i, (var, val) in enumerate(noise_pairs)})\n        range_images = Gsc.run(latents, None, truncation_psi=1, randomize_noise=False, **synthesis_kwargs)\n        range_images[flips, :, :] = range_images[flips, :, ::-1]\n        all_images.append(list(range_images))\n\n    canvas = PIL.Image.new('RGB', (w * 2, h * 2), 'white')\n    for col, col_images in enumerate(zip(*all_images)):\n        canvas.paste(PIL.Image.fromarray(col_images[0], 'RGB').crop((0, 0, w//2, h)), (col * w, 0))\n        canvas.paste(PIL.Image.fromarray(col_images[1], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, 0))\n        canvas.paste(PIL.Image.fromarray(col_images[2], 'RGB').crop((0, 0, w//2, h)), (col * w, h))\n        canvas.paste(PIL.Image.fromarray(col_images[3], 'RGB').crop((w//2, 0, w, h)), (col * w + w//2, h))\n    canvas.save(png)\n\n#----------------------------------------------------------------------------\n# Figure 8: Truncation trick.\n\ndef draw_truncation_trick_figure(png, Gs, w, h, seeds, psis):\n    print(png)\n    latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in seeds)\n    dlatents = Gs.components.mapping.run(latents, None) # [seed, layer, component]\n    dlatent_avg = Gs.get_var('dlatent_avg') # [component]\n\n    canvas = PIL.Image.new('RGB', (w * len(psis), h * len(seeds)), 'white')\n    for row, dlatent in enumerate(list(dlatents)):\n        row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg\n        row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)\n        for col, image in enumerate(list(row_images)):\n            canvas.paste(PIL.Image.fromarray(image, 'RGB'), (col * w, row * h))\n    canvas.save(png)\n\n#----------------------------------------------------------------------------\n# Main program.\n\ndef main():\n    tflib.init_tf()\n    os.makedirs(config.result_dir, exist_ok=True)\n    draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure02-uncurated-ffhq.png'), load_Gs(url_ffhq), cx=0, cy=0, cw=1024, ch=1024, rows=3, lods=[0,1,2,2,3,3], seed=5)\n    draw_style_mixing_figure(os.path.join(config.result_dir, 'figure03-style-mixing.png'), load_Gs(url_ffhq), w=1024, h=1024, src_seeds=[639,701,687,615,2268], dst_seeds=[888,829,1898,1733,1614,845], style_ranges=[range(0,4)]*3+[range(4,8)]*2+[range(8,18)])\n    draw_noise_detail_figure(os.path.join(config.result_dir, 'figure04-noise-detail.png'), load_Gs(url_ffhq), w=1024, h=1024, num_samples=100, seeds=[1157,1012])\n    draw_noise_components_figure(os.path.join(config.result_dir, 'figure05-noise-components.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[1967,1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1])\n    draw_truncation_trick_figure(os.path.join(config.result_dir, 'figure08-truncation-trick.png'), load_Gs(url_ffhq), w=1024, h=1024, seeds=[91,388], psis=[1, 0.7, 0.5, 0, -0.5, -1])\n    draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure10-uncurated-bedrooms.png'), load_Gs(url_bedrooms), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=0)\n    draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure11-uncurated-cars.png'), load_Gs(url_cars), cx=0, cy=64, cw=512, ch=384, rows=4, lods=[0,1,2,2,3,3], seed=2)\n    draw_uncurated_result_figure(os.path.join(config.result_dir, 'figure12-uncurated-cats.png'), load_Gs(url_cats), cx=0, cy=0, cw=256, ch=256, rows=5, lods=[0,0,1,1,2,2,2], seed=1)\n\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    main()\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "metrics/__init__.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n# empty\n"
  },
  {
    "path": "metrics/frechet_inception_distance.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Frechet Inception Distance (FID).\"\"\"\n\nimport os\nimport numpy as np\nimport scipy\nimport tensorflow as tf\nimport dnnlib.tflib as tflib\n\nfrom metrics import metric_base\nfrom training import misc\n\n#----------------------------------------------------------------------------\n\nclass FID(metric_base.MetricBase):\n    def __init__(self, num_images, minibatch_per_gpu, **kwargs):\n        super().__init__(**kwargs)\n        self.num_images = num_images\n        self.minibatch_per_gpu = minibatch_per_gpu\n\n    def _evaluate(self, Gs, num_gpus):\n        minibatch_size = num_gpus * self.minibatch_per_gpu\n        inception = misc.load_pkl('https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn') # inception_v3_features.pkl\n        activations = np.empty([self.num_images, inception.output_shape[1]], dtype=np.float32)\n\n        # Calculate statistics for reals.\n        cache_file = self._get_cache_file_for_reals(num_images=self.num_images)\n        os.makedirs(os.path.dirname(cache_file), exist_ok=True)\n        if os.path.isfile(cache_file):\n            mu_real, sigma_real = misc.load_pkl(cache_file)\n        else:\n            for idx, images in enumerate(self._iterate_reals(minibatch_size=minibatch_size)):\n                begin = idx * minibatch_size\n                end = min(begin + minibatch_size, self.num_images)\n                activations[begin:end] = inception.run(images[:end-begin], num_gpus=num_gpus, assume_frozen=True)\n                if end == self.num_images:\n                    break\n            mu_real = np.mean(activations, axis=0)\n            sigma_real = np.cov(activations, rowvar=False)\n            misc.save_pkl((mu_real, sigma_real), cache_file)\n\n        # Construct TensorFlow graph.\n        result_expr = []\n        for gpu_idx in range(num_gpus):\n            with tf.device('/gpu:%d' % gpu_idx):\n                Gs_clone = Gs.clone()\n                inception_clone = inception.clone()\n                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])\n                images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)\n                images = tflib.convert_images_to_uint8(images)\n                result_expr.append(inception_clone.get_output_for(images))\n\n        # Calculate statistics for fakes.\n        for begin in range(0, self.num_images, minibatch_size):\n            end = min(begin + minibatch_size, self.num_images)\n            activations[begin:end] = np.concatenate(tflib.run(result_expr), axis=0)[:end-begin]\n        mu_fake = np.mean(activations, axis=0)\n        sigma_fake = np.cov(activations, rowvar=False)\n\n        # Calculate FID.\n        m = np.square(mu_fake - mu_real).sum()\n        s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, sigma_real), disp=False) # pylint: disable=no-member\n        dist = m + np.trace(sigma_fake + sigma_real - 2*s)\n        self._report_result(np.real(dist))\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "metrics/linear_separability.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Linear Separability (LS).\"\"\"\n\nfrom collections import defaultdict\nimport numpy as np\nimport sklearn.svm\nimport tensorflow as tf\nimport dnnlib.tflib as tflib\n\nfrom metrics import metric_base\nfrom training import misc\n\n#----------------------------------------------------------------------------\n\nclassifier_urls = [\n    'https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX', # celebahq-classifier-00-male.pkl\n    'https://drive.google.com/uc?id=1Q5c6HE__ReW2W8qYAXpao68V1ryuisGo', # celebahq-classifier-01-smiling.pkl\n    'https://drive.google.com/uc?id=1Q7738mgWTljPOJQrZtSMLxzShEhrvVsU', # celebahq-classifier-02-attractive.pkl\n    'https://drive.google.com/uc?id=1QBv2Mxe7ZLvOv1YBTLq-T4DS3HjmXV0o', # celebahq-classifier-03-wavy-hair.pkl\n    'https://drive.google.com/uc?id=1QIvKTrkYpUrdA45nf7pspwAqXDwWOLhV', # celebahq-classifier-04-young.pkl\n    'https://drive.google.com/uc?id=1QJPH5rW7MbIjFUdZT7vRYfyUjNYDl4_L', # celebahq-classifier-05-5-o-clock-shadow.pkl\n    'https://drive.google.com/uc?id=1QPZXSYf6cptQnApWS_T83sqFMun3rULY', # celebahq-classifier-06-arched-eyebrows.pkl\n    'https://drive.google.com/uc?id=1QPgoAZRqINXk_PFoQ6NwMmiJfxc5d2Pg', # celebahq-classifier-07-bags-under-eyes.pkl\n    'https://drive.google.com/uc?id=1QQPQgxgI6wrMWNyxFyTLSgMVZmRr1oO7', # celebahq-classifier-08-bald.pkl\n    'https://drive.google.com/uc?id=1QcSphAmV62UrCIqhMGgcIlZfoe8hfWaF', # celebahq-classifier-09-bangs.pkl\n    'https://drive.google.com/uc?id=1QdWTVwljClTFrrrcZnPuPOR4mEuz7jGh', # celebahq-classifier-10-big-lips.pkl\n    'https://drive.google.com/uc?id=1QgvEWEtr2mS4yj1b_Y3WKe6cLWL3LYmK', # celebahq-classifier-11-big-nose.pkl\n    'https://drive.google.com/uc?id=1QidfMk9FOKgmUUIziTCeo8t-kTGwcT18', # celebahq-classifier-12-black-hair.pkl\n    'https://drive.google.com/uc?id=1QthrJt-wY31GPtV8SbnZQZ0_UEdhasHO', # celebahq-classifier-13-blond-hair.pkl\n    'https://drive.google.com/uc?id=1QvCAkXxdYT4sIwCzYDnCL9Nb5TDYUxGW', # celebahq-classifier-14-blurry.pkl\n    'https://drive.google.com/uc?id=1QvLWuwSuWI9Ln8cpxSGHIciUsnmaw8L0', # celebahq-classifier-15-brown-hair.pkl\n    'https://drive.google.com/uc?id=1QxW6THPI2fqDoiFEMaV6pWWHhKI_OoA7', # celebahq-classifier-16-bushy-eyebrows.pkl\n    'https://drive.google.com/uc?id=1R71xKw8oTW2IHyqmRDChhTBkW9wq4N9v', # celebahq-classifier-17-chubby.pkl\n    'https://drive.google.com/uc?id=1RDn_fiLfEGbTc7JjazRXuAxJpr-4Pl67', # celebahq-classifier-18-double-chin.pkl\n    'https://drive.google.com/uc?id=1RGBuwXbaz5052bM4VFvaSJaqNvVM4_cI', # celebahq-classifier-19-eyeglasses.pkl\n    'https://drive.google.com/uc?id=1RIxOiWxDpUwhB-9HzDkbkLegkd7euRU9', # celebahq-classifier-20-goatee.pkl\n    'https://drive.google.com/uc?id=1RPaNiEnJODdr-fwXhUFdoSQLFFZC7rC-', # celebahq-classifier-21-gray-hair.pkl\n    'https://drive.google.com/uc?id=1RQH8lPSwOI2K_9XQCZ2Ktz7xm46o80ep', # celebahq-classifier-22-heavy-makeup.pkl\n    'https://drive.google.com/uc?id=1RXZM61xCzlwUZKq-X7QhxOg0D2telPow', # celebahq-classifier-23-high-cheekbones.pkl\n    'https://drive.google.com/uc?id=1RgASVHW8EWMyOCiRb5fsUijFu-HfxONM', # celebahq-classifier-24-mouth-slightly-open.pkl\n    'https://drive.google.com/uc?id=1RkC8JLqLosWMaRne3DARRgolhbtg_wnr', # celebahq-classifier-25-mustache.pkl\n    'https://drive.google.com/uc?id=1RqtbtFT2EuwpGTqsTYJDyXdnDsFCPtLO', # celebahq-classifier-26-narrow-eyes.pkl\n    'https://drive.google.com/uc?id=1Rs7hU-re8bBMeRHR-fKgMbjPh-RIbrsh', # celebahq-classifier-27-no-beard.pkl\n    'https://drive.google.com/uc?id=1RynDJQWdGOAGffmkPVCrLJqy_fciPF9E', # celebahq-classifier-28-oval-face.pkl\n    'https://drive.google.com/uc?id=1S0TZ_Hdv5cb06NDaCD8NqVfKy7MuXZsN', # celebahq-classifier-29-pale-skin.pkl\n    'https://drive.google.com/uc?id=1S3JPhZH2B4gVZZYCWkxoRP11q09PjCkA', # celebahq-classifier-30-pointy-nose.pkl\n    'https://drive.google.com/uc?id=1S3pQuUz-Jiywq_euhsfezWfGkfzLZ87W', # celebahq-classifier-31-receding-hairline.pkl\n    'https://drive.google.com/uc?id=1S6nyIl_SEI3M4l748xEdTV2vymB_-lrY', # celebahq-classifier-32-rosy-cheeks.pkl\n    'https://drive.google.com/uc?id=1S9P5WCi3GYIBPVYiPTWygrYIUSIKGxbU', # celebahq-classifier-33-sideburns.pkl\n    'https://drive.google.com/uc?id=1SANviG-pp08n7AFpE9wrARzozPIlbfCH', # celebahq-classifier-34-straight-hair.pkl\n    'https://drive.google.com/uc?id=1SArgyMl6_z7P7coAuArqUC2zbmckecEY', # celebahq-classifier-35-wearing-earrings.pkl\n    'https://drive.google.com/uc?id=1SC5JjS5J-J4zXFO9Vk2ZU2DT82TZUza_', # celebahq-classifier-36-wearing-hat.pkl\n    'https://drive.google.com/uc?id=1SDAQWz03HGiu0MSOKyn7gvrp3wdIGoj-', # celebahq-classifier-37-wearing-lipstick.pkl\n    'https://drive.google.com/uc?id=1SEtrVK-TQUC0XeGkBE9y7L8VXfbchyKX', # celebahq-classifier-38-wearing-necklace.pkl\n    'https://drive.google.com/uc?id=1SF_mJIdyGINXoV-I6IAxHB_k5dxiF6M-', # celebahq-classifier-39-wearing-necktie.pkl\n]\n\n#----------------------------------------------------------------------------\n\ndef prob_normalize(p):\n    p = np.asarray(p).astype(np.float32)\n    assert len(p.shape) == 2\n    return p / np.sum(p)\n\ndef mutual_information(p):\n    p = prob_normalize(p)\n    px = np.sum(p, axis=1)\n    py = np.sum(p, axis=0)\n    result = 0.0\n    for x in range(p.shape[0]):\n        p_x = px[x]\n        for y in range(p.shape[1]):\n            p_xy = p[x][y]\n            p_y = py[y]\n            if p_xy > 0.0:\n                result += p_xy * np.log2(p_xy / (p_x * p_y)) # get bits as output\n    return result\n\ndef entropy(p):\n    p = prob_normalize(p)\n    result = 0.0\n    for x in range(p.shape[0]):\n        for y in range(p.shape[1]):\n            p_xy = p[x][y]\n            if p_xy > 0.0:\n                result -= p_xy * np.log2(p_xy)\n    return result\n\ndef conditional_entropy(p):\n    # H(Y|X) where X corresponds to axis 0, Y to axis 1\n    # i.e., How many bits of additional information are needed to where we are on axis 1 if we know where we are on axis 0?\n    p = prob_normalize(p)\n    y = np.sum(p, axis=0, keepdims=True) # marginalize to calculate H(Y)\n    return max(0.0, entropy(y) - mutual_information(p)) # can slip just below 0 due to FP inaccuracies, clean those up.\n\n#----------------------------------------------------------------------------\n\nclass LS(metric_base.MetricBase):\n    def __init__(self, num_samples, num_keep, attrib_indices, minibatch_per_gpu, **kwargs):\n        assert num_keep <= num_samples\n        super().__init__(**kwargs)\n        self.num_samples = num_samples\n        self.num_keep = num_keep\n        self.attrib_indices = attrib_indices\n        self.minibatch_per_gpu = minibatch_per_gpu\n\n    def _evaluate(self, Gs, num_gpus):\n        minibatch_size = num_gpus * self.minibatch_per_gpu\n\n        # Construct TensorFlow graph for each GPU.\n        result_expr = []\n        for gpu_idx in range(num_gpus):\n            with tf.device('/gpu:%d' % gpu_idx):\n                Gs_clone = Gs.clone()\n\n                # Generate images.\n                latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])\n                dlatents = Gs_clone.components.mapping.get_output_for(latents, None, is_validation=True)\n                images = Gs_clone.components.synthesis.get_output_for(dlatents, is_validation=True, randomize_noise=True)\n\n                # Downsample to 256x256. The attribute classifiers were built for 256x256.\n                if images.shape[2] > 256:\n                    factor = images.shape[2] // 256\n                    images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])\n                    images = tf.reduce_mean(images, axis=[3, 5])\n\n                # Run classifier for each attribute.\n                result_dict = dict(latents=latents, dlatents=dlatents[:,-1])\n                for attrib_idx in self.attrib_indices:\n                    classifier = misc.load_pkl(classifier_urls[attrib_idx])\n                    logits = classifier.get_output_for(images, None)\n                    predictions = tf.nn.softmax(tf.concat([logits, -logits], axis=1))\n                    result_dict[attrib_idx] = predictions\n                result_expr.append(result_dict)\n\n        # Sampling loop.\n        results = []\n        for _ in range(0, self.num_samples, minibatch_size):\n            results += tflib.run(result_expr)\n        results = {key: np.concatenate([value[key] for value in results], axis=0) for key in results[0].keys()}\n\n        # Calculate conditional entropy for each attribute.\n        conditional_entropies = defaultdict(list)\n        for attrib_idx in self.attrib_indices:\n            # Prune the least confident samples.\n            pruned_indices = list(range(self.num_samples))\n            pruned_indices = sorted(pruned_indices, key=lambda i: -np.max(results[attrib_idx][i]))\n            pruned_indices = pruned_indices[:self.num_keep]\n\n            # Fit SVM to the remaining samples.\n            svm_targets = np.argmax(results[attrib_idx][pruned_indices], axis=1)\n            for space in ['latents', 'dlatents']:\n                svm_inputs = results[space][pruned_indices]\n                try:\n                    svm = sklearn.svm.LinearSVC()\n                    svm.fit(svm_inputs, svm_targets)\n                    svm.score(svm_inputs, svm_targets)\n                    svm_outputs = svm.predict(svm_inputs)\n                except:\n                    svm_outputs = svm_targets # assume perfect prediction\n\n                # Calculate conditional entropy.\n                p = [[np.mean([case == (row, col) for case in zip(svm_outputs, svm_targets)]) for col in (0, 1)] for row in (0, 1)]\n                conditional_entropies[space].append(conditional_entropy(p))\n\n        # Calculate separability scores.\n        scores = {key: 2**np.sum(values) for key, values in conditional_entropies.items()}\n        self._report_result(scores['latents'], suffix='_z')\n        self._report_result(scores['dlatents'], suffix='_w')\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "metrics/metric_base.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Common definitions for GAN metrics.\"\"\"\n\nimport os\nimport time\nimport hashlib\nimport numpy as np\nimport tensorflow as tf\nimport dnnlib\nimport dnnlib.tflib as tflib\n\nimport config\nfrom training import misc\nfrom training import dataset\n\n#----------------------------------------------------------------------------\n# Standard metrics.\n\nfid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8)\nppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16)\nppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16)\nppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16)\nppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16)\nls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4)\ndummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging\n\n#----------------------------------------------------------------------------\n# Base class for metrics.\n\nclass MetricBase:\n    def __init__(self, name):\n        self.name = name\n        self._network_pkl = None\n        self._dataset_args = None\n        self._mirror_augment = None\n        self._results = []\n        self._eval_time = None\n\n    def run(self, network_pkl, run_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True):\n        self._network_pkl = network_pkl\n        self._dataset_args = dataset_args\n        self._mirror_augment = mirror_augment\n        self._results = []\n\n        if (dataset_args is None or mirror_augment is None) and run_dir is not None:\n            run_config = misc.parse_config_for_previous_run(run_dir)\n            self._dataset_args = dict(run_config['dataset'])\n            self._dataset_args['shuffle_mb'] = 0\n            self._mirror_augment = run_config['train'].get('mirror_augment', False)\n\n        time_begin = time.time()\n        with tf.Graph().as_default(), tflib.create_session(tf_config).as_default(): # pylint: disable=not-context-manager\n            _G, _D, Gs = misc.load_pkl(self._network_pkl)\n            self._evaluate(Gs, num_gpus=num_gpus)\n        self._eval_time = time.time() - time_begin\n\n        if log_results:\n            result_str = self.get_result_str()\n            if run_dir is not None:\n                log = os.path.join(run_dir, 'metric-%s.txt' % self.name)\n                with dnnlib.util.Logger(log, 'a'):\n                    print(result_str)\n            else:\n                print(result_str)\n\n    def get_result_str(self):\n        network_name = os.path.splitext(os.path.basename(self._network_pkl))[0]\n        if len(network_name) > 29:\n            network_name = '...' + network_name[-26:]\n        result_str = '%-30s' % network_name\n        result_str += ' time %-12s' % dnnlib.util.format_time(self._eval_time)\n        for res in self._results:\n            result_str += ' ' + self.name + res.suffix + ' '\n            result_str += res.fmt % res.value\n        return result_str\n\n    def update_autosummaries(self):\n        for res in self._results:\n            tflib.autosummary.autosummary('Metrics/' + self.name + res.suffix, res.value)\n\n    def _evaluate(self, Gs, num_gpus):\n        raise NotImplementedError # to be overridden by subclasses\n\n    def _report_result(self, value, suffix='', fmt='%-10.4f'):\n        self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]\n\n    def _get_cache_file_for_reals(self, extension='pkl', **kwargs):\n        all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment)\n        all_args.update(self._dataset_args)\n        all_args.update(kwargs)\n        md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))\n        dataset_name = self._dataset_args['tfrecord_dir'].replace('\\\\', '/').split('/')[-1]\n        return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension))\n\n    def _iterate_reals(self, minibatch_size):\n        dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **self._dataset_args)\n        while True:\n            images, _labels = dataset_obj.get_minibatch_np(minibatch_size)\n            if self._mirror_augment:\n                images = misc.apply_mirror_augment(images)\n            yield images\n\n    def _iterate_fakes(self, Gs, minibatch_size, num_gpus):\n        while True:\n            latents = np.random.randn(minibatch_size, *Gs.input_shape[1:])\n            fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)\n            images = Gs.run(latents, None, output_transform=fmt, is_validation=True, num_gpus=num_gpus, assume_frozen=True)\n            yield images\n\n#----------------------------------------------------------------------------\n# Group of multiple metrics.\n\nclass MetricGroup:\n    def __init__(self, metric_kwarg_list):\n        self.metrics = [dnnlib.util.call_func_by_name(**kwargs) for kwargs in metric_kwarg_list]\n\n    def run(self, *args, **kwargs):\n        for metric in self.metrics:\n            metric.run(*args, **kwargs)\n\n    def get_result_str(self):\n        return ' '.join(metric.get_result_str() for metric in self.metrics)\n\n    def update_autosummaries(self):\n        for metric in self.metrics:\n            metric.update_autosummaries()\n\n#----------------------------------------------------------------------------\n# Dummy metric for debugging purposes.\n\nclass DummyMetric(MetricBase):\n    def _evaluate(self, Gs, num_gpus):\n        _ = Gs, num_gpus\n        self._report_result(0.0)\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "metrics/perceptual_path_length.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Perceptual Path Length (PPL).\"\"\"\n\nimport numpy as np\nimport tensorflow as tf\nimport dnnlib.tflib as tflib\n\nfrom metrics import metric_base\nfrom training import misc\n\n#----------------------------------------------------------------------------\n\n# Normalize batch of vectors.\ndef normalize(v):\n    return v / tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))\n\n# Spherical interpolation of a batch of vectors.\ndef slerp(a, b, t):\n    a = normalize(a)\n    b = normalize(b)\n    d = tf.reduce_sum(a * b, axis=-1, keepdims=True)\n    p = t * tf.math.acos(d)\n    c = normalize(b - d * a)\n    d = a * tf.math.cos(p) + c * tf.math.sin(p)\n    return normalize(d)\n\n#----------------------------------------------------------------------------\n\nclass PPL(metric_base.MetricBase):\n    def __init__(self, num_samples, epsilon, space, sampling, minibatch_per_gpu, **kwargs):\n        assert space in ['z', 'w']\n        assert sampling in ['full', 'end']\n        super().__init__(**kwargs)\n        self.num_samples = num_samples\n        self.epsilon = epsilon\n        self.space = space\n        self.sampling = sampling\n        self.minibatch_per_gpu = minibatch_per_gpu\n\n    def _evaluate(self, Gs, num_gpus):\n        minibatch_size = num_gpus * self.minibatch_per_gpu\n\n        # Construct TensorFlow graph.\n        distance_expr = []\n        for gpu_idx in range(num_gpus):\n            with tf.device('/gpu:%d' % gpu_idx):\n                Gs_clone = Gs.clone()\n                noise_vars = [var for name, var in Gs_clone.components.synthesis.vars.items() if name.startswith('noise')]\n\n                # Generate random latents and interpolation t-values.\n                lat_t01 = tf.random_normal([self.minibatch_per_gpu * 2] + Gs_clone.input_shape[1:])\n                lerp_t = tf.random_uniform([self.minibatch_per_gpu], 0.0, 1.0 if self.sampling == 'full' else 0.0)\n\n                # Interpolate in W or Z.\n                if self.space == 'w':\n                    dlat_t01 = Gs_clone.components.mapping.get_output_for(lat_t01, None, is_validation=True)\n                    dlat_t0, dlat_t1 = dlat_t01[0::2], dlat_t01[1::2]\n                    dlat_e0 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis])\n                    dlat_e1 = tflib.lerp(dlat_t0, dlat_t1, lerp_t[:, np.newaxis, np.newaxis] + self.epsilon)\n                    dlat_e01 = tf.reshape(tf.stack([dlat_e0, dlat_e1], axis=1), dlat_t01.shape)\n                else: # space == 'z'\n                    lat_t0, lat_t1 = lat_t01[0::2], lat_t01[1::2]\n                    lat_e0 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis])\n                    lat_e1 = slerp(lat_t0, lat_t1, lerp_t[:, np.newaxis] + self.epsilon)\n                    lat_e01 = tf.reshape(tf.stack([lat_e0, lat_e1], axis=1), lat_t01.shape)\n                    dlat_e01 = Gs_clone.components.mapping.get_output_for(lat_e01, None, is_validation=True)\n\n                # Synthesize images.\n                with tf.control_dependencies([var.initializer for var in noise_vars]): # use same noise inputs for the entire minibatch\n                    images = Gs_clone.components.synthesis.get_output_for(dlat_e01, is_validation=True, randomize_noise=False)\n\n                # Crop only the face region.\n                c = int(images.shape[2] // 8)\n                images = images[:, :, c*3 : c*7, c*2 : c*6]\n\n                # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.\n                if images.shape[2] > 256:\n                    factor = images.shape[2] // 256\n                    images = tf.reshape(images, [-1, images.shape[1], images.shape[2] // factor, factor, images.shape[3] // factor, factor])\n                    images = tf.reduce_mean(images, axis=[3,5])\n\n                # Scale dynamic range from [-1,1] to [0,255] for VGG.\n                images = (images + 1) * (255 / 2)\n\n                # Evaluate perceptual distance.\n                img_e0, img_e1 = images[0::2], images[1::2]\n                distance_measure = misc.load_pkl('https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2') # vgg16_zhang_perceptual.pkl\n                distance_expr.append(distance_measure.get_output_for(img_e0, img_e1) * (1 / self.epsilon**2))\n\n        # Sampling loop.\n        all_distances = []\n        for _ in range(0, self.num_samples, minibatch_size):\n            all_distances += tflib.run(distance_expr)\n        all_distances = np.concatenate(all_distances, axis=0)\n\n        # Reject outliers.\n        lo = np.percentile(all_distances, 1, interpolation='lower')\n        hi = np.percentile(all_distances, 99, interpolation='higher')\n        filtered_distances = np.extract(np.logical_and(lo <= all_distances, all_distances <= hi), all_distances)\n        self._report_result(np.mean(filtered_distances))\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "pretrained_example.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Minimal script for generating an image using pre-trained StyleGAN generator.\"\"\"\n\nimport os\nimport pickle\nimport numpy as np\nimport PIL.Image\nimport dnnlib\nimport dnnlib.tflib as tflib\nimport config\n\ndef main():\n    # Initialize TensorFlow.\n    tflib.init_tf()\n\n    # Load pre-trained network.\n    url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl\n    with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:\n        _G, _D, Gs = pickle.load(f)\n        # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.\n        # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.\n        # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.\n\n    # Print network details.\n    Gs.print_layers()\n\n    # Pick latent vector.\n    rnd = np.random.RandomState(5)\n    latents = rnd.randn(1, Gs.input_shape[1])\n\n    # Generate image.\n    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)\n    images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)\n\n    # Save image.\n    os.makedirs(config.result_dir, exist_ok=True)\n    png_filename = os.path.join(config.result_dir, 'example.png')\n    PIL.Image.fromarray(images[0], 'RGB').save(png_filename)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "run_metrics.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Main entry point for training StyleGAN and ProGAN networks.\"\"\"\n\nimport dnnlib\nfrom dnnlib import EasyDict\nimport dnnlib.tflib as tflib\n\nimport config\nfrom metrics import metric_base\nfrom training import misc\n\n#----------------------------------------------------------------------------\n\ndef run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment):\n    ctx = dnnlib.RunContext(submit_config)\n    tflib.init_tf()\n    print('Evaluating %s metric on network_pkl \"%s\"...' % (metric_args.name, network_pkl))\n    metric = dnnlib.util.call_func_by_name(**metric_args)\n    print()\n    metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus)\n    print()\n    ctx.close()\n\n#----------------------------------------------------------------------------\n\ndef run_snapshot(submit_config, metric_args, run_id, snapshot):\n    ctx = dnnlib.RunContext(submit_config)\n    tflib.init_tf()\n    print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot))\n    run_dir = misc.locate_run_dir(run_id)\n    network_pkl = misc.locate_network_pkl(run_dir, snapshot)\n    metric = dnnlib.util.call_func_by_name(**metric_args)\n    print()\n    metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)\n    print()\n    ctx.close()\n\n#----------------------------------------------------------------------------\n\ndef run_all_snapshots(submit_config, metric_args, run_id):\n    ctx = dnnlib.RunContext(submit_config)\n    tflib.init_tf()\n    print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id))\n    run_dir = misc.locate_run_dir(run_id)\n    network_pkls = misc.list_network_pkls(run_dir)\n    metric = dnnlib.util.call_func_by_name(**metric_args)\n    print()\n    for idx, network_pkl in enumerate(network_pkls):\n        ctx.update('', idx, len(network_pkls))\n        metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)\n    print()\n    ctx.close()\n\n#----------------------------------------------------------------------------\n\ndef main():\n    submit_config = dnnlib.SubmitConfig()\n\n    # Which metrics to evaluate?\n    metrics = []\n    metrics += [metric_base.fid50k]\n    #metrics += [metric_base.ppl_zfull]\n    #metrics += [metric_base.ppl_wfull]\n    #metrics += [metric_base.ppl_zend]\n    #metrics += [metric_base.ppl_wend]\n    #metrics += [metric_base.ls]\n    #metrics += [metric_base.dummy]\n\n    # Which networks to evaluate them on?\n    tasks = []\n    tasks += [EasyDict(run_func_name='run_metrics.run_pickle', network_pkl='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', dataset_args=EasyDict(tfrecord_dir='ffhq', shuffle_mb=0), mirror_augment=True)] # karras2019stylegan-ffhq-1024x1024.pkl\n    #tasks += [EasyDict(run_func_name='run_metrics.run_snapshot', run_id=100, snapshot=25000)]\n    #tasks += [EasyDict(run_func_name='run_metrics.run_all_snapshots', run_id=100)]\n\n    # How many GPUs to use?\n    submit_config.num_gpus = 1\n    #submit_config.num_gpus = 2\n    #submit_config.num_gpus = 4\n    #submit_config.num_gpus = 8\n\n    # Execute.\n    submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir)\n    submit_config.run_dir_ignore += config.run_dir_ignore\n    for task in tasks:\n        for metric in metrics:\n            submit_config.run_desc = '%s-%s' % (task.run_func_name, metric.name)\n            if task.run_func_name.endswith('run_snapshot'):\n                submit_config.run_desc += '-%s-%s' % (task.run_id, task.snapshot)\n            if task.run_func_name.endswith('run_all_snapshots'):\n                submit_config.run_desc += '-%s' % task.run_id\n            submit_config.run_desc += '-%dgpu' % submit_config.num_gpus\n            dnnlib.submit_run(submit_config, metric_args=metric, **task)\n\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    main()\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "train.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Main entry point for training StyleGAN and ProGAN networks.\"\"\"\n\nimport copy\nimport dnnlib\nfrom dnnlib import EasyDict\n\nimport config\nfrom metrics import metric_base\n\n#----------------------------------------------------------------------------\n# Official training configs for StyleGAN, targeted mainly for FFHQ.\n\nif 1:\n    desc          = 'sgan'                                                                 # Description string included in result subdir name.\n    train         = EasyDict(run_func_name='training.training_loop.training_loop')         # Options for training loop.\n    G             = EasyDict(func_name='training.networks_stylegan.G_style')               # Options for generator network.\n    D             = EasyDict(func_name='training.networks_stylegan.D_basic')               # Options for discriminator network.\n    G_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                          # Options for generator optimizer.\n    D_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                          # Options for discriminator optimizer.\n    G_loss        = EasyDict(func_name='training.loss.G_logistic_nonsaturating')           # Options for generator loss.\n    D_loss        = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # Options for discriminator loss.\n    dataset       = EasyDict()                                                             # Options for load_dataset().\n    sched         = EasyDict()                                                             # Options for TrainingSchedule.\n    grid          = EasyDict(size='4k', layout='random')                                   # Options for setup_snapshot_image_grid().\n    metrics       = [metric_base.fid50k]                                                   # Options for MetricGroup.\n    submit_config = dnnlib.SubmitConfig()                                                  # Options for dnnlib.submit_run().\n    tf_config     = {'rnd.np_random_seed': 1000}                                           # Options for tflib.init_tf().\n\n    # Dataset.\n    desc += '-ffhq';     dataset = EasyDict(tfrecord_dir='ffhq');                 train.mirror_augment = True\n    #desc += '-ffhq512';  dataset = EasyDict(tfrecord_dir='ffhq', resolution=512); train.mirror_augment = True\n    #desc += '-ffhq256';  dataset = EasyDict(tfrecord_dir='ffhq', resolution=256); train.mirror_augment = True\n    #desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq');             train.mirror_augment = True\n    #desc += '-bedroom';  dataset = EasyDict(tfrecord_dir='lsun-bedroom-full');    train.mirror_augment = False\n    #desc += '-car';      dataset = EasyDict(tfrecord_dir='lsun-car-512x384');     train.mirror_augment = False\n    #desc += '-cat';      dataset = EasyDict(tfrecord_dir='lsun-cat-full');        train.mirror_augment = False\n\n    # Number of GPUs.\n    #desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}\n    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}\n    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}\n    desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}\n\n    # Default options.\n    train.total_kimg = 25000\n    sched.lod_initial_resolution = 8\n    sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}\n    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)\n\n    # WGAN-GP loss for CelebA-HQ.\n    #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)\n\n    # Table 1.\n    #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False\n    #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False\n    #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False\n    #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0\n    #desc += '-mixing-regularization' # default\n\n    # Table 2.\n    #desc += '-mix0'; G.style_mixing_prob = 0.0\n    #desc += '-mix50'; G.style_mixing_prob = 0.5\n    #desc += '-mix90'; G.style_mixing_prob = 0.9 # default\n    #desc += '-mix100'; G.style_mixing_prob = 1.0\n\n    # Table 4.\n    #desc += '-traditional-0'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False\n    #desc += '-traditional-8'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 8; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False\n    #desc += '-stylebased-0'; G.mapping_layers = 0\n    #desc += '-stylebased-1'; G.mapping_layers = 1\n    #desc += '-stylebased-2'; G.mapping_layers = 2\n    #desc += '-stylebased-8'; G.mapping_layers = 8 # default\n\n#----------------------------------------------------------------------------\n# Official training configs for Progressive GAN, targeted mainly for CelebA-HQ.\n\nif 0:\n    desc          = 'pgan'                                                         # Description string included in result subdir name.\n    train         = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.\n    G             = EasyDict(func_name='training.networks_progan.G_paper')         # Options for generator network.\n    D             = EasyDict(func_name='training.networks_progan.D_paper')         # Options for discriminator network.\n    G_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for generator optimizer.\n    D_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for discriminator optimizer.\n    G_loss        = EasyDict(func_name='training.loss.G_wgan')                     # Options for generator loss.\n    D_loss        = EasyDict(func_name='training.loss.D_wgan_gp')                  # Options for discriminator loss.\n    dataset       = EasyDict()                                                     # Options for load_dataset().\n    sched         = EasyDict()                                                     # Options for TrainingSchedule.\n    grid          = EasyDict(size='1080p', layout='random')                        # Options for setup_snapshot_image_grid().\n    metrics       = [metric_base.fid50k]                                           # Options for MetricGroup.\n    submit_config = dnnlib.SubmitConfig()                                          # Options for dnnlib.submit_run().\n    tf_config     = {'rnd.np_random_seed': 1000}                                   # Options for tflib.init_tf().\n\n    # Dataset (choose one).\n    desc += '-celebahq';            dataset = EasyDict(tfrecord_dir='celebahq'); train.mirror_augment = True\n    #desc += '-celeba';              dataset = EasyDict(tfrecord_dir='celeba'); train.mirror_augment = True\n    #desc += '-cifar10';             dataset = EasyDict(tfrecord_dir='cifar10')\n    #desc += '-cifar100';            dataset = EasyDict(tfrecord_dir='cifar100')\n    #desc += '-svhn';                dataset = EasyDict(tfrecord_dir='svhn')\n    #desc += '-mnist';               dataset = EasyDict(tfrecord_dir='mnist')\n    #desc += '-mnistrgb';            dataset = EasyDict(tfrecord_dir='mnistrgb')\n    #desc += '-syn1024rgb';          dataset = EasyDict(class_name='training.dataset.SyntheticDataset', resolution=1024, num_channels=3)\n    #desc += '-lsun-airplane';       dataset = EasyDict(tfrecord_dir='lsun-airplane-100k');       train.mirror_augment = True\n    #desc += '-lsun-bedroom';        dataset = EasyDict(tfrecord_dir='lsun-bedroom-100k');        train.mirror_augment = True\n    #desc += '-lsun-bicycle';        dataset = EasyDict(tfrecord_dir='lsun-bicycle-100k');        train.mirror_augment = True\n    #desc += '-lsun-bird';           dataset = EasyDict(tfrecord_dir='lsun-bird-100k');           train.mirror_augment = True\n    #desc += '-lsun-boat';           dataset = EasyDict(tfrecord_dir='lsun-boat-100k');           train.mirror_augment = True\n    #desc += '-lsun-bottle';         dataset = EasyDict(tfrecord_dir='lsun-bottle-100k');         train.mirror_augment = True\n    #desc += '-lsun-bridge';         dataset = EasyDict(tfrecord_dir='lsun-bridge-100k');         train.mirror_augment = True\n    #desc += '-lsun-bus';            dataset = EasyDict(tfrecord_dir='lsun-bus-100k');            train.mirror_augment = True\n    #desc += '-lsun-car';            dataset = EasyDict(tfrecord_dir='lsun-car-100k');            train.mirror_augment = True\n    #desc += '-lsun-cat';            dataset = EasyDict(tfrecord_dir='lsun-cat-100k');            train.mirror_augment = True\n    #desc += '-lsun-chair';          dataset = EasyDict(tfrecord_dir='lsun-chair-100k');          train.mirror_augment = True\n    #desc += '-lsun-churchoutdoor';  dataset = EasyDict(tfrecord_dir='lsun-churchoutdoor-100k');  train.mirror_augment = True\n    #desc += '-lsun-classroom';      dataset = EasyDict(tfrecord_dir='lsun-classroom-100k');      train.mirror_augment = True\n    #desc += '-lsun-conferenceroom'; dataset = EasyDict(tfrecord_dir='lsun-conferenceroom-100k'); train.mirror_augment = True\n    #desc += '-lsun-cow';            dataset = EasyDict(tfrecord_dir='lsun-cow-100k');            train.mirror_augment = True\n    #desc += '-lsun-diningroom';     dataset = EasyDict(tfrecord_dir='lsun-diningroom-100k');     train.mirror_augment = True\n    #desc += '-lsun-diningtable';    dataset = EasyDict(tfrecord_dir='lsun-diningtable-100k');    train.mirror_augment = True\n    #desc += '-lsun-dog';            dataset = EasyDict(tfrecord_dir='lsun-dog-100k');            train.mirror_augment = True\n    #desc += '-lsun-horse';          dataset = EasyDict(tfrecord_dir='lsun-horse-100k');          train.mirror_augment = True\n    #desc += '-lsun-kitchen';        dataset = EasyDict(tfrecord_dir='lsun-kitchen-100k');        train.mirror_augment = True\n    #desc += '-lsun-livingroom';     dataset = EasyDict(tfrecord_dir='lsun-livingroom-100k');     train.mirror_augment = True\n    #desc += '-lsun-motorbike';      dataset = EasyDict(tfrecord_dir='lsun-motorbike-100k');      train.mirror_augment = True\n    #desc += '-lsun-person';         dataset = EasyDict(tfrecord_dir='lsun-person-100k');         train.mirror_augment = True\n    #desc += '-lsun-pottedplant';    dataset = EasyDict(tfrecord_dir='lsun-pottedplant-100k');    train.mirror_augment = True\n    #desc += '-lsun-restaurant';     dataset = EasyDict(tfrecord_dir='lsun-restaurant-100k');     train.mirror_augment = True\n    #desc += '-lsun-sheep';          dataset = EasyDict(tfrecord_dir='lsun-sheep-100k');          train.mirror_augment = True\n    #desc += '-lsun-sofa';           dataset = EasyDict(tfrecord_dir='lsun-sofa-100k');           train.mirror_augment = True\n    #desc += '-lsun-tower';          dataset = EasyDict(tfrecord_dir='lsun-tower-100k');          train.mirror_augment = True\n    #desc += '-lsun-train';          dataset = EasyDict(tfrecord_dir='lsun-train-100k');          train.mirror_augment = True\n    #desc += '-lsun-tvmonitor';      dataset = EasyDict(tfrecord_dir='lsun-tvmonitor-100k');      train.mirror_augment = True\n\n    # Conditioning & snapshot options.\n    #desc += '-cond'; dataset.max_label_size = 'full' # conditioned on full label\n    #desc += '-cond1'; dataset.max_label_size = 1 # conditioned on first component of the label\n    #desc += '-g4k'; grid.size = '4k'\n    #desc += '-grpc'; grid.layout = 'row_per_class'\n\n    # Config presets (choose one).\n    #desc += '-preset-v1-1gpu'; submit_config.num_gpus = 1; D.mbstd_group_size = 16; sched.minibatch_base = 16; sched.minibatch_dict = {256: 14, 512: 6, 1024: 3}; sched.lod_training_kimg = 800; sched.lod_transition_kimg = 800; train.total_kimg = 19000\n    desc += '-preset-v2-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.G_lrate_dict = {1024: 0.0015}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000\n    #desc += '-preset-v2-2gpus'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}; sched.G_lrate_dict = {512: 0.0015, 1024: 0.002}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000\n    #desc += '-preset-v2-4gpus'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}; sched.G_lrate_dict = {256: 0.0015, 512: 0.002, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000\n    #desc += '-preset-v2-8gpus'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}; sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000\n\n    # Numerical precision (choose one).\n    desc += '-fp32'; sched.max_minibatch_per_gpu = {256: 16, 512: 8, 1024: 4}\n    #desc += '-fp16'; G.dtype = 'float16'; D.dtype = 'float16'; G.pixelnorm_epsilon=1e-4; G_opt.use_loss_scaling = True; D_opt.use_loss_scaling = True; sched.max_minibatch_per_gpu = {512: 16, 1024: 8}\n\n    # Disable individual features.\n    #desc += '-nogrowing'; sched.lod_initial_resolution = 1024; sched.lod_training_kimg = 0; sched.lod_transition_kimg = 0; train.total_kimg = 10000\n    #desc += '-nopixelnorm'; G.use_pixelnorm = False\n    #desc += '-nowscale'; G.use_wscale = False; D.use_wscale = False\n    #desc += '-noleakyrelu'; G.use_leakyrelu = False\n    #desc += '-nosmoothing'; train.G_smoothing_kimg = 0.0\n    #desc += '-norepeat'; train.minibatch_repeats = 1\n    #desc += '-noreset'; train.reset_opt_for_new_lod = False\n\n    # Special modes.\n    #desc += '-BENCHMARK'; sched.lod_initial_resolution = 4; sched.lod_training_kimg = 3; sched.lod_transition_kimg = 3; train.total_kimg = (8*2+1)*3; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000\n    #desc += '-BENCHMARK0'; sched.lod_initial_resolution = 1024; train.total_kimg = 10; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000\n    #desc += '-VERBOSE'; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1; train.network_snapshot_ticks = 100\n    #desc += '-GRAPH'; train.save_tf_graph = True\n    #desc += '-HIST'; train.save_weight_histograms = True\n\n#----------------------------------------------------------------------------\n# Main entry point for training.\n# Calls the function indicated by 'train' using the selected options.\n\ndef main():\n    kwargs = EasyDict(train)\n    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)\n    kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)\n    kwargs.submit_config = copy.deepcopy(submit_config)\n    kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir)\n    kwargs.submit_config.run_dir_ignore += config.run_dir_ignore\n    kwargs.submit_config.run_desc = desc\n    dnnlib.submit_run(**kwargs)\n\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    main()\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "training/__init__.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n# empty\n"
  },
  {
    "path": "training/dataset.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Multi-resolution input data pipeline.\"\"\"\n\nimport os\nimport glob\nimport numpy as np\nimport tensorflow as tf\nimport dnnlib\nimport dnnlib.tflib as tflib\n\n#----------------------------------------------------------------------------\n# Parse individual image from a tfrecords file.\n\ndef parse_tfrecord_tf(record):\n    features = tf.parse_single_example(record, features={\n        'shape': tf.FixedLenFeature([3], tf.int64),\n        'data': tf.FixedLenFeature([], tf.string)})\n    data = tf.decode_raw(features['data'], tf.uint8)\n    return tf.reshape(data, features['shape'])\n\ndef parse_tfrecord_np(record):\n    ex = tf.train.Example()\n    ex.ParseFromString(record)\n    shape = ex.features.feature['shape'].int64_list.value # temporary pylint workaround # pylint: disable=no-member\n    data = ex.features.feature['data'].bytes_list.value[0] # temporary pylint workaround # pylint: disable=no-member\n    return np.fromstring(data, np.uint8).reshape(shape)\n\n#----------------------------------------------------------------------------\n# Dataset class that loads data from tfrecords files.\n\nclass TFRecordDataset:\n    def __init__(self,\n        tfrecord_dir,               # Directory containing a collection of tfrecords files.\n        resolution      = None,     # Dataset resolution, None = autodetect.\n        label_file      = None,     # Relative path of the labels file, None = autodetect.\n        max_label_size  = 0,        # 0 = no labels, 'full' = full labels, <int> = N first label components.\n        repeat          = True,     # Repeat dataset indefinitely.\n        shuffle_mb      = 4096,     # Shuffle data within specified window (megabytes), 0 = disable shuffling.\n        prefetch_mb     = 2048,     # Amount of data to prefetch (megabytes), 0 = disable prefetching.\n        buffer_mb       = 256,      # Read buffer size (megabytes).\n        num_threads     = 2):       # Number of concurrent threads.\n\n        self.tfrecord_dir       = tfrecord_dir\n        self.resolution         = None\n        self.resolution_log2    = None\n        self.shape              = []        # [channel, height, width]\n        self.dtype              = 'uint8'\n        self.dynamic_range      = [0, 255]\n        self.label_file         = label_file\n        self.label_size         = None      # [component]\n        self.label_dtype        = None\n        self._np_labels         = None\n        self._tf_minibatch_in   = None\n        self._tf_labels_var     = None\n        self._tf_labels_dataset = None\n        self._tf_datasets       = dict()\n        self._tf_iterator       = None\n        self._tf_init_ops       = dict()\n        self._tf_minibatch_np   = None\n        self._cur_minibatch     = -1\n        self._cur_lod           = -1\n\n        # List tfrecords files and inspect their shapes.\n        assert os.path.isdir(self.tfrecord_dir)\n        tfr_files = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.tfrecords')))\n        assert len(tfr_files) >= 1\n        tfr_shapes = []\n        for tfr_file in tfr_files:\n            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)\n            for record in tf.python_io.tf_record_iterator(tfr_file, tfr_opt):\n                tfr_shapes.append(parse_tfrecord_np(record).shape)\n                break\n\n        # Autodetect label filename.\n        if self.label_file is None:\n            guess = sorted(glob.glob(os.path.join(self.tfrecord_dir, '*.labels')))\n            if len(guess):\n                self.label_file = guess[0]\n        elif not os.path.isfile(self.label_file):\n            guess = os.path.join(self.tfrecord_dir, self.label_file)\n            if os.path.isfile(guess):\n                self.label_file = guess\n\n        # Determine shape and resolution.\n        max_shape = max(tfr_shapes, key=np.prod)\n        self.resolution = resolution if resolution is not None else max_shape[1]\n        self.resolution_log2 = int(np.log2(self.resolution))\n        self.shape = [max_shape[0], self.resolution, self.resolution]\n        tfr_lods = [self.resolution_log2 - int(np.log2(shape[1])) for shape in tfr_shapes]\n        assert all(shape[0] == max_shape[0] for shape in tfr_shapes)\n        assert all(shape[1] == shape[2] for shape in tfr_shapes)\n        assert all(shape[1] == self.resolution // (2**lod) for shape, lod in zip(tfr_shapes, tfr_lods))\n        assert all(lod in tfr_lods for lod in range(self.resolution_log2 - 1))\n\n        # Load labels.\n        assert max_label_size == 'full' or max_label_size >= 0\n        self._np_labels = np.zeros([1<<20, 0], dtype=np.float32)\n        if self.label_file is not None and max_label_size != 0:\n            self._np_labels = np.load(self.label_file)\n            assert self._np_labels.ndim == 2\n        if max_label_size != 'full' and self._np_labels.shape[1] > max_label_size:\n            self._np_labels = self._np_labels[:, :max_label_size]\n        self.label_size = self._np_labels.shape[1]\n        self.label_dtype = self._np_labels.dtype.name\n\n        # Build TF expressions.\n        with tf.name_scope('Dataset'), tf.device('/cpu:0'):\n            self._tf_minibatch_in = tf.placeholder(tf.int64, name='minibatch_in', shape=[])\n            self._tf_labels_var = tflib.create_var_with_large_initial_value(self._np_labels, name='labels_var')\n            self._tf_labels_dataset = tf.data.Dataset.from_tensor_slices(self._tf_labels_var)\n            for tfr_file, tfr_shape, tfr_lod in zip(tfr_files, tfr_shapes, tfr_lods):\n                if tfr_lod < 0:\n                    continue\n                dset = tf.data.TFRecordDataset(tfr_file, compression_type='', buffer_size=buffer_mb<<20)\n                dset = dset.map(parse_tfrecord_tf, num_parallel_calls=num_threads)\n                dset = tf.data.Dataset.zip((dset, self._tf_labels_dataset))\n                bytes_per_item = np.prod(tfr_shape) * np.dtype(self.dtype).itemsize\n                if shuffle_mb > 0:\n                    dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)\n                if repeat:\n                    dset = dset.repeat()\n                if prefetch_mb > 0:\n                    dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)\n                dset = dset.batch(self._tf_minibatch_in)\n                self._tf_datasets[tfr_lod] = dset\n            self._tf_iterator = tf.data.Iterator.from_structure(self._tf_datasets[0].output_types, self._tf_datasets[0].output_shapes)\n            self._tf_init_ops = {lod: self._tf_iterator.make_initializer(dset) for lod, dset in self._tf_datasets.items()}\n\n    # Use the given minibatch size and level-of-detail for the data returned by get_minibatch_tf().\n    def configure(self, minibatch_size, lod=0):\n        lod = int(np.floor(lod))\n        assert minibatch_size >= 1 and lod in self._tf_datasets\n        if self._cur_minibatch != minibatch_size or self._cur_lod != lod:\n            self._tf_init_ops[lod].run({self._tf_minibatch_in: minibatch_size})\n            self._cur_minibatch = minibatch_size\n            self._cur_lod = lod\n\n    # Get next minibatch as TensorFlow expressions.\n    def get_minibatch_tf(self): # => images, labels\n        return self._tf_iterator.get_next()\n\n    # Get next minibatch as NumPy arrays.\n    def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels\n        self.configure(minibatch_size, lod)\n        if self._tf_minibatch_np is None:\n            self._tf_minibatch_np = self.get_minibatch_tf()\n        return tflib.run(self._tf_minibatch_np)\n\n    # Get random labels as TensorFlow expression.\n    def get_random_labels_tf(self, minibatch_size): # => labels\n        if self.label_size > 0:\n            with tf.device('/cpu:0'):\n                return tf.gather(self._tf_labels_var, tf.random_uniform([minibatch_size], 0, self._np_labels.shape[0], dtype=tf.int32))\n        return tf.zeros([minibatch_size, 0], self.label_dtype)\n\n    # Get random labels as NumPy array.\n    def get_random_labels_np(self, minibatch_size): # => labels\n        if self.label_size > 0:\n            return self._np_labels[np.random.randint(self._np_labels.shape[0], size=[minibatch_size])]\n        return np.zeros([minibatch_size, 0], self.label_dtype)\n\n#----------------------------------------------------------------------------\n# Base class for datasets that are generated on the fly.\n\nclass SyntheticDataset:\n    def __init__(self, resolution=1024, num_channels=3, dtype='uint8', dynamic_range=[0,255], label_size=0, label_dtype='float32'):\n        self.resolution         = resolution\n        self.resolution_log2    = int(np.log2(resolution))\n        self.shape              = [num_channels, resolution, resolution]\n        self.dtype              = dtype\n        self.dynamic_range      = dynamic_range\n        self.label_size         = label_size\n        self.label_dtype        = label_dtype\n        self._tf_minibatch_var  = None\n        self._tf_lod_var        = None\n        self._tf_minibatch_np   = None\n        self._tf_labels_np      = None\n\n        assert self.resolution == 2 ** self.resolution_log2\n        with tf.name_scope('Dataset'):\n            self._tf_minibatch_var = tf.Variable(np.int32(0), name='minibatch_var')\n            self._tf_lod_var = tf.Variable(np.int32(0), name='lod_var')\n\n    def configure(self, minibatch_size, lod=0):\n        lod = int(np.floor(lod))\n        assert minibatch_size >= 1 and 0 <= lod <= self.resolution_log2\n        tflib.set_vars({self._tf_minibatch_var: minibatch_size, self._tf_lod_var: lod})\n\n    def get_minibatch_tf(self): # => images, labels\n        with tf.name_scope('SyntheticDataset'):\n            shrink = tf.cast(2.0 ** tf.cast(self._tf_lod_var, tf.float32), tf.int32)\n            shape = [self.shape[0], self.shape[1] // shrink, self.shape[2] // shrink]\n            images = self._generate_images(self._tf_minibatch_var, self._tf_lod_var, shape)\n            labels = self._generate_labels(self._tf_minibatch_var)\n            return images, labels\n\n    def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels\n        self.configure(minibatch_size, lod)\n        if self._tf_minibatch_np is None:\n            self._tf_minibatch_np = self.get_minibatch_tf()\n        return tflib.run(self._tf_minibatch_np)\n\n    def get_random_labels_tf(self, minibatch_size): # => labels\n        with tf.name_scope('SyntheticDataset'):\n            return self._generate_labels(minibatch_size)\n\n    def get_random_labels_np(self, minibatch_size): # => labels\n        self.configure(minibatch_size)\n        if self._tf_labels_np is None:\n            self._tf_labels_np = self.get_random_labels_tf(minibatch_size)\n        return tflib.run(self._tf_labels_np)\n\n    def _generate_images(self, minibatch, lod, shape): # to be overridden by subclasses # pylint: disable=unused-argument\n        return tf.zeros([minibatch] + shape, self.dtype)\n\n    def _generate_labels(self, minibatch): # to be overridden by subclasses\n        return tf.zeros([minibatch, self.label_size], self.label_dtype)\n\n#----------------------------------------------------------------------------\n# Helper func for constructing a dataset object using the given options.\n\ndef load_dataset(class_name='training.dataset.TFRecordDataset', data_dir=None, verbose=False, **kwargs):\n    adjusted_kwargs = dict(kwargs)\n    if 'tfrecord_dir' in adjusted_kwargs and data_dir is not None:\n        adjusted_kwargs['tfrecord_dir'] = os.path.join(data_dir, adjusted_kwargs['tfrecord_dir'])\n    if verbose:\n        print('Streaming data using %s...' % class_name)\n    dataset = dnnlib.util.get_obj_by_name(class_name)(**adjusted_kwargs)\n    if verbose:\n        print('Dataset shape =', np.int32(dataset.shape).tolist())\n        print('Dynamic range =', dataset.dynamic_range)\n        print('Label size    =', dataset.label_size)\n    return dataset\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "training/loss.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Loss functions.\"\"\"\n\nimport tensorflow as tf\nimport dnnlib.tflib as tflib\nfrom dnnlib.tflib.autosummary import autosummary\n\n#----------------------------------------------------------------------------\n# Convenience func that casts all of its arguments to tf.float32.\n\ndef fp32(*values):\n    if len(values) == 1 and isinstance(values[0], tuple):\n        values = values[0]\n    values = tuple(tf.cast(v, tf.float32) for v in values)\n    return values if len(values) >= 2 else values[0]\n\n#----------------------------------------------------------------------------\n# WGAN & WGAN-GP loss functions.\n\ndef G_wgan(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    labels = training_set.get_random_labels_tf(minibatch_size)\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    loss = -fake_scores_out\n    return loss\n\ndef D_wgan(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument\n    wgan_epsilon = 0.001): # Weight for the epsilon term, \\epsilon_{drift}.\n\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    real_scores_out = autosummary('Loss/scores/real', real_scores_out)\n    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)\n    loss = fake_scores_out - real_scores_out\n\n    with tf.name_scope('EpsilonPenalty'):\n        epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))\n    loss += epsilon_penalty * wgan_epsilon\n    return loss\n\ndef D_wgan_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument\n    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.\n    wgan_epsilon    = 0.001,    # Weight for the epsilon term, \\epsilon_{drift}.\n    wgan_target     = 1.0):     # Target value for gradient magnitudes.\n\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    real_scores_out = autosummary('Loss/scores/real', real_scores_out)\n    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)\n    loss = fake_scores_out - real_scores_out\n\n    with tf.name_scope('GradientPenalty'):\n        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)\n        mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)\n        mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True))\n        mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)\n        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))\n        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))\n        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))\n        mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)\n        gradient_penalty = tf.square(mixed_norms - wgan_target)\n    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))\n\n    with tf.name_scope('EpsilonPenalty'):\n        epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out))\n    loss += epsilon_penalty * wgan_epsilon\n    return loss\n\n#----------------------------------------------------------------------------\n# Hinge loss functions. (Use G_wgan with these)\n\ndef D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    real_scores_out = autosummary('Loss/scores/real', real_scores_out)\n    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)\n    loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out)\n    return loss\n\ndef D_hinge_gp(G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument\n    wgan_lambda     = 10.0,     # Weight for the gradient penalty term.\n    wgan_target     = 1.0):     # Target value for gradient magnitudes.\n\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    real_scores_out = autosummary('Loss/scores/real', real_scores_out)\n    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)\n    loss = tf.maximum(0., 1.+fake_scores_out) + tf.maximum(0., 1.-real_scores_out)\n\n    with tf.name_scope('GradientPenalty'):\n        mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype)\n        mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors)\n        mixed_scores_out = fp32(D.get_output_for(mixed_images_out, labels, is_training=True))\n        mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out)\n        mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out))\n        mixed_grads = opt.undo_loss_scaling(fp32(tf.gradients(mixed_loss, [mixed_images_out])[0]))\n        mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1,2,3]))\n        mixed_norms = autosummary('Loss/mixed_norms', mixed_norms)\n        gradient_penalty = tf.square(mixed_norms - wgan_target)\n    loss += gradient_penalty * (wgan_lambda / (wgan_target**2))\n    return loss\n\n\n#----------------------------------------------------------------------------\n# Loss functions advocated by the paper\n# \"Which Training Methods for GANs do actually Converge?\"\n\ndef G_logistic_saturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    labels = training_set.get_random_labels_tf(minibatch_size)\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    loss = -tf.nn.softplus(fake_scores_out)  # log(1 - logistic(fake_scores_out))\n    return loss\n\ndef G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    labels = training_set.get_random_labels_tf(minibatch_size)\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    loss = tf.nn.softplus(-fake_scores_out)  # -log(logistic(fake_scores_out))\n    return loss\n\ndef D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    real_scores_out = autosummary('Loss/scores/real', real_scores_out)\n    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)\n    loss = tf.nn.softplus(fake_scores_out)  # -log(1 - logistic(fake_scores_out))\n    loss += tf.nn.softplus(-real_scores_out)  # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type\n    return loss\n\ndef D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument\n    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])\n    fake_images_out = G.get_output_for(latents, labels, is_training=True)\n    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))\n    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))\n    real_scores_out = autosummary('Loss/scores/real', real_scores_out)\n    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)\n    loss = tf.nn.softplus(fake_scores_out)  # -log(1 - logistic(fake_scores_out))\n    loss += tf.nn.softplus(-real_scores_out)  # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type\n\n    if r1_gamma != 0.0:\n        with tf.name_scope('R1Penalty'):\n            real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out))\n            real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0]))\n            r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3])\n            r1_penalty = autosummary('Loss/r1_penalty', r1_penalty)\n        loss += r1_penalty * (r1_gamma * 0.5)\n\n    if r2_gamma != 0.0:\n        with tf.name_scope('R2Penalty'):\n            fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out))\n            fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0]))\n            r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3])\n            r2_penalty = autosummary('Loss/r2_penalty', r2_penalty)\n        loss += r2_penalty * (r2_gamma * 0.5)\n    return loss\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "training/misc.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Miscellaneous utility functions.\"\"\"\n\nimport os\nimport glob\nimport pickle\nimport re\nimport numpy as np\nfrom collections import defaultdict\nimport PIL.Image\nimport dnnlib\n\nimport config\nfrom training import dataset\n\n#----------------------------------------------------------------------------\n# Convenience wrappers for pickle that are able to load data produced by\n# older versions of the code, and from external URLs.\n\ndef open_file_or_url(file_or_url):\n    if dnnlib.util.is_url(file_or_url):\n        return dnnlib.util.open_url(file_or_url, cache_dir=config.cache_dir)\n    return open(file_or_url, 'rb')\n\ndef load_pkl(file_or_url):\n    with open_file_or_url(file_or_url) as file:\n        return pickle.load(file, encoding='latin1')\n\ndef save_pkl(obj, filename):\n    with open(filename, 'wb') as file:\n        pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)\n\n#----------------------------------------------------------------------------\n# Image utils.\n\ndef adjust_dynamic_range(data, drange_in, drange_out):\n    if drange_in != drange_out:\n        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0]))\n        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)\n        data = data * scale + bias\n    return data\n\ndef create_image_grid(images, grid_size=None):\n    assert images.ndim == 3 or images.ndim == 4\n    num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2]\n\n    if grid_size is not None:\n        grid_w, grid_h = tuple(grid_size)\n    else:\n        grid_w = max(int(np.ceil(np.sqrt(num))), 1)\n        grid_h = max((num - 1) // grid_w + 1, 1)\n\n    grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype)\n    for idx in range(num):\n        x = (idx % grid_w) * img_w\n        y = (idx // grid_w) * img_h\n        grid[..., y : y + img_h, x : x + img_w] = images[idx]\n    return grid\n\ndef convert_to_pil_image(image, drange=[0,1]):\n    assert image.ndim == 2 or image.ndim == 3\n    if image.ndim == 3:\n        if image.shape[0] == 1:\n            image = image[0] # grayscale CHW => HW\n        else:\n            image = image.transpose(1, 2, 0) # CHW -> HWC\n\n    image = adjust_dynamic_range(image, drange, [0,255])\n    image = np.rint(image).clip(0, 255).astype(np.uint8)\n    fmt = 'RGB' if image.ndim == 3 else 'L'\n    return PIL.Image.fromarray(image, fmt)\n\ndef save_image(image, filename, drange=[0,1], quality=95):\n    img = convert_to_pil_image(image, drange)\n    if '.jpg' in filename:\n        img.save(filename,\"JPEG\", quality=quality, optimize=True)\n    else:\n        img.save(filename)\n\ndef save_image_grid(images, filename, drange=[0,1], grid_size=None):\n    convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename)\n\n#----------------------------------------------------------------------------\n# Locating results.\n\ndef locate_run_dir(run_id_or_run_dir):\n    if isinstance(run_id_or_run_dir, str):\n        if os.path.isdir(run_id_or_run_dir):\n            return run_id_or_run_dir\n        converted = dnnlib.submission.submit.convert_path(run_id_or_run_dir)\n        if os.path.isdir(converted):\n            return converted\n\n    run_dir_pattern = re.compile('^0*%s-' % str(run_id_or_run_dir))\n    for search_dir in ['']:\n        full_search_dir = config.result_dir if search_dir == '' else os.path.normpath(os.path.join(config.result_dir, search_dir))\n        run_dir = os.path.join(full_search_dir, str(run_id_or_run_dir))\n        if os.path.isdir(run_dir):\n            return run_dir\n        run_dirs = sorted(glob.glob(os.path.join(full_search_dir, '*')))\n        run_dirs = [run_dir for run_dir in run_dirs if run_dir_pattern.match(os.path.basename(run_dir))]\n        run_dirs = [run_dir for run_dir in run_dirs if os.path.isdir(run_dir)]\n        if len(run_dirs) == 1:\n            return run_dirs[0]\n    raise IOError('Cannot locate result subdir for run', run_id_or_run_dir)\n\ndef list_network_pkls(run_id_or_run_dir, include_final=True):\n    run_dir = locate_run_dir(run_id_or_run_dir)\n    pkls = sorted(glob.glob(os.path.join(run_dir, 'network-*.pkl')))\n    if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl':\n        if include_final:\n            pkls.append(pkls[0])\n        del pkls[0]\n    return pkls\n\ndef locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None):\n    for candidate in [snapshot_or_network_pkl, run_id_or_run_dir_or_network_pkl]:\n        if isinstance(candidate, str):\n            if os.path.isfile(candidate):\n                return candidate\n            converted = dnnlib.submission.submit.convert_path(candidate)\n            if os.path.isfile(converted):\n                return converted\n\n    pkls = list_network_pkls(run_id_or_run_dir_or_network_pkl)\n    if len(pkls) >= 1 and snapshot_or_network_pkl is None:\n        return pkls[-1]\n\n    for pkl in pkls:\n        try:\n            name = os.path.splitext(os.path.basename(pkl))[0]\n            number = int(name.split('-')[-1])\n            if number == snapshot_or_network_pkl:\n                return pkl\n        except ValueError: pass\n        except IndexError: pass\n    raise IOError('Cannot locate network pkl for snapshot', snapshot_or_network_pkl)\n\ndef get_id_string_for_network_pkl(network_pkl):\n    p = network_pkl.replace('.pkl', '').replace('\\\\', '/').split('/')\n    return '-'.join(p[max(len(p) - 2, 0):])\n\n#----------------------------------------------------------------------------\n# Loading data from previous training runs.\n\ndef load_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl=None):\n    return load_pkl(locate_network_pkl(run_id_or_run_dir_or_network_pkl, snapshot_or_network_pkl))\n\ndef parse_config_for_previous_run(run_id):\n    run_dir = locate_run_dir(run_id)\n\n    # Parse config.txt.\n    cfg = defaultdict(dict)\n    with open(os.path.join(run_dir, 'config.txt'), 'rt') as f:\n        for line in f:\n            line = re.sub(r\"^{?\\s*'(\\w+)':\\s*{(.*)(},|}})$\", r\"\\1 = {\\2}\", line.strip())\n            if line.startswith('dataset =') or line.startswith('train ='):\n                exec(line, cfg, cfg) # pylint: disable=exec-used\n\n    # Handle legacy options.\n    if 'file_pattern' in cfg['dataset']:\n        cfg['dataset']['tfrecord_dir'] = cfg['dataset'].pop('file_pattern').replace('-r??.tfrecords', '')\n    if 'mirror_augment' in cfg['dataset']:\n        cfg['train']['mirror_augment'] = cfg['dataset'].pop('mirror_augment')\n    if 'max_labels' in cfg['dataset']:\n        v = cfg['dataset'].pop('max_labels')\n        if v is None: v = 0\n        if v == 'all': v = 'full'\n        cfg['dataset']['max_label_size'] = v\n    if 'max_images' in cfg['dataset']:\n        cfg['dataset'].pop('max_images')\n    return cfg\n\ndef load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment\n    cfg = parse_config_for_previous_run(run_id)\n    cfg['dataset'].update(kwargs)\n    dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **cfg['dataset'])\n    mirror_augment = cfg['train'].get('mirror_augment', False)\n    return dataset_obj, mirror_augment\n\ndef apply_mirror_augment(minibatch):\n    mask = np.random.rand(minibatch.shape[0]) < 0.5\n    minibatch = np.array(minibatch)\n    minibatch[mask] = minibatch[mask, :, :, ::-1]\n    return minibatch\n\n#----------------------------------------------------------------------------\n# Size and contents of the image snapshot grids that are exported\n# periodically during training.\n\ndef setup_snapshot_image_grid(G, training_set,\n    size    = '1080p',      # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.\n    layout  = 'random'):    # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.\n\n    # Select size.\n    gw = 1; gh = 1\n    if size == '1080p':\n        gw = np.clip(1920 // G.output_shape[3], 3, 32)\n        gh = np.clip(1080 // G.output_shape[2], 2, 32)\n    if size == '4k':\n        gw = np.clip(3840 // G.output_shape[3], 7, 32)\n        gh = np.clip(2160 // G.output_shape[2], 4, 32)\n\n    # Initialize data arrays.\n    reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)\n    labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)\n    latents = np.random.randn(gw * gh, *G.input_shape[1:])\n\n    # Random layout.\n    if layout == 'random':\n        reals[:], labels[:] = training_set.get_minibatch_np(gw * gh)\n\n    # Class-conditional layouts.\n    class_layouts = dict(row_per_class=[gw,1], col_per_class=[1,gh], class4x4=[4,4])\n    if layout in class_layouts:\n        bw, bh = class_layouts[layout]\n        nw = (gw - 1) // bw + 1\n        nh = (gh - 1) // bh + 1\n        blocks = [[] for _i in range(nw * nh)]\n        for _iter in range(1000000):\n            real, label = training_set.get_minibatch_np(1)\n            idx = np.argmax(label[0])\n            while idx < len(blocks) and len(blocks[idx]) >= bw * bh:\n                idx += training_set.label_size\n            if idx < len(blocks):\n                blocks[idx].append((real, label))\n                if all(len(block) >= bw * bh for block in blocks):\n                    break\n        for i, block in enumerate(blocks):\n            for j, (real, label) in enumerate(block):\n                x = (i %  nw) * bw + j %  bw\n                y = (i // nw) * bh + j // bw\n                if x < gw and y < gh:\n                    reals[x + y * gw] = real[0]\n                    labels[x + y * gw] = label[0]\n\n    return (gw, gh), reals, labels, latents\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "training/networks_progan.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Network architectures used in the ProGAN paper.\"\"\"\n\nimport numpy as np\nimport tensorflow as tf\n\n# NOTE: Do not import any application-specific modules here!\n# Specify all network parameters as kwargs.\n\n#----------------------------------------------------------------------------\n\ndef lerp(a, b, t): return a + (b - a) * t\ndef lerp_clip(a, b, t): return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)\ndef cset(cur_lambda, new_cond, new_lambda): return lambda: tf.cond(new_cond, new_lambda, cur_lambda)\n\n#----------------------------------------------------------------------------\n# Get/create weight tensor for a convolutional or fully-connected layer.\n\ndef get_weight(shape, gain=np.sqrt(2), use_wscale=False):\n    fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]\n    std = gain / np.sqrt(fan_in) # He init\n    if use_wscale:\n        wscale = tf.constant(np.float32(std), name='wscale')\n        w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal()) * wscale\n    else:\n        w = tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0, std))\n    return w\n\n#----------------------------------------------------------------------------\n# Fully-connected layer.\n\ndef dense(x, fmaps, gain=np.sqrt(2), use_wscale=False):\n    if len(x.shape) > 2:\n        x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])])\n    w = get_weight([x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)\n    w = tf.cast(w, x.dtype)\n    return tf.matmul(x, w)\n\n#----------------------------------------------------------------------------\n# Convolutional layer.\n\ndef conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):\n    assert kernel >= 1 and kernel % 2 == 1\n    w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)\n    w = tf.cast(w, x.dtype)\n    return tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW')\n\n#----------------------------------------------------------------------------\n# Apply bias to the given activation tensor.\n\ndef apply_bias(x):\n    b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros())\n    b = tf.cast(b, x.dtype)\n    if len(x.shape) == 2:\n        return x + b\n    return x + tf.reshape(b, [1, -1, 1, 1])\n\n#----------------------------------------------------------------------------\n# Leaky ReLU activation. Same as tf.nn.leaky_relu, but supports FP16.\n\ndef leaky_relu(x, alpha=0.2):\n    with tf.name_scope('LeakyRelu'):\n        alpha = tf.constant(alpha, dtype=x.dtype, name='alpha')\n        return tf.maximum(x * alpha, x)\n\n#----------------------------------------------------------------------------\n# Nearest-neighbor upscaling layer.\n\ndef upscale2d(x, factor=2):\n    assert isinstance(factor, int) and factor >= 1\n    if factor == 1: return x\n    with tf.variable_scope('Upscale2D'):\n        s = x.shape\n        x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])\n        x = tf.tile(x, [1, 1, 1, factor, 1, factor])\n        x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])\n        return x\n\n#----------------------------------------------------------------------------\n# Fused upscale2d + conv2d.\n# Faster and uses less memory than performing the operations separately.\n\ndef upscale2d_conv2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):\n    assert kernel >= 1 and kernel % 2 == 1\n    w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)\n    w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in]\n    w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')\n    w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]])\n    w = tf.cast(w, x.dtype)\n    os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2]\n    return tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW')\n\n#----------------------------------------------------------------------------\n# Box filter downscaling layer.\n\ndef downscale2d(x, factor=2):\n    assert isinstance(factor, int) and factor >= 1\n    if factor == 1: return x\n    with tf.variable_scope('Downscale2D'):\n        ksize = [1, 1, factor, factor]\n        return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW') # NOTE: requires tf_config['graph_options.place_pruned_graph'] = True\n\n#----------------------------------------------------------------------------\n# Fused conv2d + downscale2d.\n# Faster and uses less memory than performing the operations separately.\n\ndef conv2d_downscale2d(x, fmaps, kernel, gain=np.sqrt(2), use_wscale=False):\n    assert kernel >= 1 and kernel % 2 == 1\n    w = get_weight([kernel, kernel, x.shape[1].value, fmaps], gain=gain, use_wscale=use_wscale)\n    w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')\n    w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25\n    w = tf.cast(w, x.dtype)\n    return tf.nn.conv2d(x, w, strides=[1,1,2,2], padding='SAME', data_format='NCHW')\n\n#----------------------------------------------------------------------------\n# Pixelwise feature vector normalization.\n\ndef pixel_norm(x, epsilon=1e-8):\n    with tf.variable_scope('PixelNorm'):\n        return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon)\n\n#----------------------------------------------------------------------------\n# Minibatch standard deviation.\n\ndef minibatch_stddev_layer(x, group_size=4, num_new_features=1):\n    with tf.variable_scope('MinibatchStddev'):\n        group_size = tf.minimum(group_size, tf.shape(x)[0])     # Minibatch must be divisible by (or smaller than) group_size.\n        s = x.shape                                             # [NCHW]  Input shape.\n        y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]])   # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.\n        y = tf.cast(y, tf.float32)                              # [GMncHW] Cast to FP32.\n        y -= tf.reduce_mean(y, axis=0, keepdims=True)           # [GMncHW] Subtract mean over group.\n        y = tf.reduce_mean(tf.square(y), axis=0)                # [MncHW]  Calc variance over group.\n        y = tf.sqrt(y + 1e-8)                                   # [MncHW]  Calc stddev over group.\n        y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True)      # [Mn111]  Take average over fmaps and pixels.\n        y = tf.reduce_mean(y, axis=[2])                         # [Mn11] Split channels into c channel groups\n        y = tf.cast(y, x.dtype)                                 # [Mn11]  Cast back to original data type.\n        y = tf.tile(y, [group_size, 1, s[2], s[3]])             # [NnHW]  Replicate over group and pixels.\n        return tf.concat([x, y], axis=1)                        # [NCHW]  Append as new fmap.\n\n#----------------------------------------------------------------------------\n# Networks used in the ProgressiveGAN paper.\n\ndef G_paper(\n    latents_in,                         # First input: Latent vectors [minibatch, latent_size].\n    labels_in,                          # Second input: Labels [minibatch, label_size].\n    num_channels        = 1,            # Number of output color channels. Overridden based on dataset.\n    resolution          = 32,           # Output resolution. Overridden based on dataset.\n    label_size          = 0,            # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.\n    fmap_base           = 8192,         # Overall multiplier for the number of feature maps.\n    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.\n    fmap_max            = 512,          # Maximum number of feature maps in any layer.\n    latent_size         = None,         # Dimensionality of the latent vectors. None = min(fmap_base, fmap_max).\n    normalize_latents   = True,         # Normalize latent vectors before feeding them to the network?\n    use_wscale          = True,         # Enable equalized learning rate?\n    use_pixelnorm       = True,         # Enable pixelwise feature vector normalization?\n    pixelnorm_epsilon   = 1e-8,         # Constant epsilon for pixelwise feature vector normalization.\n    use_leakyrelu       = True,         # True = leaky ReLU, False = ReLU.\n    dtype               = 'float32',    # Data type to use for activations and outputs.\n    fused_scale         = True,         # True = use fused upscale2d + conv2d, False = separate upscale2d layers.\n    structure           = None,         # 'linear' = human-readable, 'recursive' = efficient, None = select automatically.\n    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.\n    **_kwargs):                         # Ignore unrecognized keyword args.\n\n    resolution_log2 = int(np.log2(resolution))\n    assert resolution == 2**resolution_log2 and resolution >= 4\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n    def PN(x): return pixel_norm(x, epsilon=pixelnorm_epsilon) if use_pixelnorm else x\n    if latent_size is None: latent_size = nf(0)\n    if structure is None: structure = 'linear' if is_template_graph else 'recursive'\n    act = leaky_relu if use_leakyrelu else tf.nn.relu\n\n    latents_in.set_shape([None, latent_size])\n    labels_in.set_shape([None, label_size])\n    combo_in = tf.cast(tf.concat([latents_in, labels_in], axis=1), dtype)\n    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)\n    images_out = None\n\n    # Building blocks.\n    def block(x, res): # res = 2..resolution_log2\n        with tf.variable_scope('%dx%d' % (2**res, 2**res)):\n            if res == 2: # 4x4\n                if normalize_latents: x = pixel_norm(x, epsilon=pixelnorm_epsilon)\n                with tf.variable_scope('Dense'):\n                    x = dense(x, fmaps=nf(res-1)*16, gain=np.sqrt(2)/4, use_wscale=use_wscale) # override gain to match the original Theano implementation\n                    x = tf.reshape(x, [-1, nf(res-1), 4, 4])\n                    x = PN(act(apply_bias(x)))\n                with tf.variable_scope('Conv'):\n                    x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))\n            else: # 8x8 and up\n                if fused_scale:\n                    with tf.variable_scope('Conv0_up'):\n                        x = PN(act(apply_bias(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))\n                else:\n                    x = upscale2d(x)\n                    with tf.variable_scope('Conv0'):\n                        x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))\n                with tf.variable_scope('Conv1'):\n                    x = PN(act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale))))\n            return x\n    def torgb(x, res): # res = 2..resolution_log2\n        lod = resolution_log2 - res\n        with tf.variable_scope('ToRGB_lod%d' % lod):\n            return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale))\n\n    # Linear structure: simple but inefficient.\n    if structure == 'linear':\n        x = block(combo_in, 2)\n        images_out = torgb(x, 2)\n        for res in range(3, resolution_log2 + 1):\n            lod = resolution_log2 - res\n            x = block(x, res)\n            img = torgb(x, res)\n            images_out = upscale2d(images_out)\n            with tf.variable_scope('Grow_lod%d' % lod):\n                images_out = lerp_clip(img, images_out, lod_in - lod)\n\n    # Recursive structure: complex but efficient.\n    if structure == 'recursive':\n        def grow(x, res, lod):\n            y = block(x, res)\n            img = lambda: upscale2d(torgb(y, res), 2**lod)\n            if res > 2: img = cset(img, (lod_in > lod), lambda: upscale2d(lerp(torgb(y, res), upscale2d(torgb(x, res - 1)), lod_in - lod), 2**lod))\n            if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1))\n            return img()\n        images_out = grow(combo_in, 2, resolution_log2 - 2)\n\n    assert images_out.dtype == tf.as_dtype(dtype)\n    images_out = tf.identity(images_out, name='images_out')\n    return images_out\n\n\ndef D_paper(\n    images_in,                          # First input: Images [minibatch, channel, height, width].\n    labels_in,                          # Second input: Labels [minibatch, label_size].\n    num_channels        = 1,            # Number of input color channels. Overridden based on dataset.\n    resolution          = 32,           # Input resolution. Overridden based on dataset.\n    label_size          = 0,            # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.\n    fmap_base           = 8192,         # Overall multiplier for the number of feature maps.\n    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.\n    fmap_max            = 512,          # Maximum number of feature maps in any layer.\n    use_wscale          = True,         # Enable equalized learning rate?\n    mbstd_group_size    = 4,            # Group size for the minibatch standard deviation layer, 0 = disable.\n    dtype               = 'float32',    # Data type to use for activations and outputs.\n    fused_scale         = True,         # True = use fused conv2d + downscale2d, False = separate downscale2d layers.\n    structure           = None,         # 'linear' = human-readable, 'recursive' = efficient, None = select automatically\n    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.\n    **_kwargs):                         # Ignore unrecognized keyword args.\n\n    resolution_log2 = int(np.log2(resolution))\n    assert resolution == 2**resolution_log2 and resolution >= 4\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n    if structure is None: structure = 'linear' if is_template_graph else 'recursive'\n    act = leaky_relu\n\n    images_in.set_shape([None, num_channels, resolution, resolution])\n    labels_in.set_shape([None, label_size])\n    images_in = tf.cast(images_in, dtype)\n    labels_in = tf.cast(labels_in, dtype)\n    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)\n    scores_out = None\n\n    # Building blocks.\n    def fromrgb(x, res): # res = 2..resolution_log2\n        with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):\n            return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, use_wscale=use_wscale)))\n    def block(x, res): # res = 2..resolution_log2\n        with tf.variable_scope('%dx%d' % (2**res, 2**res)):\n            if res >= 3: # 8x8 and up\n                with tf.variable_scope('Conv0'):\n                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))\n                if fused_scale:\n                    with tf.variable_scope('Conv1_down'):\n                        x = act(apply_bias(conv2d_downscale2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale)))\n                else:\n                    with tf.variable_scope('Conv1'):\n                        x = act(apply_bias(conv2d(x, fmaps=nf(res-2), kernel=3, use_wscale=use_wscale)))\n                    x = downscale2d(x)\n            else: # 4x4\n                if mbstd_group_size > 1:\n                    x = minibatch_stddev_layer(x, mbstd_group_size)\n                with tf.variable_scope('Conv'):\n                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, use_wscale=use_wscale)))\n                with tf.variable_scope('Dense0'):\n                    x = act(apply_bias(dense(x, fmaps=nf(res-2), use_wscale=use_wscale)))\n                with tf.variable_scope('Dense1'):\n                    x = apply_bias(dense(x, fmaps=1, gain=1, use_wscale=use_wscale))\n            return x\n\n    # Linear structure: simple but inefficient.\n    if structure == 'linear':\n        img = images_in\n        x = fromrgb(img, resolution_log2)\n        for res in range(resolution_log2, 2, -1):\n            lod = resolution_log2 - res\n            x = block(x, res)\n            img = downscale2d(img)\n            y = fromrgb(img, res - 1)\n            with tf.variable_scope('Grow_lod%d' % lod):\n                x = lerp_clip(x, y, lod_in - lod)\n        scores_out = block(x, 2)\n\n    # Recursive structure: complex but efficient.\n    if structure == 'recursive':\n        def grow(res, lod):\n            x = lambda: fromrgb(downscale2d(images_in, 2**lod), res)\n            if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))\n            x = block(x(), res); y = lambda: x\n            if res > 2: y = cset(y, (lod_in > lod), lambda: lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod))\n            return y()\n        scores_out = grow(2, resolution_log2 - 2)\n\n    assert scores_out.dtype == tf.as_dtype(dtype)\n    scores_out = tf.identity(scores_out, name='scores_out')\n    return scores_out\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "training/networks_stylegan.py",
    "content": "﻿# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Network architectures used in the StyleGAN paper.\"\"\"\n\nimport numpy as np\nimport tensorflow as tf\nimport dnnlib\nimport dnnlib.tflib as tflib\n\n# NOTE: Do not import any application-specific modules here!\n# Specify all network parameters as kwargs.\n\n#----------------------------------------------------------------------------\n# Primitive ops for manipulating 4D activation tensors.\n# The gradients of these are not necessary efficient or even meaningful.\n\ndef _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):\n    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])\n    assert isinstance(stride, int) and stride >= 1\n\n    # Finalize filter kernel.\n    f = np.array(f, dtype=np.float32)\n    if f.ndim == 1:\n        f = f[:, np.newaxis] * f[np.newaxis, :]\n    assert f.ndim == 2\n    if normalize:\n        f /= np.sum(f)\n    if flip:\n        f = f[::-1, ::-1]\n    f = f[:, :, np.newaxis, np.newaxis]\n    f = np.tile(f, [1, 1, int(x.shape[1]), 1])\n\n    # No-op => early exit.\n    if f.shape == (1, 1) and f[0,0] == 1:\n        return x\n\n    # Convolve using depthwise_conv2d.\n    orig_dtype = x.dtype\n    x = tf.cast(x, tf.float32)  # tf.nn.depthwise_conv2d() doesn't support fp16\n    f = tf.constant(f, dtype=x.dtype, name='filter')\n    strides = [1, 1, stride, stride]\n    x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW')\n    x = tf.cast(x, orig_dtype)\n    return x\n\ndef _upscale2d(x, factor=2, gain=1):\n    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])\n    assert isinstance(factor, int) and factor >= 1\n\n    # Apply gain.\n    if gain != 1:\n        x *= gain\n\n    # No-op => early exit.\n    if factor == 1:\n        return x\n\n    # Upscale using tf.tile().\n    s = x.shape\n    x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])\n    x = tf.tile(x, [1, 1, 1, factor, 1, factor])\n    x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])\n    return x\n\ndef _downscale2d(x, factor=2, gain=1):\n    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])\n    assert isinstance(factor, int) and factor >= 1\n\n    # 2x2, float32 => downscale using _blur2d().\n    if factor == 2 and x.dtype == tf.float32:\n        f = [np.sqrt(gain) / factor] * factor\n        return _blur2d(x, f=f, normalize=False, stride=factor)\n\n    # Apply gain.\n    if gain != 1:\n        x *= gain\n\n    # No-op => early exit.\n    if factor == 1:\n        return x\n\n    # Large factor => downscale using tf.nn.avg_pool().\n    # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.\n    ksize = [1, 1, factor, factor]\n    return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW')\n\n#----------------------------------------------------------------------------\n# High-level ops for manipulating 4D activation tensors.\n# The gradients of these are meant to be as efficient as possible.\n\ndef blur2d(x, f=[1,2,1], normalize=True):\n    with tf.variable_scope('Blur2D'):\n        @tf.custom_gradient\n        def func(x):\n            y = _blur2d(x, f, normalize)\n            @tf.custom_gradient\n            def grad(dy):\n                dx = _blur2d(dy, f, normalize, flip=True)\n                return dx, lambda ddx: _blur2d(ddx, f, normalize)\n            return y, grad\n        return func(x)\n\ndef upscale2d(x, factor=2):\n    with tf.variable_scope('Upscale2D'):\n        @tf.custom_gradient\n        def func(x):\n            y = _upscale2d(x, factor)\n            @tf.custom_gradient\n            def grad(dy):\n                dx = _downscale2d(dy, factor, gain=factor**2)\n                return dx, lambda ddx: _upscale2d(ddx, factor)\n            return y, grad\n        return func(x)\n\ndef downscale2d(x, factor=2):\n    with tf.variable_scope('Downscale2D'):\n        @tf.custom_gradient\n        def func(x):\n            y = _downscale2d(x, factor)\n            @tf.custom_gradient\n            def grad(dy):\n                dx = _upscale2d(dy, factor, gain=1/factor**2)\n                return dx, lambda ddx: _downscale2d(ddx, factor)\n            return y, grad\n        return func(x)\n\n#----------------------------------------------------------------------------\n# Get/create weight tensor for a convolutional or fully-connected layer.\n\ndef get_weight(shape, gain=np.sqrt(2), use_wscale=False, lrmul=1):\n    fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]\n    he_std = gain / np.sqrt(fan_in) # He init\n\n    # Equalized learning rate and custom learning rate multiplier.\n    if use_wscale:\n        init_std = 1.0 / lrmul\n        runtime_coef = he_std * lrmul\n    else:\n        init_std = he_std / lrmul\n        runtime_coef = lrmul\n\n    # Create variable.\n    init = tf.initializers.random_normal(0, init_std)\n    return tf.get_variable('weight', shape=shape, initializer=init) * runtime_coef\n\n#----------------------------------------------------------------------------\n# Fully-connected layer.\n\ndef dense(x, fmaps, **kwargs):\n    if len(x.shape) > 2:\n        x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])])\n    w = get_weight([x.shape[1].value, fmaps], **kwargs)\n    w = tf.cast(w, x.dtype)\n    return tf.matmul(x, w)\n\n#----------------------------------------------------------------------------\n# Convolutional layer.\n\ndef conv2d(x, fmaps, kernel, **kwargs):\n    assert kernel >= 1 and kernel % 2 == 1\n    w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs)\n    w = tf.cast(w, x.dtype)\n    return tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='SAME', data_format='NCHW')\n\n#----------------------------------------------------------------------------\n# Fused convolution + scaling.\n# Faster and uses less memory than performing the operations separately.\n\ndef upscale2d_conv2d(x, fmaps, kernel, fused_scale='auto', **kwargs):\n    assert kernel >= 1 and kernel % 2 == 1\n    assert fused_scale in [True, False, 'auto']\n    if fused_scale == 'auto':\n        fused_scale = min(x.shape[2:]) * 2 >= 128\n\n    # Not fused => call the individual ops directly.\n    if not fused_scale:\n        return conv2d(upscale2d(x), fmaps, kernel, **kwargs)\n\n    # Fused => perform both ops simultaneously using tf.nn.conv2d_transpose().\n    w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs)\n    w = tf.transpose(w, [0, 1, 3, 2]) # [kernel, kernel, fmaps_out, fmaps_in]\n    w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')\n    w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]])\n    w = tf.cast(w, x.dtype)\n    os = [tf.shape(x)[0], fmaps, x.shape[2] * 2, x.shape[3] * 2]\n    return tf.nn.conv2d_transpose(x, w, os, strides=[1,1,2,2], padding='SAME', data_format='NCHW')\n\ndef conv2d_downscale2d(x, fmaps, kernel, fused_scale='auto', **kwargs):\n    assert kernel >= 1 and kernel % 2 == 1\n    assert fused_scale in [True, False, 'auto']\n    if fused_scale == 'auto':\n        fused_scale = min(x.shape[2:]) >= 128\n\n    # Not fused => call the individual ops directly.\n    if not fused_scale:\n        return downscale2d(conv2d(x, fmaps, kernel, **kwargs))\n\n    # Fused => perform both ops simultaneously using tf.nn.conv2d().\n    w = get_weight([kernel, kernel, x.shape[1].value, fmaps], **kwargs)\n    w = tf.pad(w, [[1,1], [1,1], [0,0], [0,0]], mode='CONSTANT')\n    w = tf.add_n([w[1:, 1:], w[:-1, 1:], w[1:, :-1], w[:-1, :-1]]) * 0.25\n    w = tf.cast(w, x.dtype)\n    return tf.nn.conv2d(x, w, strides=[1,1,2,2], padding='SAME', data_format='NCHW')\n\n#----------------------------------------------------------------------------\n# Apply bias to the given activation tensor.\n\ndef apply_bias(x, lrmul=1):\n    b = tf.get_variable('bias', shape=[x.shape[1]], initializer=tf.initializers.zeros()) * lrmul\n    b = tf.cast(b, x.dtype)\n    if len(x.shape) == 2:\n        return x + b\n    return x + tf.reshape(b, [1, -1, 1, 1])\n\n#----------------------------------------------------------------------------\n# Leaky ReLU activation. More efficient than tf.nn.leaky_relu() and supports FP16.\n\ndef leaky_relu(x, alpha=0.2):\n    with tf.variable_scope('LeakyReLU'):\n        alpha = tf.constant(alpha, dtype=x.dtype, name='alpha')\n        @tf.custom_gradient\n        def func(x):\n            y = tf.maximum(x, x * alpha)\n            @tf.custom_gradient\n            def grad(dy):\n                dx = tf.where(y >= 0, dy, dy * alpha)\n                return dx, lambda ddx: tf.where(y >= 0, ddx, ddx * alpha)\n            return y, grad\n        return func(x)\n\n#----------------------------------------------------------------------------\n# Pixelwise feature vector normalization.\n\ndef pixel_norm(x, epsilon=1e-8):\n    with tf.variable_scope('PixelNorm'):\n        epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')\n        return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True) + epsilon)\n\n#----------------------------------------------------------------------------\n# Instance normalization.\n\ndef instance_norm(x, epsilon=1e-8):\n    assert len(x.shape) == 4 # NCHW\n    with tf.variable_scope('InstanceNorm'):\n        orig_dtype = x.dtype\n        x = tf.cast(x, tf.float32)\n        x -= tf.reduce_mean(x, axis=[2,3], keepdims=True)\n        epsilon = tf.constant(epsilon, dtype=x.dtype, name='epsilon')\n        x *= tf.rsqrt(tf.reduce_mean(tf.square(x), axis=[2,3], keepdims=True) + epsilon)\n        x = tf.cast(x, orig_dtype)\n        return x\n\n#----------------------------------------------------------------------------\n# Style modulation.\n\ndef style_mod(x, dlatent, **kwargs):\n    with tf.variable_scope('StyleMod'):\n        style = apply_bias(dense(dlatent, fmaps=x.shape[1]*2, gain=1, **kwargs))\n        style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2))\n        return x * (style[:,0] + 1) + style[:,1]\n\n#----------------------------------------------------------------------------\n# Noise input.\n\ndef apply_noise(x, noise_var=None, randomize_noise=True):\n    assert len(x.shape) == 4 # NCHW\n    with tf.variable_scope('Noise'):\n        if noise_var is None or randomize_noise:\n            noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype)\n        else:\n            noise = tf.cast(noise_var, x.dtype)\n        weight = tf.get_variable('weight', shape=[x.shape[1].value], initializer=tf.initializers.zeros())\n        return x + noise * tf.reshape(tf.cast(weight, x.dtype), [1, -1, 1, 1])\n\n#----------------------------------------------------------------------------\n# Minibatch standard deviation.\n\ndef minibatch_stddev_layer(x, group_size=4, num_new_features=1):\n    with tf.variable_scope('MinibatchStddev'):\n        group_size = tf.minimum(group_size, tf.shape(x)[0])     # Minibatch must be divisible by (or smaller than) group_size.\n        s = x.shape                                             # [NCHW]  Input shape.\n        y = tf.reshape(x, [group_size, -1, num_new_features, s[1]//num_new_features, s[2], s[3]])   # [GMncHW] Split minibatch into M groups of size G. Split channels into n channel groups c.\n        y = tf.cast(y, tf.float32)                              # [GMncHW] Cast to FP32.\n        y -= tf.reduce_mean(y, axis=0, keepdims=True)           # [GMncHW] Subtract mean over group.\n        y = tf.reduce_mean(tf.square(y), axis=0)                # [MncHW]  Calc variance over group.\n        y = tf.sqrt(y + 1e-8)                                   # [MncHW]  Calc stddev over group.\n        y = tf.reduce_mean(y, axis=[2,3,4], keepdims=True)      # [Mn111]  Take average over fmaps and pixels.\n        y = tf.reduce_mean(y, axis=[2])                         # [Mn11] Split channels into c channel groups\n        y = tf.cast(y, x.dtype)                                 # [Mn11]  Cast back to original data type.\n        y = tf.tile(y, [group_size, 1, s[2], s[3]])             # [NnHW]  Replicate over group and pixels.\n        return tf.concat([x, y], axis=1)                        # [NCHW]  Append as new fmap.\n\n#----------------------------------------------------------------------------\n# Style-based generator used in the StyleGAN paper.\n# Composed of two sub-networks (G_mapping and G_synthesis) that are defined below.\n\ndef G_style(\n    latents_in,                                     # First input: Latent vectors (Z) [minibatch, latent_size].\n    labels_in,                                      # Second input: Conditioning labels [minibatch, label_size].\n    truncation_psi          = 0.7,                  # Style strength multiplier for the truncation trick. None = disable.\n    truncation_cutoff       = 8,                    # Number of layers for which to apply the truncation trick. None = disable.\n    truncation_psi_val      = None,                 # Value for truncation_psi to use during validation.\n    truncation_cutoff_val   = None,                 # Value for truncation_cutoff to use during validation.\n    dlatent_avg_beta        = 0.995,                # Decay for tracking the moving average of W during training. None = disable.\n    style_mixing_prob       = 0.9,                  # Probability of mixing styles during training. None = disable.\n    is_training             = False,                # Network is under training? Enables and disables specific features.\n    is_validation           = False,                # Network is under validation? Chooses which value to use for truncation_psi.\n    is_template_graph       = False,                # True = template graph constructed by the Network class, False = actual evaluation.\n    components              = dnnlib.EasyDict(),    # Container for sub-networks. Retained between calls.\n    **kwargs):                                      # Arguments for sub-networks (G_mapping and G_synthesis).\n\n    # Validate arguments.\n    assert not is_training or not is_validation\n    assert isinstance(components, dnnlib.EasyDict)\n    if is_validation:\n        truncation_psi = truncation_psi_val\n        truncation_cutoff = truncation_cutoff_val\n    if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1):\n        truncation_psi = None\n    if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0):\n        truncation_cutoff = None\n    if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1):\n        dlatent_avg_beta = None\n    if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0):\n        style_mixing_prob = None\n\n    # Setup components.\n    if 'synthesis' not in components:\n        components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs)\n    num_layers = components.synthesis.input_shape[1]\n    dlatent_size = components.synthesis.input_shape[2]\n    if 'mapping' not in components:\n        components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs)\n\n    # Setup variables.\n    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)\n    dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False)\n\n    # Evaluate mapping network.\n    dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs)\n\n    # Update moving average of W.\n    if dlatent_avg_beta is not None:\n        with tf.variable_scope('DlatentAvg'):\n            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)\n            update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))\n            with tf.control_dependencies([update_op]):\n                dlatents = tf.identity(dlatents)\n\n    # Perform style mixing regularization.\n    if style_mixing_prob is not None:\n        with tf.name_scope('StyleMix'):\n            latents2 = tf.random_normal(tf.shape(latents_in))\n            dlatents2 = components.mapping.get_output_for(latents2, labels_in, **kwargs)\n            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]\n            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2\n            mixing_cutoff = tf.cond(\n                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,\n                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),\n                lambda: cur_layers)\n            dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)\n\n    # Apply truncation trick.\n    if truncation_psi is not None and truncation_cutoff is not None:\n        with tf.variable_scope('Truncation'):\n            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]\n            ones = np.ones(layer_idx.shape, dtype=np.float32)\n            coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones)\n            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)\n\n    # Evaluate synthesis network.\n    with tf.control_dependencies([tf.assign(components.synthesis.find_var('lod'), lod_in)]):\n        images_out = components.synthesis.get_output_for(dlatents, force_clean_graph=is_template_graph, **kwargs)\n    return tf.identity(images_out, name='images_out')\n\n#----------------------------------------------------------------------------\n# Mapping network used in the StyleGAN paper.\n\ndef G_mapping(\n    latents_in,                             # First input: Latent vectors (Z) [minibatch, latent_size].\n    labels_in,                              # Second input: Conditioning labels [minibatch, label_size].\n    latent_size             = 512,          # Latent vector (Z) dimensionality.\n    label_size              = 0,            # Label dimensionality, 0 if no labels.\n    dlatent_size            = 512,          # Disentangled latent (W) dimensionality.\n    dlatent_broadcast       = None,         # Output disentangled latent (W) as [minibatch, dlatent_size] or [minibatch, dlatent_broadcast, dlatent_size].\n    mapping_layers          = 8,            # Number of mapping layers.\n    mapping_fmaps           = 512,          # Number of activations in the mapping layers.\n    mapping_lrmul           = 0.01,         # Learning rate multiplier for the mapping layers.\n    mapping_nonlinearity    = 'lrelu',      # Activation function: 'relu', 'lrelu'.\n    use_wscale              = True,         # Enable equalized learning rate?\n    normalize_latents       = True,         # Normalize latent vectors (Z) before feeding them to the mapping layers?\n    dtype                   = 'float32',    # Data type to use for activations and outputs.\n    **_kwargs):                             # Ignore unrecognized keyword args.\n\n    act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[mapping_nonlinearity]\n\n    # Inputs.\n    latents_in.set_shape([None, latent_size])\n    labels_in.set_shape([None, label_size])\n    latents_in = tf.cast(latents_in, dtype)\n    labels_in = tf.cast(labels_in, dtype)\n    x = latents_in\n\n    # Embed labels and concatenate them with latents.\n    if label_size:\n        with tf.variable_scope('LabelConcat'):\n            w = tf.get_variable('weight', shape=[label_size, latent_size], initializer=tf.initializers.random_normal())\n            y = tf.matmul(labels_in, tf.cast(w, dtype))\n            x = tf.concat([x, y], axis=1)\n\n    # Normalize latents.\n    if normalize_latents:\n        x = pixel_norm(x)\n\n    # Mapping layers.\n    for layer_idx in range(mapping_layers):\n        with tf.variable_scope('Dense%d' % layer_idx):\n            fmaps = dlatent_size if layer_idx == mapping_layers - 1 else mapping_fmaps\n            x = dense(x, fmaps=fmaps, gain=gain, use_wscale=use_wscale, lrmul=mapping_lrmul)\n            x = apply_bias(x, lrmul=mapping_lrmul)\n            x = act(x)\n\n    # Broadcast.\n    if dlatent_broadcast is not None:\n        with tf.variable_scope('Broadcast'):\n            x = tf.tile(x[:, np.newaxis], [1, dlatent_broadcast, 1])\n\n    # Output.\n    assert x.dtype == tf.as_dtype(dtype)\n    return tf.identity(x, name='dlatents_out')\n\n#----------------------------------------------------------------------------\n# Synthesis network used in the StyleGAN paper.\n\ndef G_synthesis(\n    dlatents_in,                        # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].\n    dlatent_size        = 512,          # Disentangled latent (W) dimensionality.\n    num_channels        = 3,            # Number of output color channels.\n    resolution          = 1024,         # Output resolution.\n    fmap_base           = 8192,         # Overall multiplier for the number of feature maps.\n    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.\n    fmap_max            = 512,          # Maximum number of feature maps in any layer.\n    use_styles          = True,         # Enable style inputs?\n    const_input_layer   = True,         # First layer is a learned constant?\n    use_noise           = True,         # Enable noise inputs?\n    randomize_noise     = True,         # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.\n    nonlinearity        = 'lrelu',      # Activation function: 'relu', 'lrelu'\n    use_wscale          = True,         # Enable equalized learning rate?\n    use_pixel_norm      = False,        # Enable pixelwise feature vector normalization?\n    use_instance_norm   = True,         # Enable instance normalization?\n    dtype               = 'float32',    # Data type to use for activations and outputs.\n    fused_scale         = 'auto',       # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically.\n    blur_filter         = [1,2,1],      # Low-pass filter to apply when resampling activations. None = no filtering.\n    structure           = 'auto',       # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.\n    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.\n    force_clean_graph   = False,        # True = construct a clean graph that looks nice in TensorBoard, False = default behavior.\n    **_kwargs):                         # Ignore unrecognized keyword args.\n\n    resolution_log2 = int(np.log2(resolution))\n    assert resolution == 2**resolution_log2 and resolution >= 4\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n    def blur(x): return blur2d(x, blur_filter) if blur_filter else x\n    if is_template_graph: force_clean_graph = True\n    if force_clean_graph: randomize_noise = False\n    if structure == 'auto': structure = 'linear' if force_clean_graph else 'recursive'\n    act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity]\n    num_layers = resolution_log2 * 2 - 2\n    num_styles = num_layers if use_styles else 1\n    images_out = None\n\n    # Primary inputs.\n    dlatents_in.set_shape([None, num_styles, dlatent_size])\n    dlatents_in = tf.cast(dlatents_in, dtype)\n    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)\n\n    # Noise inputs.\n    noise_inputs = []\n    if use_noise:\n        for layer_idx in range(num_layers):\n            res = layer_idx // 2 + 2\n            shape = [1, use_noise, 2**res, 2**res]\n            noise_inputs.append(tf.get_variable('noise%d' % layer_idx, shape=shape, initializer=tf.initializers.random_normal(), trainable=False))\n\n    # Things to do at the end of each layer.\n    def layer_epilogue(x, layer_idx):\n        if use_noise:\n            x = apply_noise(x, noise_inputs[layer_idx], randomize_noise=randomize_noise)\n        x = apply_bias(x)\n        x = act(x)\n        if use_pixel_norm:\n            x = pixel_norm(x)\n        if use_instance_norm:\n            x = instance_norm(x)\n        if use_styles:\n            x = style_mod(x, dlatents_in[:, layer_idx], use_wscale=use_wscale)\n        return x\n\n    # Early layers.\n    with tf.variable_scope('4x4'):\n        if const_input_layer:\n            with tf.variable_scope('Const'):\n                x = tf.get_variable('const', shape=[1, nf(1), 4, 4], initializer=tf.initializers.ones())\n                x = layer_epilogue(tf.tile(tf.cast(x, dtype), [tf.shape(dlatents_in)[0], 1, 1, 1]), 0)\n        else:\n            with tf.variable_scope('Dense'):\n                x = dense(dlatents_in[:, 0], fmaps=nf(1)*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN\n                x = layer_epilogue(tf.reshape(x, [-1, nf(1), 4, 4]), 0)\n        with tf.variable_scope('Conv'):\n            x = layer_epilogue(conv2d(x, fmaps=nf(1), kernel=3, gain=gain, use_wscale=use_wscale), 1)\n\n    # Building blocks for remaining layers.\n    def block(res, x): # res = 3..resolution_log2\n        with tf.variable_scope('%dx%d' % (2**res, 2**res)):\n            with tf.variable_scope('Conv0_up'):\n                x = layer_epilogue(blur(upscale2d_conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale)), res*2-4)\n            with tf.variable_scope('Conv1'):\n                x = layer_epilogue(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale), res*2-3)\n            return x\n    def torgb(res, x): # res = 2..resolution_log2\n        lod = resolution_log2 - res\n        with tf.variable_scope('ToRGB_lod%d' % lod):\n            return apply_bias(conv2d(x, fmaps=num_channels, kernel=1, gain=1, use_wscale=use_wscale))\n\n    # Fixed structure: simple and efficient, but does not support progressive growing.\n    if structure == 'fixed':\n        for res in range(3, resolution_log2 + 1):\n            x = block(res, x)\n        images_out = torgb(resolution_log2, x)\n\n    # Linear structure: simple but inefficient.\n    if structure == 'linear':\n        images_out = torgb(2, x)\n        for res in range(3, resolution_log2 + 1):\n            lod = resolution_log2 - res\n            x = block(res, x)\n            img = torgb(res, x)\n            images_out = upscale2d(images_out)\n            with tf.variable_scope('Grow_lod%d' % lod):\n                images_out = tflib.lerp_clip(img, images_out, lod_in - lod)\n\n    # Recursive structure: complex but efficient.\n    if structure == 'recursive':\n        def cset(cur_lambda, new_cond, new_lambda):\n            return lambda: tf.cond(new_cond, new_lambda, cur_lambda)\n        def grow(x, res, lod):\n            y = block(res, x)\n            img = lambda: upscale2d(torgb(res, y), 2**lod)\n            img = cset(img, (lod_in > lod), lambda: upscale2d(tflib.lerp(torgb(res, y), upscale2d(torgb(res - 1, x)), lod_in - lod), 2**lod))\n            if lod > 0: img = cset(img, (lod_in < lod), lambda: grow(y, res + 1, lod - 1))\n            return img()\n        images_out = grow(x, 3, resolution_log2 - 3)\n\n    assert images_out.dtype == tf.as_dtype(dtype)\n    return tf.identity(images_out, name='images_out')\n\n#----------------------------------------------------------------------------\n# Discriminator used in the StyleGAN paper.\n\ndef D_basic(\n    images_in,                          # First input: Images [minibatch, channel, height, width].\n    labels_in,                          # Second input: Labels [minibatch, label_size].\n    num_channels        = 1,            # Number of input color channels. Overridden based on dataset.\n    resolution          = 32,           # Input resolution. Overridden based on dataset.\n    label_size          = 0,            # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.\n    fmap_base           = 8192,         # Overall multiplier for the number of feature maps.\n    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.\n    fmap_max            = 512,          # Maximum number of feature maps in any layer.\n    nonlinearity        = 'lrelu',      # Activation function: 'relu', 'lrelu',\n    use_wscale          = True,         # Enable equalized learning rate?\n    mbstd_group_size    = 4,            # Group size for the minibatch standard deviation layer, 0 = disable.\n    mbstd_num_features  = 1,            # Number of features for the minibatch standard deviation layer.\n    dtype               = 'float32',    # Data type to use for activations and outputs.\n    fused_scale         = 'auto',       # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically.\n    blur_filter         = [1,2,1],      # Low-pass filter to apply when resampling activations. None = no filtering.\n    structure           = 'auto',       # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.\n    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.\n    **_kwargs):                         # Ignore unrecognized keyword args.\n\n    resolution_log2 = int(np.log2(resolution))\n    assert resolution == 2**resolution_log2 and resolution >= 4\n    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n    def blur(x): return blur2d(x, blur_filter) if blur_filter else x\n    if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive'\n    act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity]\n\n    images_in.set_shape([None, num_channels, resolution, resolution])\n    labels_in.set_shape([None, label_size])\n    images_in = tf.cast(images_in, dtype)\n    labels_in = tf.cast(labels_in, dtype)\n    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)\n    scores_out = None\n\n    # Building blocks.\n    def fromrgb(x, res): # res = 2..resolution_log2\n        with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):\n            return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, gain=gain, use_wscale=use_wscale)))\n    def block(x, res): # res = 2..resolution_log2\n        with tf.variable_scope('%dx%d' % (2**res, 2**res)):\n            if res >= 3: # 8x8 and up\n                with tf.variable_scope('Conv0'):\n                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale)))\n                with tf.variable_scope('Conv1_down'):\n                    x = act(apply_bias(conv2d_downscale2d(blur(x), fmaps=nf(res-2), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale)))\n            else: # 4x4\n                if mbstd_group_size > 1:\n                    x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)\n                with tf.variable_scope('Conv'):\n                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale)))\n                with tf.variable_scope('Dense0'):\n                    x = act(apply_bias(dense(x, fmaps=nf(res-2), gain=gain, use_wscale=use_wscale)))\n                with tf.variable_scope('Dense1'):\n                    x = apply_bias(dense(x, fmaps=max(label_size, 1), gain=1, use_wscale=use_wscale))\n            return x\n\n    # Fixed structure: simple and efficient, but does not support progressive growing.\n    if structure == 'fixed':\n        x = fromrgb(images_in, resolution_log2)\n        for res in range(resolution_log2, 2, -1):\n            x = block(x, res)\n        scores_out = block(x, 2)\n\n    # Linear structure: simple but inefficient.\n    if structure == 'linear':\n        img = images_in\n        x = fromrgb(img, resolution_log2)\n        for res in range(resolution_log2, 2, -1):\n            lod = resolution_log2 - res\n            x = block(x, res)\n            img = downscale2d(img)\n            y = fromrgb(img, res - 1)\n            with tf.variable_scope('Grow_lod%d' % lod):\n                x = tflib.lerp_clip(x, y, lod_in - lod)\n        scores_out = block(x, 2)\n\n    # Recursive structure: complex but efficient.\n    if structure == 'recursive':\n        def cset(cur_lambda, new_cond, new_lambda):\n            return lambda: tf.cond(new_cond, new_lambda, cur_lambda)\n        def grow(res, lod):\n            x = lambda: fromrgb(downscale2d(images_in, 2**lod), res)\n            if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))\n            x = block(x(), res); y = lambda: x\n            if res > 2: y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod))\n            return y()\n        scores_out = grow(2, resolution_log2 - 2)\n\n    # Label conditioning from \"Which Training Methods for GANs do actually Converge?\"\n    if label_size:\n        with tf.variable_scope('LabelSwitch'):\n            scores_out = tf.reduce_sum(scores_out * labels_in, axis=1, keepdims=True)\n\n    assert scores_out.dtype == tf.as_dtype(dtype)\n    scores_out = tf.identity(scores_out, name='scores_out')\n    return scores_out\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "training/training_loop.py",
    "content": "# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# This work is licensed under the Creative Commons Attribution-NonCommercial\n# 4.0 International License. To view a copy of this license, visit\n# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to\n# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.\n\n\"\"\"Main training script.\"\"\"\n\nimport os\nimport numpy as np\nimport tensorflow as tf\nimport dnnlib\nimport dnnlib.tflib as tflib\nfrom dnnlib.tflib.autosummary import autosummary\n\nimport config\nimport train\nfrom training import dataset\nfrom training import misc\nfrom metrics import metric_base\n\n#----------------------------------------------------------------------------\n# Just-in-time processing of training images before feeding them to the networks.\n\ndef process_reals(x, lod, mirror_augment, drange_data, drange_net):\n    with tf.name_scope('ProcessReals'):\n        with tf.name_scope('DynamicRange'):\n            x = tf.cast(x, tf.float32)\n            x = misc.adjust_dynamic_range(x, drange_data, drange_net)\n        if mirror_augment:\n            with tf.name_scope('MirrorAugment'):\n                s = tf.shape(x)\n                mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0)\n                mask = tf.tile(mask, [1, s[1], s[2], s[3]])\n                x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3]))\n        with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.\n            s = tf.shape(x)\n            y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])\n            y = tf.reduce_mean(y, axis=[3, 5], keepdims=True)\n            y = tf.tile(y, [1, 1, 1, 2, 1, 2])\n            y = tf.reshape(y, [-1, s[1], s[2], s[3]])\n            x = tflib.lerp(x, y, lod - tf.floor(lod))\n        with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks.\n            s = tf.shape(x)\n            factor = tf.cast(2 ** tf.floor(lod), tf.int32)\n            x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])\n            x = tf.tile(x, [1, 1, 1, factor, 1, factor])\n            x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])\n        return x\n\n#----------------------------------------------------------------------------\n# Evaluate time-varying training parameters.\n\ndef training_schedule(\n    cur_nimg,\n    training_set,\n    num_gpus,\n    lod_initial_resolution  = 4,        # Image resolution used at the beginning.\n    lod_training_kimg       = 600,      # Thousands of real images to show before doubling the resolution.\n    lod_transition_kimg     = 600,      # Thousands of real images to show when fading in new layers.\n    minibatch_base          = 16,       # Maximum minibatch size, divided evenly among GPUs.\n    minibatch_dict          = {},       # Resolution-specific overrides.\n    max_minibatch_per_gpu   = {},       # Resolution-specific maximum minibatch size per GPU.\n    G_lrate_base            = 0.001,    # Learning rate for the generator.\n    G_lrate_dict            = {},       # Resolution-specific overrides.\n    D_lrate_base            = 0.001,    # Learning rate for the discriminator.\n    D_lrate_dict            = {},       # Resolution-specific overrides.\n    lrate_rampup_kimg       = 0,        # Duration of learning rate ramp-up.\n    tick_kimg_base          = 160,      # Default interval of progress snapshots.\n    tick_kimg_dict          = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:30, 1024:20}): # Resolution-specific overrides.\n\n    # Initialize result dict.\n    s = dnnlib.EasyDict()\n    s.kimg = cur_nimg / 1000.0\n\n    # Training phase.\n    phase_dur = lod_training_kimg + lod_transition_kimg\n    phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0\n    phase_kimg = s.kimg - phase_idx * phase_dur\n\n    # Level-of-detail and resolution.\n    s.lod = training_set.resolution_log2\n    s.lod -= np.floor(np.log2(lod_initial_resolution))\n    s.lod -= phase_idx\n    if lod_transition_kimg > 0:\n        s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg\n    s.lod = max(s.lod, 0.0)\n    s.resolution = 2 ** (training_set.resolution_log2 - int(np.floor(s.lod)))\n\n    # Minibatch size.\n    s.minibatch = minibatch_dict.get(s.resolution, minibatch_base)\n    s.minibatch -= s.minibatch % num_gpus\n    if s.resolution in max_minibatch_per_gpu:\n        s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus)\n\n    # Learning rate.\n    s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base)\n    s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base)\n    if lrate_rampup_kimg > 0:\n        rampup = min(s.kimg / lrate_rampup_kimg, 1.0)\n        s.G_lrate *= rampup\n        s.D_lrate *= rampup\n\n    # Other parameters.\n    s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base)\n    return s\n\n#----------------------------------------------------------------------------\n# Main training script.\n\ndef training_loop(\n    submit_config,\n    G_args                  = {},       # Options for generator network.\n    D_args                  = {},       # Options for discriminator network.\n    G_opt_args              = {},       # Options for generator optimizer.\n    D_opt_args              = {},       # Options for discriminator optimizer.\n    G_loss_args             = {},       # Options for generator loss.\n    D_loss_args             = {},       # Options for discriminator loss.\n    dataset_args            = {},       # Options for dataset.load_dataset().\n    sched_args              = {},       # Options for train.TrainingSchedule.\n    grid_args               = {},       # Options for train.setup_snapshot_image_grid().\n    metric_arg_list         = [],       # Options for MetricGroup.\n    tf_config               = {},       # Options for tflib.init_tf().\n    G_smoothing_kimg        = 10.0,     # Half-life of the running average of generator weights.\n    D_repeats               = 1,        # How many times the discriminator is trained per G iteration.\n    minibatch_repeats       = 4,        # Number of minibatches to run before adjusting training parameters.\n    reset_opt_for_new_lod   = True,     # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?\n    total_kimg              = 15000,    # Total length of the training, measured in thousands of real images.\n    mirror_augment          = False,    # Enable mirror augment?\n    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.\n    image_snapshot_ticks    = 1,        # How often to export image snapshots?\n    network_snapshot_ticks  = 10,       # How often to export network snapshots?\n    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?\n    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?\n    resume_run_id           = None,     # Run ID or network pkl to resume training from, None = start from scratch.\n    resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.\n    resume_kimg             = 0.0,      # Assumed training progress at the beginning. Affects reporting and training schedule.\n    resume_time             = 0.0):     # Assumed wallclock time at the beginning. Affects reporting.\n\n    # Initialize dnnlib and TensorFlow.\n    ctx = dnnlib.RunContext(submit_config, train)\n    tflib.init_tf(tf_config)\n\n    # Load training set.\n    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)\n\n    # Construct networks.\n    with tf.device('/gpu:0'):\n        if resume_run_id is not None:\n            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)\n            print('Loading networks from \"%s\"...' % network_pkl)\n            G, D, Gs = misc.load_pkl(network_pkl)\n        else:\n            print('Constructing networks...')\n            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)\n            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)\n            Gs = G.clone('Gs')\n    G.print_layers(); D.print_layers()\n\n    print('Building TensorFlow graph...')\n    with tf.name_scope('Inputs'), tf.device('/cpu:0'):\n        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])\n        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])\n        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])\n        minibatch_split = minibatch_in // submit_config.num_gpus\n        Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0\n\n    G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args)\n    D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)\n    for gpu in range(submit_config.num_gpus):\n        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):\n            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')\n            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')\n            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]\n            reals, labels = training_set.get_minibatch_tf()\n            reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)\n            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):\n                G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args)\n            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):\n                D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args)\n            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)\n            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)\n    G_train_op = G_opt.apply_updates()\n    D_train_op = D_opt.apply_updates()\n\n    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)\n    with tf.device('/gpu:0'):\n        try:\n            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()\n        except tf.errors.NotFoundError:\n            peak_gpu_mem_op = tf.constant(0)\n\n    print('Setting up snapshot image grid...')\n    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args)\n    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)\n    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)\n\n    print('Setting up run dir...')\n    misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)\n    misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size)\n    summary_log = tf.summary.FileWriter(submit_config.run_dir)\n    if save_tf_graph:\n        summary_log.add_graph(tf.get_default_graph())\n    if save_weight_histograms:\n        G.setup_weight_histograms(); D.setup_weight_histograms()\n    metrics = metric_base.MetricGroup(metric_arg_list)\n\n    print('Training...\\n')\n    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)\n    maintenance_time = ctx.get_last_update_interval()\n    cur_nimg = int(resume_kimg * 1000)\n    cur_tick = 0\n    tick_start_nimg = cur_nimg\n    prev_lod = -1.0\n    while cur_nimg < total_kimg * 1000:\n        if ctx.should_stop(): break\n\n        # Choose training parameters and configure training ops.\n        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)\n        training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod)\n        if reset_opt_for_new_lod:\n            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):\n                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()\n        prev_lod = sched.lod\n\n        # Run training ops.\n        for _mb_repeat in range(minibatch_repeats):\n            for _D_repeat in range(D_repeats):\n                tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})\n                cur_nimg += sched.minibatch\n            tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})\n\n        # Perform maintenance tasks once per tick.\n        done = (cur_nimg >= total_kimg * 1000)\n        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:\n            cur_tick += 1\n            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0\n            tick_start_nimg = cur_nimg\n            tick_time = ctx.get_time_since_last_update()\n            total_time = ctx.get_time_since_start() + resume_time\n\n            # Report progress.\n            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (\n                autosummary('Progress/tick', cur_tick),\n                autosummary('Progress/kimg', cur_nimg / 1000.0),\n                autosummary('Progress/lod', sched.lod),\n                autosummary('Progress/minibatch', sched.minibatch),\n                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),\n                autosummary('Timing/sec_per_tick', tick_time),\n                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),\n                autosummary('Timing/maintenance_sec', maintenance_time),\n                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))\n            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))\n            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))\n\n            # Save snapshots.\n            if cur_tick % image_snapshot_ticks == 0 or done:\n                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)\n                misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)\n            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:\n                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))\n                misc.save_pkl((G, D, Gs), pkl)\n                metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config)\n\n            # Update summaries and RunContext.\n            metrics.update_autosummaries()\n            tflib.autosummary.save_summaries(summary_log, cur_nimg)\n            ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)\n            maintenance_time = ctx.get_last_update_interval() - tick_time\n\n    # Write final results.\n    misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))\n    summary_log.close()\n\n    ctx.close()\n\n#----------------------------------------------------------------------------\n"
  }
]