[
  {
    "path": ".gitignore",
    "content": "\nData/\nlogs/\n\nVanillaVAE/version_0/\n\n__pycache__/\n.ipynb_checkpoints/\n\nRun.ipynb\n"
  },
  {
    "path": ".idea/.gitignore",
    "content": "# Default ignored files\n/workspace.xml\n"
  },
  {
    "path": ".idea/PyTorch-VAE.iml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager\">\n    <content url=\"file://$MODULE_DIR$\" />\n    <orderEntry type=\"jdk\" jdkName=\"Python 3.7 (main)\" jdkType=\"Python SDK\" />\n    <orderEntry type=\"sourceFolder\" forTests=\"false\" />\n    <orderEntry type=\"module\" module-name=\"somic_research\" />\n  </component>\n  <component name=\"ReSTService\">\n    <option name=\"DOC_DIR\" value=\"$MODULE_DIR$/../Project_S/somic_research/docs\" />\n  </component>\n</module>"
  },
  {
    "path": ".idea/inspectionProfiles/profiles_settings.xml",
    "content": "<component name=\"InspectionProjectProfileManager\">\n  <settings>\n    <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n    <version value=\"1.0\" />\n  </settings>\n</component>"
  },
  {
    "path": ".idea/misc.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectRootManager\" version=\"2\" project-jdk-name=\"Python 3.7 (main)\" project-jdk-type=\"Python SDK\" />\n  <component name=\"PyCharmProfessionalAdvertiser\">\n    <option name=\"shown\" value=\"true\" />\n  </component>\n</project>"
  },
  {
    "path": ".idea/modules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n      <module fileurl=\"file://$PROJECT_DIR$/.idea/PyTorch-VAE.iml\" filepath=\"$PROJECT_DIR$/.idea/PyTorch-VAE.iml\" />\n      <module fileurl=\"file://$PROJECT_DIR$/../Project_S/somic_research/.idea/somic_research.iml\" filepath=\"$PROJECT_DIR$/../Project_S/somic_research/.idea/somic_research.iml\" />\n    </modules>\n  </component>\n</project>"
  },
  {
    "path": ".idea/vcs.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping directory=\"\" vcs=\"Git\" />\n    <mapping directory=\"$PROJECT_DIR$/../Project_S/somic_research\" vcs=\"Git\" />\n  </component>\n</project>"
  },
  {
    "path": "LICENSE.md",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\t\t\t    Copyright Anand Krishnamoorthy Subramanian 2020\n\t\t\t               anandkrish894@gmail.com\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"{}\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright {yyyy} {name of copyright owner}\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align=\"center\">\n  <b>PyTorch VAE</b><br>\n</h1>\n\n<p align=\"center\">\n      <a href=\"https://www.python.org/\">\n        <img src=\"https://img.shields.io/badge/Python-3.5-ff69b4.svg\" /></a>\n       <a href= \"https://pytorch.org/\">\n        <img src=\"https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg\" /></a>\n       <a href= \"https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md\">\n        <img src=\"https://img.shields.io/badge/license-Apache2.0-blue.svg\" /></a>\n         <a href= \"https://twitter.com/intent/tweet?text=PyTorch-VAE:%20Collection%20of%20VAE%20models%20in%20PyTorch.&url=https://github.com/AntixK/PyTorch-VAE\">\n        <img src=\"https://img.shields.io/twitter/url/https/shields.io.svg?style=social\" /></a>\n\n</p>\n\n**Update 22/12/2021:** Added support for PyTorch Lightning 1.5.6 version and cleaned up the code.\n\nA collection of Variational AutoEncoders (VAEs) implemented in pytorch with focus on reproducibility. The aim of this project is to provide\na quick and simple working example for many of the cool VAE models out there. All the models are trained on the [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)\nfor consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates \na radically different architecture (Ex. VQ VAE uses Residual layers and no Batch-Norm, unlike other models).\nHere are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README.md#--results) of each model.\n\n### Requirements\n- Python >= 3.5\n- PyTorch >= 1.3\n- Pytorch Lightning >= 0.6.0 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32))\n- CUDA enabled computing device\n\n### Installation\n```\n$ git clone https://github.com/AntixK/PyTorch-VAE\n$ cd PyTorch-VAE\n$ pip install -r requirements.txt\n```\n\n### Usage\n```\n$ cd PyTorch-VAE\n$ python run.py -c configs/<config-file-name.yaml>\n```\n**Config file template**\n\n```yaml\nmodel_params:\n  name: \"<name of VAE model>\"\n  in_channels: 3\n  latent_dim: \n    .         # Other parameters required by the model\n    .\n    .\n\ndata_params:\n  data_path: \"<path to the celebA dataset>\"\n  train_batch_size: 64 # Better to have a square number\n  val_batch_size:  64\n  patch_size: 64  # Models are designed to work for this size\n  num_workers: 4\n  \nexp_params:\n  manual_seed: 1265\n  LR: 0.005\n  weight_decay:\n    .         # Other arguments required for training, like scheduler etc.\n    .\n    .\n\ntrainer_params:\n  gpus: 1         \n  max_epochs: 100\n  gradient_clip_val: 1.5\n    .\n    .\n    .\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"<experiment name>\"\n```\n\n**View TensorBoard Logs**\n```\n$ cd logs/<experiment name>/version_<the version you want>\n$ tensorboard --logdir .\n```\n\n**Note:** The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the [file](https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing) from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference.\n\n\n----\n<h2 align=\"center\">\n  <b>Results</b><br>\n</h2>\n\n\n| Model                                                                  | Paper                                            |Reconstruction | Samples |\n|------------------------------------------------------------------------|--------------------------------------------------|---------------|---------|\n| VAE ([Code][vae_code], [Config][vae_config])                           |[Link](https://arxiv.org/abs/1312.6114)           |    ![][2]     | ![][1]  |\n| Conditional VAE ([Code][cvae_code], [Config][cvae_config])             |[Link](https://openreview.net/forum?id=rJWXGDWd-H)|    ![][16]    | ![][15] |\n| WAE - MMD (RBF Kernel) ([Code][wae_code], [Config][wae_rbf_config])    |[Link](https://arxiv.org/abs/1711.01558)          |    ![][4]     | ![][3]  |\n| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config])    |[Link](https://arxiv.org/abs/1711.01558)          |    ![][6]     | ![][5]  |\n| Beta-VAE ([Code][bvae_code], [Config][bbvae_config])                   |[Link](https://openreview.net/forum?id=Sy2fzU9gl) |    ![][8]     | ![][7]  |\n| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config])      |[Link](https://arxiv.org/abs/1804.03599)          |    ![][22]    | ![][21] |\n| Beta-TC-VAE ([Code][btcvae_code], [Config][btcvae_config])             |[Link](https://arxiv.org/abs/1802.04942)          |    ![][34]    | ![][33] |\n| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config])              |[Link](https://arxiv.org/abs/1509.00519)          |    ![][10]    | ![][9]  |\n| MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config])    |[Link](https://arxiv.org/abs/1802.04537)          |    ![][30]    | ![][29] |\n| DFCVAE   ([Code][dfcvae_code], [Config][dfcvae_config])                |[Link](https://arxiv.org/abs/1610.00291)          |    ![][12]    | ![][11] |\n| MSSIM VAE    ([Code][mssimvae_code], [Config][mssimvae_config])        |[Link](https://arxiv.org/abs/1511.06409)          |    ![][14]    | ![][13] |\n| Categorical VAE   ([Code][catvae_code], [Config][catvae_config])       |[Link](https://arxiv.org/abs/1611.01144)          |    ![][18]    | ![][17] |\n| Joint VAE ([Code][jointvae_code], [Config][jointvae_config])           |[Link](https://arxiv.org/abs/1804.00104)          |    ![][20]    | ![][19] |\n| Info VAE   ([Code][infovae_code], [Config][infovae_config])            |[Link](https://arxiv.org/abs/1706.02262)          |    ![][24]    | ![][23] |\n| LogCosh VAE   ([Code][logcoshvae_code], [Config][logcoshvae_config])   |[Link](https://openreview.net/forum?id=rkglvsC9Ym)|    ![][26]    | ![][25] |\n| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config])      |[Link](https://arxiv.org/abs/1804.01947)          |    ![][28]    | ![][27] |\n| VQ-VAE (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937)          |    ![][31]    | **N/A** |\n| DIP VAE ([Code][dipvae_code], [Config][dipvae_config])                 |[Link](https://arxiv.org/abs/1711.00848)          |    ![][36]    | ![][35] |\n\n\n<!-- | Gamma VAE             |[Link](https://arxiv.org/abs/1610.05683)          |    ![][16]    | ![][15] |-->\n\n<!--\n### TODO\n- [x] VanillaVAE\n- [x] Beta VAE\n- [x] DFC VAE\n- [x] MSSIM VAE\n- [x] IWAE\n- [x] MIWAE\n- [x] WAE-MMD\n- [x] Conditional VAE- [ ] PixelVAE\n- [x] Categorical VAE (Gumbel-Softmax VAE)\n- [x] Joint VAE\n- [x] Disentangled beta-VAE\n- [x] InfoVAE\n- [x] LogCosh VAE\n- [x] SWAE\n- [x] VQVAE\n- [x] Beta TC-VAE\n- [x] DIP VAE\n- [ ] Ladder VAE (Doesn't work well)\n- [ ] Gamma VAE (Doesn't work well) \n- [ ] Vamp VAE (Doesn't work well)\n-->\n\n### Contributing\nIf you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file,\nI would be happy to include your result (along with your config file) in this repo, citing your name 😊.\n\nAdditionally, if you would like to contribute some models, please submit a PR.\n\n### License\n**Apache License 2.0**\n\n| Permissions      | Limitations       | Conditions                       |\n|------------------|-------------------|----------------------------------|\n| ✔️ Commercial use |  ❌  Trademark use |  ⓘ License and copyright notice | \n| ✔️ Modification   |  ❌  Liability     |  ⓘ State changes                |\n| ✔️ Distribution   |  ❌  Warranty      |                                  |\n| ✔️ Patent use     |                   |                                  |\n| ✔️ Private use    |                   |                                  |\n\n\n### Citation\n```\n@misc{Subramanian2020,\n  author = {Subramanian, A.K},\n  title = {PyTorch-VAE},\n  year = {2020},\n  publisher = {GitHub},\n  journal = {GitHub repository},\n  howpublished = {\\url{https://github.com/AntixK/PyTorch-VAE}}\n}\n```\n-----------\n\n[vae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py\n[cvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cvae.py\n[bvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py\n[btcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/betatc_vae.py\n[wae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/wae_mmd.py\n[iwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/iwae.py\n[miwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/miwae.py\n[swae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/swae.py\n[jointvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/joint_vae.py\n[dfcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dfcvae.py\n[mssimvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/mssim_vae.py\n[logcoshvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/logcosh_vae.py\n[catvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cat_vae.py\n[infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py\n[vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py\n[dipvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dip_vae.py\n\n[vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml\n[cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml\n[bbvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bbvae.yaml\n[bhvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bhvae.yaml\n[btcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/betatc_vae.yaml\n[wae_rbf_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_rbf.yaml\n[wae_imq_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_imq.yaml\n[iwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/iwae.yaml\n[miwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/miwae.yaml\n[swae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/swae.yaml\n[jointvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/joint_vae.yaml\n[dfcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dfc_vae.yaml\n[mssimvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/mssim_vae.yaml\n[logcoshvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/logcosh_vae.yaml\n[catvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cat_vae.yaml\n[infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml\n[vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml\n[dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml\n\n[1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png\n[2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png\n[3]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_RBF_18.png\n[4]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_RBF_19.png\n[5]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_IMQ_15.png\n[6]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_IMQ_15.png\n[7]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_H_20.png\n[8]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_H_20.png\n[9]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/IWAE_19.png\n[10]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_IWAE_19.png\n[11]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DFCVAE_49.png\n[12]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DFCVAE_49.png\n[13]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MSSIMVAE_29.png\n[14]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MSSIMVAE_29.png\n[15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_20.png\n[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png\n[17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_49.png\n[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_49.png\n[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_49.png\n[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_49.png\n[21]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_B_35.png\n[22]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_B_35.png\n[23]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/InfoVAE_31.png\n[24]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_InfoVAE_31.png\n[25]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/LogCoshVAE_49.png\n[26]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_LogCoshVAE_49.png\n[27]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/SWAE_49.png\n[28]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_SWAE_49.png\n[29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png\n[30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png\n[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png\n[33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_49.png\n[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png\n[35]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DIPVAE_83.png\n[36]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DIPVAE_83.png\n\n[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg\n[python-url]: https://www.python.org/\n\n[pytorch-image]: https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg\n[pytorch-url]: https://pytorch.org/\n\n[twitter-image]:https://img.shields.io/twitter/url/https/shields.io.svg?style=social\n[twitter-url]:https://twitter.com/intent/tweet?text=Neural%20Blocks-Easy%20to%20use%20neural%20net%20blocks%20for%20fast%20prototyping.&url=https://github.com/AntixK/NeuralBlocks\n\n\n[license-image]:https://img.shields.io/badge/license-Apache2.0-blue.svg\n[license-url]:https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md\n"
  },
  {
    "path": "configs/bbvae.yaml",
    "content": "model_params:\n  name: 'BetaVAE'\n  in_channels: 3\n  latent_dim: 128\n  loss_type: 'B'\n  gamma: 10.0\n  max_capacity: 25\n  Capacity_max_iter: 10000\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n  \nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  manual_seed: 1265\n  name: 'BetaVAE'\n"
  },
  {
    "path": "configs/betatc_vae.yaml",
    "content": "model_params:\n  name: 'BetaTCVAE'\n  in_channels: 3\n  latent_dim: 10\n  anneal_steps: 10000\n  alpha: 1.\n  beta:  6.\n  gamma: 1.\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: 'BetaTCVAE'\n"
  },
  {
    "path": "configs/bhvae.yaml",
    "content": "model_params:\n  name: 'BetaVAE'\n  in_channels: 3\n  latent_dim: 128\n  loss_type: 'H'\n  beta: 10.\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: 'BetaVAE'\n"
  },
  {
    "path": "configs/cat_vae.yaml",
    "content": "model_params:\n  name: 'CategoricalVAE'\n  in_channels: 3\n  latent_dim: 512\n  categorical_dim: 40\n  temperature: 0.5\n  anneal_rate: 0.00003\n  anneal_interval: 100\n  alpha: 1.0\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"CategoricalVAE\"\n"
  },
  {
    "path": "configs/cvae.yaml",
    "content": "model_params:\n  name: 'ConditionalVAE'\n  in_channels: 3\n  num_classes: 40\n  latent_dim: 128\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"ConditionalVAE\""
  },
  {
    "path": "configs/dfc_vae.yaml",
    "content": "model_params:\n  name: 'DFCVAE'\n  in_channels: 3\n  latent_dim: 128\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"DFCVAE\"\n"
  },
  {
    "path": "configs/dip_vae.yaml",
    "content": "model_params:\n  name: 'DIPVAE'\n  in_channels: 3\n  latent_dim: 128\n  lambda_diag: 0.05\n  lambda_offdiag: 0.1\n\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.001\n  weight_decay: 0.0\n  scheduler_gamma: 0.97\n  kld_weight: 1\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"DIPVAE\"\n  manual_seed: 1265\n"
  },
  {
    "path": "configs/factorvae.yaml",
    "content": "model_params:\n  name: 'FactorVAE'\n  in_channels: 3\n  latent_dim: 128\n  gamma: 6.4\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  submodel: 'discriminator'\n  retain_first_backpass: True\n  LR: 0.005\n  weight_decay: 0.0\n  LR_2: 0.005\n  scheduler_gamma_2: 0.95\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"FactorVAE\"  \n  \n\n"
  },
  {
    "path": "configs/gammavae.yaml",
    "content": "model_params:\n  name: 'GammaVAE'\n  in_channels: 3\n  latent_dim: 128\n  gamma_shape: 8.\n  prior_shape: 2.\n  prior_rate: 1.\n\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.003\n  weight_decay: 0.00005\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n  gradient_clip_val: 0.8\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"GammaVAE\"\n"
  },
  {
    "path": "configs/hvae.yaml",
    "content": "model_params:\n  name: 'HVAE'\n  in_channels: 3\n  latent1_dim: 64\n  latent2_dim: 64\n  pseudo_input_size: 128\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"VampVAE\"\n"
  },
  {
    "path": "configs/infovae.yaml",
    "content": "model_params:\n  name: 'InfoVAE'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 110  # MMD weight\n  kernel_type: 'imq'\n  alpha: -9.0     # KLD weight\n  beta: 10.5      # Reconstruction weight\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n  gradient_clip_val: 0.8\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"InfoVAE\"\n  manual_seed: 1265\n\n\n\n\n"
  },
  {
    "path": "configs/iwae.yaml",
    "content": "model_params:\n  name: 'IWAE'\n  in_channels: 3\n  latent_dim: 128\n  num_samples: 5\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.007\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"IWAE\"\n"
  },
  {
    "path": "configs/joint_vae.yaml",
    "content": "model_params:\n  name: 'JointVAE'\n  in_channels: 3\n  latent_dim: 512\n  categorical_dim: 40\n  latent_min_capacity: 0.0\n  latent_max_capacity: 20.0\n  latent_gamma: 10.\n  latent_num_iter: 25000\n  categorical_min_capacity: 0.0\n  categorical_max_capacity: 20.0\n  categorical_gamma: 10.\n  categorical_num_iter: 25000\n  temperature: 0.5\n  anneal_rate: 0.00003\n  anneal_interval: 100\n  alpha: 10.0\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"JointVAE\"\n\n"
  },
  {
    "path": "configs/logcosh_vae.yaml",
    "content": "model_params:\n  name: 'LogCoshVAE'\n  in_channels: 3\n  latent_dim: 128\n  alpha: 10.0\n  beta: 1.0\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.97\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"LogCoshVAE\"\n\n"
  },
  {
    "path": "configs/lvae.yaml",
    "content": "model_params:\n  name: 'LVAE'\n  in_channels: 3\n  latent_dims: [4,8,16,32,128]\n  hidden_dims: [32, 64,128, 256, 512]\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"LVAE\"\n"
  },
  {
    "path": "configs/miwae.yaml",
    "content": "model_params:\n  name: 'MIWAE'\n  in_channels: 3\n  latent_dim: 128\n  num_samples: 5\n  num_estimates: 3\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"MIWAE\"\n\n"
  },
  {
    "path": "configs/mssim_vae.yaml",
    "content": "model_params:\n  name: 'MSSIMVAE'\n  in_channels: 3\n  latent_dim: 128\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"MSSIMVAE\"\n"
  },
  {
    "path": "configs/swae.yaml",
    "content": "model_params:\n  name: 'SWAE'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 100\n  wasserstein_deg: 2.0\n  num_projections: 200\n  projection_dist: \"normal\" #\"cauchy\"\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"SWAE\"\n\n\n\n\n\n"
  },
  {
    "path": "configs/vae.yaml",
    "content": "model_params:\n  name: 'VanillaVAE'\n  in_channels: 3\n  latent_dim: 128\n\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 100\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"VanillaVAE\"\n  \n"
  },
  {
    "path": "configs/vampvae.yaml",
    "content": "model_params:\n  name: 'VampVAE'\n  in_channels: 3\n  latent_dim: 128\n\nexp_params:\n  dataset: celeba\n  data_path: \"../../shared/Data/\"\n  img_size: 64\n  batch_size: 144 # Better to have a square number\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n\ntrainer_params:\n  gpus: 1\n  max_nb_epochs: 50\n  max_epochs: 50\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"VampVAE\"\n  manual_seed: 1265\n"
  },
  {
    "path": "configs/vq_vae.yaml",
    "content": "model_params:\n  name: 'VQVAE'\n  in_channels: 3\n  embedding_dim: 64\n  num_embeddings: 512\n  img_size: 64\n  beta: 0.25\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.0\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: 'VQVAE'\n"
  },
  {
    "path": "configs/wae_mmd_imq.yaml",
    "content": "model_params:\n  name: 'WAE_MMD'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 100\n  kernel_type: 'imq'\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"WassersteinVAE_IMQ\"\n\n\n\n\n\n"
  },
  {
    "path": "configs/wae_mmd_rbf.yaml",
    "content": "model_params:\n  name: 'WAE_MMD'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 5000\n  kernel_type: 'rbf'\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size: 64\n  val_batch_size:  64\n  patch_size: 64\n  num_workers: 4\n\n\nexp_params:\n  LR: 0.005\n  weight_decay: 0.0\n  scheduler_gamma: 0.95\n  kld_weight: 0.00025\n  manual_seed: 1265\n\ntrainer_params:\n  gpus: [1]\n  max_epochs: 10\n\nlogging_params:\n  save_dir: \"logs/\"\n  name: \"WassersteinVAE_RBF\"\n\n\n\n\n\n"
  },
  {
    "path": "dataset.py",
    "content": "import os\nimport torch\nfrom torch import Tensor\nfrom pathlib import Path\nfrom typing import List, Optional, Sequence, Union, Any, Callable\nfrom torchvision.datasets.folder import default_loader\nfrom pytorch_lightning import LightningDataModule\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision import transforms\nfrom torchvision.datasets import CelebA\nimport zipfile\n\n\n# Add your custom dataset class here\nclass MyDataset(Dataset):\n    def __init__(self):\n        pass\n    \n    \n    def __len__(self):\n        pass\n    \n    def __getitem__(self, idx):\n        pass\n\n\nclass MyCelebA(CelebA):\n    \"\"\"\n    A work-around to address issues with pytorch's celebA dataset class.\n    \n    Download and Extract\n    URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing\n    \"\"\"\n    \n    def _check_integrity(self) -> bool:\n        return True\n    \n    \n\nclass OxfordPets(Dataset):\n    \"\"\"\n    URL = https://www.robots.ox.ac.uk/~vgg/data/pets/\n    \"\"\"\n    def __init__(self, \n                 data_path: str, \n                 split: str,\n                 transform: Callable,\n                **kwargs):\n        self.data_dir = Path(data_path) / \"OxfordPets\"        \n        self.transforms = transform\n        imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg'])\n        \n        self.imgs = imgs[:int(len(imgs) * 0.75)] if split == \"train\" else imgs[int(len(imgs) * 0.75):]\n    \n    def __len__(self):\n        return len(self.imgs)\n    \n    def __getitem__(self, idx):\n        img = default_loader(self.imgs[idx])\n        \n        if self.transforms is not None:\n            img = self.transforms(img)\n        \n        return img, 0.0 # dummy datat to prevent breaking \n\nclass VAEDataset(LightningDataModule):\n    \"\"\"\n    PyTorch Lightning data module \n\n    Args:\n        data_dir: root directory of your dataset.\n        train_batch_size: the batch size to use during training.\n        val_batch_size: the batch size to use during validation.\n        patch_size: the size of the crop to take from the original images.\n        num_workers: the number of parallel workers to create to load data\n            items (see PyTorch's Dataloader documentation for more details).\n        pin_memory: whether prepared items should be loaded into pinned memory\n            or not. This can improve performance on GPUs.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_path: str,\n        train_batch_size: int = 8,\n        val_batch_size: int = 8,\n        patch_size: Union[int, Sequence[int]] = (256, 256),\n        num_workers: int = 0,\n        pin_memory: bool = False,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.data_dir = data_path\n        self.train_batch_size = train_batch_size\n        self.val_batch_size = val_batch_size\n        self.patch_size = patch_size\n        self.num_workers = num_workers\n        self.pin_memory = pin_memory\n\n    def setup(self, stage: Optional[str] = None) -> None:\n#       =========================  OxfordPets Dataset  =========================\n            \n#         train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),\n#                                               transforms.CenterCrop(self.patch_size),\n# #                                               transforms.Resize(self.patch_size),\n#                                               transforms.ToTensor(),\n#                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])\n        \n#         val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),\n#                                             transforms.CenterCrop(self.patch_size),\n# #                                             transforms.Resize(self.patch_size),\n#                                             transforms.ToTensor(),\n#                                               transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])\n\n#         self.train_dataset = OxfordPets(\n#             self.data_dir,\n#             split='train',\n#             transform=train_transforms,\n#         )\n        \n#         self.val_dataset = OxfordPets(\n#             self.data_dir,\n#             split='val',\n#             transform=val_transforms,\n#         )\n        \n#       =========================  CelebA Dataset  =========================\n    \n        train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),\n                                              transforms.CenterCrop(148),\n                                              transforms.Resize(self.patch_size),\n                                              transforms.ToTensor(),])\n        \n        val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),\n                                            transforms.CenterCrop(148),\n                                            transforms.Resize(self.patch_size),\n                                            transforms.ToTensor(),])\n        \n        self.train_dataset = MyCelebA(\n            self.data_dir,\n            split='train',\n            transform=train_transforms,\n            download=False,\n        )\n        \n        # Replace CelebA with your dataset\n        self.val_dataset = MyCelebA(\n            self.data_dir,\n            split='test',\n            transform=val_transforms,\n            download=False,\n        )\n#       ===============================================================\n        \n    def train_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.train_batch_size,\n            num_workers=self.num_workers,\n            shuffle=True,\n            pin_memory=self.pin_memory,\n        )\n\n    def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:\n        return DataLoader(\n            self.val_dataset,\n            batch_size=self.val_batch_size,\n            num_workers=self.num_workers,\n            shuffle=False,\n            pin_memory=self.pin_memory,\n        )\n    \n    def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:\n        return DataLoader(\n            self.val_dataset,\n            batch_size=144,\n            num_workers=self.num_workers,\n            shuffle=True,\n            pin_memory=self.pin_memory,\n        )\n     "
  },
  {
    "path": "experiment.py",
    "content": "import os\nimport math\nimport torch\nfrom torch import optim\nfrom models import BaseVAE\nfrom models.types_ import *\nfrom utils import data_loader\nimport pytorch_lightning as pl\nfrom torchvision import transforms\nimport torchvision.utils as vutils\nfrom torchvision.datasets import CelebA\nfrom torch.utils.data import DataLoader\n\n\nclass VAEXperiment(pl.LightningModule):\n\n    def __init__(self,\n                 vae_model: BaseVAE,\n                 params: dict) -> None:\n        super(VAEXperiment, self).__init__()\n\n        self.model = vae_model\n        self.params = params\n        self.curr_device = None\n        self.hold_graph = False\n        try:\n            self.hold_graph = self.params['retain_first_backpass']\n        except:\n            pass\n\n    def forward(self, input: Tensor, **kwargs) -> Tensor:\n        return self.model(input, **kwargs)\n\n    def training_step(self, batch, batch_idx, optimizer_idx = 0):\n        real_img, labels = batch\n        self.curr_device = real_img.device\n\n        results = self.forward(real_img, labels = labels)\n        train_loss = self.model.loss_function(*results,\n                                              M_N = self.params['kld_weight'], #al_img.shape[0]/ self.num_train_imgs,\n                                              optimizer_idx=optimizer_idx,\n                                              batch_idx = batch_idx)\n\n        self.log_dict({key: val.item() for key, val in train_loss.items()}, sync_dist=True)\n\n        return train_loss['loss']\n\n    def validation_step(self, batch, batch_idx, optimizer_idx = 0):\n        real_img, labels = batch\n        self.curr_device = real_img.device\n\n        results = self.forward(real_img, labels = labels)\n        val_loss = self.model.loss_function(*results,\n                                            M_N = 1.0, #real_img.shape[0]/ self.num_val_imgs,\n                                            optimizer_idx = optimizer_idx,\n                                            batch_idx = batch_idx)\n\n        self.log_dict({f\"val_{key}\": val.item() for key, val in val_loss.items()}, sync_dist=True)\n\n        \n    def on_validation_end(self) -> None:\n        self.sample_images()\n        \n    def sample_images(self):\n        # Get sample reconstruction image            \n        test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader()))\n        test_input = test_input.to(self.curr_device)\n        test_label = test_label.to(self.curr_device)\n\n#         test_input, test_label = batch\n        recons = self.model.generate(test_input, labels = test_label)\n        vutils.save_image(recons.data,\n                          os.path.join(self.logger.log_dir , \n                                       \"Reconstructions\", \n                                       f\"recons_{self.logger.name}_Epoch_{self.current_epoch}.png\"),\n                          normalize=True,\n                          nrow=12)\n\n        try:\n            samples = self.model.sample(144,\n                                        self.curr_device,\n                                        labels = test_label)\n            vutils.save_image(samples.cpu().data,\n                              os.path.join(self.logger.log_dir , \n                                           \"Samples\",      \n                                           f\"{self.logger.name}_Epoch_{self.current_epoch}.png\"),\n                              normalize=True,\n                              nrow=12)\n        except Warning:\n            pass\n\n    def configure_optimizers(self):\n\n        optims = []\n        scheds = []\n\n        optimizer = optim.Adam(self.model.parameters(),\n                               lr=self.params['LR'],\n                               weight_decay=self.params['weight_decay'])\n        optims.append(optimizer)\n        # Check if more than 1 optimizer is required (Used for adversarial training)\n        try:\n            if self.params['LR_2'] is not None:\n                optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(),\n                                        lr=self.params['LR_2'])\n                optims.append(optimizer2)\n        except:\n            pass\n\n        try:\n            if self.params['scheduler_gamma'] is not None:\n                scheduler = optim.lr_scheduler.ExponentialLR(optims[0],\n                                                             gamma = self.params['scheduler_gamma'])\n                scheds.append(scheduler)\n\n                # Check if another scheduler is required for the second optimizer\n                try:\n                    if self.params['scheduler_gamma_2'] is not None:\n                        scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1],\n                                                                      gamma = self.params['scheduler_gamma_2'])\n                        scheds.append(scheduler2)\n                except:\n                    pass\n                return optims, scheds\n        except:\n            return optims\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from .base import *\nfrom .vanilla_vae import *\nfrom .gamma_vae import *\nfrom .beta_vae import *\nfrom .wae_mmd import *\nfrom .cvae import *\nfrom .hvae import *\nfrom .vampvae import *\nfrom .iwae import *\nfrom .dfcvae import *\nfrom .mssim_vae import MSSIMVAE\nfrom .fvae import *\nfrom .cat_vae import *\nfrom .joint_vae import *\nfrom .info_vae import *\n# from .twostage_vae import *\nfrom .lvae import LVAE\nfrom .logcosh_vae import *\nfrom .swae import *\nfrom .miwae import *\nfrom .vq_vae import *\nfrom .betatc_vae import *\nfrom .dip_vae import *\n\n\n# Aliases\nVAE = VanillaVAE\nGaussianVAE = VanillaVAE\nCVAE = ConditionalVAE\nGumbelVAE = CategoricalVAE\n\nvae_models = {'HVAE':HVAE,\n              'LVAE':LVAE,\n              'IWAE':IWAE,\n              'SWAE':SWAE,\n              'MIWAE':MIWAE,\n              'VQVAE':VQVAE,\n              'DFCVAE':DFCVAE,\n              'DIPVAE':DIPVAE,\n              'BetaVAE':BetaVAE,\n              'InfoVAE':InfoVAE,\n              'WAE_MMD':WAE_MMD,\n              'VampVAE': VampVAE,\n              'GammaVAE':GammaVAE,\n              'MSSIMVAE':MSSIMVAE,\n              'JointVAE':JointVAE,\n              'BetaTCVAE':BetaTCVAE,\n              'FactorVAE':FactorVAE,\n              'LogCoshVAE':LogCoshVAE,\n              'VanillaVAE':VanillaVAE,\n              'ConditionalVAE':ConditionalVAE,\n              'CategoricalVAE':CategoricalVAE}\n"
  },
  {
    "path": "models/base.py",
    "content": "from .types_ import *\nfrom torch import nn\nfrom abc import abstractmethod\n\nclass BaseVAE(nn.Module):\n    \n    def __init__(self) -> None:\n        super(BaseVAE, self).__init__()\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        raise NotImplementedError\n\n    def decode(self, input: Tensor) -> Any:\n        raise NotImplementedError\n\n    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:\n        raise NotImplementedError\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        raise NotImplementedError\n\n    @abstractmethod\n    def forward(self, *inputs: Tensor) -> Tensor:\n        pass\n\n    @abstractmethod\n    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:\n        pass\n\n\n\n"
  },
  {
    "path": "models/beta_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass BetaVAE(BaseVAE):\n\n    num_iter = 0 # Global static variable to keep track of iterations\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 beta: int = 4,\n                 gamma:float = 1000.,\n                 max_capacity: int = 25,\n                 Capacity_max_iter: int = 1e5,\n                 loss_type:str = 'B',\n                 **kwargs) -> None:\n        super(BetaVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.beta = beta\n        self.gamma = gamma\n        self.loss_type = loss_type\n        self.C_max = torch.Tensor([max_capacity])\n        self.C_stop_iter = Capacity_max_iter\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Will a single z be enough ti compute the expectation\n        for the loss??\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> Tensor:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        self.num_iter += 1\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset\n\n        recons_loss =F.mse_loss(recons, input)\n\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl\n            loss = recons_loss + self.beta * kld_weight * kld_loss\n        elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf\n            self.C_max = self.C_max.to(input.device)\n            C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])\n            loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()\n        else:\n            raise ValueError('Undefined loss type.')\n\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/betatc_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\nimport math\n\n\nclass BetaTCVAE(BaseVAE):\n    num_iter = 0 # Global static variable to keep track of iterations\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 anneal_steps: int = 200,\n                 alpha: float = 1.,\n                 beta: float =  6.,\n                 gamma: float = 1.,\n                 **kwargs) -> None:\n        super(BetaTCVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.anneal_steps = anneal_steps\n\n        self.alpha = alpha\n        self.beta = beta\n        self.gamma = gamma\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 32, 32, 32]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 4, stride= 2, padding  = 1),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n\n        self.fc = nn.Linear(hidden_dims[-1]*16, 256)\n        self.fc_mu = nn.Linear(256, latent_dim)\n        self.fc_var = nn.Linear(256, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, 256 *  2)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.LeakyReLU())\n            )\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n\n        result = torch.flatten(result, start_dim=1)\n        result = self.fc(result)\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 32, 4, 4)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var, z]\n\n    def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):\n        \"\"\"\n        Computes the log pdf of the Gaussian with parameters mu and logvar at x\n        :param x: (Tensor) Point at whichGaussian PDF is to be evaluated\n        :param mu: (Tensor) Mean of the Gaussian distribution\n        :param logvar: (Tensor) Log variance of the Gaussian distribution\n        :return:\n        \"\"\"\n        norm = - 0.5 * (math.log(2 * math.pi) + logvar)\n        log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))\n        return log_density\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n            \n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n        z = args[4]\n\n        weight = 1 #kwargs['M_N']  # Account for the minibatch samples from the dataset\n\n        recons_loss =F.mse_loss(recons, input, reduction='sum')\n\n        log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)\n\n        zeros = torch.zeros_like(z)\n        log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1)\n\n        batch_size, latent_dim = z.shape\n        mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim),\n                                                mu.view(1, batch_size, latent_dim),\n                                                log_var.view(1, batch_size, latent_dim))\n\n        # Reference\n        # [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54\n        dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size\n        strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1))\n        importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device)\n        importance_weights.view(-1)[::batch_size] = 1 / dataset_size\n        importance_weights.view(-1)[1::batch_size] = strat_weight\n        importance_weights[batch_size - 2, 0] = strat_weight\n        log_importance_weights = importance_weights.log()\n\n        mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1)\n\n        log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)\n        log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)\n\n        mi_loss  = (log_q_zx - log_q_z).mean()\n        tc_loss = (log_q_z - log_prod_q_z).mean()\n        kld_loss = (log_prod_q_z - log_p_z).mean()\n\n        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        if self.training:\n            self.num_iter += 1\n            anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1)\n        else:\n            anneal_rate = 1.\n\n        loss = recons_loss/batch_size + \\\n               self.alpha * mi_loss + \\\n               weight * (self.beta * tc_loss +\n                         anneal_rate * self.gamma * kld_loss)\n        \n        return {'loss': loss,\n                'Reconstruction_Loss':recons_loss,\n                'KLD':kld_loss,\n                'TC_Loss':tc_loss,\n                'MI_Loss':mi_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/cat_vae.py",
    "content": "import torch\nimport numpy as np\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass CategoricalVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 categorical_dim: int = 40, # Num classes\n                 hidden_dims: List = None,\n                 temperature: float = 0.5,\n                 anneal_rate: float = 3e-5,\n                 anneal_interval: int = 100, # every 100 batches\n                 alpha: float = 30.,\n                 **kwargs) -> None:\n        super(CategoricalVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.categorical_dim = categorical_dim\n        self.temp = temperature\n        self.min_temp = temperature\n        self.anneal_rate = anneal_rate\n        self.anneal_interval = anneal_interval\n        self.alpha = alpha\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_z = nn.Linear(hidden_dims[-1]*4,\n                               self.latent_dim * self.categorical_dim)\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim\n                                       , hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n        self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [B x C x H x W]\n        :return: (Tensor) Latent code [B x D x Q]\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        z = self.fc_z(result)\n        z = z.view(-1, self.latent_dim, self.categorical_dim)\n        return [z]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D x Q]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:\n        \"\"\"\n        Gumbel-softmax trick to sample from Categorical Distribution\n        :param z: (Tensor) Latent Codes [B x D x Q]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        # Sample from Gumbel\n        u = torch.rand_like(z)\n        g = - torch.log(- torch.log(u + eps) + eps)\n\n        # Gumbel-Softmax sample\n        s = F.softmax((z + g) / self.temp, dim=-1)\n        s = s.view(-1, self.latent_dim * self.categorical_dim)\n        return s\n\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        q = self.encode(input)[0]\n        z = self.reparameterize(q)\n        return  [self.decode(z), input, q]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        q = args[2]\n\n        q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        batch_idx = kwargs['batch_idx']\n\n        # Anneal the temperature at regular intervals\n        if batch_idx % self.anneal_interval == 0 and self.training:\n            self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),\n                                   self.min_temp)\n\n        recons_loss =F.mse_loss(recons, input, reduction='mean')\n\n        # KL divergence between gumbel-softmax distribution\n        eps = 1e-7\n\n        # Entropy of the logits\n        h1 = q_p * torch.log(q_p + eps)\n\n        # Cross entropy with the categorical distribution\n        h2 = q_p * np.log(1. / self.categorical_dim + eps)\n        kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)\n\n        # kld_weight = 1.2\n        loss = self.alpha * recons_loss + kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        # [S x D x Q]\n\n        M = num_samples * self.latent_dim\n        np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)\n        np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1\n        np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim])\n        z = torch.from_numpy(np_y)\n\n        # z = self.sampling_dist.sample((num_samples * self.latent_dim, ))\n        z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device)\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/cvae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass ConditionalVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 num_classes: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 img_size:int = 64,\n                 **kwargs) -> None:\n        super(ConditionalVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.img_size = img_size\n\n        self.embed_class = nn.Linear(num_classes, img_size * img_size)\n        self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        in_channels += 1 # To account for the extra label channel\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Will a single z be enough ti compute the expectation\n        for the loss??\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        y = kwargs['labels'].float()\n        embedded_class = self.embed_class(y)\n        embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)\n        embedded_input = self.embed_data(input)\n\n        x = torch.cat([embedded_input, embedded_class], dim = 1)\n        mu, log_var = self.encode(x)\n\n        z = self.reparameterize(mu, log_var)\n\n        z = torch.cat([z, y], dim = 1)\n        return  [self.decode(z), input, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n\n        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input)\n\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        loss = recons_loss + kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int,\n               **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        y = kwargs['labels'].float()\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        z = torch.cat([z, y], dim=1)\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x, **kwargs)[0]"
  },
  {
    "path": "models/dfcvae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torchvision.models import vgg19_bn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass DFCVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 alpha:float = 1,\n                 beta:float = 0.5,\n                 **kwargs) -> None:\n        super(DFCVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.alpha = alpha\n        self.beta = beta\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n        self.feature_network = vgg19_bn(pretrained=True)\n\n        # Freeze the pretrained feature network\n        for param in self.feature_network.parameters():\n            param.requires_grad = False\n\n        self.feature_network.eval()\n\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        recons = self.decode(z)\n\n        recons_features = self.extract_features(recons)\n        input_features = self.extract_features(input)\n\n        return  [recons, input, recons_features, input_features, mu, log_var]\n\n    def extract_features(self,\n                         input: Tensor,\n                         feature_layers: List = None) -> List[Tensor]:\n        \"\"\"\n        Extracts the features from the pretrained model\n        at the layers indicated by feature_layers.\n        :param input: (Tensor) [B x C x H x W]\n        :param feature_layers: List of string of IDs\n        :return: List of the extracted features\n        \"\"\"\n        if feature_layers is None:\n            feature_layers = ['14', '24', '34', '43']\n        features = []\n        result = input\n        for (key, module) in self.feature_network.features._modules.items():\n            result = module(result)\n            if(key in feature_layers):\n                features.append(result)\n\n        return features\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        recons_features = args[2]\n        input_features = args[3]\n        mu = args[4]\n        log_var = args[5]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input)\n\n        feature_loss = 0.0\n        for (r, i) in zip(recons_features, input_features):\n            feature_loss += F.mse_loss(r, i)\n\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/dip_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass DIPVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 lambda_diag: float = 10.,\n                 lambda_offdiag: float = 5.,\n                 **kwargs) -> None:\n        super(DIPVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.lambda_diag = lambda_diag\n        self.lambda_offdiag = lambda_offdiag\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input, reduction='sum')\n\n\n        kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        # DIP Loss\n        centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D]\n        cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D]\n\n        # Add Variance for DIP Loss II\n        cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D]\n        # For DIp Loss I\n        # cov_z = cov_mu\n\n        cov_diag = torch.diag(cov_z) # [D]\n        cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D]\n        dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \\\n                   self.lambda_diag * torch.sum((cov_diag - 1) ** 2)\n\n        loss = recons_loss + kld_weight * kld_loss + dip_loss\n        return {'loss': loss,\n                'Reconstruction_Loss':recons_loss,\n                'KLD':-kld_loss,\n                'DIP_Loss':dip_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/fvae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass FactorVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 gamma: float = 40.,\n                 **kwargs) -> None:\n        super(FactorVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.gamma = gamma\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n        # Discriminator network for the Total Correlation (TC) loss\n        self.discriminator = nn.Sequential(nn.Linear(self.latent_dim, 1000),\n                                          nn.BatchNorm1d(1000),\n                                          nn.LeakyReLU(0.2),\n                                          nn.Linear(1000, 1000),\n                                          nn.BatchNorm1d(1000),\n                                          nn.LeakyReLU(0.2),\n                                          nn.Linear(1000, 1000),\n                                          nn.BatchNorm1d(1000),\n                                          nn.LeakyReLU(0.2),\n                                          nn.Linear(1000, 2))\n        self.D_z_reserve = None\n\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var, z]\n\n    def permute_latent(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Permutes each of the latent codes in the batch\n        :param z: [B x D]\n        :return: [B x D]\n        \"\"\"\n        B, D = z.size()\n\n        # Returns a shuffled inds for each latent code in the batch\n        inds = torch.cat([(D *i) + torch.randperm(D) for i in range(B)])\n        return z.view(-1)[inds].view(B, D)\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n        z = args[4]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        optimizer_idx = kwargs['optimizer_idx']\n\n        # Update the VAE\n        if optimizer_idx == 0:\n            recons_loss =F.mse_loss(recons, input)\n            kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n            self.D_z_reserve = self.discriminator(z)\n            vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean()\n\n            loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss\n\n            # print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}')\n            return {'loss': loss,\n                    'Reconstruction_Loss':recons_loss,\n                    'KLD':-kld_loss,\n                    'VAE_TC_Loss': vae_tc_loss}\n\n        # Update the Discriminator\n        elif optimizer_idx == 1:\n            device = input.device\n            true_labels = torch.ones(input.size(0), dtype= torch.long,\n                                     requires_grad=False).to(device)\n            false_labels = torch.zeros(input.size(0), dtype= torch.long,\n                                       requires_grad=False).to(device)\n\n            z = z.detach() # Detach so that VAE is not trained again\n            z_perm = self.permute_latent(z)\n            D_z_perm = self.discriminator(z_perm)\n            D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) +\n                               F.cross_entropy(D_z_perm, true_labels))\n            # print(f'D_TC: {D_tc_loss}')\n            return {'loss': D_tc_loss,\n                    'D_TC_Loss':D_tc_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/gamma_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.distributions import Gamma\nfrom torch.nn import functional as F\nfrom .types_ import *\nimport torch.nn.init as init\n\n\nclass GammaVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 gamma_shape: float = 8.,\n                 prior_shape: float = 2.0,\n                 prior_rate: float = 1.,\n                 **kwargs) -> None:\n        super(GammaVAE, self).__init__()\n        self.latent_dim = latent_dim\n        self.B = gamma_shape\n\n        self.prior_alpha = torch.tensor([prior_shape])\n        self.prior_beta = torch.tensor([prior_rate])\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size=3, stride=2, padding=1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),\n                                   nn.Softmax())\n        self.fc_var = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),\n                                    nn.Softmax())\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims[-1] * 4))\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride=2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n            nn.ConvTranspose2d(hidden_dims[-1],\n                               hidden_dims[-1],\n                               kernel_size=3,\n                               stride=2,\n                               padding=1,\n                               output_padding=1),\n            nn.BatchNorm2d(hidden_dims[-1]),\n            nn.LeakyReLU(),\n            nn.Conv2d(hidden_dims[-1], out_channels=3,\n                      kernel_size=3, padding=1),\n            nn.Sigmoid())\n\n        self.weight_init()\n\n    def weight_init(self):\n\n        # print(self._modules)\n        for block in self._modules:\n            for m in self._modules[block]:\n                init_(m)\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        alpha = self.fc_mu(result)\n        beta = self.fc_var(result)\n\n        return [alpha, beta]\n\n    def decode(self, z: Tensor) -> Tensor:\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterize the Gamma distribution by the shape augmentation trick.\n        Reference:\n        [1] https://arxiv.org/pdf/1610.05683.pdf\n\n        :param alpha: (Tensor) Shape parameter of the latent Gamma\n        :param beta: (Tensor) Rate parameter of the latent Gamma\n        :return:\n        \"\"\"\n        # Sample from Gamma to guarantee acceptance\n        alpha_ = alpha.clone().detach()\n        z_hat = Gamma(alpha_ + self.B, torch.ones_like(alpha_)).sample()\n\n        # Compute the eps ~ N(0,1) that produces z_hat\n        eps = self.inv_h_func(alpha + self.B , z_hat)\n        z = self.h_func(alpha + self.B, eps)\n\n        # When beta != 1, scale by beta\n        return z / beta\n\n    def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterize a sample eps ~ N(0, 1) so that h(z) ~ Gamma(alpha, 1)\n        :param alpha: (Tensor) Shape parameter\n        :param eps: (Tensor) Random sample to reparameterize\n        :return: (Tensor)\n        \"\"\"\n\n        z = (alpha - 1./3.) * (1 + eps / torch.sqrt(9. * alpha - 3.))**3\n        return z\n\n    def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor:\n        \"\"\"\n        Inverse reparameterize the given z into eps.\n        :param alpha: (Tensor)\n        :param z: (Tensor)\n        :return: (Tensor)\n        \"\"\"\n        eps = torch.sqrt(9. * alpha - 3.) * ((z / (alpha - 1./3.))**(1. / 3.) - 1.)\n        return eps\n\n    def forward(self, input: Tensor, **kwargs) -> Tensor:\n        alpha, beta = self.encode(input)\n        z = self.reparameterize(alpha, beta)\n        return [self.decode(z), input, alpha, beta]\n\n    # def I_function(self, alpha_p, beta_p, alpha_q, beta_q):\n    #     return - (alpha_q * beta_q) / alpha_p - \\\n    #            beta_p * torch.log(alpha_p) - torch.lgamma(beta_p) + \\\n    #            (beta_p - 1) * torch.digamma(beta_q) + \\\n    #            (beta_p - 1) * torch.log(alpha_q)\n    def I_function(self, a, b, c, d):\n        return - c * d / a - b * torch.log(a) - torch.lgamma(b) + (b - 1) * (torch.digamma(d) + torch.log(c))\n\n    def vae_gamma_kl_loss(self, a, b, c, d):\n        \"\"\"\n        https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions\n        b and d are Gamma shape parameters and\n        a and c are scale parameters.\n        (All, therefore, must be positive.)\n        \"\"\"\n\n        a = 1 / a\n        c = 1 / c\n        losses = self.I_function(c, d, c, d) - self.I_function(a, b, c, d)\n        return torch.sum(losses, dim=1)\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        recons = args[0]\n        input = args[1]\n        alpha = args[2]\n        beta = args[3]\n\n        curr_device = input.device\n        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset\n        recons_loss = torch.mean(F.mse_loss(recons, input, reduction = 'none'), dim = (1,2,3))\n\n        # https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions\n        # alpha = 1./ alpha\n\n\n        self.prior_alpha = self.prior_alpha.to(curr_device)\n        self.prior_beta = self.prior_beta.to(curr_device)\n\n        # kld_loss = - self.I_function(alpha, beta, self.prior_alpha, self.prior_beta)\n\n        kld_loss = self.vae_gamma_kl_loss(alpha, beta, self.prior_alpha, self.prior_beta)\n\n        # kld_loss = torch.sum(kld_loss, dim=1)\n\n        loss = recons_loss + kld_loss\n        loss = torch.mean(loss, dim = 0)\n        # print(loss, recons_loss, kld_loss)\n        return {'loss': loss} #, 'Reconstruction_Loss': recons_loss, 'KLD': -kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the modelSay\n        :return: (Tensor)\n        \"\"\"\n        z = Gamma(self.prior_alpha, self.prior_beta).sample((num_samples, self.latent_dim))\n        z = z.squeeze().to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]\n\ndef init_(m):\n    if isinstance(m, (nn.Linear, nn.Conv2d)):\n        init.orthogonal_(m.weight)\n        if m.bias is not None:\n            m.bias.data.fill_(0)\n    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):\n        m.weight.data.fill_(1)\n        if m.bias is not None:\n            m.bias.data.fill_(0)\n"
  },
  {
    "path": "models/hvae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass HVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent1_dim: int,\n                 latent2_dim: int,\n                 hidden_dims: List = None,\n                 img_size:int = 64,\n                 pseudo_input_size: int = 128,\n                 **kwargs) -> None:\n        super(HVAE, self).__init__()\n\n        self.latent1_dim = latent1_dim\n        self.latent2_dim = latent2_dim\n        self.img_size = img_size\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n        channels = in_channels\n\n        # Build z2 Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            channels = h_dim\n\n        self.encoder_z2_layers = nn.Sequential(*modules)\n        self.fc_z2_mu = nn.Linear(hidden_dims[-1]*4, latent2_dim)\n        self.fc_z2_var = nn.Linear(hidden_dims[-1]*4, latent2_dim)\n        # ========================================================================#\n        # Build z1 Encoder\n        self.embed_z2_code = nn.Linear(latent2_dim, img_size * img_size)\n        self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)\n\n        modules = []\n        channels = in_channels + 1 # One more channel for the latent code\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            channels = h_dim\n\n        self.encoder_z1_layers = nn.Sequential(*modules)\n        self.fc_z1_mu = nn.Linear(hidden_dims[-1]*4, latent1_dim)\n        self.fc_z1_var = nn.Linear(hidden_dims[-1]*4, latent1_dim)\n\n        #========================================================================#\n        # Build z2 Decoder\n        self.recons_z1_mu = nn.Linear(latent2_dim, latent1_dim)\n        self.recons_z1_log_var = nn.Linear(latent2_dim, latent1_dim)\n\n        # ========================================================================#\n        # Build z1 Decoder\n        self.debed_z1_code = nn.Linear(latent1_dim, 1024)\n        self.debed_z2_code = nn.Linear(latent2_dim, 1024)\n        modules = []\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n        # ========================================================================#\n        # Pesudo Input for the Vamp-Prior\n        # self.pseudo_input =  torch.eye(pseudo_input_size,\n        #                                requires_grad=False).view(1, 1, pseudo_input_size, -1)\n        #\n        #\n        # self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels,\n        #                              kernel_size=3, stride=2, padding=1)\n\n    def encode_z2(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder_z2_layers(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        z2_mu = self.fc_z2_mu(result)\n        z2_log_var = self.fc_z2_var(result)\n\n        return [z2_mu, z2_log_var]\n\n    def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]:\n        x = self.embed_data(input)\n        z2 = self.embed_z2_code(z2)\n        z2 = z2.view(-1, self.img_size, self.img_size).unsqueeze(1)\n        result = torch.cat([x, z2], dim=1)\n\n        result = self.encoder_z1_layers(result)\n        result = torch.flatten(result, start_dim=1)\n        z1_mu = self.fc_z1_mu(result)\n        z1_log_var = self.fc_z1_var(result)\n\n        return [z1_mu, z1_log_var]\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        z2_mu, z2_log_var = self.encode_z2(input)\n        z2 = self.reparameterize(z2_mu, z2_log_var)\n\n        # z1 ~ q(z1|x, z2)\n        z1_mu, z1_log_var = self.encode_z1(input, z2)\n        return [z1_mu, z1_log_var, z2_mu, z2_log_var, z2]\n\n    def decode(self, input: Tensor) -> Tensor:\n        result = self.decoder(input)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Will a single z be enough ti compute the expectation\n        for the loss??\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n\n        # Encode the input into the latent codes z1 and z2\n        # z2 ~q(z2 | x)\n        # z1 ~ q(z1|x, z2)\n        z1_mu, z1_log_var, z2_mu, z2_log_var, z2 = self.encode(input)\n        z1 = self.reparameterize(z1_mu, z1_log_var)\n\n        # Reconstruct the image using both the latent codes\n        # x ~ p(x|z1, z2)\n        debedded_z1 = self.debed_z1_code(z1)\n        debedded_z2 = self.debed_z2_code(z2)\n        result = torch.cat([debedded_z1, debedded_z2], dim=1)\n        result = result.view(-1, 512, 2, 2)\n        recons = self.decode(result)\n\n        return  [recons,\n                 input,\n                 z1_mu, z1_log_var,\n                 z2_mu, z2_log_var,\n                 z1, z2]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        recons = args[0]\n        input = args[1]\n\n        z1_mu = args[2]\n        z1_log_var = args[3]\n\n        z2_mu = args[4]\n        z2_log_var = args[5]\n\n        z1= args[6]\n        z2 = args[7]\n\n        # Reconstruct (decode) z2 into z1\n        # z1 ~ p(z1|z2) [This for the loss calculation]\n        z1_p_mu = self.recons_z1_mu(z2)\n        z1_p_log_var = self.recons_z1_log_var(z2)\n\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input)\n\n        z1_kld = torch.mean(-0.5 * torch.sum(1 + z1_log_var - z1_mu ** 2 - z1_log_var.exp(), dim = 1),\n                            dim = 0)\n        z2_kld = torch.mean(-0.5 * torch.sum(1 + z2_log_var - z2_mu ** 2 - z2_log_var.exp(), dim = 1),\n                            dim = 0)\n\n        z1_p_kld = torch.mean(-0.5 * torch.sum(1 + z1_p_log_var - (z1 - z1_p_mu) ** 2 - z1_p_log_var.exp(),\n                                               dim = 1),\n                            dim = 0)\n\n        z2_p_kld = torch.mean(-0.5*(z2**2), dim = 0)\n\n        kld_loss = -(z1_p_kld - z1_kld - z2_kld)\n        loss = recons_loss + kld_weight * kld_loss\n        # print(z2_p_kld)\n\n        return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:\n        z2 = torch.randn(batch_size,\n                         self.latent2_dim)\n\n        z2 = z2.cuda(current_device)\n\n        z1_mu = self.recons_z1_mu(z2)\n        z1_log_var = self.recons_z1_log_var(z2)\n        z1 = self.reparameterize(z1_mu, z1_log_var)\n\n        debedded_z1 = self.debed_z1_code(z1)\n        debedded_z2 = self.debed_z2_code(z2)\n\n        result = torch.cat([debedded_z1, debedded_z2], dim=1)\n        result = result.view(-1, 512, 2, 2)\n        samples = self.decode(result)\n\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]\n"
  },
  {
    "path": "models/info_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass InfoVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 alpha: float = -0.5,\n                 beta: float = 5.0,\n                 reg_weight: int = 100,\n                 kernel_type: str = 'imq',\n                 latent_var: float = 2.,\n                 **kwargs) -> None:\n        super(InfoVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.reg_weight = reg_weight\n        self.kernel_type = kernel_type\n        self.z_var = latent_var\n\n        assert alpha <= 0, 'alpha must be negative or zero.'\n\n        self.alpha = alpha\n        self.beta = beta\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, z, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        recons = args[0]\n        input = args[1]\n        z = args[2]\n        mu = args[3]\n        log_var = args[4]\n\n        batch_size = input.size(0)\n        bias_corr = batch_size *  (batch_size - 1)\n        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset\n\n        recons_loss =F.mse_loss(recons, input)\n        mmd_loss = self.compute_mmd(z)\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)\n\n        loss = self.beta * recons_loss + \\\n               (1. - self.alpha) * kld_weight * kld_loss + \\\n               (self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss}\n\n    def compute_kernel(self,\n                       x1: Tensor,\n                       x2: Tensor) -> Tensor:\n        # Convert the tensors into row and column vectors\n        D = x1.size(1)\n        N = x1.size(0)\n\n        x1 = x1.unsqueeze(-2) # Make it into a column tensor\n        x2 = x2.unsqueeze(-3) # Make it into a row tensor\n\n        \"\"\"\n        Usually the below lines are not required, especially in our case,\n        but this is useful when x1 and x2 have different sizes\n        along the 0th dimension.\n        \"\"\"\n        x1 = x1.expand(N, N, D)\n        x2 = x2.expand(N, N, D)\n\n        if self.kernel_type == 'rbf':\n            result = self.compute_rbf(x1, x2)\n        elif self.kernel_type == 'imq':\n            result = self.compute_inv_mult_quad(x1, x2)\n        else:\n            raise ValueError('Undefined kernel type.')\n\n        return result\n\n\n    def compute_rbf(self,\n                    x1: Tensor,\n                    x2: Tensor,\n                    eps: float = 1e-7) -> Tensor:\n        \"\"\"\n        Computes the RBF Kernel between x1 and x2.\n        :param x1: (Tensor)\n        :param x2: (Tensor)\n        :param eps: (Float)\n        :return:\n        \"\"\"\n        z_dim = x2.size(-1)\n        sigma = 2. * z_dim * self.z_var\n\n        result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))\n        return result\n\n    def compute_inv_mult_quad(self,\n                               x1: Tensor,\n                               x2: Tensor,\n                               eps: float = 1e-7) -> Tensor:\n        \"\"\"\n        Computes the Inverse Multi-Quadratics Kernel between x1 and x2,\n        given by\n\n                k(x_1, x_2) = \\sum \\frac{C}{C + \\|x_1 - x_2 \\|^2}\n        :param x1: (Tensor)\n        :param x2: (Tensor)\n        :param eps: (Float)\n        :return:\n        \"\"\"\n        z_dim = x2.size(-1)\n        C = 2 * z_dim * self.z_var\n        kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))\n\n        # Exclude diagonal elements\n        result = kernel.sum() - kernel.diag().sum()\n\n        return result\n\n    def compute_mmd(self, z: Tensor) -> Tensor:\n        # Sample from prior (Gaussian) distribution\n        prior_z = torch.randn_like(z)\n\n        prior_z__kernel = self.compute_kernel(prior_z, prior_z)\n        z__kernel = self.compute_kernel(z, z)\n        priorz_z__kernel = self.compute_kernel(prior_z, z)\n\n        mmd = prior_z__kernel.mean() + \\\n              z__kernel.mean() - \\\n              2 * priorz_z__kernel.mean()\n        return mmd\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/iwae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass IWAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 num_samples: int = 5,\n                 **kwargs) -> None:\n        super(IWAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.num_samples = num_samples\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes of S samples\n        onto the image space.\n        :param z: (Tensor) [B x S x D]\n        :return: (Tensor) [B x S x C x H x W]\n        \"\"\"\n        B, _, _ = z.size()\n        z = z.view(-1, self.latent_dim) #[BS x D]\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result) #[BS x C x H x W ]\n        result = result.view([B, -1, result.size(1), result.size(2), result.size(3)]) #[B x S x C x H x W]\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        mu = mu.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]\n        log_var = log_var.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]\n        z= self.reparameterize(mu, log_var) # [B x S x D]\n        eps = (z - mu) / log_var # Prior samples\n        return  [self.decode(z), input, mu, log_var, z, eps]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n        z = args[4]\n        eps = args[5]\n\n        input = input.repeat(self.num_samples, 1, 1, 1, 1).permute(1, 0, 2, 3, 4) #[B x S x C x H x W]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n\n        log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S]\n        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S]\n        # Get importance weights\n        log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data\n\n        # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1\n        weight = F.softmax(log_weight, dim = -1)\n        # kld_loss = torch.mean(kld_loss, dim = 0)\n\n        loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0)\n\n        return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples, 1,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z).squeeze()\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image.\n        Returns only the first reconstructed sample\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0][:, 0, :]\n"
  },
  {
    "path": "models/joint_vae.py",
    "content": "import torch\nimport numpy as np\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass JointVAE(BaseVAE):\n    num_iter = 1\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 categorical_dim: int,\n                 latent_min_capacity: float =0.,\n                 latent_max_capacity: float = 25.,\n                 latent_gamma: float = 30.,\n                 latent_num_iter: int = 25000,\n                 categorical_min_capacity: float =0.,\n                 categorical_max_capacity: float = 25.,\n                 categorical_gamma: float = 30.,\n                 categorical_num_iter: int = 25000,\n                 hidden_dims: List = None,\n                 temperature: float = 0.5,\n                 anneal_rate: float = 3e-5,\n                 anneal_interval: int = 100, # every 100 batches\n                 alpha: float = 30.,\n                 **kwargs) -> None:\n        super(JointVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.categorical_dim = categorical_dim\n        self.temp = temperature\n        self.min_temp = temperature\n        self.anneal_rate = anneal_rate\n        self.anneal_interval = anneal_interval\n        self.alpha = alpha\n\n        self.cont_min = latent_min_capacity\n        self.cont_max = latent_max_capacity\n\n        self.disc_min = categorical_min_capacity\n        self.disc_max = categorical_max_capacity\n\n        self.cont_gamma = latent_gamma\n        self.disc_gamma = categorical_gamma\n\n        self.cont_iter = latent_num_iter\n        self.disc_iter = categorical_num_iter\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, self.latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, self.latent_dim)\n        self.fc_z = nn.Linear(hidden_dims[-1]*4, self.categorical_dim)\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(self.latent_dim + self.categorical_dim,\n                                       hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n        self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [B x C x H x W]\n        :return: (Tensor) Latent code [B x D x Q]\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n        z = self.fc_z(result)\n        z = z.view(-1, self.categorical_dim)\n        return [mu, log_var, z]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D x Q]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self,\n                       mu: Tensor,\n                       log_var: Tensor,\n                       q: Tensor,\n                       eps:float = 1e-7) -> Tensor:\n        \"\"\"\n        Gumbel-softmax trick to sample from Categorical Distribution\n        :param mu: (Tensor) mean of the latent Gaussian  [B x D]\n        :param log_var: (Tensor) Log variance of the latent Gaussian [B x D]\n        :param q: (Tensor) Categorical latent Codes [B x Q]\n        :return: (Tensor) [B x (D + Q)]\n        \"\"\"\n\n        std = torch.exp(0.5 * log_var)\n        e = torch.randn_like(std)\n        z = e * std + mu\n\n        # Sample from Gumbel\n        u = torch.rand_like(q)\n        g = - torch.log(- torch.log(u + eps) + eps)\n\n        # Gumbel-Softmax sample\n        s = F.softmax((q + g) / self.temp, dim=-1)\n        s = s.view(-1, self.categorical_dim)\n\n        return torch.cat([z, s], dim=1)\n\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var, q = self.encode(input)\n        z = self.reparameterize(mu, log_var, q)\n        return  [self.decode(z), input, q, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        q = args[2]\n        mu = args[3]\n        log_var = args[4]\n\n        q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities\n\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        batch_idx = kwargs['batch_idx']\n\n        # Anneal the temperature at regular intervals\n        if batch_idx % self.anneal_interval == 0 and self.training:\n            self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),\n                                   self.min_temp)\n\n        recons_loss =F.mse_loss(recons, input, reduction='mean')\n\n        # Adaptively increase the discrinimator capacity\n        disc_curr = (self.disc_max - self.disc_min) * \\\n                    self.num_iter/ float(self.disc_iter) + self.disc_min\n        disc_curr = min(disc_curr, np.log(self.categorical_dim))\n\n        # KL divergence between gumbel-softmax distribution\n        eps = 1e-7\n\n        # Entropy of the logits\n        h1 = q_p * torch.log(q_p + eps)\n        # Cross entropy with the categorical distribution\n        h2 = q_p * np.log(1. / self.categorical_dim + eps)\n        kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0)\n\n        # Compute Continuous loss\n        # Adaptively increase the continuous capacity\n        cont_curr = (self.cont_max - self.cont_min) * \\\n                    self.num_iter/ float(self.cont_iter) + self.cont_min\n        cont_curr = min(cont_curr, self.cont_max)\n\n        kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(),\n                                                    dim=1),\n                                   dim=0)\n        capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss) + \\\n                        self.cont_gamma * torch.abs(cont_curr - kld_cont_loss)\n        # kld_weight = 1.2\n        loss = self.alpha * recons_loss + kld_weight * capacity_loss\n\n        if self.training:\n            self.num_iter += 1\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'Capacity_Loss':capacity_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        # [S x D]\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        M = num_samples\n        np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)\n        np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1\n        np_y = np.reshape(np_y, [M , self.categorical_dim])\n        q = torch.from_numpy(np_y)\n\n        # z = self.sampling_dist.sample((num_samples * self.latent_dim, ))\n        z = torch.cat([z, q], dim = 1).to(current_device)\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/logcosh_vae.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom models import BaseVAE\nfrom torch import nn\nfrom .types_ import *\n\n\nclass LogCoshVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 alpha: float = 100.,\n                 beta: float = 10.,\n                 **kwargs) -> None:\n        super(LogCoshVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.alpha = alpha\n        self.beta = beta\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        t = recons - input\n        # recons_loss = F.mse_loss(recons, input)\n        # cosh = torch.cosh(self.alpha * t)\n        # recons_loss = (1./self.alpha * torch.log(cosh)).mean()\n\n        recons_loss = self.alpha * t + \\\n                      torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \\\n                      torch.log(torch.tensor(2.0))\n        # print(self.alpha* t.max(), self.alpha*t.min())\n        recons_loss = (1. / self.alpha) * recons_loss.mean()\n\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        loss = recons_loss + self.beta * kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/lvae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\nfrom math import floor, pi, log\n\n\ndef conv_out_shape(img_size):\n    return floor((img_size + 2 - 3) / 2.) + 1\n\nclass EncoderBlock(nn.Module):\n    def __init__(self,\n                 in_channels: int,\n                 out_channels: int,\n                 latent_dim: int,\n                 img_size: int):\n        super(EncoderBlock, self).__init__()\n\n        # Build Encoder\n        self.encoder = nn.Sequential(\n                            nn.Conv2d(in_channels,\n                                      out_channels,\n                                      kernel_size=3, stride=2, padding=1),\n                            nn.BatchNorm2d(out_channels),\n                            nn.LeakyReLU())\n\n        out_size = conv_out_shape(img_size)\n        self.encoder_mu = nn.Linear(out_channels * out_size ** 2 , latent_dim)\n        self.encoder_var = nn.Linear(out_channels * out_size ** 2, latent_dim)\n\n    def forward(self, input: Tensor) -> Tensor:\n        result = self.encoder(input)\n        h = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.encoder_mu(h)\n        log_var = self.encoder_var(h)\n\n        return [result, mu, log_var]\n\nclass LadderBlock(nn.Module):\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int):\n        super(LadderBlock, self).__init__()\n\n        # Build Decoder\n        self.decode = nn.Sequential(nn.Linear(in_channels, latent_dim),\n                                    nn.BatchNorm1d(latent_dim))\n        self.fc_mu = nn.Linear(latent_dim, latent_dim)\n        self.fc_var = nn.Linear(latent_dim, latent_dim)\n\n    def forward(self, z: Tensor) -> Tensor:\n        z = self.decode(z)\n        mu = self.fc_mu(z)\n        log_var = self.fc_var(z)\n\n        return [mu, log_var]\n\nclass LVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dims: List,\n                 hidden_dims: List,\n                 **kwargs) -> None:\n        super(LVAE, self).__init__()\n\n        self.latent_dims = latent_dims\n        self.hidden_dims = hidden_dims\n        self.num_rungs = len(latent_dims)\n\n        assert len(latent_dims) == len(hidden_dims), \"Length of the latent\" \\\n                                                     \"and hidden dims must be the same\"\n\n        # Build Encoder\n        modules = []\n        img_size = 64\n        for i, h_dim in enumerate(hidden_dims):\n            modules.append(EncoderBlock(in_channels,\n                                        h_dim,\n                                        latent_dims[i],\n                                        img_size))\n\n            img_size = conv_out_shape(img_size)\n            in_channels = h_dim\n\n        self.encoders = nn.Sequential(*modules)\n        # ====================================================================== #\n        # Build Decoder\n        modules = []\n\n        for i in range(self.num_rungs -1, 0, -1):\n            modules.append(LadderBlock(latent_dims[i],\n                                       latent_dims[i-1]))\n\n        self.ladders = nn.Sequential(*modules)\n\n        self.decoder_input = nn.Linear(latent_dims[0], hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n        modules = []\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n        hidden_dims.reverse()\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        h = input\n\n        # Posterior Parameters\n        post_params = []\n        for encoder_block in self.encoders:\n            h, mu, log_var = encoder_block(h)\n            post_params.append((mu, log_var))\n\n        return post_params\n\n    def decode(self, z: Tensor, post_params: List) -> Tuple:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        kl_div = 0\n        post_params.reverse()\n        for i, ladder_block in enumerate(self.ladders):\n            mu_e, log_var_e = post_params[i]\n            mu_t, log_var_t = ladder_block(z)\n            mu, log_var = self.merge_gauss(mu_e, mu_t,\n                                           log_var_e, log_var_t)\n            z = self.reparameterize(mu, log_var)\n            kl_div += self.compute_kl_divergence(z, (mu, log_var), (mu_e, log_var_e))\n\n        result = self.decoder_input(z)\n        result = result.view(-1, self.hidden_dims[-1], 2, 2)\n        result = self.decoder(result)\n        return self.final_layer(result), kl_div\n\n    def merge_gauss(self,\n                    mu_1: Tensor,\n                    mu_2: Tensor,\n                    log_var_1: Tensor,\n                    log_var_2: Tensor) -> List:\n\n        p_1 = 1. / (log_var_1.exp() + 1e-7)\n        p_2 = 1. / (log_var_2.exp() + 1e-7)\n\n        mu = (mu_1 * p_1 + mu_2 * p_2)/(p_1 + p_2)\n        log_var = torch.log(1./(p_1 + p_2))\n        return [mu, log_var]\n\n    def compute_kl_divergence(self, z: Tensor, q_params: Tuple, p_params: Tuple):\n        mu_q, log_var_q = q_params\n        mu_p, log_var_p = p_params\n        #\n        # qz = -0.5 * torch.sum(1 + log_var_q + (z - mu_q) ** 2 / (2 * log_var_q.exp() + 1e-8), dim=1)\n        # pz = -0.5 * torch.sum(1 + log_var_p + (z - mu_p) ** 2 / (2 * log_var_p.exp() + 1e-8), dim=1)\n\n        kl = (log_var_p - log_var_q) + (log_var_q.exp() + (mu_q - mu_p)**2)/(2 * log_var_p.exp()) - 0.5\n        kl = torch.sum(kl, dim = -1)\n        return kl\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        post_params = self.encode(input)\n        mu, log_var = post_params.pop()\n        z = self.reparameterize(mu, log_var)\n        recons, kl_div = self.decode(z, post_params)\n\n        #kl_div += -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1)\n        return [recons, input, kl_div]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        kl_div = args[2]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input)\n\n        kld_loss = torch.mean(kl_div, dim = 0)\n        loss = recons_loss + kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss }\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dims[-1])\n\n        z = z.to(current_device)\n\n        for ladder_block in self.ladders:\n            mu, log_var = ladder_block(z)\n            z = self.reparameterize(mu, log_var)\n\n        result = self.decoder_input(z)\n        result = result.view(-1, self.hidden_dims[-1], 2, 2)\n        result = self.decoder(result)\n        samples = self.final_layer(result)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/miwae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\nfrom torch.distributions import Normal\n\n\nclass MIWAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 num_samples: int = 5,\n                 num_estimates: int = 5,\n                 **kwargs) -> None:\n        super(MIWAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.num_samples = num_samples # K\n        self.num_estimates = num_estimates # M\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes of S samples\n        onto the image space.\n        :param z: (Tensor) [B x S x D]\n        :return: (Tensor) [B x S x C x H x W]\n        \"\"\"\n        B, M,S, D = z.size()\n        z = z.contiguous().view(-1, self.latent_dim) #[BMS x D]\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result) #[BMS x C x H x W ]\n        result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W]\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]\n        log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]\n        z = self.reparameterize(mu, log_var) # [B x M x S x D]\n        eps = (z - mu) / log_var # Prior samples\n        return  [self.decode(z), input, mu, log_var, z, eps]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n        z = args[4]\n        eps = args[5]\n\n        input = input.repeat(self.num_estimates,\n                             self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n\n        log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S]\n\n        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S]\n        # Get importance weights\n        log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data\n\n        # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1\n        weight = F.softmax(log_weight, dim = -1)  # [B x M x S]\n\n        loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0)\n\n        return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples, 1, 1,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z).squeeze()\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image.\n        Returns only the first reconstructed sample\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0][:, 0, 0, :]\n"
  },
  {
    "path": "models/mssim_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\nfrom math import exp\n\n\nclass MSSIMVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 window_size: int = 11,\n                 size_average: bool = True,\n                 **kwargs) -> None:\n        super(MSSIMVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.in_channels = in_channels\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n        self.mssim_loss = MSSIM(self.in_channels,\n                                window_size,\n                                size_average)\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var]\n\n    def loss_function(self,\n                      *args: Any,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss = self.mssim_loss(recons, input)\n\n\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        loss = recons_loss + kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.cuda(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]\n\nclass MSSIM(nn.Module):\n\n    def __init__(self,\n                 in_channels: int = 3,\n                 window_size: int=11,\n                 size_average:bool = True) -> None:\n        \"\"\"\n        Computes the differentiable MS-SSIM loss\n        Reference:\n        [1] https://github.com/jorge-pessoa/pytorch-msssim/blob/dev/pytorch_msssim/__init__.py\n            (MIT License)\n\n        :param in_channels: (Int)\n        :param window_size: (Int)\n        :param size_average: (Bool)\n        \"\"\"\n        super(MSSIM, self).__init__()\n        self.in_channels = in_channels\n        self.window_size = window_size\n        self.size_average = size_average\n\n    def gaussian_window(self, window_size:int, sigma: float) -> Tensor:\n        kernel = torch.tensor([exp((x - window_size // 2)**2/(2 * sigma ** 2))\n                               for x in range(window_size)])\n        return kernel/kernel.sum()\n\n    def create_window(self, window_size, in_channels):\n        _1D_window = self.gaussian_window(window_size, 1.5).unsqueeze(1)\n        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n        window = _2D_window.expand(in_channels, 1, window_size, window_size).contiguous()\n        return window\n\n    def ssim(self,\n             img1: Tensor,\n             img2: Tensor,\n             window_size: int,\n             in_channel: int,\n             size_average: bool) -> Tensor:\n\n        device = img1.device\n        window = self.create_window(window_size, in_channel).to(device)\n        mu1 = F.conv2d(img1, window, padding= window_size//2, groups=in_channel)\n        mu2 = F.conv2d(img2, window, padding= window_size//2, groups=in_channel)\n\n        mu1_sq = mu1.pow(2)\n        mu2_sq = mu2.pow(2)\n        mu1_mu2 = mu1 * mu2\n\n        sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size//2, groups=in_channel) - mu1_sq\n        sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size//2, groups=in_channel) - mu2_sq\n        sigma12   = F.conv2d(img1 * img2, window, padding = window_size//2, groups=in_channel) - mu1_mu2\n\n        img_range = 1.0 #img1.max() - img1.min() # Dynamic range\n        C1 = (0.01 * img_range) ** 2\n        C2 = (0.03 * img_range) ** 2\n\n        v1 = 2.0 * sigma12 + C2\n        v2 = sigma1_sq + sigma2_sq + C2\n        cs = torch.mean(v1 / v2)  # contrast sensitivity\n\n        ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)\n\n        if size_average:\n            ret = ssim_map.mean()\n        else:\n            ret = ssim_map.mean(1).mean(1).mean(1)\n        return ret, cs\n\n    def forward(self, img1: Tensor, img2: Tensor) -> Tensor:\n        device = img1.device\n        weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)\n        levels = weights.size()[0]\n        mssim = []\n        mcs = []\n\n        for _ in range(levels):\n            sim, cs = self.ssim(img1, img2,\n                                self.window_size,\n                                self.in_channels,\n                                self.size_average)\n            mssim.append(sim)\n            mcs.append(cs)\n\n            img1 = F.avg_pool2d(img1, (2, 2))\n            img2 = F.avg_pool2d(img2, (2, 2))\n\n        mssim = torch.stack(mssim)\n        mcs = torch.stack(mcs)\n\n        # # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)\n        # if normalize:\n        #     mssim = (mssim + 1) / 2\n        #     mcs = (mcs + 1) / 2\n\n        pow1 = mcs ** weights\n        pow2 = mssim ** weights\n\n        output = torch.prod(pow1[:-1] * pow2[-1])\n        return 1 - output\n\n\n"
  },
  {
    "path": "models/swae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch import distributions as dist\nfrom .types_ import *\n\n\nclass SWAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 reg_weight: int = 100,\n                 wasserstein_deg: float= 2.,\n                 num_projections: int = 50,\n                 projection_dist: str = 'normal',\n                    **kwargs) -> None:\n        super(SWAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.reg_weight = reg_weight\n        self.p = wasserstein_deg\n        self.num_projections = num_projections\n        self.proj_dist = projection_dist\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> Tensor:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        z = self.fc_z(result)\n        return z\n\n    def decode(self, z: Tensor) -> Tensor:\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        z = self.encode(input)\n        return  [self.decode(z), input, z]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        recons = args[0]\n        input = args[1]\n        z = args[2]\n\n        batch_size = input.size(0)\n        bias_corr = batch_size *  (batch_size - 1)\n        reg_weight = self.reg_weight / bias_corr\n\n        recons_loss_l2 = F.mse_loss(recons, input)\n        recons_loss_l1 = F.l1_loss(recons, input)\n\n        swd_loss = self.compute_swd(z, self.p, reg_weight)\n\n        loss = recons_loss_l2 + recons_loss_l1 + swd_loss\n        return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss}\n\n    def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:\n        \"\"\"\n        Returns random samples from latent distribution's (Gaussian)\n        unit sphere for projecting the encoded samples and the\n        distribution samples.\n\n        :param latent_dim: (Int) Dimensionality of the latent space (D)\n        :param num_samples: (Int) Number of samples required (S)\n        :return: Random projections from the latent unit sphere\n        \"\"\"\n        if self.proj_dist == 'normal':\n            rand_samples = torch.randn(num_samples, latent_dim)\n        elif self.proj_dist == 'cauchy':\n            rand_samples = dist.Cauchy(torch.tensor([0.0]),\n                                       torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()\n        else:\n            raise ValueError('Unknown projection distribution.')\n\n        rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)\n        return rand_proj # [S x D]\n\n\n    def compute_swd(self,\n                    z: Tensor,\n                    p: float,\n                    reg_weight: float) -> Tensor:\n        \"\"\"\n        Computes the Sliced Wasserstein Distance (SWD) - which consists of\n        randomly projecting the encoded and prior vectors and computing\n        their Wasserstein distance along those projections.\n\n        :param z: Latent samples # [N  x D]\n        :param p: Value for the p^th Wasserstein distance\n        :param reg_weight:\n        :return:\n        \"\"\"\n        prior_z = torch.randn_like(z) # [N x D]\n        device = z.device\n\n        proj_matrix = self.get_random_projections(self.latent_dim,\n                                                  num_samples=self.num_projections).transpose(0,1).to(device)\n\n        latent_projections = z.matmul(proj_matrix) # [N x S]\n        prior_projections = prior_z.matmul(proj_matrix) # [N x S]\n\n        # The Wasserstein distance is computed by sorting the two projections\n        # across the batches and computing their element-wise l2 distance\n        w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \\\n                 torch.sort(prior_projections.t(), dim=1)[0]\n        w_dist = w_dist.pow(p)\n        return reg_weight * w_dist.mean()\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]\n"
  },
  {
    "path": "models/twostage_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass TwoStageVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 hidden_dims2: List = None,\n                 **kwargs) -> None:\n        super(TwoStageVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        if hidden_dims2 is None:\n            hidden_dims2 = [1024, 1024]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n        #---------------------- Second VAE ---------------------------#\n        encoder2 = []\n        in_channels = self.latent_dim\n        for h_dim in hidden_dims2:\n            encoder2.append(nn.Sequential(\n                                nn.Linear(in_channels, h_dim),\n                                nn.BatchNorm1d(h_dim),\n                                nn.LeakyReLU()))\n            in_channels = h_dim\n        self.encoder2 = nn.Sequential(*encoder2)\n        self.fc_mu2 = nn.Linear(hidden_dims2[-1], self.latent_dim)\n        self.fc_var2 = nn.Linear(hidden_dims2[-1], self.latent_dim)\n\n        decoder2 = []\n        hidden_dims2.reverse()\n\n        in_channels = self.latent_dim\n        for h_dim in hidden_dims2:\n            decoder2.append(nn.Sequential(\n                                nn.Linear(in_channels, h_dim),\n                                nn.BatchNorm1d(h_dim),\n                                nn.LeakyReLU()))\n            in_channels = h_dim\n        self.decoder2 = nn.Sequential(*decoder2)\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n\n        return  [self.decode(z), input, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input)\n\n\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        loss = recons_loss + kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/types_.py",
    "content": "from typing import List, Callable, Union, Any, TypeVar, Tuple\n# from torch import tensor as Tensor\n\nTensor = TypeVar('torch.tensor')\n"
  },
  {
    "path": "models/vampvae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass VampVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 num_components: int = 50,\n                 **kwargs) -> None:\n        super(VampVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.num_components = num_components\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n        self.pseudo_input = torch.eye(self.num_components, requires_grad= False)\n        self.embed_pseudo = nn.Sequential(nn.Linear(self.num_components, 12288),\n                                          nn.Hardtanh(0.0, 1.0)) # 3x64x64 = 12288\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Will a single z be enough ti compute the expectation\n        for the loss??\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var, z]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n        z = args[4]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input)\n\n        E_log_q_z = torch.mean(torch.sum(-0.5 * (log_var + (z - mu) ** 2)/ log_var.exp(),\n                                         dim = 1),\n                               dim = 0)\n\n        # Original Prior\n        # E_log_p_z = torch.mean(torch.sum(-0.5 * (z ** 2), dim = 1), dim = 0)\n\n        # Vamp Prior\n        M, C, H, W = input.size()\n        curr_device = input.device\n        self.pseudo_input = self.pseudo_input.cuda(curr_device)\n        x = self.embed_pseudo(self.pseudo_input)\n        x = x.view(-1, C, H, W)\n        prior_mu, prior_log_var = self.encode(x)\n\n        z_expand = z.unsqueeze(1)\n        prior_mu = prior_mu.unsqueeze(0)\n        prior_log_var = prior_log_var.unsqueeze(0)\n\n        E_log_p_z = torch.sum(-0.5 *\n                              (prior_log_var + (z_expand - prior_mu) ** 2)/ prior_log_var.exp(),\n                              dim = 2) - torch.log(torch.tensor(self.num_components).float())\n\n                               # dim = 0)\n        E_log_p_z = torch.logsumexp(E_log_p_z, dim = 1)\n        E_log_p_z = torch.mean(E_log_p_z, dim = 0)\n\n        # KLD = E_q log q - E_q log p\n        kld_loss = -(E_log_p_z - E_log_q_z)\n        # print(E_log_p_z, E_log_q_z)\n\n\n        loss = recons_loss + kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.cuda(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/vanilla_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass VanillaVAE(BaseVAE):\n\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 **kwargs) -> None:\n        super(VanillaVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Reparameterization trick to sample from N(mu, var) from\n        N(0,1).\n        :param mu: (Tensor) Mean of the latent Gaussian [B x D]\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]\n        :return: (Tensor) [B x D]\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        mu, log_var = self.encode(input)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), input, mu, log_var]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        Computes the VAE loss function.\n        KL(N(\\mu, \\sigma), N(0, 1)) = \\log \\frac{1}{\\sigma} + \\frac{\\sigma^2 + \\mu^2}{2} - \\frac{1}{2}\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        mu = args[2]\n        log_var = args[3]\n\n        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset\n        recons_loss =F.mse_loss(recons, input)\n\n\n        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n        loss = recons_loss + kld_weight * kld_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/vq_vae.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\nclass VectorQuantizer(nn.Module):\n    \"\"\"\n    Reference:\n    [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py\n    \"\"\"\n    def __init__(self,\n                 num_embeddings: int,\n                 embedding_dim: int,\n                 beta: float = 0.25):\n        super(VectorQuantizer, self).__init__()\n        self.K = num_embeddings\n        self.D = embedding_dim\n        self.beta = beta\n\n        self.embedding = nn.Embedding(self.K, self.D)\n        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)\n\n    def forward(self, latents: Tensor) -> Tensor:\n        latents = latents.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]\n        latents_shape = latents.shape\n        flat_latents = latents.view(-1, self.D)  # [BHW x D]\n\n        # Compute L2 distance between latents and embedding weights\n        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \\\n               torch.sum(self.embedding.weight ** 2, dim=1) - \\\n               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHW x K]\n\n        # Get the encoding that has the min distance\n        encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHW, 1]\n\n        # Convert to one-hot encodings\n        device = latents.device\n        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)\n        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]\n\n        # Quantize the latents\n        quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]\n        quantized_latents = quantized_latents.view(latents_shape)  # [B x H x W x D]\n\n        # Compute the VQ Losses\n        commitment_loss = F.mse_loss(quantized_latents.detach(), latents)\n        embedding_loss = F.mse_loss(quantized_latents, latents.detach())\n\n        vq_loss = commitment_loss * self.beta + embedding_loss\n\n        # Add the residue back to the latents\n        quantized_latents = latents + (quantized_latents - latents).detach()\n\n        return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss  # [B x D x H x W]\n\nclass ResidualLayer(nn.Module):\n\n    def __init__(self,\n                 in_channels: int,\n                 out_channels: int):\n        super(ResidualLayer, self).__init__()\n        self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,\n                                                kernel_size=3, padding=1, bias=False),\n                                      nn.ReLU(True),\n                                      nn.Conv2d(out_channels, out_channels,\n                                                kernel_size=1, bias=False))\n\n    def forward(self, input: Tensor) -> Tensor:\n        return input + self.resblock(input)\n\n\nclass VQVAE(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 embedding_dim: int,\n                 num_embeddings: int,\n                 hidden_dims: List = None,\n                 beta: float = 0.25,\n                 img_size: int = 64,\n                 **kwargs) -> None:\n        super(VQVAE, self).__init__()\n\n        self.embedding_dim = embedding_dim\n        self.num_embeddings = num_embeddings\n        self.img_size = img_size\n        self.beta = beta\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [128, 256]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size=4, stride=2, padding=1),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        modules.append(\n            nn.Sequential(\n                nn.Conv2d(in_channels, in_channels,\n                          kernel_size=3, stride=1, padding=1),\n                nn.LeakyReLU())\n        )\n\n        for _ in range(6):\n            modules.append(ResidualLayer(in_channels, in_channels))\n        modules.append(nn.LeakyReLU())\n\n        modules.append(\n            nn.Sequential(\n                nn.Conv2d(in_channels, embedding_dim,\n                          kernel_size=1, stride=1),\n                nn.LeakyReLU())\n        )\n\n        self.encoder = nn.Sequential(*modules)\n\n        self.vq_layer = VectorQuantizer(num_embeddings,\n                                        embedding_dim,\n                                        self.beta)\n\n        # Build Decoder\n        modules = []\n        modules.append(\n            nn.Sequential(\n                nn.Conv2d(embedding_dim,\n                          hidden_dims[-1],\n                          kernel_size=3,\n                          stride=1,\n                          padding=1),\n                nn.LeakyReLU())\n        )\n\n        for _ in range(6):\n            modules.append(ResidualLayer(hidden_dims[-1], hidden_dims[-1]))\n\n        modules.append(nn.LeakyReLU())\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=4,\n                                       stride=2,\n                                       padding=1),\n                    nn.LeakyReLU())\n            )\n\n        modules.append(\n            nn.Sequential(\n                nn.ConvTranspose2d(hidden_dims[-1],\n                                   out_channels=3,\n                                   kernel_size=4,\n                                   stride=2, padding=1),\n                nn.Tanh()))\n\n        self.decoder = nn.Sequential(*modules)\n\n    def encode(self, input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        return [result]\n\n    def decode(self, z: Tensor) -> Tensor:\n        \"\"\"\n        Maps the given latent codes\n        onto the image space.\n        :param z: (Tensor) [B x D x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        result = self.decoder(z)\n        return result\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        encoding = self.encode(input)[0]\n        quantized_inputs, vq_loss = self.vq_layer(encoding)\n        return [self.decode(quantized_inputs), input, vq_loss]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        \"\"\"\n        :param args:\n        :param kwargs:\n        :return:\n        \"\"\"\n        recons = args[0]\n        input = args[1]\n        vq_loss = args[2]\n\n        recons_loss = F.mse_loss(recons, input)\n\n        loss = recons_loss + vq_loss\n        return {'loss': loss,\n                'Reconstruction_Loss': recons_loss,\n                'VQ_Loss':vq_loss}\n\n    def sample(self,\n               num_samples: int,\n               current_device: Union[int, str], **kwargs) -> Tensor:\n        raise Warning('VQVAE sampler is not implemented.')\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "models/wae_mmd.py",
    "content": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n\n\nclass WAE_MMD(BaseVAE):\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 reg_weight: int = 100,\n                 kernel_type: str = 'imq',\n                 latent_var: float = 2.,\n                 **kwargs) -> None:\n        super(WAE_MMD, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.reg_weight = reg_weight\n        self.kernel_type = kernel_type\n        self.z_var = latent_var\n\n        modules = []\n        if hidden_dims is None:\n            hidden_dims = [32, 64, 128, 256, 512]\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim,\n                              kernel_size= 3, stride= 2, padding  = 1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)\n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)\n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                       hidden_dims[i + 1],\n                                       kernel_size=3,\n                                       stride = 2,\n                                       padding=1,\n                                       output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= 3,\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n    def encode(self, input: Tensor) -> Tensor:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :param input: (Tensor) Input tensor to encoder [N x C x H x W]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = self.encoder(input)\n        result = torch.flatten(result, start_dim=1)\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        z = self.fc_z(result)\n        return z\n\n    def decode(self, z: Tensor) -> Tensor:\n        result = self.decoder_input(z)\n        result = result.view(-1, 512, 2, 2)\n        result = self.decoder(result)\n        result = self.final_layer(result)\n        return result\n\n    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:\n        z = self.encode(input)\n        return  [self.decode(z), input, z]\n\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        recons = args[0]\n        input = args[1]\n        z = args[2]\n\n        batch_size = input.size(0)\n        bias_corr = batch_size *  (batch_size - 1)\n        reg_weight = self.reg_weight / bias_corr\n\n        recons_loss =F.mse_loss(recons, input)\n\n        mmd_loss = self.compute_mmd(z, reg_weight)\n\n        loss = recons_loss + mmd_loss\n        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss}\n\n    def compute_kernel(self,\n                       x1: Tensor,\n                       x2: Tensor) -> Tensor:\n        # Convert the tensors into row and column vectors\n        D = x1.size(1)\n        N = x1.size(0)\n\n        x1 = x1.unsqueeze(-2) # Make it into a column tensor\n        x2 = x2.unsqueeze(-3) # Make it into a row tensor\n\n        \"\"\"\n        Usually the below lines are not required, especially in our case,\n        but this is useful when x1 and x2 have different sizes\n        along the 0th dimension.\n        \"\"\"\n        x1 = x1.expand(N, N, D)\n        x2 = x2.expand(N, N, D)\n\n        if self.kernel_type == 'rbf':\n            result = self.compute_rbf(x1, x2)\n        elif self.kernel_type == 'imq':\n            result = self.compute_inv_mult_quad(x1, x2)\n        else:\n            raise ValueError('Undefined kernel type.')\n\n        return result\n\n\n    def compute_rbf(self,\n                    x1: Tensor,\n                    x2: Tensor,\n                    eps: float = 1e-7) -> Tensor:\n        \"\"\"\n        Computes the RBF Kernel between x1 and x2.\n        :param x1: (Tensor)\n        :param x2: (Tensor)\n        :param eps: (Float)\n        :return:\n        \"\"\"\n        z_dim = x2.size(-1)\n        sigma = 2. * z_dim * self.z_var\n\n        result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))\n        return result\n\n    def compute_inv_mult_quad(self,\n                               x1: Tensor,\n                               x2: Tensor,\n                               eps: float = 1e-7) -> Tensor:\n        \"\"\"\n        Computes the Inverse Multi-Quadratics Kernel between x1 and x2,\n        given by\n\n                k(x_1, x_2) = \\sum \\frac{C}{C + \\|x_1 - x_2 \\|^2}\n        :param x1: (Tensor)\n        :param x2: (Tensor)\n        :param eps: (Float)\n        :return:\n        \"\"\"\n        z_dim = x2.size(-1)\n        C = 2 * z_dim * self.z_var\n        kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))\n\n        # Exclude diagonal elements\n        result = kernel.sum() - kernel.diag().sum()\n\n        return result\n\n    def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor:\n        # Sample from prior (Gaussian) distribution\n        prior_z = torch.randn_like(z)\n\n        prior_z__kernel = self.compute_kernel(prior_z, prior_z)\n        z__kernel = self.compute_kernel(z, z)\n        priorz_z__kernel = self.compute_kernel(prior_z, z)\n\n        mmd = reg_weight * prior_z__kernel.mean() + \\\n              reg_weight * z__kernel.mean() - \\\n              2 * reg_weight * priorz_z__kernel.mean()\n        return mmd\n\n    def sample(self,\n               num_samples:int,\n               current_device: int, **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples,\n                        self.latent_dim)\n\n        z = z.to(current_device)\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]"
  },
  {
    "path": "requirements.txt",
    "content": "pytorch-lightning==1.5.6\nPyYAML==6.0\ntensorboard>=2.2.0\ntorch>=1.6.1\ntorchsummary==1.5.1\ntorchvision>=0.10.1"
  },
  {
    "path": "run.py",
    "content": "import os\nimport yaml\nimport argparse\nimport numpy as np\nfrom pathlib import Path\nfrom models import *\nfrom experiment import VAEXperiment\nimport torch.backends.cudnn as cudnn\nfrom pytorch_lightning import Trainer\nfrom pytorch_lightning.loggers import TensorBoardLogger\nfrom pytorch_lightning.utilities.seed import seed_everything\nfrom pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint\nfrom dataset import VAEDataset\nfrom pytorch_lightning.plugins import DDPPlugin\n\n\nparser = argparse.ArgumentParser(description='Generic runner for VAE models')\nparser.add_argument('--config',  '-c',\n                    dest=\"filename\",\n                    metavar='FILE',\n                    help =  'path to the config file',\n                    default='configs/vae.yaml')\n\nargs = parser.parse_args()\nwith open(args.filename, 'r') as file:\n    try:\n        config = yaml.safe_load(file)\n    except yaml.YAMLError as exc:\n        print(exc)\n\n\ntb_logger =  TensorBoardLogger(save_dir=config['logging_params']['save_dir'],\n                               name=config['model_params']['name'],)\n\n# For reproducibility\nseed_everything(config['exp_params']['manual_seed'], True)\n\nmodel = vae_models[config['model_params']['name']](**config['model_params'])\nexperiment = VAEXperiment(model,\n                          config['exp_params'])\n\ndata = VAEDataset(**config[\"data_params\"], pin_memory=len(config['trainer_params']['gpus']) != 0)\n\ndata.setup()\nrunner = Trainer(logger=tb_logger,\n                 callbacks=[\n                     LearningRateMonitor(),\n                     ModelCheckpoint(save_top_k=2, \n                                     dirpath =os.path.join(tb_logger.log_dir , \"checkpoints\"), \n                                     monitor= \"val_loss\",\n                                     save_last= True),\n                 ],\n                 strategy=DDPPlugin(find_unused_parameters=False),\n                 **config['trainer_params'])\n\n\nPath(f\"{tb_logger.log_dir}/Samples\").mkdir(exist_ok=True, parents=True)\nPath(f\"{tb_logger.log_dir}/Reconstructions\").mkdir(exist_ok=True, parents=True)\n\n\nprint(f\"======= Training {config['model_params']['name']} =======\")\nrunner.fit(experiment, datamodule=data)"
  },
  {
    "path": "tests/bvae.py",
    "content": "import torch\nimport unittest\nfrom models import BetaVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = BetaVAE(3, 10, loss_type='H').cuda()\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64).cuda()\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_betatcvae.py",
    "content": "import torch\nimport unittest\nfrom models import BetaTCVAE\nfrom torchsummary import summary\n\n\nclass TestBetaTCVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = BetaTCVAE(3, 64, anneal_steps= 100)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(8, 'cuda')\n        print(y.shape)\n\n    def test_generate(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model.generate(x)\n        print(y.shape)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_cat_vae.py",
    "content": "import torch\nimport unittest\nfrom models import GumbelVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = GumbelVAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(128, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5)\n        print(loss)\n\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n        print(y.shape)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_dfc.py",
    "content": "import torch\nimport unittest\nfrom models import DFCVAE\nfrom torchsummary import summary\n\n\nclass TestDFCVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = DFCVAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_dipvae.py",
    "content": "import torch\nimport unittest\nfrom models import DIPVAE\nfrom torchsummary import summary\n\n\nclass TestDIPVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = DIPVAE(3, 64)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(8, 'cuda')\n        print(y.shape)\n\n    def test_generate(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model.generate(x)\n        print(y.shape)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_fvae.py",
    "content": "import torch\nimport unittest\nfrom models import FactorVAE\nfrom torchsummary import summary\n\n\nclass TestFAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = FactorVAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        #\n        # print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))\n\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n        x2 = torch.randn(16,3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=0, secondary_input=x2)\n        loss = self.model.loss_function(*result, M_N = 0.005, optimizer_idx=1, secondary_input=x2)\n        print(loss)\n\n    def test_optim(self):\n        optim1 = torch.optim.Adam(self.model.parameters(), lr = 0.001)\n        optim2 = torch.optim.Adam(self.model.discrminator.parameters(), lr = 0.001)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_gvae.py",
    "content": "import torch\nimport unittest\nfrom models import GammaVAE\nfrom torchsummary import summary\n\n\nclass TestGammaVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = GammaVAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_hvae.py",
    "content": "import torch\nimport unittest\nfrom models import HVAE\nfrom torchsummary import summary\n\n\nclass TestHVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = HVAE(3, latent1_dim=10, latent2_dim=20)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_iwae.py",
    "content": "import torch\nimport unittest\nfrom models import IWAE\nfrom torchsummary import summary\n\n\nclass TestIWAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = IWAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_joint_Vae.py",
    "content": "import torch\nimport unittest\nfrom models import JointVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = JointVAE(3, 10, 40, 0.0)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(128, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005, batch_idx=5)\n        print(loss)\n\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n        print(y.shape)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_logcosh.py",
    "content": "import torch\nimport unittest\nfrom models import LogCoshVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = LogCoshVAE(3, 10, alpha=10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.rand(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_lvae.py",
    "content": "import torch\nimport unittest\nfrom models import LVAE\nfrom torchsummary import summary\n\n\nclass TestLVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = LVAE(3, [4,8,16,32,128], hidden_dims=[32, 64,128, 256, 512])\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n        print(y.shape)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_miwae.py",
    "content": "import torch\nimport unittest\nfrom models import MIWAE\nfrom torchsummary import summary\n\n\nclass TestMIWAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = MIWAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n        print(y.shape)\n\n    def test_generate(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model.generate(x)\n        print(y.shape)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_mssimvae.py",
    "content": "import torch\nimport unittest\nfrom models import MSSIMVAE\nfrom torchsummary import summary\n\n\nclass TestMSSIMVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = MSSIMVAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(144, 0)\n\n\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_swae.py",
    "content": "import torch\nimport unittest\nfrom models import SWAE\nfrom torchsummary import summary\n\n\nclass TestSWAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        self.model = SWAE(3, 10, reg_weight = 100)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_vae.py",
    "content": "import torch\nimport unittest\nfrom models import VanillaVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = VanillaVAE(3, 10)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_vq_vae.py",
    "content": "import torch\nimport unittest\nfrom models import VQVAE\nfrom torchsummary import summary\n\n\nclass TestVQVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = VQVAE(3, 64, 512)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        print(sum(p.numel() for p in self.model.parameters() if p.requires_grad))\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n    def test_sample(self):\n        self.model.cuda()\n        y = self.model.sample(8, 'cuda')\n        print(y.shape)\n\n    def test_generate(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model.generate(x)\n        print(y.shape)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_wae.py",
    "content": "import torch\nimport unittest\nfrom models import WAE_MMD\nfrom torchsummary import summary\n\n\nclass TestWAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        self.model = WAE_MMD(3, 10, reg_weight = 100)\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/text_cvae.py",
    "content": "import torch\nimport unittest\nfrom models import CVAE\n\n\nclass TestCVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = CVAE(3, 40, 10)\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        c = torch.randn(16, 40)\n        y = self.model(x, c)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(16, 3, 64, 64)\n        c = torch.randn(16, 40)\n        result = self.model(x, labels = c)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/text_vamp.py",
    "content": "import torch\nimport unittest\nfrom models import VampVAE\nfrom torchsummary import summary\n\n\nclass TestVVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n        # self.model2 = VAE(3, 10)\n        self.model = VampVAE(3, latent_dim=10).cuda()\n\n    def test_summary(self):\n        print(summary(self.model, (3, 64, 64), device='cpu'))\n        # print(summary(self.model2, (3, 64, 64), device='cpu'))\n\n    def test_forward(self):\n        x = torch.randn(16, 3, 64, 64)\n        y = self.model(x)\n        print(\"Model Output size:\", y[0].size())\n        # print(\"Model2 Output size:\", self.model2(x)[0].size())\n\n    def test_loss(self):\n        x = torch.randn(144, 3, 64, 64).cuda()\n\n        result = self.model(x)\n        loss = self.model.loss_function(*result, M_N = 0.005)\n        print(loss)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "utils.py",
    "content": "import pytorch_lightning as pl\n\n\n## Utils to handle newer PyTorch Lightning changes from version 0.6\n## ==================================================================================================== ##\n\n\ndef data_loader(fn):\n    \"\"\"\n    Decorator to handle the deprecation of data_loader from 0.7\n    :param fn: User defined data loader function\n    :return: A wrapper for the data_loader function\n    \"\"\"\n\n    def func_wrapper(self):\n        try: # Works for version 0.6.0\n            return pl.data_loader(fn)(self)\n\n        except: # Works for version > 0.6.0\n            return fn(self)\n\n    return func_wrapper\n"
  }
]