Repository: AntixK/PyTorch-VAE Branch: master Commit: a6896b944c91 Files: 82 Total size: 236.1 KB Directory structure: gitextract_vk1fb46d/ ├── .gitignore ├── .idea/ │ ├── .gitignore │ ├── PyTorch-VAE.iml │ ├── inspectionProfiles/ │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ └── vcs.xml ├── LICENSE.md ├── README.md ├── configs/ │ ├── bbvae.yaml │ ├── betatc_vae.yaml │ ├── bhvae.yaml │ ├── cat_vae.yaml │ ├── cvae.yaml │ ├── dfc_vae.yaml │ ├── dip_vae.yaml │ ├── factorvae.yaml │ ├── gammavae.yaml │ ├── hvae.yaml │ ├── infovae.yaml │ ├── iwae.yaml │ ├── joint_vae.yaml │ ├── logcosh_vae.yaml │ ├── lvae.yaml │ ├── miwae.yaml │ ├── mssim_vae.yaml │ ├── swae.yaml │ ├── vae.yaml │ ├── vampvae.yaml │ ├── vq_vae.yaml │ ├── wae_mmd_imq.yaml │ └── wae_mmd_rbf.yaml ├── dataset.py ├── experiment.py ├── models/ │ ├── __init__.py │ ├── base.py │ ├── beta_vae.py │ ├── betatc_vae.py │ ├── cat_vae.py │ ├── cvae.py │ ├── dfcvae.py │ ├── dip_vae.py │ ├── fvae.py │ ├── gamma_vae.py │ ├── hvae.py │ ├── info_vae.py │ ├── iwae.py │ ├── joint_vae.py │ ├── logcosh_vae.py │ ├── lvae.py │ ├── miwae.py │ ├── mssim_vae.py │ ├── swae.py │ ├── twostage_vae.py │ ├── types_.py │ ├── vampvae.py │ ├── vanilla_vae.py │ ├── vq_vae.py │ └── wae_mmd.py ├── requirements.txt ├── run.py ├── tests/ │ ├── bvae.py │ ├── test_betatcvae.py │ ├── test_cat_vae.py │ ├── test_dfc.py │ ├── test_dipvae.py │ ├── test_fvae.py │ ├── test_gvae.py │ ├── test_hvae.py │ ├── test_iwae.py │ ├── test_joint_Vae.py │ ├── test_logcosh.py │ ├── test_lvae.py │ ├── test_miwae.py │ ├── test_mssimvae.py │ ├── test_swae.py │ ├── test_vae.py │ ├── test_vq_vae.py │ ├── test_wae.py │ ├── text_cvae.py │ └── text_vamp.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ Data/ logs/ VanillaVAE/version_0/ __pycache__/ .ipynb_checkpoints/ Run.ipynb ================================================ FILE: .idea/.gitignore ================================================ # Default ignored files /workspace.xml ================================================ FILE: .idea/PyTorch-VAE.iml ================================================ ================================================ FILE: .idea/inspectionProfiles/profiles_settings.xml ================================================ ================================================ FILE: .idea/misc.xml ================================================ ================================================ FILE: .idea/modules.xml ================================================ ================================================ FILE: .idea/vcs.xml ================================================ ================================================ FILE: LICENSE.md ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ Copyright Anand Krishnamoorthy Subramanian 2020 anandkrish894@gmail.com TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

PyTorch VAE

**Update 22/12/2021:** Added support for PyTorch Lightning 1.5.6 version and cleaned up the code. A collection of Variational AutoEncoders (VAEs) implemented in pytorch with focus on reproducibility. The aim of this project is to provide a quick and simple working example for many of the cool VAE models out there. All the models are trained on the [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates a radically different architecture (Ex. VQ VAE uses Residual layers and no Batch-Norm, unlike other models). Here are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README.md#--results) of each model. ### Requirements - Python >= 3.5 - PyTorch >= 1.3 - Pytorch Lightning >= 0.6.0 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32)) - CUDA enabled computing device ### Installation ``` $ git clone https://github.com/AntixK/PyTorch-VAE $ cd PyTorch-VAE $ pip install -r requirements.txt ``` ### Usage ``` $ cd PyTorch-VAE $ python run.py -c configs/ ``` **Config file template** ```yaml model_params: name: "" in_channels: 3 latent_dim: . # Other parameters required by the model . . data_params: data_path: "" train_batch_size: 64 # Better to have a square number val_batch_size: 64 patch_size: 64 # Models are designed to work for this size num_workers: 4 exp_params: manual_seed: 1265 LR: 0.005 weight_decay: . # Other arguments required for training, like scheduler etc. . . trainer_params: gpus: 1 max_epochs: 100 gradient_clip_val: 1.5 . . . logging_params: save_dir: "logs/" name: "" ``` **View TensorBoard Logs** ``` $ cd logs//version_ $ tensorboard --logdir . ``` **Note:** The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the [file](https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing) from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference. ----

Results

| Model | Paper |Reconstruction | Samples | |------------------------------------------------------------------------|--------------------------------------------------|---------------|---------| | VAE ([Code][vae_code], [Config][vae_config]) |[Link](https://arxiv.org/abs/1312.6114) | ![][2] | ![][1] | | Conditional VAE ([Code][cvae_code], [Config][cvae_config]) |[Link](https://openreview.net/forum?id=rJWXGDWd-H)| ![][16] | ![][15] | | WAE - MMD (RBF Kernel) ([Code][wae_code], [Config][wae_rbf_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] | | WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] | | Beta-VAE ([Code][bvae_code], [Config][bbvae_config]) |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] | | Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config]) |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] | | Beta-TC-VAE ([Code][btcvae_code], [Config][btcvae_config]) |[Link](https://arxiv.org/abs/1802.04942) | ![][34] | ![][33] | | IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] | | MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] | | DFCVAE ([Code][dfcvae_code], [Config][dfcvae_config]) |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] | | MSSIM VAE ([Code][mssimvae_code], [Config][mssimvae_config]) |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] | | Categorical VAE ([Code][catvae_code], [Config][catvae_config]) |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] | | Joint VAE ([Code][jointvae_code], [Config][jointvae_config]) |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] | | Info VAE ([Code][infovae_code], [Config][infovae_config]) |[Link](https://arxiv.org/abs/1706.02262) | ![][24] | ![][23] | | LogCosh VAE ([Code][logcoshvae_code], [Config][logcoshvae_config]) |[Link](https://openreview.net/forum?id=rkglvsC9Ym)| ![][26] | ![][25] | | SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] | | VQ-VAE (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937) | ![][31] | **N/A** | | DIP VAE ([Code][dipvae_code], [Config][dipvae_config]) |[Link](https://arxiv.org/abs/1711.00848) | ![][36] | ![][35] | ### Contributing If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file, I would be happy to include your result (along with your config file) in this repo, citing your name 😊. Additionally, if you would like to contribute some models, please submit a PR. ### License **Apache License 2.0** | Permissions | Limitations | Conditions | |------------------|-------------------|----------------------------------| | ✔️ Commercial use | ❌ Trademark use | ⓘ License and copyright notice | | ✔️ Modification | ❌ Liability | ⓘ State changes | | ✔️ Distribution | ❌ Warranty | | | ✔️ Patent use | | | | ✔️ Private use | | | ### Citation ``` @misc{Subramanian2020, author = {Subramanian, A.K}, title = {PyTorch-VAE}, year = {2020}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/AntixK/PyTorch-VAE}} } ``` ----------- [vae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py [cvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cvae.py [bvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py [btcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/betatc_vae.py [wae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/wae_mmd.py [iwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/iwae.py [miwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/miwae.py [swae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/swae.py [jointvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/joint_vae.py [dfcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dfcvae.py [mssimvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/mssim_vae.py [logcoshvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/logcosh_vae.py [catvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cat_vae.py [infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py [vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py [dipvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dip_vae.py [vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml [cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml [bbvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bbvae.yaml [bhvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bhvae.yaml [btcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/betatc_vae.yaml [wae_rbf_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_rbf.yaml [wae_imq_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_imq.yaml [iwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/iwae.yaml [miwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/miwae.yaml [swae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/swae.yaml [jointvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/joint_vae.yaml [dfcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dfc_vae.yaml [mssimvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/mssim_vae.yaml [logcoshvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/logcosh_vae.yaml [catvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cat_vae.yaml [infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml [vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml [dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml [1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png [2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png [3]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_RBF_18.png [4]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_RBF_19.png [5]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_IMQ_15.png [6]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_IMQ_15.png [7]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_H_20.png [8]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_H_20.png [9]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/IWAE_19.png [10]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_IWAE_19.png [11]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DFCVAE_49.png [12]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DFCVAE_49.png [13]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MSSIMVAE_29.png [14]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MSSIMVAE_29.png [15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_20.png [16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png [17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_49.png [18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_49.png [19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_49.png [20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_49.png [21]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_B_35.png [22]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_B_35.png [23]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/InfoVAE_31.png [24]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_InfoVAE_31.png [25]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/LogCoshVAE_49.png [26]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_LogCoshVAE_49.png [27]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/SWAE_49.png [28]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_SWAE_49.png [29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png [30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png [31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png [33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_49.png [34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png [35]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DIPVAE_83.png [36]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DIPVAE_83.png [python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg [python-url]: https://www.python.org/ [pytorch-image]: https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg [pytorch-url]: https://pytorch.org/ [twitter-image]:https://img.shields.io/twitter/url/https/shields.io.svg?style=social [twitter-url]:https://twitter.com/intent/tweet?text=Neural%20Blocks-Easy%20to%20use%20neural%20net%20blocks%20for%20fast%20prototyping.&url=https://github.com/AntixK/NeuralBlocks [license-image]:https://img.shields.io/badge/license-Apache2.0-blue.svg [license-url]:https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md ================================================ FILE: configs/bbvae.yaml ================================================ model_params: name: 'BetaVAE' in_channels: 3 latent_dim: 128 loss_type: 'B' gamma: 10.0 max_capacity: 25 Capacity_max_iter: 10000 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" manual_seed: 1265 name: 'BetaVAE' ================================================ FILE: configs/betatc_vae.yaml ================================================ model_params: name: 'BetaTCVAE' in_channels: 3 latent_dim: 10 anneal_steps: 10000 alpha: 1. beta: 6. gamma: 1. data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: 'BetaTCVAE' ================================================ FILE: configs/bhvae.yaml ================================================ model_params: name: 'BetaVAE' in_channels: 3 latent_dim: 128 loss_type: 'H' beta: 10. data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: 'BetaVAE' ================================================ FILE: configs/cat_vae.yaml ================================================ model_params: name: 'CategoricalVAE' in_channels: 3 latent_dim: 512 categorical_dim: 40 temperature: 0.5 anneal_rate: 0.00003 anneal_interval: 100 alpha: 1.0 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "CategoricalVAE" ================================================ FILE: configs/cvae.yaml ================================================ model_params: name: 'ConditionalVAE' in_channels: 3 num_classes: 40 latent_dim: 128 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "ConditionalVAE" ================================================ FILE: configs/dfc_vae.yaml ================================================ model_params: name: 'DFCVAE' in_channels: 3 latent_dim: 128 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "DFCVAE" ================================================ FILE: configs/dip_vae.yaml ================================================ model_params: name: 'DIPVAE' in_channels: 3 latent_dim: 128 lambda_diag: 0.05 lambda_offdiag: 0.1 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.001 weight_decay: 0.0 scheduler_gamma: 0.97 kld_weight: 1 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "DIPVAE" manual_seed: 1265 ================================================ FILE: configs/factorvae.yaml ================================================ model_params: name: 'FactorVAE' in_channels: 3 latent_dim: 128 gamma: 6.4 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: submodel: 'discriminator' retain_first_backpass: True LR: 0.005 weight_decay: 0.0 LR_2: 0.005 scheduler_gamma_2: 0.95 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "FactorVAE" ================================================ FILE: configs/gammavae.yaml ================================================ model_params: name: 'GammaVAE' in_channels: 3 latent_dim: 128 gamma_shape: 8. prior_shape: 2. prior_rate: 1. data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.003 weight_decay: 0.00005 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 gradient_clip_val: 0.8 logging_params: save_dir: "logs/" name: "GammaVAE" ================================================ FILE: configs/hvae.yaml ================================================ model_params: name: 'HVAE' in_channels: 3 latent1_dim: 64 latent2_dim: 64 pseudo_input_size: 128 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "VampVAE" ================================================ FILE: configs/infovae.yaml ================================================ model_params: name: 'InfoVAE' in_channels: 3 latent_dim: 128 reg_weight: 110 # MMD weight kernel_type: 'imq' alpha: -9.0 # KLD weight beta: 10.5 # Reconstruction weight data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 gradient_clip_val: 0.8 logging_params: save_dir: "logs/" name: "InfoVAE" manual_seed: 1265 ================================================ FILE: configs/iwae.yaml ================================================ model_params: name: 'IWAE' in_channels: 3 latent_dim: 128 num_samples: 5 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.007 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "IWAE" ================================================ FILE: configs/joint_vae.yaml ================================================ model_params: name: 'JointVAE' in_channels: 3 latent_dim: 512 categorical_dim: 40 latent_min_capacity: 0.0 latent_max_capacity: 20.0 latent_gamma: 10. latent_num_iter: 25000 categorical_min_capacity: 0.0 categorical_max_capacity: 20.0 categorical_gamma: 10. categorical_num_iter: 25000 temperature: 0.5 anneal_rate: 0.00003 anneal_interval: 100 alpha: 10.0 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "JointVAE" ================================================ FILE: configs/logcosh_vae.yaml ================================================ model_params: name: 'LogCoshVAE' in_channels: 3 latent_dim: 128 alpha: 10.0 beta: 1.0 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.97 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "LogCoshVAE" ================================================ FILE: configs/lvae.yaml ================================================ model_params: name: 'LVAE' in_channels: 3 latent_dims: [4,8,16,32,128] hidden_dims: [32, 64,128, 256, 512] data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "LVAE" ================================================ FILE: configs/miwae.yaml ================================================ model_params: name: 'MIWAE' in_channels: 3 latent_dim: 128 num_samples: 5 num_estimates: 3 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "MIWAE" ================================================ FILE: configs/mssim_vae.yaml ================================================ model_params: name: 'MSSIMVAE' in_channels: 3 latent_dim: 128 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "MSSIMVAE" ================================================ FILE: configs/swae.yaml ================================================ model_params: name: 'SWAE' in_channels: 3 latent_dim: 128 reg_weight: 100 wasserstein_deg: 2.0 num_projections: 200 projection_dist: "normal" #"cauchy" data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "SWAE" ================================================ FILE: configs/vae.yaml ================================================ model_params: name: 'VanillaVAE' in_channels: 3 latent_dim: 128 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 100 logging_params: save_dir: "logs/" name: "VanillaVAE" ================================================ FILE: configs/vampvae.yaml ================================================ model_params: name: 'VampVAE' in_channels: 3 latent_dim: 128 exp_params: dataset: celeba data_path: "../../shared/Data/" img_size: 64 batch_size: 144 # Better to have a square number LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 trainer_params: gpus: 1 max_nb_epochs: 50 max_epochs: 50 logging_params: save_dir: "logs/" name: "VampVAE" manual_seed: 1265 ================================================ FILE: configs/vq_vae.yaml ================================================ model_params: name: 'VQVAE' in_channels: 3 embedding_dim: 64 num_embeddings: 512 img_size: 64 beta: 0.25 data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.0 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: 'VQVAE' ================================================ FILE: configs/wae_mmd_imq.yaml ================================================ model_params: name: 'WAE_MMD' in_channels: 3 latent_dim: 128 reg_weight: 100 kernel_type: 'imq' data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "WassersteinVAE_IMQ" ================================================ FILE: configs/wae_mmd_rbf.yaml ================================================ model_params: name: 'WAE_MMD' in_channels: 3 latent_dim: 128 reg_weight: 5000 kernel_type: 'rbf' data_params: data_path: "Data/" train_batch_size: 64 val_batch_size: 64 patch_size: 64 num_workers: 4 exp_params: LR: 0.005 weight_decay: 0.0 scheduler_gamma: 0.95 kld_weight: 0.00025 manual_seed: 1265 trainer_params: gpus: [1] max_epochs: 10 logging_params: save_dir: "logs/" name: "WassersteinVAE_RBF" ================================================ FILE: dataset.py ================================================ import os import torch from torch import Tensor from pathlib import Path from typing import List, Optional, Sequence, Union, Any, Callable from torchvision.datasets.folder import default_loader from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torchvision.datasets import CelebA import zipfile # Add your custom dataset class here class MyDataset(Dataset): def __init__(self): pass def __len__(self): pass def __getitem__(self, idx): pass class MyCelebA(CelebA): """ A work-around to address issues with pytorch's celebA dataset class. Download and Extract URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing """ def _check_integrity(self) -> bool: return True class OxfordPets(Dataset): """ URL = https://www.robots.ox.ac.uk/~vgg/data/pets/ """ def __init__(self, data_path: str, split: str, transform: Callable, **kwargs): self.data_dir = Path(data_path) / "OxfordPets" self.transforms = transform imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg']) self.imgs = imgs[:int(len(imgs) * 0.75)] if split == "train" else imgs[int(len(imgs) * 0.75):] def __len__(self): return len(self.imgs) def __getitem__(self, idx): img = default_loader(self.imgs[idx]) if self.transforms is not None: img = self.transforms(img) return img, 0.0 # dummy datat to prevent breaking class VAEDataset(LightningDataModule): """ PyTorch Lightning data module Args: data_dir: root directory of your dataset. train_batch_size: the batch size to use during training. val_batch_size: the batch size to use during validation. patch_size: the size of the crop to take from the original images. num_workers: the number of parallel workers to create to load data items (see PyTorch's Dataloader documentation for more details). pin_memory: whether prepared items should be loaded into pinned memory or not. This can improve performance on GPUs. """ def __init__( self, data_path: str, train_batch_size: int = 8, val_batch_size: int = 8, patch_size: Union[int, Sequence[int]] = (256, 256), num_workers: int = 0, pin_memory: bool = False, **kwargs, ): super().__init__() self.data_dir = data_path self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.patch_size = patch_size self.num_workers = num_workers self.pin_memory = pin_memory def setup(self, stage: Optional[str] = None) -> None: # ========================= OxfordPets Dataset ========================= # train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), # transforms.CenterCrop(self.patch_size), # # transforms.Resize(self.patch_size), # transforms.ToTensor(), # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), # transforms.CenterCrop(self.patch_size), # # transforms.Resize(self.patch_size), # transforms.ToTensor(), # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # self.train_dataset = OxfordPets( # self.data_dir, # split='train', # transform=train_transforms, # ) # self.val_dataset = OxfordPets( # self.data_dir, # split='val', # transform=val_transforms, # ) # ========================= CelebA Dataset ========================= train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.CenterCrop(148), transforms.Resize(self.patch_size), transforms.ToTensor(),]) val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.CenterCrop(148), transforms.Resize(self.patch_size), transforms.ToTensor(),]) self.train_dataset = MyCelebA( self.data_dir, split='train', transform=train_transforms, download=False, ) # Replace CelebA with your dataset self.val_dataset = MyCelebA( self.data_dir, split='test', transform=val_transforms, download=False, ) # =============================================================== def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, batch_size=self.train_batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=self.pin_memory, ) def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: return DataLoader( self.val_dataset, batch_size=self.val_batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=self.pin_memory, ) def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: return DataLoader( self.val_dataset, batch_size=144, num_workers=self.num_workers, shuffle=True, pin_memory=self.pin_memory, ) ================================================ FILE: experiment.py ================================================ import os import math import torch from torch import optim from models import BaseVAE from models.types_ import * from utils import data_loader import pytorch_lightning as pl from torchvision import transforms import torchvision.utils as vutils from torchvision.datasets import CelebA from torch.utils.data import DataLoader class VAEXperiment(pl.LightningModule): def __init__(self, vae_model: BaseVAE, params: dict) -> None: super(VAEXperiment, self).__init__() self.model = vae_model self.params = params self.curr_device = None self.hold_graph = False try: self.hold_graph = self.params['retain_first_backpass'] except: pass def forward(self, input: Tensor, **kwargs) -> Tensor: return self.model(input, **kwargs) def training_step(self, batch, batch_idx, optimizer_idx = 0): real_img, labels = batch self.curr_device = real_img.device results = self.forward(real_img, labels = labels) train_loss = self.model.loss_function(*results, M_N = self.params['kld_weight'], #al_img.shape[0]/ self.num_train_imgs, optimizer_idx=optimizer_idx, batch_idx = batch_idx) self.log_dict({key: val.item() for key, val in train_loss.items()}, sync_dist=True) return train_loss['loss'] def validation_step(self, batch, batch_idx, optimizer_idx = 0): real_img, labels = batch self.curr_device = real_img.device results = self.forward(real_img, labels = labels) val_loss = self.model.loss_function(*results, M_N = 1.0, #real_img.shape[0]/ self.num_val_imgs, optimizer_idx = optimizer_idx, batch_idx = batch_idx) self.log_dict({f"val_{key}": val.item() for key, val in val_loss.items()}, sync_dist=True) def on_validation_end(self) -> None: self.sample_images() def sample_images(self): # Get sample reconstruction image test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader())) test_input = test_input.to(self.curr_device) test_label = test_label.to(self.curr_device) # test_input, test_label = batch recons = self.model.generate(test_input, labels = test_label) vutils.save_image(recons.data, os.path.join(self.logger.log_dir , "Reconstructions", f"recons_{self.logger.name}_Epoch_{self.current_epoch}.png"), normalize=True, nrow=12) try: samples = self.model.sample(144, self.curr_device, labels = test_label) vutils.save_image(samples.cpu().data, os.path.join(self.logger.log_dir , "Samples", f"{self.logger.name}_Epoch_{self.current_epoch}.png"), normalize=True, nrow=12) except Warning: pass def configure_optimizers(self): optims = [] scheds = [] optimizer = optim.Adam(self.model.parameters(), lr=self.params['LR'], weight_decay=self.params['weight_decay']) optims.append(optimizer) # Check if more than 1 optimizer is required (Used for adversarial training) try: if self.params['LR_2'] is not None: optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(), lr=self.params['LR_2']) optims.append(optimizer2) except: pass try: if self.params['scheduler_gamma'] is not None: scheduler = optim.lr_scheduler.ExponentialLR(optims[0], gamma = self.params['scheduler_gamma']) scheds.append(scheduler) # Check if another scheduler is required for the second optimizer try: if self.params['scheduler_gamma_2'] is not None: scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1], gamma = self.params['scheduler_gamma_2']) scheds.append(scheduler2) except: pass return optims, scheds except: return optims ================================================ FILE: models/__init__.py ================================================ from .base import * from .vanilla_vae import * from .gamma_vae import * from .beta_vae import * from .wae_mmd import * from .cvae import * from .hvae import * from .vampvae import * from .iwae import * from .dfcvae import * from .mssim_vae import MSSIMVAE from .fvae import * from .cat_vae import * from .joint_vae import * from .info_vae import * # from .twostage_vae import * from .lvae import LVAE from .logcosh_vae import * from .swae import * from .miwae import * from .vq_vae import * from .betatc_vae import * from .dip_vae import * # Aliases VAE = VanillaVAE GaussianVAE = VanillaVAE CVAE = ConditionalVAE GumbelVAE = CategoricalVAE vae_models = {'HVAE':HVAE, 'LVAE':LVAE, 'IWAE':IWAE, 'SWAE':SWAE, 'MIWAE':MIWAE, 'VQVAE':VQVAE, 'DFCVAE':DFCVAE, 'DIPVAE':DIPVAE, 'BetaVAE':BetaVAE, 'InfoVAE':InfoVAE, 'WAE_MMD':WAE_MMD, 'VampVAE': VampVAE, 'GammaVAE':GammaVAE, 'MSSIMVAE':MSSIMVAE, 'JointVAE':JointVAE, 'BetaTCVAE':BetaTCVAE, 'FactorVAE':FactorVAE, 'LogCoshVAE':LogCoshVAE, 'VanillaVAE':VanillaVAE, 'ConditionalVAE':ConditionalVAE, 'CategoricalVAE':CategoricalVAE} ================================================ FILE: models/base.py ================================================ from .types_ import * from torch import nn from abc import abstractmethod class BaseVAE(nn.Module): def __init__(self) -> None: super(BaseVAE, self).__init__() def encode(self, input: Tensor) -> List[Tensor]: raise NotImplementedError def decode(self, input: Tensor) -> Any: raise NotImplementedError def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: raise NotImplementedError def generate(self, x: Tensor, **kwargs) -> Tensor: raise NotImplementedError @abstractmethod def forward(self, *inputs: Tensor) -> Tensor: pass @abstractmethod def loss_function(self, *inputs: Any, **kwargs) -> Tensor: pass ================================================ FILE: models/beta_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class BetaVAE(BaseVAE): num_iter = 0 # Global static variable to keep track of iterations def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, beta: int = 4, gamma:float = 1000., max_capacity: int = 25, Capacity_max_iter: int = 1e5, loss_type:str = 'B', **kwargs) -> None: super(BetaVAE, self).__init__() self.latent_dim = latent_dim self.beta = beta self.gamma = gamma self.loss_type = loss_type self.C_max = torch.Tensor([max_capacity]) self.C_stop_iter = Capacity_max_iter modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Will a single z be enough ti compute the expectation for the loss?? :param mu: (Tensor) Mean of the latent Gaussian :param logvar: (Tensor) Standard deviation of the latent Gaussian :return: """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> Tensor: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: self.num_iter += 1 recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl loss = recons_loss + self.beta * kld_weight * kld_loss elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf self.C_max = self.C_max.to(input.device) C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0]) loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs() else: raise ValueError('Undefined loss type.') return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/betatc_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * import math class BetaTCVAE(BaseVAE): num_iter = 0 # Global static variable to keep track of iterations def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, anneal_steps: int = 200, alpha: float = 1., beta: float = 6., gamma: float = 1., **kwargs) -> None: super(BetaTCVAE, self).__init__() self.latent_dim = latent_dim self.anneal_steps = anneal_steps self.alpha = alpha self.beta = beta self.gamma = gamma modules = [] if hidden_dims is None: hidden_dims = [32, 32, 32, 32] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 4, stride= 2, padding = 1), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc = nn.Linear(hidden_dims[-1]*16, 256) self.fc_mu = nn.Linear(256, latent_dim) self.fc_var = nn.Linear(256, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, 256 * 2) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) result = self.fc(result) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 32, 4, 4) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var, z] def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor): """ Computes the log pdf of the Gaussian with parameters mu and logvar at x :param x: (Tensor) Point at whichGaussian PDF is to be evaluated :param mu: (Tensor) Mean of the Gaussian distribution :param logvar: (Tensor) Log variance of the Gaussian distribution :return: """ norm = - 0.5 * (math.log(2 * math.pi) + logvar) log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar)) return log_density def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] z = args[4] weight = 1 #kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input, reduction='sum') log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1) zeros = torch.zeros_like(z) log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1) batch_size, latent_dim = z.shape mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim), mu.view(1, batch_size, latent_dim), log_var.view(1, batch_size, latent_dim)) # Reference # [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54 dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1)) importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device) importance_weights.view(-1)[::batch_size] = 1 / dataset_size importance_weights.view(-1)[1::batch_size] = strat_weight importance_weights[batch_size - 2, 0] = strat_weight log_importance_weights = importance_weights.log() mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1) log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False) log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1) mi_loss = (log_q_zx - log_q_z).mean() tc_loss = (log_q_z - log_prod_q_z).mean() kld_loss = (log_prod_q_z - log_p_z).mean() # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) if self.training: self.num_iter += 1 anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1) else: anneal_rate = 1. loss = recons_loss/batch_size + \ self.alpha * mi_loss + \ weight * (self.beta * tc_loss + anneal_rate * self.gamma * kld_loss) return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss, 'TC_Loss':tc_loss, 'MI_Loss':mi_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/cat_vae.py ================================================ import torch import numpy as np from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class CategoricalVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, categorical_dim: int = 40, # Num classes hidden_dims: List = None, temperature: float = 0.5, anneal_rate: float = 3e-5, anneal_interval: int = 100, # every 100 batches alpha: float = 30., **kwargs) -> None: super(CategoricalVAE, self).__init__() self.latent_dim = latent_dim self.categorical_dim = categorical_dim self.temp = temperature self.min_temp = temperature self.anneal_rate = anneal_rate self.anneal_interval = anneal_interval self.alpha = alpha modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_z = nn.Linear(hidden_dims[-1]*4, self.latent_dim * self.categorical_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim , hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1))) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [B x C x H x W] :return: (Tensor) Latent code [B x D x Q] """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution z = self.fc_z(result) z = z.view(-1, self.latent_dim, self.categorical_dim) return [z] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D x Q] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor: """ Gumbel-softmax trick to sample from Categorical Distribution :param z: (Tensor) Latent Codes [B x D x Q] :return: (Tensor) [B x D] """ # Sample from Gumbel u = torch.rand_like(z) g = - torch.log(- torch.log(u + eps) + eps) # Gumbel-Softmax sample s = F.softmax((z + g) / self.temp, dim=-1) s = s.view(-1, self.latent_dim * self.categorical_dim) return s def forward(self, input: Tensor, **kwargs) -> List[Tensor]: q = self.encode(input)[0] z = self.reparameterize(q) return [self.decode(z), input, q] def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] q = args[2] q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset batch_idx = kwargs['batch_idx'] # Anneal the temperature at regular intervals if batch_idx % self.anneal_interval == 0 and self.training: self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx), self.min_temp) recons_loss =F.mse_loss(recons, input, reduction='mean') # KL divergence between gumbel-softmax distribution eps = 1e-7 # Entropy of the logits h1 = q_p * torch.log(q_p + eps) # Cross entropy with the categorical distribution h2 = q_p * np.log(1. / self.categorical_dim + eps) kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0) # kld_weight = 1.2 loss = self.alpha * recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ # [S x D x Q] M = num_samples * self.latent_dim np_y = np.zeros((M, self.categorical_dim), dtype=np.float32) np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1 np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim]) z = torch.from_numpy(np_y) # z = self.sampling_dist.sample((num_samples * self.latent_dim, )) z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/cvae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class ConditionalVAE(BaseVAE): def __init__(self, in_channels: int, num_classes: int, latent_dim: int, hidden_dims: List = None, img_size:int = 64, **kwargs) -> None: super(ConditionalVAE, self).__init__() self.latent_dim = latent_dim self.img_size = img_size self.embed_class = nn.Linear(num_classes, img_size * img_size) self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1) modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] in_channels += 1 # To account for the extra label channel # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Will a single z be enough ti compute the expectation for the loss?? :param mu: (Tensor) Mean of the latent Gaussian :param logvar: (Tensor) Standard deviation of the latent Gaussian :return: """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: y = kwargs['labels'].float() embedded_class = self.embed_class(y) embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1) embedded_input = self.embed_data(input) x = torch.cat([embedded_input, embedded_class], dim = 1) mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) z = torch.cat([z, y], dim = 1) return [self.decode(z), input, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ y = kwargs['labels'].float() z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) z = torch.cat([z, y], dim=1) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x, **kwargs)[0] ================================================ FILE: models/dfcvae.py ================================================ import torch from models import BaseVAE from torch import nn from torchvision.models import vgg19_bn from torch.nn import functional as F from .types_ import * class DFCVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, alpha:float = 1, beta:float = 0.5, **kwargs) -> None: super(DFCVAE, self).__init__() self.latent_dim = latent_dim self.alpha = alpha self.beta = beta modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) self.feature_network = vgg19_bn(pretrained=True) # Freeze the pretrained feature network for param in self.feature_network.parameters(): param.requires_grad = False self.feature_network.eval() def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) recons = self.decode(z) recons_features = self.extract_features(recons) input_features = self.extract_features(input) return [recons, input, recons_features, input_features, mu, log_var] def extract_features(self, input: Tensor, feature_layers: List = None) -> List[Tensor]: """ Extracts the features from the pretrained model at the layers indicated by feature_layers. :param input: (Tensor) [B x C x H x W] :param feature_layers: List of string of IDs :return: List of the extracted features """ if feature_layers is None: feature_layers = ['14', '24', '34', '43'] features = [] result = input for (key, module) in self.feature_network.features._modules.items(): result = module(result) if(key in feature_layers): features.append(result) return features def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] recons_features = args[2] input_features = args[3] mu = args[4] log_var = args[5] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) feature_loss = 0.0 for (r, i) in zip(recons_features, input_features): feature_loss += F.mse_loss(r, i) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/dip_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class DIPVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, lambda_diag: float = 10., lambda_offdiag: float = 5., **kwargs) -> None: super(DIPVAE, self).__init__() self.latent_dim = latent_dim self.lambda_diag = lambda_diag self.lambda_offdiag = lambda_offdiag modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input, reduction='sum') kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) # DIP Loss centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D] cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D] # Add Variance for DIP Loss II cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D] # For DIp Loss I # cov_z = cov_mu cov_diag = torch.diag(cov_z) # [D] cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D] dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \ self.lambda_diag * torch.sum((cov_diag - 1) ** 2) loss = recons_loss + kld_weight * kld_loss + dip_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss, 'DIP_Loss':dip_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/fvae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class FactorVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, gamma: float = 40., **kwargs) -> None: super(FactorVAE, self).__init__() self.latent_dim = latent_dim self.gamma = gamma modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) # Discriminator network for the Total Correlation (TC) loss self.discriminator = nn.Sequential(nn.Linear(self.latent_dim, 1000), nn.BatchNorm1d(1000), nn.LeakyReLU(0.2), nn.Linear(1000, 1000), nn.BatchNorm1d(1000), nn.LeakyReLU(0.2), nn.Linear(1000, 1000), nn.BatchNorm1d(1000), nn.LeakyReLU(0.2), nn.Linear(1000, 2)) self.D_z_reserve = None def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var, z] def permute_latent(self, z: Tensor) -> Tensor: """ Permutes each of the latent codes in the batch :param z: [B x D] :return: [B x D] """ B, D = z.size() # Returns a shuffled inds for each latent code in the batch inds = torch.cat([(D *i) + torch.randperm(D) for i in range(B)]) return z.view(-1)[inds].view(B, D) def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] z = args[4] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset optimizer_idx = kwargs['optimizer_idx'] # Update the VAE if optimizer_idx == 0: recons_loss =F.mse_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) self.D_z_reserve = self.discriminator(z) vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean() loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss # print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}') return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss, 'VAE_TC_Loss': vae_tc_loss} # Update the Discriminator elif optimizer_idx == 1: device = input.device true_labels = torch.ones(input.size(0), dtype= torch.long, requires_grad=False).to(device) false_labels = torch.zeros(input.size(0), dtype= torch.long, requires_grad=False).to(device) z = z.detach() # Detach so that VAE is not trained again z_perm = self.permute_latent(z) D_z_perm = self.discriminator(z_perm) D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) + F.cross_entropy(D_z_perm, true_labels)) # print(f'D_TC: {D_tc_loss}') return {'loss': D_tc_loss, 'D_TC_Loss':D_tc_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/gamma_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.distributions import Gamma from torch.nn import functional as F from .types_ import * import torch.nn.init as init class GammaVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, gamma_shape: float = 8., prior_shape: float = 2.0, prior_rate: float = 1., **kwargs) -> None: super(GammaVAE, self).__init__() self.latent_dim = latent_dim self.B = gamma_shape self.prior_alpha = torch.tensor([prior_shape]) self.prior_beta = torch.tensor([prior_rate]) modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim), nn.Softmax()) self.fc_var = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim), nn.Softmax()) # Build Decoder modules = [] self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims[-1] * 4)) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1), nn.Sigmoid()) self.weight_init() def weight_init(self): # print(self._modules) for block in self._modules: for m in self._modules[block]: init_(m) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution alpha = self.fc_mu(result) beta = self.fc_var(result) return [alpha, beta] def decode(self, z: Tensor) -> Tensor: result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor: """ Reparameterize the Gamma distribution by the shape augmentation trick. Reference: [1] https://arxiv.org/pdf/1610.05683.pdf :param alpha: (Tensor) Shape parameter of the latent Gamma :param beta: (Tensor) Rate parameter of the latent Gamma :return: """ # Sample from Gamma to guarantee acceptance alpha_ = alpha.clone().detach() z_hat = Gamma(alpha_ + self.B, torch.ones_like(alpha_)).sample() # Compute the eps ~ N(0,1) that produces z_hat eps = self.inv_h_func(alpha + self.B , z_hat) z = self.h_func(alpha + self.B, eps) # When beta != 1, scale by beta return z / beta def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor: """ Reparameterize a sample eps ~ N(0, 1) so that h(z) ~ Gamma(alpha, 1) :param alpha: (Tensor) Shape parameter :param eps: (Tensor) Random sample to reparameterize :return: (Tensor) """ z = (alpha - 1./3.) * (1 + eps / torch.sqrt(9. * alpha - 3.))**3 return z def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor: """ Inverse reparameterize the given z into eps. :param alpha: (Tensor) :param z: (Tensor) :return: (Tensor) """ eps = torch.sqrt(9. * alpha - 3.) * ((z / (alpha - 1./3.))**(1. / 3.) - 1.) return eps def forward(self, input: Tensor, **kwargs) -> Tensor: alpha, beta = self.encode(input) z = self.reparameterize(alpha, beta) return [self.decode(z), input, alpha, beta] # def I_function(self, alpha_p, beta_p, alpha_q, beta_q): # return - (alpha_q * beta_q) / alpha_p - \ # beta_p * torch.log(alpha_p) - torch.lgamma(beta_p) + \ # (beta_p - 1) * torch.digamma(beta_q) + \ # (beta_p - 1) * torch.log(alpha_q) def I_function(self, a, b, c, d): return - c * d / a - b * torch.log(a) - torch.lgamma(b) + (b - 1) * (torch.digamma(d) + torch.log(c)) def vae_gamma_kl_loss(self, a, b, c, d): """ https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions b and d are Gamma shape parameters and a and c are scale parameters. (All, therefore, must be positive.) """ a = 1 / a c = 1 / c losses = self.I_function(c, d, c, d) - self.I_function(a, b, c, d) return torch.sum(losses, dim=1) def loss_function(self, *args, **kwargs) -> dict: recons = args[0] input = args[1] alpha = args[2] beta = args[3] curr_device = input.device kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss = torch.mean(F.mse_loss(recons, input, reduction = 'none'), dim = (1,2,3)) # https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions # alpha = 1./ alpha self.prior_alpha = self.prior_alpha.to(curr_device) self.prior_beta = self.prior_beta.to(curr_device) # kld_loss = - self.I_function(alpha, beta, self.prior_alpha, self.prior_beta) kld_loss = self.vae_gamma_kl_loss(alpha, beta, self.prior_alpha, self.prior_beta) # kld_loss = torch.sum(kld_loss, dim=1) loss = recons_loss + kld_loss loss = torch.mean(loss, dim = 0) # print(loss, recons_loss, kld_loss) return {'loss': loss} #, 'Reconstruction_Loss': recons_loss, 'KLD': -kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the modelSay :return: (Tensor) """ z = Gamma(self.prior_alpha, self.prior_beta).sample((num_samples, self.latent_dim)) z = z.squeeze().to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] def init_(m): if isinstance(m, (nn.Linear, nn.Conv2d)): init.orthogonal_(m.weight) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0) ================================================ FILE: models/hvae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class HVAE(BaseVAE): def __init__(self, in_channels: int, latent1_dim: int, latent2_dim: int, hidden_dims: List = None, img_size:int = 64, pseudo_input_size: int = 128, **kwargs) -> None: super(HVAE, self).__init__() self.latent1_dim = latent1_dim self.latent2_dim = latent2_dim self.img_size = img_size modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] channels = in_channels # Build z2 Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) channels = h_dim self.encoder_z2_layers = nn.Sequential(*modules) self.fc_z2_mu = nn.Linear(hidden_dims[-1]*4, latent2_dim) self.fc_z2_var = nn.Linear(hidden_dims[-1]*4, latent2_dim) # ========================================================================# # Build z1 Encoder self.embed_z2_code = nn.Linear(latent2_dim, img_size * img_size) self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1) modules = [] channels = in_channels + 1 # One more channel for the latent code for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) channels = h_dim self.encoder_z1_layers = nn.Sequential(*modules) self.fc_z1_mu = nn.Linear(hidden_dims[-1]*4, latent1_dim) self.fc_z1_var = nn.Linear(hidden_dims[-1]*4, latent1_dim) #========================================================================# # Build z2 Decoder self.recons_z1_mu = nn.Linear(latent2_dim, latent1_dim) self.recons_z1_log_var = nn.Linear(latent2_dim, latent1_dim) # ========================================================================# # Build z1 Decoder self.debed_z1_code = nn.Linear(latent1_dim, 1024) self.debed_z2_code = nn.Linear(latent2_dim, 1024) modules = [] hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) # ========================================================================# # Pesudo Input for the Vamp-Prior # self.pseudo_input = torch.eye(pseudo_input_size, # requires_grad=False).view(1, 1, pseudo_input_size, -1) # # # self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels, # kernel_size=3, stride=2, padding=1) def encode_z2(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder_z2_layers(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution z2_mu = self.fc_z2_mu(result) z2_log_var = self.fc_z2_var(result) return [z2_mu, z2_log_var] def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]: x = self.embed_data(input) z2 = self.embed_z2_code(z2) z2 = z2.view(-1, self.img_size, self.img_size).unsqueeze(1) result = torch.cat([x, z2], dim=1) result = self.encoder_z1_layers(result) result = torch.flatten(result, start_dim=1) z1_mu = self.fc_z1_mu(result) z1_log_var = self.fc_z1_var(result) return [z1_mu, z1_log_var] def encode(self, input: Tensor) -> List[Tensor]: z2_mu, z2_log_var = self.encode_z2(input) z2 = self.reparameterize(z2_mu, z2_log_var) # z1 ~ q(z1|x, z2) z1_mu, z1_log_var = self.encode_z1(input, z2) return [z1_mu, z1_log_var, z2_mu, z2_log_var, z2] def decode(self, input: Tensor) -> Tensor: result = self.decoder(input) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Will a single z be enough ti compute the expectation for the loss?? :param mu: (Tensor) Mean of the latent Gaussian :param logvar: (Tensor) Standard deviation of the latent Gaussian :return: """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: # Encode the input into the latent codes z1 and z2 # z2 ~q(z2 | x) # z1 ~ q(z1|x, z2) z1_mu, z1_log_var, z2_mu, z2_log_var, z2 = self.encode(input) z1 = self.reparameterize(z1_mu, z1_log_var) # Reconstruct the image using both the latent codes # x ~ p(x|z1, z2) debedded_z1 = self.debed_z1_code(z1) debedded_z2 = self.debed_z2_code(z2) result = torch.cat([debedded_z1, debedded_z2], dim=1) result = result.view(-1, 512, 2, 2) recons = self.decode(result) return [recons, input, z1_mu, z1_log_var, z2_mu, z2_log_var, z1, z2] def loss_function(self, *args, **kwargs) -> dict: recons = args[0] input = args[1] z1_mu = args[2] z1_log_var = args[3] z2_mu = args[4] z2_log_var = args[5] z1= args[6] z2 = args[7] # Reconstruct (decode) z2 into z1 # z1 ~ p(z1|z2) [This for the loss calculation] z1_p_mu = self.recons_z1_mu(z2) z1_p_log_var = self.recons_z1_log_var(z2) kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) z1_kld = torch.mean(-0.5 * torch.sum(1 + z1_log_var - z1_mu ** 2 - z1_log_var.exp(), dim = 1), dim = 0) z2_kld = torch.mean(-0.5 * torch.sum(1 + z2_log_var - z2_mu ** 2 - z2_log_var.exp(), dim = 1), dim = 0) z1_p_kld = torch.mean(-0.5 * torch.sum(1 + z1_p_log_var - (z1 - z1_p_mu) ** 2 - z1_p_log_var.exp(), dim = 1), dim = 0) z2_p_kld = torch.mean(-0.5*(z2**2), dim = 0) kld_loss = -(z1_p_kld - z1_kld - z2_kld) loss = recons_loss + kld_weight * kld_loss # print(z2_p_kld) return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss} def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: z2 = torch.randn(batch_size, self.latent2_dim) z2 = z2.cuda(current_device) z1_mu = self.recons_z1_mu(z2) z1_log_var = self.recons_z1_log_var(z2) z1 = self.reparameterize(z1_mu, z1_log_var) debedded_z1 = self.debed_z1_code(z1) debedded_z2 = self.debed_z2_code(z2) result = torch.cat([debedded_z1, debedded_z2], dim=1) result = result.view(-1, 512, 2, 2) samples = self.decode(result) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/info_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class InfoVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, alpha: float = -0.5, beta: float = 5.0, reg_weight: int = 100, kernel_type: str = 'imq', latent_var: float = 2., **kwargs) -> None: super(InfoVAE, self).__init__() self.latent_dim = latent_dim self.reg_weight = reg_weight self.kernel_type = kernel_type self.z_var = latent_var assert alpha <= 0, 'alpha must be negative or zero.' self.alpha = alpha self.beta = beta modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, z, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: recons = args[0] input = args[1] z = args[2] mu = args[3] log_var = args[4] batch_size = input.size(0) bias_corr = batch_size * (batch_size - 1) kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) mmd_loss = self.compute_mmd(z) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) loss = self.beta * recons_loss + \ (1. - self.alpha) * kld_weight * kld_loss + \ (self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss} def compute_kernel(self, x1: Tensor, x2: Tensor) -> Tensor: # Convert the tensors into row and column vectors D = x1.size(1) N = x1.size(0) x1 = x1.unsqueeze(-2) # Make it into a column tensor x2 = x2.unsqueeze(-3) # Make it into a row tensor """ Usually the below lines are not required, especially in our case, but this is useful when x1 and x2 have different sizes along the 0th dimension. """ x1 = x1.expand(N, N, D) x2 = x2.expand(N, N, D) if self.kernel_type == 'rbf': result = self.compute_rbf(x1, x2) elif self.kernel_type == 'imq': result = self.compute_inv_mult_quad(x1, x2) else: raise ValueError('Undefined kernel type.') return result def compute_rbf(self, x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor: """ Computes the RBF Kernel between x1 and x2. :param x1: (Tensor) :param x2: (Tensor) :param eps: (Float) :return: """ z_dim = x2.size(-1) sigma = 2. * z_dim * self.z_var result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma)) return result def compute_inv_mult_quad(self, x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor: """ Computes the Inverse Multi-Quadratics Kernel between x1 and x2, given by k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2} :param x1: (Tensor) :param x2: (Tensor) :param eps: (Float) :return: """ z_dim = x2.size(-1) C = 2 * z_dim * self.z_var kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1)) # Exclude diagonal elements result = kernel.sum() - kernel.diag().sum() return result def compute_mmd(self, z: Tensor) -> Tensor: # Sample from prior (Gaussian) distribution prior_z = torch.randn_like(z) prior_z__kernel = self.compute_kernel(prior_z, prior_z) z__kernel = self.compute_kernel(z, z) priorz_z__kernel = self.compute_kernel(prior_z, z) mmd = prior_z__kernel.mean() + \ z__kernel.mean() - \ 2 * priorz_z__kernel.mean() return mmd def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/iwae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class IWAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, num_samples: int = 5, **kwargs) -> None: super(IWAE, self).__init__() self.latent_dim = latent_dim self.num_samples = num_samples modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes of S samples onto the image space. :param z: (Tensor) [B x S x D] :return: (Tensor) [B x S x C x H x W] """ B, _, _ = z.size() z = z.view(-1, self.latent_dim) #[BS x D] result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) #[BS x C x H x W ] result = result.view([B, -1, result.size(1), result.size(2), result.size(3)]) #[B x S x C x H x W] return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ :param mu: (Tensor) Mean of the latent Gaussian :param logvar: (Tensor) Standard deviation of the latent Gaussian :return: """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) mu = mu.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] log_var = log_var.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D] z= self.reparameterize(mu, log_var) # [B x S x D] eps = (z - mu) / log_var # Prior samples return [self.decode(z), input, mu, log_var, z, eps] def loss_function(self, *args, **kwargs) -> dict: """ KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] z = args[4] eps = args[5] input = input.repeat(self.num_samples, 1, 1, 1, 1).permute(1, 0, 2, 3, 4) #[B x S x C x H x W] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S] kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S] # Get importance weights log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1 weight = F.softmax(log_weight, dim = -1) # kld_loss = torch.mean(kld_loss, dim = 0) loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0) return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, 1, self.latent_dim) z = z.to(current_device) samples = self.decode(z).squeeze() return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image. Returns only the first reconstructed sample :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0][:, 0, :] ================================================ FILE: models/joint_vae.py ================================================ import torch import numpy as np from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class JointVAE(BaseVAE): num_iter = 1 def __init__(self, in_channels: int, latent_dim: int, categorical_dim: int, latent_min_capacity: float =0., latent_max_capacity: float = 25., latent_gamma: float = 30., latent_num_iter: int = 25000, categorical_min_capacity: float =0., categorical_max_capacity: float = 25., categorical_gamma: float = 30., categorical_num_iter: int = 25000, hidden_dims: List = None, temperature: float = 0.5, anneal_rate: float = 3e-5, anneal_interval: int = 100, # every 100 batches alpha: float = 30., **kwargs) -> None: super(JointVAE, self).__init__() self.latent_dim = latent_dim self.categorical_dim = categorical_dim self.temp = temperature self.min_temp = temperature self.anneal_rate = anneal_rate self.anneal_interval = anneal_interval self.alpha = alpha self.cont_min = latent_min_capacity self.cont_max = latent_max_capacity self.disc_min = categorical_min_capacity self.disc_max = categorical_max_capacity self.cont_gamma = latent_gamma self.disc_gamma = categorical_gamma self.cont_iter = latent_num_iter self.disc_iter = categorical_num_iter modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, self.latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, self.latent_dim) self.fc_z = nn.Linear(hidden_dims[-1]*4, self.categorical_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(self.latent_dim + self.categorical_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1))) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [B x C x H x W] :return: (Tensor) Latent code [B x D x Q] """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) z = self.fc_z(result) z = z.view(-1, self.categorical_dim) return [mu, log_var, z] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D x Q] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, log_var: Tensor, q: Tensor, eps:float = 1e-7) -> Tensor: """ Gumbel-softmax trick to sample from Categorical Distribution :param mu: (Tensor) mean of the latent Gaussian [B x D] :param log_var: (Tensor) Log variance of the latent Gaussian [B x D] :param q: (Tensor) Categorical latent Codes [B x Q] :return: (Tensor) [B x (D + Q)] """ std = torch.exp(0.5 * log_var) e = torch.randn_like(std) z = e * std + mu # Sample from Gumbel u = torch.rand_like(q) g = - torch.log(- torch.log(u + eps) + eps) # Gumbel-Softmax sample s = F.softmax((q + g) / self.temp, dim=-1) s = s.view(-1, self.categorical_dim) return torch.cat([z, s], dim=1) def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var, q = self.encode(input) z = self.reparameterize(mu, log_var, q) return [self.decode(z), input, q, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] q = args[2] mu = args[3] log_var = args[4] q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset batch_idx = kwargs['batch_idx'] # Anneal the temperature at regular intervals if batch_idx % self.anneal_interval == 0 and self.training: self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx), self.min_temp) recons_loss =F.mse_loss(recons, input, reduction='mean') # Adaptively increase the discrinimator capacity disc_curr = (self.disc_max - self.disc_min) * \ self.num_iter/ float(self.disc_iter) + self.disc_min disc_curr = min(disc_curr, np.log(self.categorical_dim)) # KL divergence between gumbel-softmax distribution eps = 1e-7 # Entropy of the logits h1 = q_p * torch.log(q_p + eps) # Cross entropy with the categorical distribution h2 = q_p * np.log(1. / self.categorical_dim + eps) kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0) # Compute Continuous loss # Adaptively increase the continuous capacity cont_curr = (self.cont_max - self.cont_min) * \ self.num_iter/ float(self.cont_iter) + self.cont_min cont_curr = min(cont_curr, self.cont_max) kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss) + \ self.cont_gamma * torch.abs(cont_curr - kld_cont_loss) # kld_weight = 1.2 loss = self.alpha * recons_loss + kld_weight * capacity_loss if self.training: self.num_iter += 1 return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'Capacity_Loss':capacity_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ # [S x D] z = torch.randn(num_samples, self.latent_dim) M = num_samples np_y = np.zeros((M, self.categorical_dim), dtype=np.float32) np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1 np_y = np.reshape(np_y, [M , self.categorical_dim]) q = torch.from_numpy(np_y) # z = self.sampling_dist.sample((num_samples * self.latent_dim, )) z = torch.cat([z, q], dim = 1).to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/logcosh_vae.py ================================================ import torch import torch.nn.functional as F from models import BaseVAE from torch import nn from .types_ import * class LogCoshVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, alpha: float = 100., beta: float = 10., **kwargs) -> None: super(LogCoshVAE, self).__init__() self.latent_dim = latent_dim self.alpha = alpha self.beta = beta modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset t = recons - input # recons_loss = F.mse_loss(recons, input) # cosh = torch.cosh(self.alpha * t) # recons_loss = (1./self.alpha * torch.log(cosh)).mean() recons_loss = self.alpha * t + \ torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \ torch.log(torch.tensor(2.0)) # print(self.alpha* t.max(), self.alpha*t.min()) recons_loss = (1. / self.alpha) * recons_loss.mean() kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = recons_loss + self.beta * kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/lvae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * from math import floor, pi, log def conv_out_shape(img_size): return floor((img_size + 2 - 3) / 2.) + 1 class EncoderBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, latent_dim: int, img_size: int): super(EncoderBlock, self).__init__() # Build Encoder self.encoder = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(out_channels), nn.LeakyReLU()) out_size = conv_out_shape(img_size) self.encoder_mu = nn.Linear(out_channels * out_size ** 2 , latent_dim) self.encoder_var = nn.Linear(out_channels * out_size ** 2, latent_dim) def forward(self, input: Tensor) -> Tensor: result = self.encoder(input) h = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.encoder_mu(h) log_var = self.encoder_var(h) return [result, mu, log_var] class LadderBlock(nn.Module): def __init__(self, in_channels: int, latent_dim: int): super(LadderBlock, self).__init__() # Build Decoder self.decode = nn.Sequential(nn.Linear(in_channels, latent_dim), nn.BatchNorm1d(latent_dim)) self.fc_mu = nn.Linear(latent_dim, latent_dim) self.fc_var = nn.Linear(latent_dim, latent_dim) def forward(self, z: Tensor) -> Tensor: z = self.decode(z) mu = self.fc_mu(z) log_var = self.fc_var(z) return [mu, log_var] class LVAE(BaseVAE): def __init__(self, in_channels: int, latent_dims: List, hidden_dims: List, **kwargs) -> None: super(LVAE, self).__init__() self.latent_dims = latent_dims self.hidden_dims = hidden_dims self.num_rungs = len(latent_dims) assert len(latent_dims) == len(hidden_dims), "Length of the latent" \ "and hidden dims must be the same" # Build Encoder modules = [] img_size = 64 for i, h_dim in enumerate(hidden_dims): modules.append(EncoderBlock(in_channels, h_dim, latent_dims[i], img_size)) img_size = conv_out_shape(img_size) in_channels = h_dim self.encoders = nn.Sequential(*modules) # ====================================================================== # # Build Decoder modules = [] for i in range(self.num_rungs -1, 0, -1): modules.append(LadderBlock(latent_dims[i], latent_dims[i-1])) self.ladders = nn.Sequential(*modules) self.decoder_input = nn.Linear(latent_dims[0], hidden_dims[-1] * 4) hidden_dims.reverse() modules = [] for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) hidden_dims.reverse() def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ h = input # Posterior Parameters post_params = [] for encoder_block in self.encoders: h, mu, log_var = encoder_block(h) post_params.append((mu, log_var)) return post_params def decode(self, z: Tensor, post_params: List) -> Tuple: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ kl_div = 0 post_params.reverse() for i, ladder_block in enumerate(self.ladders): mu_e, log_var_e = post_params[i] mu_t, log_var_t = ladder_block(z) mu, log_var = self.merge_gauss(mu_e, mu_t, log_var_e, log_var_t) z = self.reparameterize(mu, log_var) kl_div += self.compute_kl_divergence(z, (mu, log_var), (mu_e, log_var_e)) result = self.decoder_input(z) result = result.view(-1, self.hidden_dims[-1], 2, 2) result = self.decoder(result) return self.final_layer(result), kl_div def merge_gauss(self, mu_1: Tensor, mu_2: Tensor, log_var_1: Tensor, log_var_2: Tensor) -> List: p_1 = 1. / (log_var_1.exp() + 1e-7) p_2 = 1. / (log_var_2.exp() + 1e-7) mu = (mu_1 * p_1 + mu_2 * p_2)/(p_1 + p_2) log_var = torch.log(1./(p_1 + p_2)) return [mu, log_var] def compute_kl_divergence(self, z: Tensor, q_params: Tuple, p_params: Tuple): mu_q, log_var_q = q_params mu_p, log_var_p = p_params # # qz = -0.5 * torch.sum(1 + log_var_q + (z - mu_q) ** 2 / (2 * log_var_q.exp() + 1e-8), dim=1) # pz = -0.5 * torch.sum(1 + log_var_p + (z - mu_p) ** 2 / (2 * log_var_p.exp() + 1e-8), dim=1) kl = (log_var_p - log_var_q) + (log_var_q.exp() + (mu_q - mu_p)**2)/(2 * log_var_p.exp()) - 0.5 kl = torch.sum(kl, dim = -1) return kl def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: post_params = self.encode(input) mu, log_var = post_params.pop() z = self.reparameterize(mu, log_var) recons, kl_div = self.decode(z, post_params) #kl_div += -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1) return [recons, input, kl_div] def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] kl_div = args[2] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) kld_loss = torch.mean(kl_div, dim = 0) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss } def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dims[-1]) z = z.to(current_device) for ladder_block in self.ladders: mu, log_var = ladder_block(z) z = self.reparameterize(mu, log_var) result = self.decoder_input(z) result = result.view(-1, self.hidden_dims[-1], 2, 2) result = self.decoder(result) samples = self.final_layer(result) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/miwae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * from torch.distributions import Normal class MIWAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, num_samples: int = 5, num_estimates: int = 5, **kwargs) -> None: super(MIWAE, self).__init__() self.latent_dim = latent_dim self.num_samples = num_samples # K self.num_estimates = num_estimates # M modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes of S samples onto the image space. :param z: (Tensor) [B x S x D] :return: (Tensor) [B x S x C x H x W] """ B, M,S, D = z.size() z = z.contiguous().view(-1, self.latent_dim) #[BMS x D] result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) #[BMS x C x H x W ] result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W] return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ :param mu: (Tensor) Mean of the latent Gaussian :param logvar: (Tensor) Standard deviation of the latent Gaussian :return: """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D] log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D] z = self.reparameterize(mu, log_var) # [B x M x S x D] eps = (z - mu) / log_var # Prior samples return [self.decode(z), input, mu, log_var, z, eps] def loss_function(self, *args, **kwargs) -> dict: """ KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] z = args[4] eps = args[5] input = input.repeat(self.num_estimates, self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S] kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S] # Get importance weights log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1 weight = F.softmax(log_weight, dim = -1) # [B x M x S] loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0) return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, 1, 1, self.latent_dim) z = z.to(current_device) samples = self.decode(z).squeeze() return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image. Returns only the first reconstructed sample :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0][:, 0, 0, :] ================================================ FILE: models/mssim_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * from math import exp class MSSIMVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, window_size: int = 11, size_average: bool = True, **kwargs) -> None: super(MSSIMVAE, self).__init__() self.latent_dim = latent_dim self.in_channels = in_channels modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) self.mssim_loss = MSSIM(self.in_channels, window_size, size_average) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var] def loss_function(self, *args: Any, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss = self.mssim_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.cuda(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] class MSSIM(nn.Module): def __init__(self, in_channels: int = 3, window_size: int=11, size_average:bool = True) -> None: """ Computes the differentiable MS-SSIM loss Reference: [1] https://github.com/jorge-pessoa/pytorch-msssim/blob/dev/pytorch_msssim/__init__.py (MIT License) :param in_channels: (Int) :param window_size: (Int) :param size_average: (Bool) """ super(MSSIM, self).__init__() self.in_channels = in_channels self.window_size = window_size self.size_average = size_average def gaussian_window(self, window_size:int, sigma: float) -> Tensor: kernel = torch.tensor([exp((x - window_size // 2)**2/(2 * sigma ** 2)) for x in range(window_size)]) return kernel/kernel.sum() def create_window(self, window_size, in_channels): _1D_window = self.gaussian_window(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = _2D_window.expand(in_channels, 1, window_size, window_size).contiguous() return window def ssim(self, img1: Tensor, img2: Tensor, window_size: int, in_channel: int, size_average: bool) -> Tensor: device = img1.device window = self.create_window(window_size, in_channel).to(device) mu1 = F.conv2d(img1, window, padding= window_size//2, groups=in_channel) mu2 = F.conv2d(img2, window, padding= window_size//2, groups=in_channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size//2, groups=in_channel) - mu1_sq sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size//2, groups=in_channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding = window_size//2, groups=in_channel) - mu1_mu2 img_range = 1.0 #img1.max() - img1.min() # Dynamic range C1 = (0.01 * img_range) ** 2 C2 = (0.03 * img_range) ** 2 v1 = 2.0 * sigma12 + C2 v2 = sigma1_sq + sigma2_sq + C2 cs = torch.mean(v1 / v2) # contrast sensitivity ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) if size_average: ret = ssim_map.mean() else: ret = ssim_map.mean(1).mean(1).mean(1) return ret, cs def forward(self, img1: Tensor, img2: Tensor) -> Tensor: device = img1.device weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) levels = weights.size()[0] mssim = [] mcs = [] for _ in range(levels): sim, cs = self.ssim(img1, img2, self.window_size, self.in_channels, self.size_average) mssim.append(sim) mcs.append(cs) img1 = F.avg_pool2d(img1, (2, 2)) img2 = F.avg_pool2d(img2, (2, 2)) mssim = torch.stack(mssim) mcs = torch.stack(mcs) # # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) # if normalize: # mssim = (mssim + 1) / 2 # mcs = (mcs + 1) / 2 pow1 = mcs ** weights pow2 = mssim ** weights output = torch.prod(pow1[:-1] * pow2[-1]) return 1 - output ================================================ FILE: models/swae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from torch import distributions as dist from .types_ import * class SWAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, reg_weight: int = 100, wasserstein_deg: float= 2., num_projections: int = 50, projection_dist: str = 'normal', **kwargs) -> None: super(SWAE, self).__init__() self.latent_dim = latent_dim self.reg_weight = reg_weight self.p = wasserstein_deg self.num_projections = num_projections self.proj_dist = projection_dist modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> Tensor: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution z = self.fc_z(result) return z def decode(self, z: Tensor) -> Tensor: result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def forward(self, input: Tensor, **kwargs) -> List[Tensor]: z = self.encode(input) return [self.decode(z), input, z] def loss_function(self, *args, **kwargs) -> dict: recons = args[0] input = args[1] z = args[2] batch_size = input.size(0) bias_corr = batch_size * (batch_size - 1) reg_weight = self.reg_weight / bias_corr recons_loss_l2 = F.mse_loss(recons, input) recons_loss_l1 = F.l1_loss(recons, input) swd_loss = self.compute_swd(z, self.p, reg_weight) loss = recons_loss_l2 + recons_loss_l1 + swd_loss return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss} def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor: """ Returns random samples from latent distribution's (Gaussian) unit sphere for projecting the encoded samples and the distribution samples. :param latent_dim: (Int) Dimensionality of the latent space (D) :param num_samples: (Int) Number of samples required (S) :return: Random projections from the latent unit sphere """ if self.proj_dist == 'normal': rand_samples = torch.randn(num_samples, latent_dim) elif self.proj_dist == 'cauchy': rand_samples = dist.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze() else: raise ValueError('Unknown projection distribution.') rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1) return rand_proj # [S x D] def compute_swd(self, z: Tensor, p: float, reg_weight: float) -> Tensor: """ Computes the Sliced Wasserstein Distance (SWD) - which consists of randomly projecting the encoded and prior vectors and computing their Wasserstein distance along those projections. :param z: Latent samples # [N x D] :param p: Value for the p^th Wasserstein distance :param reg_weight: :return: """ prior_z = torch.randn_like(z) # [N x D] device = z.device proj_matrix = self.get_random_projections(self.latent_dim, num_samples=self.num_projections).transpose(0,1).to(device) latent_projections = z.matmul(proj_matrix) # [N x S] prior_projections = prior_z.matmul(proj_matrix) # [N x S] # The Wasserstein distance is computed by sorting the two projections # across the batches and computing their element-wise l2 distance w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \ torch.sort(prior_projections.t(), dim=1)[0] w_dist = w_dist.pow(p) return reg_weight * w_dist.mean() def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/twostage_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class TwoStageVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, hidden_dims2: List = None, **kwargs) -> None: super(TwoStageVAE, self).__init__() self.latent_dim = latent_dim modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] if hidden_dims2 is None: hidden_dims2 = [1024, 1024] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) #---------------------- Second VAE ---------------------------# encoder2 = [] in_channels = self.latent_dim for h_dim in hidden_dims2: encoder2.append(nn.Sequential( nn.Linear(in_channels, h_dim), nn.BatchNorm1d(h_dim), nn.LeakyReLU())) in_channels = h_dim self.encoder2 = nn.Sequential(*encoder2) self.fc_mu2 = nn.Linear(hidden_dims2[-1], self.latent_dim) self.fc_var2 = nn.Linear(hidden_dims2[-1], self.latent_dim) decoder2 = [] hidden_dims2.reverse() in_channels = self.latent_dim for h_dim in hidden_dims2: decoder2.append(nn.Sequential( nn.Linear(in_channels, h_dim), nn.BatchNorm1d(h_dim), nn.LeakyReLU())) in_channels = h_dim self.decoder2 = nn.Sequential(*decoder2) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/types_.py ================================================ from typing import List, Callable, Union, Any, TypeVar, Tuple # from torch import tensor as Tensor Tensor = TypeVar('torch.tensor') ================================================ FILE: models/vampvae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class VampVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, num_components: int = 50, **kwargs) -> None: super(VampVAE, self).__init__() self.latent_dim = latent_dim self.num_components = num_components modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) self.pseudo_input = torch.eye(self.num_components, requires_grad= False) self.embed_pseudo = nn.Sequential(nn.Linear(self.num_components, 12288), nn.Hardtanh(0.0, 1.0)) # 3x64x64 = 12288 def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Will a single z be enough ti compute the expectation for the loss?? :param mu: (Tensor) Mean of the latent Gaussian :param logvar: (Tensor) Standard deviation of the latent Gaussian :return: """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var, z] def loss_function(self, *args, **kwargs) -> dict: recons = args[0] input = args[1] mu = args[2] log_var = args[3] z = args[4] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) E_log_q_z = torch.mean(torch.sum(-0.5 * (log_var + (z - mu) ** 2)/ log_var.exp(), dim = 1), dim = 0) # Original Prior # E_log_p_z = torch.mean(torch.sum(-0.5 * (z ** 2), dim = 1), dim = 0) # Vamp Prior M, C, H, W = input.size() curr_device = input.device self.pseudo_input = self.pseudo_input.cuda(curr_device) x = self.embed_pseudo(self.pseudo_input) x = x.view(-1, C, H, W) prior_mu, prior_log_var = self.encode(x) z_expand = z.unsqueeze(1) prior_mu = prior_mu.unsqueeze(0) prior_log_var = prior_log_var.unsqueeze(0) E_log_p_z = torch.sum(-0.5 * (prior_log_var + (z_expand - prior_mu) ** 2)/ prior_log_var.exp(), dim = 2) - torch.log(torch.tensor(self.num_components).float()) # dim = 0) E_log_p_z = torch.logsumexp(E_log_p_z, dim = 1) E_log_p_z = torch.mean(E_log_p_z, dim = 0) # KLD = E_q log q - E_q log p kld_loss = -(E_log_p_z - E_log_q_z) # print(E_log_p_z, E_log_q_z) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.cuda(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/vanilla_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class VanillaVAE(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, **kwargs) -> None: super(VanillaVAE, self).__init__() self.latent_dim = latent_dim modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return [mu, log_var] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) return [self.decode(z), input, mu, log_var] def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()} def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/vq_vae.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class VectorQuantizer(nn.Module): """ Reference: [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py """ def __init__(self, num_embeddings: int, embedding_dim: int, beta: float = 0.25): super(VectorQuantizer, self).__init__() self.K = num_embeddings self.D = embedding_dim self.beta = beta self.embedding = nn.Embedding(self.K, self.D) self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) def forward(self, latents: Tensor) -> Tensor: latents = latents.permute(0, 2, 3, 1).contiguous() # [B x D x H x W] -> [B x H x W x D] latents_shape = latents.shape flat_latents = latents.view(-1, self.D) # [BHW x D] # Compute L2 distance between latents and embedding weights dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ torch.sum(self.embedding.weight ** 2, dim=1) - \ 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # [BHW x K] # Get the encoding that has the min distance encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] # Convert to one-hot encodings device = latents.device encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] # Quantize the latents quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D] quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] # Compute the VQ Losses commitment_loss = F.mse_loss(quantized_latents.detach(), latents) embedding_loss = F.mse_loss(quantized_latents, latents.detach()) vq_loss = commitment_loss * self.beta + embedding_loss # Add the residue back to the latents quantized_latents = latents + (quantized_latents - latents).detach() return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W] class ResidualLayer(nn.Module): def __init__(self, in_channels: int, out_channels: int): super(ResidualLayer, self).__init__() self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.ReLU(True), nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)) def forward(self, input: Tensor) -> Tensor: return input + self.resblock(input) class VQVAE(BaseVAE): def __init__(self, in_channels: int, embedding_dim: int, num_embeddings: int, hidden_dims: List = None, beta: float = 0.25, img_size: int = 64, **kwargs) -> None: super(VQVAE, self).__init__() self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.img_size = img_size self.beta = beta modules = [] if hidden_dims is None: hidden_dims = [128, 256] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=4, stride=2, padding=1), nn.LeakyReLU()) ) in_channels = h_dim modules.append( nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1), nn.LeakyReLU()) ) for _ in range(6): modules.append(ResidualLayer(in_channels, in_channels)) modules.append(nn.LeakyReLU()) modules.append( nn.Sequential( nn.Conv2d(in_channels, embedding_dim, kernel_size=1, stride=1), nn.LeakyReLU()) ) self.encoder = nn.Sequential(*modules) self.vq_layer = VectorQuantizer(num_embeddings, embedding_dim, self.beta) # Build Decoder modules = [] modules.append( nn.Sequential( nn.Conv2d(embedding_dim, hidden_dims[-1], kernel_size=3, stride=1, padding=1), nn.LeakyReLU()) ) for _ in range(6): modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1])) modules.append(nn.LeakyReLU()) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=4, stride=2, padding=1), nn.LeakyReLU()) ) modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], out_channels=3, kernel_size=4, stride=2, padding=1), nn.Tanh())) self.decoder = nn.Sequential(*modules) def encode(self, input: Tensor) -> List[Tensor]: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) return [result] def decode(self, z: Tensor) -> Tensor: """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D x H x W] :return: (Tensor) [B x C x H x W] """ result = self.decoder(z) return result def forward(self, input: Tensor, **kwargs) -> List[Tensor]: encoding = self.encode(input)[0] quantized_inputs, vq_loss = self.vq_layer(encoding) return [self.decode(quantized_inputs), input, vq_loss] def loss_function(self, *args, **kwargs) -> dict: """ :param args: :param kwargs: :return: """ recons = args[0] input = args[1] vq_loss = args[2] recons_loss = F.mse_loss(recons, input) loss = recons_loss + vq_loss return {'loss': loss, 'Reconstruction_Loss': recons_loss, 'VQ_Loss':vq_loss} def sample(self, num_samples: int, current_device: Union[int, str], **kwargs) -> Tensor: raise Warning('VQVAE sampler is not implemented.') def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: models/wae_mmd.py ================================================ import torch from models import BaseVAE from torch import nn from torch.nn import functional as F from .types_ import * class WAE_MMD(BaseVAE): def __init__(self, in_channels: int, latent_dim: int, hidden_dims: List = None, reg_weight: int = 100, kernel_type: str = 'imq', latent_var: float = 2., **kwargs) -> None: super(WAE_MMD, self).__init__() self.latent_dim = latent_dim self.reg_weight = reg_weight self.kernel_type = kernel_type self.z_var = latent_var modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input: Tensor) -> Tensor: """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution z = self.fc_z(result) return z def decode(self, z: Tensor) -> Tensor: result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def forward(self, input: Tensor, **kwargs) -> List[Tensor]: z = self.encode(input) return [self.decode(z), input, z] def loss_function(self, *args, **kwargs) -> dict: recons = args[0] input = args[1] z = args[2] batch_size = input.size(0) bias_corr = batch_size * (batch_size - 1) reg_weight = self.reg_weight / bias_corr recons_loss =F.mse_loss(recons, input) mmd_loss = self.compute_mmd(z, reg_weight) loss = recons_loss + mmd_loss return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss} def compute_kernel(self, x1: Tensor, x2: Tensor) -> Tensor: # Convert the tensors into row and column vectors D = x1.size(1) N = x1.size(0) x1 = x1.unsqueeze(-2) # Make it into a column tensor x2 = x2.unsqueeze(-3) # Make it into a row tensor """ Usually the below lines are not required, especially in our case, but this is useful when x1 and x2 have different sizes along the 0th dimension. """ x1 = x1.expand(N, N, D) x2 = x2.expand(N, N, D) if self.kernel_type == 'rbf': result = self.compute_rbf(x1, x2) elif self.kernel_type == 'imq': result = self.compute_inv_mult_quad(x1, x2) else: raise ValueError('Undefined kernel type.') return result def compute_rbf(self, x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor: """ Computes the RBF Kernel between x1 and x2. :param x1: (Tensor) :param x2: (Tensor) :param eps: (Float) :return: """ z_dim = x2.size(-1) sigma = 2. * z_dim * self.z_var result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma)) return result def compute_inv_mult_quad(self, x1: Tensor, x2: Tensor, eps: float = 1e-7) -> Tensor: """ Computes the Inverse Multi-Quadratics Kernel between x1 and x2, given by k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2} :param x1: (Tensor) :param x2: (Tensor) :param eps: (Float) :return: """ z_dim = x2.size(-1) C = 2 * z_dim * self.z_var kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1)) # Exclude diagonal elements result = kernel.sum() - kernel.diag().sum() return result def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor: # Sample from prior (Gaussian) distribution prior_z = torch.randn_like(z) prior_z__kernel = self.compute_kernel(prior_z, prior_z) z__kernel = self.compute_kernel(z, z) priorz_z__kernel = self.compute_kernel(prior_z, z) mmd = reg_weight * prior_z__kernel.mean() + \ reg_weight * z__kernel.mean() - \ 2 * reg_weight * priorz_z__kernel.mean() return mmd def sample(self, num_samples:int, current_device: int, **kwargs) -> Tensor: """ Samples from the latent space and return the corresponding image space map. :param num_samples: (Int) Number of samples :param current_device: (Int) Device to run the model :return: (Tensor) """ z = torch.randn(num_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples def generate(self, x: Tensor, **kwargs) -> Tensor: """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0] ================================================ FILE: requirements.txt ================================================ pytorch-lightning==1.5.6 PyYAML==6.0 tensorboard>=2.2.0 torch>=1.6.1 torchsummary==1.5.1 torchvision>=0.10.1 ================================================ FILE: run.py ================================================ import os import yaml import argparse import numpy as np from pathlib import Path from models import * from experiment import VAEXperiment import torch.backends.cudnn as cudnn from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from dataset import VAEDataset from pytorch_lightning.plugins import DDPPlugin parser = argparse.ArgumentParser(description='Generic runner for VAE models') parser.add_argument('--config', '-c', dest="filename", metavar='FILE', help = 'path to the config file', default='configs/vae.yaml') args = parser.parse_args() with open(args.filename, 'r') as file: try: config = yaml.safe_load(file) except yaml.YAMLError as exc: print(exc) tb_logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'], name=config['model_params']['name'],) # For reproducibility seed_everything(config['exp_params']['manual_seed'], True) model = vae_models[config['model_params']['name']](**config['model_params']) experiment = VAEXperiment(model, config['exp_params']) data = VAEDataset(**config["data_params"], pin_memory=len(config['trainer_params']['gpus']) != 0) data.setup() runner = Trainer(logger=tb_logger, callbacks=[ LearningRateMonitor(), ModelCheckpoint(save_top_k=2, dirpath =os.path.join(tb_logger.log_dir , "checkpoints"), monitor= "val_loss", save_last= True), ], strategy=DDPPlugin(find_unused_parameters=False), **config['trainer_params']) Path(f"{tb_logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True) Path(f"{tb_logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True) print(f"======= Training {config['model_params']['name']} =======") runner.fit(experiment, datamodule=data) ================================================ FILE: tests/bvae.py ================================================ import torch import unittest from models import BetaVAE from torchsummary import summary class TestVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = BetaVAE(3, 10, loss_type='H').cuda() def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64).cuda() result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_betatcvae.py ================================================ import torch import unittest from models import BetaTCVAE from torchsummary import summary class TestBetaTCVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = BetaTCVAE(3, 64, anneal_steps= 100) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(8, 'cuda') print(y.shape) def test_generate(self): x = torch.randn(16, 3, 64, 64) y = self.model.generate(x) print(y.shape) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_cat_vae.py ================================================ import torch import unittest from models import GumbelVAE from torchsummary import summary class TestVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = GumbelVAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(128, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) print(y.shape) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_dfc.py ================================================ import torch import unittest from models import DFCVAE from torchsummary import summary class TestDFCVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = DFCVAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_dipvae.py ================================================ import torch import unittest from models import DIPVAE from torchsummary import summary class TestDIPVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = DIPVAE(3, 64) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(8, 'cuda') print(y.shape) def test_generate(self): x = torch.randn(16, 3, 64, 64) y = self.model.generate(x) print(y.shape) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_fvae.py ================================================ import torch import unittest from models import FactorVAE from torchsummary import summary class TestFAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = FactorVAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # # print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) x2 = torch.randn(16,3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=0, secondary_input=x2) loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=1, secondary_input=x2) print(loss) def test_optim(self): optim1 = torch.optim.Adam(self.model.parameters(), lr = 0.001) optim2 = torch.optim.Adam(self.model.discrminator.parameters(), lr = 0.001) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_gvae.py ================================================ import torch import unittest from models import GammaVAE from torchsummary import summary class TestGammaVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = GammaVAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_hvae.py ================================================ import torch import unittest from models import HVAE from torchsummary import summary class TestHVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = HVAE(3, latent1_dim=10, latent2_dim=20) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_iwae.py ================================================ import torch import unittest from models import IWAE from torchsummary import summary class TestIWAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = IWAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_joint_Vae.py ================================================ import torch import unittest from models import JointVAE from torchsummary import summary class TestVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = JointVAE(3, 10, 40, 0.0) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(128, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) print(y.shape) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_logcosh.py ================================================ import torch import unittest from models import LogCoshVAE from torchsummary import summary class TestVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = LogCoshVAE(3, 10, alpha=10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.rand(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_lvae.py ================================================ import torch import unittest from models import LVAE from torchsummary import summary class TestLVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = LVAE(3, [4,8,16,32,128], hidden_dims=[32, 64,128, 256, 512]) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) print(y.shape) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_miwae.py ================================================ import torch import unittest from models import MIWAE from torchsummary import summary class TestMIWAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = MIWAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) print(y.shape) def test_generate(self): x = torch.randn(16, 3, 64, 64) y = self.model.generate(x) print(y.shape) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_mssimvae.py ================================================ import torch import unittest from models import MSSIMVAE from torchsummary import summary class TestMSSIMVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = MSSIMVAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(144, 0) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_swae.py ================================================ import torch import unittest from models import SWAE from torchsummary import summary class TestSWAE(unittest.TestCase): def setUp(self) -> None: self.model = SWAE(3, 10, reg_weight = 100) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_vae.py ================================================ import torch import unittest from models import VanillaVAE from torchsummary import summary class TestVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = VanillaVAE(3, 10) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_vq_vae.py ================================================ import torch import unittest from models import VQVAE from torchsummary import summary class TestVQVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = VQVAE(3, 64, 512) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): print(sum(p.numel() for p in self.model.parameters() if p.requires_grad)) x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) def test_sample(self): self.model.cuda() y = self.model.sample(8, 'cuda') print(y.shape) def test_generate(self): x = torch.randn(16, 3, 64, 64) y = self.model.generate(x) print(y.shape) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_wae.py ================================================ import torch import unittest from models import WAE_MMD from torchsummary import summary class TestWAE(unittest.TestCase): def setUp(self) -> None: self.model = WAE_MMD(3, 10, reg_weight = 100) def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) result = self.model(x) loss = self.model.loss_function(*result) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/text_cvae.py ================================================ import torch import unittest from models import CVAE class TestCVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = CVAE(3, 40, 10) def test_forward(self): x = torch.randn(16, 3, 64, 64) c = torch.randn(16, 40) y = self.model(x, c) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(16, 3, 64, 64) c = torch.randn(16, 40) result = self.model(x, labels = c) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/text_vamp.py ================================================ import torch import unittest from models import VampVAE from torchsummary import summary class TestVVAE(unittest.TestCase): def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = VampVAE(3, latent_dim=10).cuda() def test_summary(self): print(summary(self.model, (3, 64, 64), device='cpu')) # print(summary(self.model2, (3, 64, 64), device='cpu')) def test_forward(self): x = torch.randn(16, 3, 64, 64) y = self.model(x) print("Model Output size:", y[0].size()) # print("Model2 Output size:", self.model2(x)[0].size()) def test_loss(self): x = torch.randn(144, 3, 64, 64).cuda() result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) if __name__ == '__main__': unittest.main() ================================================ FILE: utils.py ================================================ import pytorch_lightning as pl ## Utils to handle newer PyTorch Lightning changes from version 0.6 ## ==================================================================================================== ## def data_loader(fn): """ Decorator to handle the deprecation of data_loader from 0.7 :param fn: User defined data loader function :return: A wrapper for the data_loader function """ def func_wrapper(self): try: # Works for version 0.6.0 return pl.data_loader(fn)(self) except: # Works for version > 0.6.0 return fn(self) return func_wrapper