Full Code of AntixK/PyTorch-VAE for AI

master a6896b944c91 cached
82 files
236.1 KB
61.2k tokens
386 symbols
1 requests
Download .txt
Showing preview only (254K chars total). Download the full file or copy to clipboard to get everything.
Repository: AntixK/PyTorch-VAE
Branch: master
Commit: a6896b944c91
Files: 82
Total size: 236.1 KB

Directory structure:
gitextract_vk1fb46d/

├── .gitignore
├── .idea/
│   ├── .gitignore
│   ├── PyTorch-VAE.iml
│   ├── inspectionProfiles/
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   └── vcs.xml
├── LICENSE.md
├── README.md
├── configs/
│   ├── bbvae.yaml
│   ├── betatc_vae.yaml
│   ├── bhvae.yaml
│   ├── cat_vae.yaml
│   ├── cvae.yaml
│   ├── dfc_vae.yaml
│   ├── dip_vae.yaml
│   ├── factorvae.yaml
│   ├── gammavae.yaml
│   ├── hvae.yaml
│   ├── infovae.yaml
│   ├── iwae.yaml
│   ├── joint_vae.yaml
│   ├── logcosh_vae.yaml
│   ├── lvae.yaml
│   ├── miwae.yaml
│   ├── mssim_vae.yaml
│   ├── swae.yaml
│   ├── vae.yaml
│   ├── vampvae.yaml
│   ├── vq_vae.yaml
│   ├── wae_mmd_imq.yaml
│   └── wae_mmd_rbf.yaml
├── dataset.py
├── experiment.py
├── models/
│   ├── __init__.py
│   ├── base.py
│   ├── beta_vae.py
│   ├── betatc_vae.py
│   ├── cat_vae.py
│   ├── cvae.py
│   ├── dfcvae.py
│   ├── dip_vae.py
│   ├── fvae.py
│   ├── gamma_vae.py
│   ├── hvae.py
│   ├── info_vae.py
│   ├── iwae.py
│   ├── joint_vae.py
│   ├── logcosh_vae.py
│   ├── lvae.py
│   ├── miwae.py
│   ├── mssim_vae.py
│   ├── swae.py
│   ├── twostage_vae.py
│   ├── types_.py
│   ├── vampvae.py
│   ├── vanilla_vae.py
│   ├── vq_vae.py
│   └── wae_mmd.py
├── requirements.txt
├── run.py
├── tests/
│   ├── bvae.py
│   ├── test_betatcvae.py
│   ├── test_cat_vae.py
│   ├── test_dfc.py
│   ├── test_dipvae.py
│   ├── test_fvae.py
│   ├── test_gvae.py
│   ├── test_hvae.py
│   ├── test_iwae.py
│   ├── test_joint_Vae.py
│   ├── test_logcosh.py
│   ├── test_lvae.py
│   ├── test_miwae.py
│   ├── test_mssimvae.py
│   ├── test_swae.py
│   ├── test_vae.py
│   ├── test_vq_vae.py
│   ├── test_wae.py
│   ├── text_cvae.py
│   └── text_vamp.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================

Data/
logs/

VanillaVAE/version_0/

__pycache__/
.ipynb_checkpoints/

Run.ipynb


================================================
FILE: .idea/.gitignore
================================================
# Default ignored files
/workspace.xml


================================================
FILE: .idea/PyTorch-VAE.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
  <component name="NewModuleRootManager">
    <content url="file://$MODULE_DIR$" />
    <orderEntry type="jdk" jdkName="Python 3.7 (main)" jdkType="Python SDK" />
    <orderEntry type="sourceFolder" forTests="false" />
    <orderEntry type="module" module-name="somic_research" />
  </component>
  <component name="ReSTService">
    <option name="DOC_DIR" value="$MODULE_DIR$/../Project_S/somic_research/docs" />
  </component>
</module>

================================================
FILE: .idea/inspectionProfiles/profiles_settings.xml
================================================
<component name="InspectionProjectProfileManager">
  <settings>
    <option name="USE_PROJECT_PROFILE" value="false" />
    <version value="1.0" />
  </settings>
</component>

================================================
FILE: .idea/misc.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (main)" project-jdk-type="Python SDK" />
  <component name="PyCharmProfessionalAdvertiser">
    <option name="shown" value="true" />
  </component>
</project>

================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ProjectModuleManager">
    <modules>
      <module fileurl="file://$PROJECT_DIR$/.idea/PyTorch-VAE.iml" filepath="$PROJECT_DIR$/.idea/PyTorch-VAE.iml" />
      <module fileurl="file://$PROJECT_DIR$/../Project_S/somic_research/.idea/somic_research.iml" filepath="$PROJECT_DIR$/../Project_S/somic_research/.idea/somic_research.iml" />
    </modules>
  </component>
</project>

================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="VcsDirectoryMappings">
    <mapping directory="" vcs="Git" />
    <mapping directory="$PROJECT_DIR$/../Project_S/somic_research" vcs="Git" />
  </component>
</project>

================================================
FILE: LICENSE.md
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/
			    Copyright Anand Krishnamoorthy Subramanian 2020
			               anandkrish894@gmail.com

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "{}"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright {yyyy} {name of copyright owner}

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<h1 align="center">
  <b>PyTorch VAE</b><br>
</h1>

<p align="center">
      <a href="https://www.python.org/">
        <img src="https://img.shields.io/badge/Python-3.5-ff69b4.svg" /></a>
       <a href= "https://pytorch.org/">
        <img src="https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg" /></a>
       <a href= "https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md">
        <img src="https://img.shields.io/badge/license-Apache2.0-blue.svg" /></a>
         <a href= "https://twitter.com/intent/tweet?text=PyTorch-VAE:%20Collection%20of%20VAE%20models%20in%20PyTorch.&url=https://github.com/AntixK/PyTorch-VAE">
        <img src="https://img.shields.io/twitter/url/https/shields.io.svg?style=social" /></a>

</p>

**Update 22/12/2021:** Added support for PyTorch Lightning 1.5.6 version and cleaned up the code.

A collection of Variational AutoEncoders (VAEs) implemented in pytorch with focus on reproducibility. The aim of this project is to provide
a quick and simple working example for many of the cool VAE models out there. All the models are trained on the [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates 
a radically different architecture (Ex. VQ VAE uses Residual layers and no Batch-Norm, unlike other models).
Here are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README.md#--results) of each model.

### Requirements
- Python >= 3.5
- PyTorch >= 1.3
- Pytorch Lightning >= 0.6.0 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32))
- CUDA enabled computing device

### Installation
```
$ git clone https://github.com/AntixK/PyTorch-VAE
$ cd PyTorch-VAE
$ pip install -r requirements.txt
```

### Usage
```
$ cd PyTorch-VAE
$ python run.py -c configs/<config-file-name.yaml>
```
**Config file template**

```yaml
model_params:
  name: "<name of VAE model>"
  in_channels: 3
  latent_dim: 
    .         # Other parameters required by the model
    .
    .

data_params:
  data_path: "<path to the celebA dataset>"
  train_batch_size: 64 # Better to have a square number
  val_batch_size:  64
  patch_size: 64  # Models are designed to work for this size
  num_workers: 4
  
exp_params:
  manual_seed: 1265
  LR: 0.005
  weight_decay:
    .         # Other arguments required for training, like scheduler etc.
    .
    .

trainer_params:
  gpus: 1         
  max_epochs: 100
  gradient_clip_val: 1.5
    .
    .
    .

logging_params:
  save_dir: "logs/"
  name: "<experiment name>"
```

**View TensorBoard Logs**
```
$ cd logs/<experiment name>/version_<the version you want>
$ tensorboard --logdir .
```

**Note:** The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the [file](https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing) from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference.


----
<h2 align="center">
  <b>Results</b><br>
</h2>


| Model                                                                  | Paper                                            |Reconstruction | Samples |
|------------------------------------------------------------------------|--------------------------------------------------|---------------|---------|
| VAE ([Code][vae_code], [Config][vae_config])                           |[Link](https://arxiv.org/abs/1312.6114)           |    ![][2]     | ![][1]  |
| Conditional VAE ([Code][cvae_code], [Config][cvae_config])             |[Link](https://openreview.net/forum?id=rJWXGDWd-H)|    ![][16]    | ![][15] |
| WAE - MMD (RBF Kernel) ([Code][wae_code], [Config][wae_rbf_config])    |[Link](https://arxiv.org/abs/1711.01558)          |    ![][4]     | ![][3]  |
| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config])    |[Link](https://arxiv.org/abs/1711.01558)          |    ![][6]     | ![][5]  |
| Beta-VAE ([Code][bvae_code], [Config][bbvae_config])                   |[Link](https://openreview.net/forum?id=Sy2fzU9gl) |    ![][8]     | ![][7]  |
| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config])      |[Link](https://arxiv.org/abs/1804.03599)          |    ![][22]    | ![][21] |
| Beta-TC-VAE ([Code][btcvae_code], [Config][btcvae_config])             |[Link](https://arxiv.org/abs/1802.04942)          |    ![][34]    | ![][33] |
| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config])              |[Link](https://arxiv.org/abs/1509.00519)          |    ![][10]    | ![][9]  |
| MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config])    |[Link](https://arxiv.org/abs/1802.04537)          |    ![][30]    | ![][29] |
| DFCVAE   ([Code][dfcvae_code], [Config][dfcvae_config])                |[Link](https://arxiv.org/abs/1610.00291)          |    ![][12]    | ![][11] |
| MSSIM VAE    ([Code][mssimvae_code], [Config][mssimvae_config])        |[Link](https://arxiv.org/abs/1511.06409)          |    ![][14]    | ![][13] |
| Categorical VAE   ([Code][catvae_code], [Config][catvae_config])       |[Link](https://arxiv.org/abs/1611.01144)          |    ![][18]    | ![][17] |
| Joint VAE ([Code][jointvae_code], [Config][jointvae_config])           |[Link](https://arxiv.org/abs/1804.00104)          |    ![][20]    | ![][19] |
| Info VAE   ([Code][infovae_code], [Config][infovae_config])            |[Link](https://arxiv.org/abs/1706.02262)          |    ![][24]    | ![][23] |
| LogCosh VAE   ([Code][logcoshvae_code], [Config][logcoshvae_config])   |[Link](https://openreview.net/forum?id=rkglvsC9Ym)|    ![][26]    | ![][25] |
| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config])      |[Link](https://arxiv.org/abs/1804.01947)          |    ![][28]    | ![][27] |
| VQ-VAE (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937)          |    ![][31]    | **N/A** |
| DIP VAE ([Code][dipvae_code], [Config][dipvae_config])                 |[Link](https://arxiv.org/abs/1711.00848)          |    ![][36]    | ![][35] |


<!-- | Gamma VAE             |[Link](https://arxiv.org/abs/1610.05683)          |    ![][16]    | ![][15] |-->

<!--
### TODO
- [x] VanillaVAE
- [x] Beta VAE
- [x] DFC VAE
- [x] MSSIM VAE
- [x] IWAE
- [x] MIWAE
- [x] WAE-MMD
- [x] Conditional VAE- [ ] PixelVAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [x] Joint VAE
- [x] Disentangled beta-VAE
- [x] InfoVAE
- [x] LogCosh VAE
- [x] SWAE
- [x] VQVAE
- [x] Beta TC-VAE
- [x] DIP VAE
- [ ] Ladder VAE (Doesn't work well)
- [ ] Gamma VAE (Doesn't work well) 
- [ ] Vamp VAE (Doesn't work well)
-->

### Contributing
If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file,
I would be happy to include your result (along with your config file) in this repo, citing your name 😊.

Additionally, if you would like to contribute some models, please submit a PR.

### License
**Apache License 2.0**

| Permissions      | Limitations       | Conditions                       |
|------------------|-------------------|----------------------------------|
| ✔️ Commercial use |  ❌  Trademark use |  ⓘ License and copyright notice | 
| ✔️ Modification   |  ❌  Liability     |  ⓘ State changes                |
| ✔️ Distribution   |  ❌  Warranty      |                                  |
| ✔️ Patent use     |                   |                                  |
| ✔️ Private use    |                   |                                  |


### Citation
```
@misc{Subramanian2020,
  author = {Subramanian, A.K},
  title = {PyTorch-VAE},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/AntixK/PyTorch-VAE}}
}
```
-----------

[vae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
[cvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cvae.py
[bvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
[btcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/betatc_vae.py
[wae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/wae_mmd.py
[iwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/iwae.py
[miwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/miwae.py
[swae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/swae.py
[jointvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/joint_vae.py
[dfcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dfcvae.py
[mssimvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/mssim_vae.py
[logcoshvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/logcosh_vae.py
[catvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cat_vae.py
[infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py
[vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
[dipvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dip_vae.py

[vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml
[cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml
[bbvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bbvae.yaml
[bhvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bhvae.yaml
[btcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/betatc_vae.yaml
[wae_rbf_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_rbf.yaml
[wae_imq_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_imq.yaml
[iwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/iwae.yaml
[miwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/miwae.yaml
[swae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/swae.yaml
[jointvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/joint_vae.yaml
[dfcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dfc_vae.yaml
[mssimvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/mssim_vae.yaml
[logcoshvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/logcosh_vae.yaml
[catvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cat_vae.yaml
[infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml
[vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml
[dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml

[1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png
[2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png
[3]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_RBF_18.png
[4]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_RBF_19.png
[5]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_IMQ_15.png
[6]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_IMQ_15.png
[7]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_H_20.png
[8]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_H_20.png
[9]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/IWAE_19.png
[10]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_IWAE_19.png
[11]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DFCVAE_49.png
[12]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DFCVAE_49.png
[13]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MSSIMVAE_29.png
[14]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MSSIMVAE_29.png
[15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_20.png
[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png
[17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_49.png
[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_49.png
[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_49.png
[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_49.png
[21]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_B_35.png
[22]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_B_35.png
[23]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/InfoVAE_31.png
[24]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_InfoVAE_31.png
[25]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/LogCoshVAE_49.png
[26]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_LogCoshVAE_49.png
[27]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/SWAE_49.png
[28]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_SWAE_49.png
[29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png
[30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png
[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png
[33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_49.png
[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png
[35]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DIPVAE_83.png
[36]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DIPVAE_83.png

[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/

[pytorch-image]: https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg
[pytorch-url]: https://pytorch.org/

[twitter-image]:https://img.shields.io/twitter/url/https/shields.io.svg?style=social
[twitter-url]:https://twitter.com/intent/tweet?text=Neural%20Blocks-Easy%20to%20use%20neural%20net%20blocks%20for%20fast%20prototyping.&url=https://github.com/AntixK/NeuralBlocks


[license-image]:https://img.shields.io/badge/license-Apache2.0-blue.svg
[license-url]:https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md


================================================
FILE: configs/bbvae.yaml
================================================
model_params:
  name: 'BetaVAE'
  in_channels: 3
  latent_dim: 128
  loss_type: 'B'
  gamma: 10.0
  max_capacity: 25
  Capacity_max_iter: 10000

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4
  
exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  manual_seed: 1265
  name: 'BetaVAE'


================================================
FILE: configs/betatc_vae.yaml
================================================
model_params:
  name: 'BetaTCVAE'
  in_channels: 3
  latent_dim: 10
  anneal_steps: 10000
  alpha: 1.
  beta:  6.
  gamma: 1.

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: 'BetaTCVAE'


================================================
FILE: configs/bhvae.yaml
================================================
model_params:
  name: 'BetaVAE'
  in_channels: 3
  latent_dim: 128
  loss_type: 'H'
  beta: 10.

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: 'BetaVAE'


================================================
FILE: configs/cat_vae.yaml
================================================
model_params:
  name: 'CategoricalVAE'
  in_channels: 3
  latent_dim: 512
  categorical_dim: 40
  temperature: 0.5
  anneal_rate: 0.00003
  anneal_interval: 100
  alpha: 1.0

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "CategoricalVAE"


================================================
FILE: configs/cvae.yaml
================================================
model_params:
  name: 'ConditionalVAE'
  in_channels: 3
  num_classes: 40
  latent_dim: 128

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "ConditionalVAE"

================================================
FILE: configs/dfc_vae.yaml
================================================
model_params:
  name: 'DFCVAE'
  in_channels: 3
  latent_dim: 128

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "DFCVAE"


================================================
FILE: configs/dip_vae.yaml
================================================
model_params:
  name: 'DIPVAE'
  in_channels: 3
  latent_dim: 128
  lambda_diag: 0.05
  lambda_offdiag: 0.1


data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.001
  weight_decay: 0.0
  scheduler_gamma: 0.97
  kld_weight: 1
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "DIPVAE"
  manual_seed: 1265


================================================
FILE: configs/factorvae.yaml
================================================
model_params:
  name: 'FactorVAE'
  in_channels: 3
  latent_dim: 128
  gamma: 6.4

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  submodel: 'discriminator'
  retain_first_backpass: True
  LR: 0.005
  weight_decay: 0.0
  LR_2: 0.005
  scheduler_gamma_2: 0.95
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "FactorVAE"  
  



================================================
FILE: configs/gammavae.yaml
================================================
model_params:
  name: 'GammaVAE'
  in_channels: 3
  latent_dim: 128
  gamma_shape: 8.
  prior_shape: 2.
  prior_rate: 1.


data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.003
  weight_decay: 0.00005
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10
  gradient_clip_val: 0.8

logging_params:
  save_dir: "logs/"
  name: "GammaVAE"


================================================
FILE: configs/hvae.yaml
================================================
model_params:
  name: 'HVAE'
  in_channels: 3
  latent1_dim: 64
  latent2_dim: 64
  pseudo_input_size: 128

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "VampVAE"


================================================
FILE: configs/infovae.yaml
================================================
model_params:
  name: 'InfoVAE'
  in_channels: 3
  latent_dim: 128
  reg_weight: 110  # MMD weight
  kernel_type: 'imq'
  alpha: -9.0     # KLD weight
  beta: 10.5      # Reconstruction weight

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10
  gradient_clip_val: 0.8

logging_params:
  save_dir: "logs/"
  name: "InfoVAE"
  manual_seed: 1265






================================================
FILE: configs/iwae.yaml
================================================
model_params:
  name: 'IWAE'
  in_channels: 3
  latent_dim: 128
  num_samples: 5

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.007
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "IWAE"


================================================
FILE: configs/joint_vae.yaml
================================================
model_params:
  name: 'JointVAE'
  in_channels: 3
  latent_dim: 512
  categorical_dim: 40
  latent_min_capacity: 0.0
  latent_max_capacity: 20.0
  latent_gamma: 10.
  latent_num_iter: 25000
  categorical_min_capacity: 0.0
  categorical_max_capacity: 20.0
  categorical_gamma: 10.
  categorical_num_iter: 25000
  temperature: 0.5
  anneal_rate: 0.00003
  anneal_interval: 100
  alpha: 10.0

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "JointVAE"



================================================
FILE: configs/logcosh_vae.yaml
================================================
model_params:
  name: 'LogCoshVAE'
  in_channels: 3
  latent_dim: 128
  alpha: 10.0
  beta: 1.0

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.97
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "LogCoshVAE"



================================================
FILE: configs/lvae.yaml
================================================
model_params:
  name: 'LVAE'
  in_channels: 3
  latent_dims: [4,8,16,32,128]
  hidden_dims: [32, 64,128, 256, 512]

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "LVAE"


================================================
FILE: configs/miwae.yaml
================================================
model_params:
  name: 'MIWAE'
  in_channels: 3
  latent_dim: 128
  num_samples: 5
  num_estimates: 3

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "MIWAE"



================================================
FILE: configs/mssim_vae.yaml
================================================
model_params:
  name: 'MSSIMVAE'
  in_channels: 3
  latent_dim: 128

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "MSSIMVAE"


================================================
FILE: configs/swae.yaml
================================================
model_params:
  name: 'SWAE'
  in_channels: 3
  latent_dim: 128
  reg_weight: 100
  wasserstein_deg: 2.0
  num_projections: 200
  projection_dist: "normal" #"cauchy"

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "SWAE"







================================================
FILE: configs/vae.yaml
================================================
model_params:
  name: 'VanillaVAE'
  in_channels: 3
  latent_dim: 128


data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 100

logging_params:
  save_dir: "logs/"
  name: "VanillaVAE"
  


================================================
FILE: configs/vampvae.yaml
================================================
model_params:
  name: 'VampVAE'
  in_channels: 3
  latent_dim: 128

exp_params:
  dataset: celeba
  data_path: "../../shared/Data/"
  img_size: 64
  batch_size: 144 # Better to have a square number
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95

trainer_params:
  gpus: 1
  max_nb_epochs: 50
  max_epochs: 50

logging_params:
  save_dir: "logs/"
  name: "VampVAE"
  manual_seed: 1265


================================================
FILE: configs/vq_vae.yaml
================================================
model_params:
  name: 'VQVAE'
  in_channels: 3
  embedding_dim: 64
  num_embeddings: 512
  img_size: 64
  beta: 0.25

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.0
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: 'VQVAE'


================================================
FILE: configs/wae_mmd_imq.yaml
================================================
model_params:
  name: 'WAE_MMD'
  in_channels: 3
  latent_dim: 128
  reg_weight: 100
  kernel_type: 'imq'

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "WassersteinVAE_IMQ"







================================================
FILE: configs/wae_mmd_rbf.yaml
================================================
model_params:
  name: 'WAE_MMD'
  in_channels: 3
  latent_dim: 128
  reg_weight: 5000
  kernel_type: 'rbf'

data_params:
  data_path: "Data/"
  train_batch_size: 64
  val_batch_size:  64
  patch_size: 64
  num_workers: 4


exp_params:
  LR: 0.005
  weight_decay: 0.0
  scheduler_gamma: 0.95
  kld_weight: 0.00025
  manual_seed: 1265

trainer_params:
  gpus: [1]
  max_epochs: 10

logging_params:
  save_dir: "logs/"
  name: "WassersteinVAE_RBF"







================================================
FILE: dataset.py
================================================
import os
import torch
from torch import Tensor
from pathlib import Path
from typing import List, Optional, Sequence, Union, Any, Callable
from torchvision.datasets.folder import default_loader
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import CelebA
import zipfile


# Add your custom dataset class here
class MyDataset(Dataset):
    def __init__(self):
        pass
    
    
    def __len__(self):
        pass
    
    def __getitem__(self, idx):
        pass


class MyCelebA(CelebA):
    """
    A work-around to address issues with pytorch's celebA dataset class.
    
    Download and Extract
    URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
    """
    
    def _check_integrity(self) -> bool:
        return True
    
    

class OxfordPets(Dataset):
    """
    URL = https://www.robots.ox.ac.uk/~vgg/data/pets/
    """
    def __init__(self, 
                 data_path: str, 
                 split: str,
                 transform: Callable,
                **kwargs):
        self.data_dir = Path(data_path) / "OxfordPets"        
        self.transforms = transform
        imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg'])
        
        self.imgs = imgs[:int(len(imgs) * 0.75)] if split == "train" else imgs[int(len(imgs) * 0.75):]
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        img = default_loader(self.imgs[idx])
        
        if self.transforms is not None:
            img = self.transforms(img)
        
        return img, 0.0 # dummy datat to prevent breaking 

class VAEDataset(LightningDataModule):
    """
    PyTorch Lightning data module 

    Args:
        data_dir: root directory of your dataset.
        train_batch_size: the batch size to use during training.
        val_batch_size: the batch size to use during validation.
        patch_size: the size of the crop to take from the original images.
        num_workers: the number of parallel workers to create to load data
            items (see PyTorch's Dataloader documentation for more details).
        pin_memory: whether prepared items should be loaded into pinned memory
            or not. This can improve performance on GPUs.
    """

    def __init__(
        self,
        data_path: str,
        train_batch_size: int = 8,
        val_batch_size: int = 8,
        patch_size: Union[int, Sequence[int]] = (256, 256),
        num_workers: int = 0,
        pin_memory: bool = False,
        **kwargs,
    ):
        super().__init__()

        self.data_dir = data_path
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.patch_size = patch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

    def setup(self, stage: Optional[str] = None) -> None:
#       =========================  OxfordPets Dataset  =========================
            
#         train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
#                                               transforms.CenterCrop(self.patch_size),
# #                                               transforms.Resize(self.patch_size),
#                                               transforms.ToTensor(),
#                                                 transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
        
#         val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
#                                             transforms.CenterCrop(self.patch_size),
# #                                             transforms.Resize(self.patch_size),
#                                             transforms.ToTensor(),
#                                               transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

#         self.train_dataset = OxfordPets(
#             self.data_dir,
#             split='train',
#             transform=train_transforms,
#         )
        
#         self.val_dataset = OxfordPets(
#             self.data_dir,
#             split='val',
#             transform=val_transforms,
#         )
        
#       =========================  CelebA Dataset  =========================
    
        train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                              transforms.CenterCrop(148),
                                              transforms.Resize(self.patch_size),
                                              transforms.ToTensor(),])
        
        val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                            transforms.CenterCrop(148),
                                            transforms.Resize(self.patch_size),
                                            transforms.ToTensor(),])
        
        self.train_dataset = MyCelebA(
            self.data_dir,
            split='train',
            transform=train_transforms,
            download=False,
        )
        
        # Replace CelebA with your dataset
        self.val_dataset = MyCelebA(
            self.data_dir,
            split='test',
            transform=val_transforms,
            download=False,
        )
#       ===============================================================
        
    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=self.pin_memory,
        )

    def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=self.pin_memory,
        )
    
    def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        return DataLoader(
            self.val_dataset,
            batch_size=144,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=self.pin_memory,
        )
     

================================================
FILE: experiment.py
================================================
import os
import math
import torch
from torch import optim
from models import BaseVAE
from models.types_ import *
from utils import data_loader
import pytorch_lightning as pl
from torchvision import transforms
import torchvision.utils as vutils
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader


class VAEXperiment(pl.LightningModule):

    def __init__(self,
                 vae_model: BaseVAE,
                 params: dict) -> None:
        super(VAEXperiment, self).__init__()

        self.model = vae_model
        self.params = params
        self.curr_device = None
        self.hold_graph = False
        try:
            self.hold_graph = self.params['retain_first_backpass']
        except:
            pass

    def forward(self, input: Tensor, **kwargs) -> Tensor:
        return self.model(input, **kwargs)

    def training_step(self, batch, batch_idx, optimizer_idx = 0):
        real_img, labels = batch
        self.curr_device = real_img.device

        results = self.forward(real_img, labels = labels)
        train_loss = self.model.loss_function(*results,
                                              M_N = self.params['kld_weight'], #al_img.shape[0]/ self.num_train_imgs,
                                              optimizer_idx=optimizer_idx,
                                              batch_idx = batch_idx)

        self.log_dict({key: val.item() for key, val in train_loss.items()}, sync_dist=True)

        return train_loss['loss']

    def validation_step(self, batch, batch_idx, optimizer_idx = 0):
        real_img, labels = batch
        self.curr_device = real_img.device

        results = self.forward(real_img, labels = labels)
        val_loss = self.model.loss_function(*results,
                                            M_N = 1.0, #real_img.shape[0]/ self.num_val_imgs,
                                            optimizer_idx = optimizer_idx,
                                            batch_idx = batch_idx)

        self.log_dict({f"val_{key}": val.item() for key, val in val_loss.items()}, sync_dist=True)

        
    def on_validation_end(self) -> None:
        self.sample_images()
        
    def sample_images(self):
        # Get sample reconstruction image            
        test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader()))
        test_input = test_input.to(self.curr_device)
        test_label = test_label.to(self.curr_device)

#         test_input, test_label = batch
        recons = self.model.generate(test_input, labels = test_label)
        vutils.save_image(recons.data,
                          os.path.join(self.logger.log_dir , 
                                       "Reconstructions", 
                                       f"recons_{self.logger.name}_Epoch_{self.current_epoch}.png"),
                          normalize=True,
                          nrow=12)

        try:
            samples = self.model.sample(144,
                                        self.curr_device,
                                        labels = test_label)
            vutils.save_image(samples.cpu().data,
                              os.path.join(self.logger.log_dir , 
                                           "Samples",      
                                           f"{self.logger.name}_Epoch_{self.current_epoch}.png"),
                              normalize=True,
                              nrow=12)
        except Warning:
            pass

    def configure_optimizers(self):

        optims = []
        scheds = []

        optimizer = optim.Adam(self.model.parameters(),
                               lr=self.params['LR'],
                               weight_decay=self.params['weight_decay'])
        optims.append(optimizer)
        # Check if more than 1 optimizer is required (Used for adversarial training)
        try:
            if self.params['LR_2'] is not None:
                optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(),
                                        lr=self.params['LR_2'])
                optims.append(optimizer2)
        except:
            pass

        try:
            if self.params['scheduler_gamma'] is not None:
                scheduler = optim.lr_scheduler.ExponentialLR(optims[0],
                                                             gamma = self.params['scheduler_gamma'])
                scheds.append(scheduler)

                # Check if another scheduler is required for the second optimizer
                try:
                    if self.params['scheduler_gamma_2'] is not None:
                        scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1],
                                                                      gamma = self.params['scheduler_gamma_2'])
                        scheds.append(scheduler2)
                except:
                    pass
                return optims, scheds
        except:
            return optims


================================================
FILE: models/__init__.py
================================================
from .base import *
from .vanilla_vae import *
from .gamma_vae import *
from .beta_vae import *
from .wae_mmd import *
from .cvae import *
from .hvae import *
from .vampvae import *
from .iwae import *
from .dfcvae import *
from .mssim_vae import MSSIMVAE
from .fvae import *
from .cat_vae import *
from .joint_vae import *
from .info_vae import *
# from .twostage_vae import *
from .lvae import LVAE
from .logcosh_vae import *
from .swae import *
from .miwae import *
from .vq_vae import *
from .betatc_vae import *
from .dip_vae import *


# Aliases
VAE = VanillaVAE
GaussianVAE = VanillaVAE
CVAE = ConditionalVAE
GumbelVAE = CategoricalVAE

vae_models = {'HVAE':HVAE,
              'LVAE':LVAE,
              'IWAE':IWAE,
              'SWAE':SWAE,
              'MIWAE':MIWAE,
              'VQVAE':VQVAE,
              'DFCVAE':DFCVAE,
              'DIPVAE':DIPVAE,
              'BetaVAE':BetaVAE,
              'InfoVAE':InfoVAE,
              'WAE_MMD':WAE_MMD,
              'VampVAE': VampVAE,
              'GammaVAE':GammaVAE,
              'MSSIMVAE':MSSIMVAE,
              'JointVAE':JointVAE,
              'BetaTCVAE':BetaTCVAE,
              'FactorVAE':FactorVAE,
              'LogCoshVAE':LogCoshVAE,
              'VanillaVAE':VanillaVAE,
              'ConditionalVAE':ConditionalVAE,
              'CategoricalVAE':CategoricalVAE}


================================================
FILE: models/base.py
================================================
from .types_ import *
from torch import nn
from abc import abstractmethod

class BaseVAE(nn.Module):
    
    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass





================================================
FILE: models/beta_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class BetaVAE(BaseVAE):

    num_iter = 0 # Global static variable to keep track of iterations

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 beta: int = 4,
                 gamma:float = 1000.,
                 max_capacity: int = 25,
                 Capacity_max_iter: int = 1e5,
                 loss_type:str = 'B',
                 **kwargs) -> None:
        super(BetaVAE, self).__init__()

        self.latent_dim = latent_dim
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type
        self.C_max = torch.Tensor([max_capacity])
        self.C_stop_iter = Capacity_max_iter

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> Tensor:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        self.num_iter += 1
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset

        recons_loss =F.mse_loss(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
            loss = recons_loss + self.beta * kld_weight * kld_loss
        elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
            self.C_max = self.C_max.to(input.device)
            C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
            loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
        else:
            raise ValueError('Undefined loss type.')

        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/betatc_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
import math


class BetaTCVAE(BaseVAE):
    num_iter = 0 # Global static variable to keep track of iterations

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 anneal_steps: int = 200,
                 alpha: float = 1.,
                 beta: float =  6.,
                 gamma: float = 1.,
                 **kwargs) -> None:
        super(BetaTCVAE, self).__init__()

        self.latent_dim = latent_dim
        self.anneal_steps = anneal_steps

        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 32, 32, 32]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 4, stride= 2, padding  = 1),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)

        self.fc = nn.Linear(hidden_dims[-1]*16, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_var = nn.Linear(256, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, 256 *  2)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)

        result = torch.flatten(result, start_dim=1)
        result = self.fc(result)
        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 32, 4, 4)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var, z]

    def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
        """
        Computes the log pdf of the Gaussian with parameters mu and logvar at x
        :param x: (Tensor) Point at whichGaussian PDF is to be evaluated
        :param mu: (Tensor) Mean of the Gaussian distribution
        :param logvar: (Tensor) Log variance of the Gaussian distribution
        :return:
        """
        norm = - 0.5 * (math.log(2 * math.pi) + logvar)
        log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))
        return log_density

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
            
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        z = args[4]

        weight = 1 #kwargs['M_N']  # Account for the minibatch samples from the dataset

        recons_loss =F.mse_loss(recons, input, reduction='sum')

        log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)

        zeros = torch.zeros_like(z)
        log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1)

        batch_size, latent_dim = z.shape
        mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim),
                                                mu.view(1, batch_size, latent_dim),
                                                log_var.view(1, batch_size, latent_dim))

        # Reference
        # [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54
        dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size
        strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1))
        importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device)
        importance_weights.view(-1)[::batch_size] = 1 / dataset_size
        importance_weights.view(-1)[1::batch_size] = strat_weight
        importance_weights[batch_size - 2, 0] = strat_weight
        log_importance_weights = importance_weights.log()

        mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1)

        log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)
        log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)

        mi_loss  = (log_q_zx - log_q_z).mean()
        tc_loss = (log_q_z - log_prod_q_z).mean()
        kld_loss = (log_prod_q_z - log_p_z).mean()

        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        if self.training:
            self.num_iter += 1
            anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1)
        else:
            anneal_rate = 1.

        loss = recons_loss/batch_size + \
               self.alpha * mi_loss + \
               weight * (self.beta * tc_loss +
                         anneal_rate * self.gamma * kld_loss)
        
        return {'loss': loss,
                'Reconstruction_Loss':recons_loss,
                'KLD':kld_loss,
                'TC_Loss':tc_loss,
                'MI_Loss':mi_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/cat_vae.py
================================================
import torch
import numpy as np
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class CategoricalVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 categorical_dim: int = 40, # Num classes
                 hidden_dims: List = None,
                 temperature: float = 0.5,
                 anneal_rate: float = 3e-5,
                 anneal_interval: int = 100, # every 100 batches
                 alpha: float = 30.,
                 **kwargs) -> None:
        super(CategoricalVAE, self).__init__()

        self.latent_dim = latent_dim
        self.categorical_dim = categorical_dim
        self.temp = temperature
        self.min_temp = temperature
        self.anneal_rate = anneal_rate
        self.anneal_interval = anneal_interval
        self.alpha = alpha

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_z = nn.Linear(hidden_dims[-1]*4,
                               self.latent_dim * self.categorical_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim
                                       , hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())
        self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [B x C x H x W]
        :return: (Tensor) Latent code [B x D x Q]
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        z = self.fc_z(result)
        z = z.view(-1, self.latent_dim, self.categorical_dim)
        return [z]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D x Q]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:
        """
        Gumbel-softmax trick to sample from Categorical Distribution
        :param z: (Tensor) Latent Codes [B x D x Q]
        :return: (Tensor) [B x D]
        """
        # Sample from Gumbel
        u = torch.rand_like(z)
        g = - torch.log(- torch.log(u + eps) + eps)

        # Gumbel-Softmax sample
        s = F.softmax((z + g) / self.temp, dim=-1)
        s = s.view(-1, self.latent_dim * self.categorical_dim)
        return s


    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        q = self.encode(input)[0]
        z = self.reparameterize(q)
        return  [self.decode(z), input, q]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        q = args[2]

        q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        batch_idx = kwargs['batch_idx']

        # Anneal the temperature at regular intervals
        if batch_idx % self.anneal_interval == 0 and self.training:
            self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
                                   self.min_temp)

        recons_loss =F.mse_loss(recons, input, reduction='mean')

        # KL divergence between gumbel-softmax distribution
        eps = 1e-7

        # Entropy of the logits
        h1 = q_p * torch.log(q_p + eps)

        # Cross entropy with the categorical distribution
        h2 = q_p * np.log(1. / self.categorical_dim + eps)
        kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)

        # kld_weight = 1.2
        loss = self.alpha * recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        # [S x D x Q]

        M = num_samples * self.latent_dim
        np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)
        np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1
        np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim])
        z = torch.from_numpy(np_y)

        # z = self.sampling_dist.sample((num_samples * self.latent_dim, ))
        z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device)
        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/cvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class ConditionalVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 num_classes: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 img_size:int = 64,
                 **kwargs) -> None:
        super(ConditionalVAE, self).__init__()

        self.latent_dim = latent_dim
        self.img_size = img_size

        self.embed_class = nn.Linear(num_classes, img_size * img_size)
        self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        in_channels += 1 # To account for the extra label channel
        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        y = kwargs['labels'].float()
        embedded_class = self.embed_class(y)
        embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
        embedded_input = self.embed_data(input)

        x = torch.cat([embedded_input, embedded_class], dim = 1)
        mu, log_var = self.encode(x)

        z = self.reparameterize(mu, log_var)

        z = torch.cat([z, y], dim = 1)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int,
               **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        y = kwargs['labels'].float()
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        z = torch.cat([z, y], dim=1)
        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x, **kwargs)[0]

================================================
FILE: models/dfcvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torchvision.models import vgg19_bn
from torch.nn import functional as F
from .types_ import *


class DFCVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 alpha:float = 1,
                 beta:float = 0.5,
                 **kwargs) -> None:
        super(DFCVAE, self).__init__()

        self.latent_dim = latent_dim
        self.alpha = alpha
        self.beta = beta

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

        self.feature_network = vgg19_bn(pretrained=True)

        # Freeze the pretrained feature network
        for param in self.feature_network.parameters():
            param.requires_grad = False

        self.feature_network.eval()


    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        recons = self.decode(z)

        recons_features = self.extract_features(recons)
        input_features = self.extract_features(input)

        return  [recons, input, recons_features, input_features, mu, log_var]

    def extract_features(self,
                         input: Tensor,
                         feature_layers: List = None) -> List[Tensor]:
        """
        Extracts the features from the pretrained model
        at the layers indicated by feature_layers.
        :param input: (Tensor) [B x C x H x W]
        :param feature_layers: List of string of IDs
        :return: List of the extracted features
        """
        if feature_layers is None:
            feature_layers = ['14', '24', '34', '43']
        features = []
        result = input
        for (key, module) in self.feature_network.features._modules.items():
            result = module(result)
            if(key in feature_layers):
                features.append(result)

        return features

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        recons_features = args[2]
        input_features = args[3]
        mu = args[4]
        log_var = args[5]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)

        feature_loss = 0.0
        for (r, i) in zip(recons_features, input_features):
            feature_loss += F.mse_loss(r, i)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/dip_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class DIPVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 lambda_diag: float = 10.,
                 lambda_offdiag: float = 5.,
                 **kwargs) -> None:
        super(DIPVAE, self).__init__()

        self.latent_dim = latent_dim
        self.lambda_diag = lambda_diag
        self.lambda_offdiag = lambda_offdiag

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input, reduction='sum')


        kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        # DIP Loss
        centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D]
        cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D]

        # Add Variance for DIP Loss II
        cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D]
        # For DIp Loss I
        # cov_z = cov_mu

        cov_diag = torch.diag(cov_z) # [D]
        cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D]
        dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \
                   self.lambda_diag * torch.sum((cov_diag - 1) ** 2)

        loss = recons_loss + kld_weight * kld_loss + dip_loss
        return {'loss': loss,
                'Reconstruction_Loss':recons_loss,
                'KLD':-kld_loss,
                'DIP_Loss':dip_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/fvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class FactorVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 gamma: float = 40.,
                 **kwargs) -> None:
        super(FactorVAE, self).__init__()

        self.latent_dim = latent_dim
        self.gamma = gamma

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

        # Discriminator network for the Total Correlation (TC) loss
        self.discriminator = nn.Sequential(nn.Linear(self.latent_dim, 1000),
                                          nn.BatchNorm1d(1000),
                                          nn.LeakyReLU(0.2),
                                          nn.Linear(1000, 1000),
                                          nn.BatchNorm1d(1000),
                                          nn.LeakyReLU(0.2),
                                          nn.Linear(1000, 1000),
                                          nn.BatchNorm1d(1000),
                                          nn.LeakyReLU(0.2),
                                          nn.Linear(1000, 2))
        self.D_z_reserve = None


    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var, z]

    def permute_latent(self, z: Tensor) -> Tensor:
        """
        Permutes each of the latent codes in the batch
        :param z: [B x D]
        :return: [B x D]
        """
        B, D = z.size()

        # Returns a shuffled inds for each latent code in the batch
        inds = torch.cat([(D *i) + torch.randperm(D) for i in range(B)])
        return z.view(-1)[inds].view(B, D)

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        z = args[4]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        optimizer_idx = kwargs['optimizer_idx']

        # Update the VAE
        if optimizer_idx == 0:
            recons_loss =F.mse_loss(recons, input)
            kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

            self.D_z_reserve = self.discriminator(z)
            vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean()

            loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss

            # print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}')
            return {'loss': loss,
                    'Reconstruction_Loss':recons_loss,
                    'KLD':-kld_loss,
                    'VAE_TC_Loss': vae_tc_loss}

        # Update the Discriminator
        elif optimizer_idx == 1:
            device = input.device
            true_labels = torch.ones(input.size(0), dtype= torch.long,
                                     requires_grad=False).to(device)
            false_labels = torch.zeros(input.size(0), dtype= torch.long,
                                       requires_grad=False).to(device)

            z = z.detach() # Detach so that VAE is not trained again
            z_perm = self.permute_latent(z)
            D_z_perm = self.discriminator(z_perm)
            D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) +
                               F.cross_entropy(D_z_perm, true_labels))
            # print(f'D_TC: {D_tc_loss}')
            return {'loss': D_tc_loss,
                    'D_TC_Loss':D_tc_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/gamma_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.distributions import Gamma
from torch.nn import functional as F
from .types_ import *
import torch.nn.init as init


class GammaVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 gamma_shape: float = 8.,
                 prior_shape: float = 2.0,
                 prior_rate: float = 1.,
                 **kwargs) -> None:
        super(GammaVAE, self).__init__()
        self.latent_dim = latent_dim
        self.B = gamma_shape

        self.prior_alpha = torch.tensor([prior_shape])
        self.prior_beta = torch.tensor([prior_rate])

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),
                                   nn.Softmax())
        self.fc_var = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),
                                    nn.Softmax())

        # Build Decoder
        modules = []

        self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims[-1] * 4))

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1],
                               hidden_dims[-1],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=3,
                      kernel_size=3, padding=1),
            nn.Sigmoid())

        self.weight_init()

    def weight_init(self):

        # print(self._modules)
        for block in self._modules:
            for m in self._modules[block]:
                init_(m)

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        alpha = self.fc_mu(result)
        beta = self.fc_var(result)

        return [alpha, beta]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor:
        """
        Reparameterize the Gamma distribution by the shape augmentation trick.
        Reference:
        [1] https://arxiv.org/pdf/1610.05683.pdf

        :param alpha: (Tensor) Shape parameter of the latent Gamma
        :param beta: (Tensor) Rate parameter of the latent Gamma
        :return:
        """
        # Sample from Gamma to guarantee acceptance
        alpha_ = alpha.clone().detach()
        z_hat = Gamma(alpha_ + self.B, torch.ones_like(alpha_)).sample()

        # Compute the eps ~ N(0,1) that produces z_hat
        eps = self.inv_h_func(alpha + self.B , z_hat)
        z = self.h_func(alpha + self.B, eps)

        # When beta != 1, scale by beta
        return z / beta

    def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor:
        """
        Reparameterize a sample eps ~ N(0, 1) so that h(z) ~ Gamma(alpha, 1)
        :param alpha: (Tensor) Shape parameter
        :param eps: (Tensor) Random sample to reparameterize
        :return: (Tensor)
        """

        z = (alpha - 1./3.) * (1 + eps / torch.sqrt(9. * alpha - 3.))**3
        return z

    def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor:
        """
        Inverse reparameterize the given z into eps.
        :param alpha: (Tensor)
        :param z: (Tensor)
        :return: (Tensor)
        """
        eps = torch.sqrt(9. * alpha - 3.) * ((z / (alpha - 1./3.))**(1. / 3.) - 1.)
        return eps

    def forward(self, input: Tensor, **kwargs) -> Tensor:
        alpha, beta = self.encode(input)
        z = self.reparameterize(alpha, beta)
        return [self.decode(z), input, alpha, beta]

    # def I_function(self, alpha_p, beta_p, alpha_q, beta_q):
    #     return - (alpha_q * beta_q) / alpha_p - \
    #            beta_p * torch.log(alpha_p) - torch.lgamma(beta_p) + \
    #            (beta_p - 1) * torch.digamma(beta_q) + \
    #            (beta_p - 1) * torch.log(alpha_q)
    def I_function(self, a, b, c, d):
        return - c * d / a - b * torch.log(a) - torch.lgamma(b) + (b - 1) * (torch.digamma(d) + torch.log(c))

    def vae_gamma_kl_loss(self, a, b, c, d):
        """
        https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions
        b and d are Gamma shape parameters and
        a and c are scale parameters.
        (All, therefore, must be positive.)
        """

        a = 1 / a
        c = 1 / c
        losses = self.I_function(c, d, c, d) - self.I_function(a, b, c, d)
        return torch.sum(losses, dim=1)

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        alpha = args[2]
        beta = args[3]

        curr_device = input.device
        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset
        recons_loss = torch.mean(F.mse_loss(recons, input, reduction = 'none'), dim = (1,2,3))

        # https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions
        # alpha = 1./ alpha


        self.prior_alpha = self.prior_alpha.to(curr_device)
        self.prior_beta = self.prior_beta.to(curr_device)

        # kld_loss = - self.I_function(alpha, beta, self.prior_alpha, self.prior_beta)

        kld_loss = self.vae_gamma_kl_loss(alpha, beta, self.prior_alpha, self.prior_beta)

        # kld_loss = torch.sum(kld_loss, dim=1)

        loss = recons_loss + kld_loss
        loss = torch.mean(loss, dim = 0)
        # print(loss, recons_loss, kld_loss)
        return {'loss': loss} #, 'Reconstruction_Loss': recons_loss, 'KLD': -kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the modelSay
        :return: (Tensor)
        """
        z = Gamma(self.prior_alpha, self.prior_beta).sample((num_samples, self.latent_dim))
        z = z.squeeze().to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

def init_(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        init.orthogonal_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        m.weight.data.fill_(1)
        if m.bias is not None:
            m.bias.data.fill_(0)


================================================
FILE: models/hvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class HVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent1_dim: int,
                 latent2_dim: int,
                 hidden_dims: List = None,
                 img_size:int = 64,
                 pseudo_input_size: int = 128,
                 **kwargs) -> None:
        super(HVAE, self).__init__()

        self.latent1_dim = latent1_dim
        self.latent2_dim = latent2_dim
        self.img_size = img_size

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]
        channels = in_channels

        # Build z2 Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            channels = h_dim

        self.encoder_z2_layers = nn.Sequential(*modules)
        self.fc_z2_mu = nn.Linear(hidden_dims[-1]*4, latent2_dim)
        self.fc_z2_var = nn.Linear(hidden_dims[-1]*4, latent2_dim)
        # ========================================================================#
        # Build z1 Encoder
        self.embed_z2_code = nn.Linear(latent2_dim, img_size * img_size)
        self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        modules = []
        channels = in_channels + 1 # One more channel for the latent code
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            channels = h_dim

        self.encoder_z1_layers = nn.Sequential(*modules)
        self.fc_z1_mu = nn.Linear(hidden_dims[-1]*4, latent1_dim)
        self.fc_z1_var = nn.Linear(hidden_dims[-1]*4, latent1_dim)

        #========================================================================#
        # Build z2 Decoder
        self.recons_z1_mu = nn.Linear(latent2_dim, latent1_dim)
        self.recons_z1_log_var = nn.Linear(latent2_dim, latent1_dim)

        # ========================================================================#
        # Build z1 Decoder
        self.debed_z1_code = nn.Linear(latent1_dim, 1024)
        self.debed_z2_code = nn.Linear(latent2_dim, 1024)
        modules = []
        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

        # ========================================================================#
        # Pesudo Input for the Vamp-Prior
        # self.pseudo_input =  torch.eye(pseudo_input_size,
        #                                requires_grad=False).view(1, 1, pseudo_input_size, -1)
        #
        #
        # self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels,
        #                              kernel_size=3, stride=2, padding=1)

    def encode_z2(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder_z2_layers(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        z2_mu = self.fc_z2_mu(result)
        z2_log_var = self.fc_z2_var(result)

        return [z2_mu, z2_log_var]

    def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]:
        x = self.embed_data(input)
        z2 = self.embed_z2_code(z2)
        z2 = z2.view(-1, self.img_size, self.img_size).unsqueeze(1)
        result = torch.cat([x, z2], dim=1)

        result = self.encoder_z1_layers(result)
        result = torch.flatten(result, start_dim=1)
        z1_mu = self.fc_z1_mu(result)
        z1_log_var = self.fc_z1_var(result)

        return [z1_mu, z1_log_var]

    def encode(self, input: Tensor) -> List[Tensor]:
        z2_mu, z2_log_var = self.encode_z2(input)
        z2 = self.reparameterize(z2_mu, z2_log_var)

        # z1 ~ q(z1|x, z2)
        z1_mu, z1_log_var = self.encode_z1(input, z2)
        return [z1_mu, z1_log_var, z2_mu, z2_log_var, z2]

    def decode(self, input: Tensor) -> Tensor:
        result = self.decoder(input)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:

        # Encode the input into the latent codes z1 and z2
        # z2 ~q(z2 | x)
        # z1 ~ q(z1|x, z2)
        z1_mu, z1_log_var, z2_mu, z2_log_var, z2 = self.encode(input)
        z1 = self.reparameterize(z1_mu, z1_log_var)

        # Reconstruct the image using both the latent codes
        # x ~ p(x|z1, z2)
        debedded_z1 = self.debed_z1_code(z1)
        debedded_z2 = self.debed_z2_code(z2)
        result = torch.cat([debedded_z1, debedded_z2], dim=1)
        result = result.view(-1, 512, 2, 2)
        recons = self.decode(result)

        return  [recons,
                 input,
                 z1_mu, z1_log_var,
                 z2_mu, z2_log_var,
                 z1, z2]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]

        z1_mu = args[2]
        z1_log_var = args[3]

        z2_mu = args[4]
        z2_log_var = args[5]

        z1= args[6]
        z2 = args[7]

        # Reconstruct (decode) z2 into z1
        # z1 ~ p(z1|z2) [This for the loss calculation]
        z1_p_mu = self.recons_z1_mu(z2)
        z1_p_log_var = self.recons_z1_log_var(z2)


        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)

        z1_kld = torch.mean(-0.5 * torch.sum(1 + z1_log_var - z1_mu ** 2 - z1_log_var.exp(), dim = 1),
                            dim = 0)
        z2_kld = torch.mean(-0.5 * torch.sum(1 + z2_log_var - z2_mu ** 2 - z2_log_var.exp(), dim = 1),
                            dim = 0)

        z1_p_kld = torch.mean(-0.5 * torch.sum(1 + z1_p_log_var - (z1 - z1_p_mu) ** 2 - z1_p_log_var.exp(),
                                               dim = 1),
                            dim = 0)

        z2_p_kld = torch.mean(-0.5*(z2**2), dim = 0)

        kld_loss = -(z1_p_kld - z1_kld - z2_kld)
        loss = recons_loss + kld_weight * kld_loss
        # print(z2_p_kld)

        return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        z2 = torch.randn(batch_size,
                         self.latent2_dim)

        z2 = z2.cuda(current_device)

        z1_mu = self.recons_z1_mu(z2)
        z1_log_var = self.recons_z1_log_var(z2)
        z1 = self.reparameterize(z1_mu, z1_log_var)

        debedded_z1 = self.debed_z1_code(z1)
        debedded_z2 = self.debed_z2_code(z2)

        result = torch.cat([debedded_z1, debedded_z2], dim=1)
        result = result.view(-1, 512, 2, 2)
        samples = self.decode(result)

        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]


================================================
FILE: models/info_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class InfoVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 alpha: float = -0.5,
                 beta: float = 5.0,
                 reg_weight: int = 100,
                 kernel_type: str = 'imq',
                 latent_var: float = 2.,
                 **kwargs) -> None:
        super(InfoVAE, self).__init__()

        self.latent_dim = latent_dim
        self.reg_weight = reg_weight
        self.kernel_type = kernel_type
        self.z_var = latent_var

        assert alpha <= 0, 'alpha must be negative or zero.'

        self.alpha = alpha
        self.beta = beta

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, z, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        z = args[2]
        mu = args[3]
        log_var = args[4]

        batch_size = input.size(0)
        bias_corr = batch_size *  (batch_size - 1)
        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset

        recons_loss =F.mse_loss(recons, input)
        mmd_loss = self.compute_mmd(z)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)

        loss = self.beta * recons_loss + \
               (1. - self.alpha) * kld_weight * kld_loss + \
               (self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss}

    def compute_kernel(self,
                       x1: Tensor,
                       x2: Tensor) -> Tensor:
        # Convert the tensors into row and column vectors
        D = x1.size(1)
        N = x1.size(0)

        x1 = x1.unsqueeze(-2) # Make it into a column tensor
        x2 = x2.unsqueeze(-3) # Make it into a row tensor

        """
        Usually the below lines are not required, especially in our case,
        but this is useful when x1 and x2 have different sizes
        along the 0th dimension.
        """
        x1 = x1.expand(N, N, D)
        x2 = x2.expand(N, N, D)

        if self.kernel_type == 'rbf':
            result = self.compute_rbf(x1, x2)
        elif self.kernel_type == 'imq':
            result = self.compute_inv_mult_quad(x1, x2)
        else:
            raise ValueError('Undefined kernel type.')

        return result


    def compute_rbf(self,
                    x1: Tensor,
                    x2: Tensor,
                    eps: float = 1e-7) -> Tensor:
        """
        Computes the RBF Kernel between x1 and x2.
        :param x1: (Tensor)
        :param x2: (Tensor)
        :param eps: (Float)
        :return:
        """
        z_dim = x2.size(-1)
        sigma = 2. * z_dim * self.z_var

        result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
        return result

    def compute_inv_mult_quad(self,
                               x1: Tensor,
                               x2: Tensor,
                               eps: float = 1e-7) -> Tensor:
        """
        Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
        given by

                k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
        :param x1: (Tensor)
        :param x2: (Tensor)
        :param eps: (Float)
        :return:
        """
        z_dim = x2.size(-1)
        C = 2 * z_dim * self.z_var
        kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))

        # Exclude diagonal elements
        result = kernel.sum() - kernel.diag().sum()

        return result

    def compute_mmd(self, z: Tensor) -> Tensor:
        # Sample from prior (Gaussian) distribution
        prior_z = torch.randn_like(z)

        prior_z__kernel = self.compute_kernel(prior_z, prior_z)
        z__kernel = self.compute_kernel(z, z)
        priorz_z__kernel = self.compute_kernel(prior_z, z)

        mmd = prior_z__kernel.mean() + \
              z__kernel.mean() - \
              2 * priorz_z__kernel.mean()
        return mmd

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/iwae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class IWAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 num_samples: int = 5,
                 **kwargs) -> None:
        super(IWAE, self).__init__()

        self.latent_dim = latent_dim
        self.num_samples = num_samples

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes of S samples
        onto the image space.
        :param z: (Tensor) [B x S x D]
        :return: (Tensor) [B x S x C x H x W]
        """
        B, _, _ = z.size()
        z = z.view(-1, self.latent_dim) #[BS x D]
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result) #[BS x C x H x W ]
        result = result.view([B, -1, result.size(1), result.size(2), result.size(3)]) #[B x S x C x H x W]
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        mu = mu.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
        log_var = log_var.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
        z= self.reparameterize(mu, log_var) # [B x S x D]
        eps = (z - mu) / log_var # Prior samples
        return  [self.decode(z), input, mu, log_var, z, eps]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        z = args[4]
        eps = args[5]

        input = input.repeat(self.num_samples, 1, 1, 1, 1).permute(1, 0, 2, 3, 4) #[B x S x C x H x W]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset

        log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S]
        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S]
        # Get importance weights
        log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data

        # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
        weight = F.softmax(log_weight, dim = -1)
        # kld_loss = torch.mean(kld_loss, dim = 0)

        loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0)

        return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples, 1,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z).squeeze()
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image.
        Returns only the first reconstructed sample
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0][:, 0, :]


================================================
FILE: models/joint_vae.py
================================================
import torch
import numpy as np
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class JointVAE(BaseVAE):
    num_iter = 1

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 categorical_dim: int,
                 latent_min_capacity: float =0.,
                 latent_max_capacity: float = 25.,
                 latent_gamma: float = 30.,
                 latent_num_iter: int = 25000,
                 categorical_min_capacity: float =0.,
                 categorical_max_capacity: float = 25.,
                 categorical_gamma: float = 30.,
                 categorical_num_iter: int = 25000,
                 hidden_dims: List = None,
                 temperature: float = 0.5,
                 anneal_rate: float = 3e-5,
                 anneal_interval: int = 100, # every 100 batches
                 alpha: float = 30.,
                 **kwargs) -> None:
        super(JointVAE, self).__init__()

        self.latent_dim = latent_dim
        self.categorical_dim = categorical_dim
        self.temp = temperature
        self.min_temp = temperature
        self.anneal_rate = anneal_rate
        self.anneal_interval = anneal_interval
        self.alpha = alpha

        self.cont_min = latent_min_capacity
        self.cont_max = latent_max_capacity

        self.disc_min = categorical_min_capacity
        self.disc_max = categorical_max_capacity

        self.cont_gamma = latent_gamma
        self.disc_gamma = categorical_gamma

        self.cont_iter = latent_num_iter
        self.disc_iter = categorical_num_iter

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, self.latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, self.latent_dim)
        self.fc_z = nn.Linear(hidden_dims[-1]*4, self.categorical_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(self.latent_dim + self.categorical_dim,
                                       hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())
        self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [B x C x H x W]
        :return: (Tensor) Latent code [B x D x Q]
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        z = self.fc_z(result)
        z = z.view(-1, self.categorical_dim)
        return [mu, log_var, z]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D x Q]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self,
                       mu: Tensor,
                       log_var: Tensor,
                       q: Tensor,
                       eps:float = 1e-7) -> Tensor:
        """
        Gumbel-softmax trick to sample from Categorical Distribution
        :param mu: (Tensor) mean of the latent Gaussian  [B x D]
        :param log_var: (Tensor) Log variance of the latent Gaussian [B x D]
        :param q: (Tensor) Categorical latent Codes [B x Q]
        :return: (Tensor) [B x (D + Q)]
        """

        std = torch.exp(0.5 * log_var)
        e = torch.randn_like(std)
        z = e * std + mu

        # Sample from Gumbel
        u = torch.rand_like(q)
        g = - torch.log(- torch.log(u + eps) + eps)

        # Gumbel-Softmax sample
        s = F.softmax((q + g) / self.temp, dim=-1)
        s = s.view(-1, self.categorical_dim)

        return torch.cat([z, s], dim=1)


    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var, q = self.encode(input)
        z = self.reparameterize(mu, log_var, q)
        return  [self.decode(z), input, q, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        q = args[2]
        mu = args[3]
        log_var = args[4]

        q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities


        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        batch_idx = kwargs['batch_idx']

        # Anneal the temperature at regular intervals
        if batch_idx % self.anneal_interval == 0 and self.training:
            self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
                                   self.min_temp)

        recons_loss =F.mse_loss(recons, input, reduction='mean')

        # Adaptively increase the discrinimator capacity
        disc_curr = (self.disc_max - self.disc_min) * \
                    self.num_iter/ float(self.disc_iter) + self.disc_min
        disc_curr = min(disc_curr, np.log(self.categorical_dim))

        # KL divergence between gumbel-softmax distribution
        eps = 1e-7

        # Entropy of the logits
        h1 = q_p * torch.log(q_p + eps)
        # Cross entropy with the categorical distribution
        h2 = q_p * np.log(1. / self.categorical_dim + eps)
        kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0)

        # Compute Continuous loss
        # Adaptively increase the continuous capacity
        cont_curr = (self.cont_max - self.cont_min) * \
                    self.num_iter/ float(self.cont_iter) + self.cont_min
        cont_curr = min(cont_curr, self.cont_max)

        kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(),
                                                    dim=1),
                                   dim=0)
        capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss) + \
                        self.cont_gamma * torch.abs(cont_curr - kld_cont_loss)
        # kld_weight = 1.2
        loss = self.alpha * recons_loss + kld_weight * capacity_loss

        if self.training:
            self.num_iter += 1
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'Capacity_Loss':capacity_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        # [S x D]
        z = torch.randn(num_samples,
                        self.latent_dim)

        M = num_samples
        np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)
        np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1
        np_y = np.reshape(np_y, [M , self.categorical_dim])
        q = torch.from_numpy(np_y)

        # z = self.sampling_dist.sample((num_samples * self.latent_dim, ))
        z = torch.cat([z, q], dim = 1).to(current_device)
        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/logcosh_vae.py
================================================
import torch
import torch.nn.functional as F
from models import BaseVAE
from torch import nn
from .types_ import *


class LogCoshVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 alpha: float = 100.,
                 beta: float = 10.,
                 **kwargs) -> None:
        super(LogCoshVAE, self).__init__()

        self.latent_dim = latent_dim
        self.alpha = alpha
        self.beta = beta

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        t = recons - input
        # recons_loss = F.mse_loss(recons, input)
        # cosh = torch.cosh(self.alpha * t)
        # recons_loss = (1./self.alpha * torch.log(cosh)).mean()

        recons_loss = self.alpha * t + \
                      torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \
                      torch.log(torch.tensor(2.0))
        # print(self.alpha* t.max(), self.alpha*t.min())
        recons_loss = (1. / self.alpha) * recons_loss.mean()

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + self.beta * kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/lvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
from math import floor, pi, log


def conv_out_shape(img_size):
    return floor((img_size + 2 - 3) / 2.) + 1

class EncoderBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 latent_dim: int,
                 img_size: int):
        super(EncoderBlock, self).__init__()

        # Build Encoder
        self.encoder = nn.Sequential(
                            nn.Conv2d(in_channels,
                                      out_channels,
                                      kernel_size=3, stride=2, padding=1),
                            nn.BatchNorm2d(out_channels),
                            nn.LeakyReLU())

        out_size = conv_out_shape(img_size)
        self.encoder_mu = nn.Linear(out_channels * out_size ** 2 , latent_dim)
        self.encoder_var = nn.Linear(out_channels * out_size ** 2, latent_dim)

    def forward(self, input: Tensor) -> Tensor:
        result = self.encoder(input)
        h = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.encoder_mu(h)
        log_var = self.encoder_var(h)

        return [result, mu, log_var]

class LadderBlock(nn.Module):
    def __init__(self,
                 in_channels: int,
                 latent_dim: int):
        super(LadderBlock, self).__init__()

        # Build Decoder
        self.decode = nn.Sequential(nn.Linear(in_channels, latent_dim),
                                    nn.BatchNorm1d(latent_dim))
        self.fc_mu = nn.Linear(latent_dim, latent_dim)
        self.fc_var = nn.Linear(latent_dim, latent_dim)

    def forward(self, z: Tensor) -> Tensor:
        z = self.decode(z)
        mu = self.fc_mu(z)
        log_var = self.fc_var(z)

        return [mu, log_var]

class LVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dims: List,
                 hidden_dims: List,
                 **kwargs) -> None:
        super(LVAE, self).__init__()

        self.latent_dims = latent_dims
        self.hidden_dims = hidden_dims
        self.num_rungs = len(latent_dims)

        assert len(latent_dims) == len(hidden_dims), "Length of the latent" \
                                                     "and hidden dims must be the same"

        # Build Encoder
        modules = []
        img_size = 64
        for i, h_dim in enumerate(hidden_dims):
            modules.append(EncoderBlock(in_channels,
                                        h_dim,
                                        latent_dims[i],
                                        img_size))

            img_size = conv_out_shape(img_size)
            in_channels = h_dim

        self.encoders = nn.Sequential(*modules)
        # ====================================================================== #
        # Build Decoder
        modules = []

        for i in range(self.num_rungs -1, 0, -1):
            modules.append(LadderBlock(latent_dims[i],
                                       latent_dims[i-1]))

        self.ladders = nn.Sequential(*modules)

        self.decoder_input = nn.Linear(latent_dims[0], hidden_dims[-1] * 4)

        hidden_dims.reverse()
        modules = []
        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())
        hidden_dims.reverse()

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        h = input

        # Posterior Parameters
        post_params = []
        for encoder_block in self.encoders:
            h, mu, log_var = encoder_block(h)
            post_params.append((mu, log_var))

        return post_params

    def decode(self, z: Tensor, post_params: List) -> Tuple:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        kl_div = 0
        post_params.reverse()
        for i, ladder_block in enumerate(self.ladders):
            mu_e, log_var_e = post_params[i]
            mu_t, log_var_t = ladder_block(z)
            mu, log_var = self.merge_gauss(mu_e, mu_t,
                                           log_var_e, log_var_t)
            z = self.reparameterize(mu, log_var)
            kl_div += self.compute_kl_divergence(z, (mu, log_var), (mu_e, log_var_e))

        result = self.decoder_input(z)
        result = result.view(-1, self.hidden_dims[-1], 2, 2)
        result = self.decoder(result)
        return self.final_layer(result), kl_div

    def merge_gauss(self,
                    mu_1: Tensor,
                    mu_2: Tensor,
                    log_var_1: Tensor,
                    log_var_2: Tensor) -> List:

        p_1 = 1. / (log_var_1.exp() + 1e-7)
        p_2 = 1. / (log_var_2.exp() + 1e-7)

        mu = (mu_1 * p_1 + mu_2 * p_2)/(p_1 + p_2)
        log_var = torch.log(1./(p_1 + p_2))
        return [mu, log_var]

    def compute_kl_divergence(self, z: Tensor, q_params: Tuple, p_params: Tuple):
        mu_q, log_var_q = q_params
        mu_p, log_var_p = p_params
        #
        # qz = -0.5 * torch.sum(1 + log_var_q + (z - mu_q) ** 2 / (2 * log_var_q.exp() + 1e-8), dim=1)
        # pz = -0.5 * torch.sum(1 + log_var_p + (z - mu_p) ** 2 / (2 * log_var_p.exp() + 1e-8), dim=1)

        kl = (log_var_p - log_var_q) + (log_var_q.exp() + (mu_q - mu_p)**2)/(2 * log_var_p.exp()) - 0.5
        kl = torch.sum(kl, dim = -1)
        return kl

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        post_params = self.encode(input)
        mu, log_var = post_params.pop()
        z = self.reparameterize(mu, log_var)
        recons, kl_div = self.decode(z, post_params)

        #kl_div += -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1)
        return [recons, input, kl_div]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        kl_div = args[2]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)

        kld_loss = torch.mean(kl_div, dim = 0)
        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss }

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dims[-1])

        z = z.to(current_device)

        for ladder_block in self.ladders:
            mu, log_var = ladder_block(z)
            z = self.reparameterize(mu, log_var)

        result = self.decoder_input(z)
        result = result.view(-1, self.hidden_dims[-1], 2, 2)
        result = self.decoder(result)
        samples = self.final_layer(result)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/miwae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
from torch.distributions import Normal


class MIWAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 num_samples: int = 5,
                 num_estimates: int = 5,
                 **kwargs) -> None:
        super(MIWAE, self).__init__()

        self.latent_dim = latent_dim
        self.num_samples = num_samples # K
        self.num_estimates = num_estimates # M

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes of S samples
        onto the image space.
        :param z: (Tensor) [B x S x D]
        :return: (Tensor) [B x S x C x H x W]
        """
        B, M,S, D = z.size()
        z = z.contiguous().view(-1, self.latent_dim) #[BMS x D]
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result) #[BMS x C x H x W ]
        result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W]
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
        log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
        z = self.reparameterize(mu, log_var) # [B x M x S x D]
        eps = (z - mu) / log_var # Prior samples
        return  [self.decode(z), input, mu, log_var, z, eps]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        z = args[4]
        eps = args[5]

        input = input.repeat(self.num_estimates,
                             self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset

        log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S]

        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S]
        # Get importance weights
        log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data

        # Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
        weight = F.softmax(log_weight, dim = -1)  # [B x M x S]

        loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0)

        return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples, 1, 1,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z).squeeze()
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image.
        Returns only the first reconstructed sample
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0][:, 0, 0, :]


================================================
FILE: models/mssim_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
from math import exp


class MSSIMVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 window_size: int = 11,
                 size_average: bool = True,
                 **kwargs) -> None:
        super(MSSIMVAE, self).__init__()

        self.latent_dim = latent_dim
        self.in_channels = in_channels

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

        self.mssim_loss = MSSIM(self.in_channels,
                                window_size,
                                size_average)

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args: Any,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss = self.mssim_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.cuda(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

class MSSIM(nn.Module):

    def __init__(self,
                 in_channels: int = 3,
                 window_size: int=11,
                 size_average:bool = True) -> None:
        """
        Computes the differentiable MS-SSIM loss
        Reference:
        [1] https://github.com/jorge-pessoa/pytorch-msssim/blob/dev/pytorch_msssim/__init__.py
            (MIT License)

        :param in_channels: (Int)
        :param window_size: (Int)
        :param size_average: (Bool)
        """
        super(MSSIM, self).__init__()
        self.in_channels = in_channels
        self.window_size = window_size
        self.size_average = size_average

    def gaussian_window(self, window_size:int, sigma: float) -> Tensor:
        kernel = torch.tensor([exp((x - window_size // 2)**2/(2 * sigma ** 2))
                               for x in range(window_size)])
        return kernel/kernel.sum()

    def create_window(self, window_size, in_channels):
        _1D_window = self.gaussian_window(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(in_channels, 1, window_size, window_size).contiguous()
        return window

    def ssim(self,
             img1: Tensor,
             img2: Tensor,
             window_size: int,
             in_channel: int,
             size_average: bool) -> Tensor:

        device = img1.device
        window = self.create_window(window_size, in_channel).to(device)
        mu1 = F.conv2d(img1, window, padding= window_size//2, groups=in_channel)
        mu2 = F.conv2d(img2, window, padding= window_size//2, groups=in_channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size//2, groups=in_channel) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size//2, groups=in_channel) - mu2_sq
        sigma12   = F.conv2d(img1 * img2, window, padding = window_size//2, groups=in_channel) - mu1_mu2

        img_range = 1.0 #img1.max() - img1.min() # Dynamic range
        C1 = (0.01 * img_range) ** 2
        C2 = (0.03 * img_range) ** 2

        v1 = 2.0 * sigma12 + C2
        v2 = sigma1_sq + sigma2_sq + C2
        cs = torch.mean(v1 / v2)  # contrast sensitivity

        ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

        if size_average:
            ret = ssim_map.mean()
        else:
            ret = ssim_map.mean(1).mean(1).mean(1)
        return ret, cs

    def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
        device = img1.device
        weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
        levels = weights.size()[0]
        mssim = []
        mcs = []

        for _ in range(levels):
            sim, cs = self.ssim(img1, img2,
                                self.window_size,
                                self.in_channels,
                                self.size_average)
            mssim.append(sim)
            mcs.append(cs)

            img1 = F.avg_pool2d(img1, (2, 2))
            img2 = F.avg_pool2d(img2, (2, 2))

        mssim = torch.stack(mssim)
        mcs = torch.stack(mcs)

        # # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
        # if normalize:
        #     mssim = (mssim + 1) / 2
        #     mcs = (mcs + 1) / 2

        pow1 = mcs ** weights
        pow2 = mssim ** weights

        output = torch.prod(pow1[:-1] * pow2[-1])
        return 1 - output




================================================
FILE: models/swae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from torch import distributions as dist
from .types_ import *


class SWAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 reg_weight: int = 100,
                 wasserstein_deg: float= 2.,
                 num_projections: int = 50,
                 projection_dist: str = 'normal',
                    **kwargs) -> None:
        super(SWAE, self).__init__()

        self.latent_dim = latent_dim
        self.reg_weight = reg_weight
        self.p = wasserstein_deg
        self.num_projections = num_projections
        self.proj_dist = projection_dist

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input: Tensor) -> Tensor:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        z = self.fc_z(result)
        return z

    def decode(self, z: Tensor) -> Tensor:
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        z = self.encode(input)
        return  [self.decode(z), input, z]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        z = args[2]

        batch_size = input.size(0)
        bias_corr = batch_size *  (batch_size - 1)
        reg_weight = self.reg_weight / bias_corr

        recons_loss_l2 = F.mse_loss(recons, input)
        recons_loss_l1 = F.l1_loss(recons, input)

        swd_loss = self.compute_swd(z, self.p, reg_weight)

        loss = recons_loss_l2 + recons_loss_l1 + swd_loss
        return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss}

    def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:
        """
        Returns random samples from latent distribution's (Gaussian)
        unit sphere for projecting the encoded samples and the
        distribution samples.

        :param latent_dim: (Int) Dimensionality of the latent space (D)
        :param num_samples: (Int) Number of samples required (S)
        :return: Random projections from the latent unit sphere
        """
        if self.proj_dist == 'normal':
            rand_samples = torch.randn(num_samples, latent_dim)
        elif self.proj_dist == 'cauchy':
            rand_samples = dist.Cauchy(torch.tensor([0.0]),
                                       torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()
        else:
            raise ValueError('Unknown projection distribution.')

        rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)
        return rand_proj # [S x D]


    def compute_swd(self,
                    z: Tensor,
                    p: float,
                    reg_weight: float) -> Tensor:
        """
        Computes the Sliced Wasserstein Distance (SWD) - which consists of
        randomly projecting the encoded and prior vectors and computing
        their Wasserstein distance along those projections.

        :param z: Latent samples # [N  x D]
        :param p: Value for the p^th Wasserstein distance
        :param reg_weight:
        :return:
        """
        prior_z = torch.randn_like(z) # [N x D]
        device = z.device

        proj_matrix = self.get_random_projections(self.latent_dim,
                                                  num_samples=self.num_projections).transpose(0,1).to(device)

        latent_projections = z.matmul(proj_matrix) # [N x S]
        prior_projections = prior_z.matmul(proj_matrix) # [N x S]

        # The Wasserstein distance is computed by sorting the two projections
        # across the batches and computing their element-wise l2 distance
        w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \
                 torch.sort(prior_projections.t(), dim=1)[0]
        w_dist = w_dist.pow(p)
        return reg_weight * w_dist.mean()

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]


================================================
FILE: models/twostage_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class TwoStageVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 hidden_dims2: List = None,
                 **kwargs) -> None:
        super(TwoStageVAE, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        if hidden_dims2 is None:
            hidden_dims2 = [1024, 1024]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )
        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

        #---------------------- Second VAE ---------------------------#
        encoder2 = []
        in_channels = self.latent_dim
        for h_dim in hidden_dims2:
            encoder2.append(nn.Sequential(
                                nn.Linear(in_channels, h_dim),
                                nn.BatchNorm1d(h_dim),
                                nn.LeakyReLU()))
            in_channels = h_dim
        self.encoder2 = nn.Sequential(*encoder2)
        self.fc_mu2 = nn.Linear(hidden_dims2[-1], self.latent_dim)
        self.fc_var2 = nn.Linear(hidden_dims2[-1], self.latent_dim)

        decoder2 = []
        hidden_dims2.reverse()

        in_channels = self.latent_dim
        for h_dim in hidden_dims2:
            decoder2.append(nn.Sequential(
                                nn.Linear(in_channels, h_dim),
                                nn.BatchNorm1d(h_dim),
                                nn.LeakyReLU()))
            in_channels = h_dim
        self.decoder2 = nn.Sequential(*decoder2)

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)

        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        recons_loss =F.mse_loss(recons, input)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}

    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

================================================
FILE: models/types_.py
================================================
from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor

Tensor = TypeVar('torch.tensor')


================================================
FILE: models/vampvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *


class VampVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 num_components: int = 50,
                 **kwargs) -> None:
        super(VampVAE, self).__init__()

        self.latent_dim = latent_dim
        self.num_components = num_components

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_
Download .txt
gitextract_vk1fb46d/

├── .gitignore
├── .idea/
│   ├── .gitignore
│   ├── PyTorch-VAE.iml
│   ├── inspectionProfiles/
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   └── vcs.xml
├── LICENSE.md
├── README.md
├── configs/
│   ├── bbvae.yaml
│   ├── betatc_vae.yaml
│   ├── bhvae.yaml
│   ├── cat_vae.yaml
│   ├── cvae.yaml
│   ├── dfc_vae.yaml
│   ├── dip_vae.yaml
│   ├── factorvae.yaml
│   ├── gammavae.yaml
│   ├── hvae.yaml
│   ├── infovae.yaml
│   ├── iwae.yaml
│   ├── joint_vae.yaml
│   ├── logcosh_vae.yaml
│   ├── lvae.yaml
│   ├── miwae.yaml
│   ├── mssim_vae.yaml
│   ├── swae.yaml
│   ├── vae.yaml
│   ├── vampvae.yaml
│   ├── vq_vae.yaml
│   ├── wae_mmd_imq.yaml
│   └── wae_mmd_rbf.yaml
├── dataset.py
├── experiment.py
├── models/
│   ├── __init__.py
│   ├── base.py
│   ├── beta_vae.py
│   ├── betatc_vae.py
│   ├── cat_vae.py
│   ├── cvae.py
│   ├── dfcvae.py
│   ├── dip_vae.py
│   ├── fvae.py
│   ├── gamma_vae.py
│   ├── hvae.py
│   ├── info_vae.py
│   ├── iwae.py
│   ├── joint_vae.py
│   ├── logcosh_vae.py
│   ├── lvae.py
│   ├── miwae.py
│   ├── mssim_vae.py
│   ├── swae.py
│   ├── twostage_vae.py
│   ├── types_.py
│   ├── vampvae.py
│   ├── vanilla_vae.py
│   ├── vq_vae.py
│   └── wae_mmd.py
├── requirements.txt
├── run.py
├── tests/
│   ├── bvae.py
│   ├── test_betatcvae.py
│   ├── test_cat_vae.py
│   ├── test_dfc.py
│   ├── test_dipvae.py
│   ├── test_fvae.py
│   ├── test_gvae.py
│   ├── test_hvae.py
│   ├── test_iwae.py
│   ├── test_joint_Vae.py
│   ├── test_logcosh.py
│   ├── test_lvae.py
│   ├── test_miwae.py
│   ├── test_mssimvae.py
│   ├── test_swae.py
│   ├── test_vae.py
│   ├── test_vq_vae.py
│   ├── test_wae.py
│   ├── text_cvae.py
│   └── text_vamp.py
└── utils.py
Download .txt
SYMBOL INDEX (386 symbols across 46 files)

FILE: dataset.py
  class MyDataset (line 15) | class MyDataset(Dataset):
    method __init__ (line 16) | def __init__(self):
    method __len__ (line 20) | def __len__(self):
    method __getitem__ (line 23) | def __getitem__(self, idx):
  class MyCelebA (line 27) | class MyCelebA(CelebA):
    method _check_integrity (line 35) | def _check_integrity(self) -> bool:
  class OxfordPets (line 40) | class OxfordPets(Dataset):
    method __init__ (line 44) | def __init__(self,
    method __len__ (line 55) | def __len__(self):
    method __getitem__ (line 58) | def __getitem__(self, idx):
  class VAEDataset (line 66) | class VAEDataset(LightningDataModule):
    method __init__ (line 81) | def __init__(
    method setup (line 100) | def setup(self, stage: Optional[str] = None) -> None:
    method train_dataloader (line 155) | def train_dataloader(self) -> DataLoader:
    method val_dataloader (line 164) | def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
    method test_dataloader (line 173) | def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:

FILE: experiment.py
  class VAEXperiment (line 15) | class VAEXperiment(pl.LightningModule):
    method __init__ (line 17) | def __init__(self,
    method forward (line 31) | def forward(self, input: Tensor, **kwargs) -> Tensor:
    method training_step (line 34) | def training_step(self, batch, batch_idx, optimizer_idx = 0):
    method validation_step (line 48) | def validation_step(self, batch, batch_idx, optimizer_idx = 0):
    method on_validation_end (line 61) | def on_validation_end(self) -> None:
    method sample_images (line 64) | def sample_images(self):
    method configure_optimizers (line 92) | def configure_optimizers(self):

FILE: models/base.py
  class BaseVAE (line 5) | class BaseVAE(nn.Module):
    method __init__ (line 7) | def __init__(self) -> None:
    method encode (line 10) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 13) | def decode(self, input: Tensor) -> Any:
    method sample (line 16) | def sample(self, batch_size:int, current_device: int, **kwargs) -> Ten...
    method generate (line 19) | def generate(self, x: Tensor, **kwargs) -> Tensor:
    method forward (line 23) | def forward(self, *inputs: Tensor) -> Tensor:
    method loss_function (line 27) | def loss_function(self, *inputs: Any, **kwargs) -> Tensor:

FILE: models/beta_vae.py
  class BetaVAE (line 8) | class BetaVAE(BaseVAE):
    method __init__ (line 12) | def __init__(self,
    method encode (line 88) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 105) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 112) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 124) | def forward(self, input: Tensor, **kwargs) -> Tensor:
    method loss_function (line 129) | def loss_function(self,
    method sample (line 154) | def sample(self,
    method generate (line 172) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/betatc_vae.py
  class BetaTCVAE (line 9) | class BetaTCVAE(BaseVAE):
    method __init__ (line 12) | def __init__(self,
    method encode (line 84) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 102) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 115) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 127) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method log_density_gaussian (line 132) | def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
    method loss_function (line 144) | def loss_function(self,
    method sample (line 213) | def sample(self,
    method generate (line 231) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/cat_vae.py
  class CategoricalVAE (line 9) | class CategoricalVAE(BaseVAE):
    method __init__ (line 11) | def __init__(self,
    method encode (line 89) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 105) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 118) | def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:
    method forward (line 134) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 139) | def loss_function(self,
    method sample (line 179) | def sample(self,
    method generate (line 202) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/cvae.py
  class ConditionalVAE (line 8) | class ConditionalVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 83) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 100) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 107) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 119) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 133) | def loss_function(self,
    method sample (line 149) | def sample(self,
    method generate (line 170) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/dfcvae.py
  class DFCVAE (line 9) | class DFCVAE(BaseVAE):
    method __init__ (line 11) | def __init__(self,
    method encode (line 90) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 107) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 120) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 132) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method extract_features (line 142) | def extract_features(self,
    method loss_function (line 163) | def loss_function(self,
    method sample (line 192) | def sample(self,
    method generate (line 210) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/dip_vae.py
  class DIPVAE (line 8) | class DIPVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 78) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 95) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 108) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 120) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 125) | def loss_function(self,
    method sample (line 166) | def sample(self,
    method generate (line 184) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/fvae.py
  class FactorVAE (line 8) | class FactorVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 92) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 109) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 122) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 134) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method permute_latent (line 139) | def permute_latent(self, z: Tensor) -> Tensor:
    method loss_function (line 151) | def loss_function(self,
    method sample (line 203) | def sample(self,
    method generate (line 221) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/gamma_vae.py
  class GammaVAE (line 10) | class GammaVAE(BaseVAE):
    method __init__ (line 12) | def __init__(self,
    method weight_init (line 85) | def weight_init(self):
    method encode (line 92) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 109) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 116) | def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor:
    method h_func (line 137) | def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor:
    method inv_h_func (line 148) | def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor:
    method forward (line 158) | def forward(self, input: Tensor, **kwargs) -> Tensor:
    method I_function (line 168) | def I_function(self, a, b, c, d):
    method vae_gamma_kl_loss (line 171) | def vae_gamma_kl_loss(self, a, b, c, d):
    method loss_function (line 184) | def loss_function(self,
    method sample (line 214) | def sample(self,
    method generate (line 230) | def generate(self, x: Tensor, **kwargs) -> Tensor:
  function init_ (line 239) | def init_(m):

FILE: models/hvae.py
  class HVAE (line 8) | class HVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode_z2 (line 115) | def encode_z2(self, input: Tensor) -> List[Tensor]:
    method encode_z1 (line 132) | def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]:
    method encode (line 145) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 153) | def decode(self, input: Tensor) -> Tensor:
    method reparameterize (line 158) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 170) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 192) | def loss_function(self,
    method sample (line 233) | def sample(self, batch_size:int, current_device: int, **kwargs) -> Ten...
    method generate (line 252) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/info_vae.py
  class InfoVAE (line 8) | class InfoVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 88) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 104) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 111) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 123) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 128) | def loss_function(self,
    method compute_kernel (line 150) | def compute_kernel(self,
    method compute_rbf (line 178) | def compute_rbf(self,
    method compute_inv_mult_quad (line 195) | def compute_inv_mult_quad(self,
    method compute_mmd (line 218) | def compute_mmd(self, z: Tensor) -> Tensor:
    method sample (line 231) | def sample(self,
    method generate (line 249) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/iwae.py
  class IWAE (line 8) | class IWAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 78) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 95) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 111) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 121) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 129) | def loss_function(self,
    method sample (line 162) | def sample(self,
    method generate (line 180) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/joint_vae.py
  class JointVAE (line 9) | class JointVAE(BaseVAE):
    method __init__ (line 12) | def __init__(self,
    method encode (line 111) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 129) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 142) | def reparameterize(self,
    method forward (line 170) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 175) | def loss_function(self,
    method sample (line 236) | def sample(self,
    method generate (line 261) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/logcosh_vae.py
  class LogCoshVAE (line 8) | class LogCoshVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 78) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 95) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 108) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 120) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 125) | def loss_function(self,
    method sample (line 157) | def sample(self,
    method generate (line 175) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/lvae.py
  function conv_out_shape (line 9) | def conv_out_shape(img_size):
  class EncoderBlock (line 12) | class EncoderBlock(nn.Module):
    method __init__ (line 13) | def __init__(self,
    method forward (line 32) | def forward(self, input: Tensor) -> Tensor:
  class LadderBlock (line 43) | class LadderBlock(nn.Module):
    method __init__ (line 44) | def __init__(self,
    method forward (line 55) | def forward(self, z: Tensor) -> Tensor:
  class LVAE (line 62) | class LVAE(BaseVAE):
    method __init__ (line 64) | def __init__(self,
    method encode (line 134) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 151) | def decode(self, z: Tensor, post_params: List) -> Tuple:
    method merge_gauss (line 173) | def merge_gauss(self,
    method compute_kl_divergence (line 186) | def compute_kl_divergence(self, z: Tensor, q_params: Tuple, p_params: ...
    method reparameterize (line 197) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 209) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 218) | def loss_function(self,
    method sample (line 239) | def sample(self,
    method generate (line 264) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/miwae.py
  class MIWAE (line 9) | class MIWAE(BaseVAE):
    method __init__ (line 11) | def __init__(self,
    method encode (line 81) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 98) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 114) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 124) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 132) | def loss_function(self,
    method sample (line 166) | def sample(self,
    method generate (line 184) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/mssim_vae.py
  class MSSIMVAE (line 9) | class MSSIMVAE(BaseVAE):
    method __init__ (line 11) | def __init__(self,
    method encode (line 84) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 101) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 114) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 126) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 131) | def loss_function(self,
    method sample (line 155) | def sample(self,
    method generate (line 173) | def generate(self, x: Tensor, **kwargs) -> Tensor:
  class MSSIM (line 182) | class MSSIM(nn.Module):
    method __init__ (line 184) | def __init__(self,
    method gaussian_window (line 203) | def gaussian_window(self, window_size:int, sigma: float) -> Tensor:
    method create_window (line 208) | def create_window(self, window_size, in_channels):
    method ssim (line 214) | def ssim(self,
    method forward (line 250) | def forward(self, img1: Tensor, img2: Tensor) -> Tensor:

FILE: models/swae.py
  class SWAE (line 9) | class SWAE(BaseVAE):
    method __init__ (line 11) | def __init__(self,
    method encode (line 84) | def encode(self, input: Tensor) -> Tensor:
    method decode (line 99) | def decode(self, z: Tensor) -> Tensor:
    method forward (line 106) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 110) | def loss_function(self,
    method get_random_projections (line 129) | def get_random_projections(self, latent_dim: int, num_samples: int) ->...
    method compute_swd (line 151) | def compute_swd(self,
    method sample (line 181) | def sample(self,
    method generate (line 199) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/twostage_vae.py
  class TwoStageVAE (line 8) | class TwoStageVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 100) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 117) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 130) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 142) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 148) | def loss_function(self,
    method sample (line 172) | def sample(self,
    method generate (line 190) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/vampvae.py
  class VampVAE (line 8) | class VampVAE(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 82) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 99) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 106) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 118) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 123) | def loss_function(self,
    method sample (line 170) | def sample(self,
    method generate (line 188) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/vanilla_vae.py
  class VanillaVAE (line 8) | class VanillaVAE(BaseVAE):
    method __init__ (line 11) | def __init__(self,
    method encode (line 77) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 94) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 107) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 119) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 124) | def loss_function(self,
    method sample (line 148) | def sample(self,
    method generate (line 166) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/vq_vae.py
  class VectorQuantizer (line 7) | class VectorQuantizer(nn.Module):
    method __init__ (line 12) | def __init__(self,
    method forward (line 24) | def forward(self, latents: Tensor) -> Tensor:
  class ResidualLayer (line 57) | class ResidualLayer(nn.Module):
    method __init__ (line 59) | def __init__(self,
    method forward (line 69) | def forward(self, input: Tensor) -> Tensor:
  class VQVAE (line 73) | class VQVAE(BaseVAE):
    method __init__ (line 75) | def __init__(self,
    method encode (line 168) | def encode(self, input: Tensor) -> List[Tensor]:
    method decode (line 178) | def decode(self, z: Tensor) -> Tensor:
    method forward (line 189) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 194) | def loss_function(self,
    method sample (line 213) | def sample(self,
    method generate (line 218) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: models/wae_mmd.py
  class WAE_MMD (line 8) | class WAE_MMD(BaseVAE):
    method __init__ (line 10) | def __init__(self,
    method encode (line 81) | def encode(self, input: Tensor) -> Tensor:
    method decode (line 96) | def decode(self, z: Tensor) -> Tensor:
    method forward (line 103) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
    method loss_function (line 107) | def loss_function(self,
    method compute_kernel (line 125) | def compute_kernel(self,
    method compute_rbf (line 153) | def compute_rbf(self,
    method compute_inv_mult_quad (line 170) | def compute_inv_mult_quad(self,
    method compute_mmd (line 193) | def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor:
    method sample (line 206) | def sample(self,
    method generate (line 224) | def generate(self, x: Tensor, **kwargs) -> Tensor:

FILE: tests/bvae.py
  class TestVAE (line 7) | class TestVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):

FILE: tests/test_betatcvae.py
  class TestBetaTCVAE (line 7) | class TestBetaTCVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 24) | def test_loss(self):
    method test_sample (line 31) | def test_sample(self):
    method test_generate (line 36) | def test_generate(self):

FILE: tests/test_cat_vae.py
  class TestVAE (line 7) | class TestVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):
    method test_sample (line 31) | def test_sample(self):

FILE: tests/test_dfc.py
  class TestDFCVAE (line 7) | class TestDFCVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 18) | def test_forward(self):
    method test_loss (line 25) | def test_loss(self):
    method test_sample (line 32) | def test_sample(self):

FILE: tests/test_dipvae.py
  class TestDIPVAE (line 7) | class TestDIPVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 24) | def test_loss(self):
    method test_sample (line 31) | def test_sample(self):
    method test_generate (line 36) | def test_generate(self):

FILE: tests/test_fvae.py
  class TestFAE (line 7) | class TestFAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 20) | def test_forward(self):
    method test_loss (line 27) | def test_loss(self):
    method test_optim (line 36) | def test_optim(self):
    method test_sample (line 40) | def test_sample(self):

FILE: tests/test_gvae.py
  class TestGammaVAE (line 7) | class TestGammaVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 18) | def test_forward(self):
    method test_loss (line 25) | def test_loss(self):
    method test_sample (line 32) | def test_sample(self):

FILE: tests/test_hvae.py
  class TestHVAE (line 7) | class TestHVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):

FILE: tests/test_iwae.py
  class TestIWAE (line 7) | class TestIWAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):
    method test_sample (line 30) | def test_sample(self):

FILE: tests/test_joint_Vae.py
  class TestVAE (line 7) | class TestVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):
    method test_sample (line 31) | def test_sample(self):

FILE: tests/test_logcosh.py
  class TestVAE (line 7) | class TestVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):

FILE: tests/test_lvae.py
  class TestLVAE (line 7) | class TestLVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 24) | def test_loss(self):
    method test_sample (line 31) | def test_sample(self):

FILE: tests/test_miwae.py
  class TestMIWAE (line 7) | class TestMIWAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):
    method test_sample (line 30) | def test_sample(self):
    method test_generate (line 35) | def test_generate(self):

FILE: tests/test_mssimvae.py
  class TestMSSIMVAE (line 7) | class TestMSSIMVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 18) | def test_forward(self):
    method test_loss (line 25) | def test_loss(self):
    method test_sample (line 32) | def test_sample(self):

FILE: tests/test_swae.py
  class TestSWAE (line 7) | class TestSWAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 12) | def test_summary(self):
    method test_forward (line 16) | def test_forward(self):
    method test_loss (line 22) | def test_loss(self):

FILE: tests/test_vae.py
  class TestVAE (line 7) | class TestVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):

FILE: tests/test_vq_vae.py
  class TestVQVAE (line 7) | class TestVQVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 24) | def test_loss(self):
    method test_sample (line 31) | def test_sample(self):
    method test_generate (line 36) | def test_generate(self):

FILE: tests/test_wae.py
  class TestWAE (line 7) | class TestWAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 12) | def test_summary(self):
    method test_forward (line 16) | def test_forward(self):
    method test_loss (line 22) | def test_loss(self):

FILE: tests/text_cvae.py
  class TestCVAE (line 6) | class TestCVAE(unittest.TestCase):
    method setUp (line 8) | def setUp(self) -> None:
    method test_forward (line 12) | def test_forward(self):
    method test_loss (line 19) | def test_loss(self):

FILE: tests/text_vamp.py
  class TestVVAE (line 7) | class TestVVAE(unittest.TestCase):
    method setUp (line 9) | def setUp(self) -> None:
    method test_summary (line 13) | def test_summary(self):
    method test_forward (line 17) | def test_forward(self):
    method test_loss (line 23) | def test_loss(self):

FILE: utils.py
  function data_loader (line 8) | def data_loader(fn):
Condensed preview — 82 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (254K chars).
[
  {
    "path": ".gitignore",
    "chars": 81,
    "preview": "\nData/\nlogs/\n\nVanillaVAE/version_0/\n\n__pycache__/\n.ipynb_checkpoints/\n\nRun.ipynb\n"
  },
  {
    "path": ".idea/.gitignore",
    "chars": 39,
    "preview": "# Default ignored files\n/workspace.xml\n"
  },
  {
    "path": ".idea/PyTorch-VAE.iml",
    "chars": 518,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager"
  },
  {
    "path": ".idea/inspectionProfiles/profiles_settings.xml",
    "chars": 174,
    "preview": "<component name=\"InspectionProjectProfileManager\">\n  <settings>\n    <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n"
  },
  {
    "path": ".idea/misc.xml",
    "chars": 299,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectRootManager\" version=\"2\" project-"
  },
  {
    "path": ".idea/modules.xml",
    "chars": 453,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n   "
  },
  {
    "path": ".idea/vcs.xml",
    "chars": 247,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping dire"
  },
  {
    "path": "LICENSE.md",
    "chars": 11455,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 14526,
    "preview": "<h1 align=\"center\">\n  <b>PyTorch VAE</b><br>\n</h1>\n\n<p align=\"center\">\n      <a href=\"https://www.python.org/\">\n        "
  },
  {
    "path": "configs/bbvae.yaml",
    "chars": 492,
    "preview": "model_params:\n  name: 'BetaVAE'\n  in_channels: 3\n  latent_dim: 128\n  loss_type: 'B'\n  gamma: 10.0\n  max_capacity: 25\n  C"
  },
  {
    "path": "configs/betatc_vae.yaml",
    "chars": 455,
    "preview": "model_params:\n  name: 'BetaTCVAE'\n  in_channels: 3\n  latent_dim: 10\n  anneal_steps: 10000\n  alpha: 1.\n  beta:  6.\n  gamm"
  },
  {
    "path": "configs/bhvae.yaml",
    "chars": 423,
    "preview": "model_params:\n  name: 'BetaVAE'\n  in_channels: 3\n  latent_dim: 128\n  loss_type: 'H'\n  beta: 10.\n\ndata_params:\n  data_pat"
  },
  {
    "path": "configs/cat_vae.yaml",
    "chars": 508,
    "preview": "model_params:\n  name: 'CategoricalVAE'\n  in_channels: 3\n  latent_dim: 512\n  categorical_dim: 40\n  temperature: 0.5\n  ann"
  },
  {
    "path": "configs/cvae.yaml",
    "chars": 425,
    "preview": "model_params:\n  name: 'ConditionalVAE'\n  in_channels: 3\n  num_classes: 40\n  latent_dim: 128\n\ndata_params:\n  data_path: \""
  },
  {
    "path": "configs/dfc_vae.yaml",
    "chars": 392,
    "preview": "model_params:\n  name: 'DFCVAE'\n  in_channels: 3\n  latent_dim: 128\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_size:"
  },
  {
    "path": "configs/dip_vae.yaml",
    "chars": 449,
    "preview": "model_params:\n  name: 'DIPVAE'\n  in_channels: 3\n  latent_dim: 128\n  lambda_diag: 0.05\n  lambda_offdiag: 0.1\n\n\ndata_param"
  },
  {
    "path": "configs/factorvae.yaml",
    "chars": 515,
    "preview": "model_params:\n  name: 'FactorVAE'\n  in_channels: 3\n  latent_dim: 128\n  gamma: 6.4\n\ndata_params:\n  data_path: \"Data/\"\n  t"
  },
  {
    "path": "configs/gammavae.yaml",
    "chars": 479,
    "preview": "model_params:\n  name: 'GammaVAE'\n  in_channels: 3\n  latent_dim: 128\n  gamma_shape: 8.\n  prior_shape: 2.\n  prior_rate: 1."
  },
  {
    "path": "configs/hvae.yaml",
    "chars": 434,
    "preview": "model_params:\n  name: 'HVAE'\n  in_channels: 3\n  latent1_dim: 64\n  latent2_dim: 64\n  pseudo_input_size: 128\n\ndata_params:"
  },
  {
    "path": "configs/infovae.yaml",
    "chars": 569,
    "preview": "model_params:\n  name: 'InfoVAE'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 110  # MMD weight\n  kernel_type: 'imq'\n"
  },
  {
    "path": "configs/iwae.yaml",
    "chars": 405,
    "preview": "model_params:\n  name: 'IWAE'\n  in_channels: 3\n  latent_dim: 128\n  num_samples: 5\n\ndata_params:\n  data_path: \"Data/\"\n  tr"
  },
  {
    "path": "configs/joint_vae.yaml",
    "chars": 718,
    "preview": "model_params:\n  name: 'JointVAE'\n  in_channels: 3\n  latent_dim: 512\n  categorical_dim: 40\n  latent_min_capacity: 0.0\n  l"
  },
  {
    "path": "configs/logcosh_vae.yaml",
    "chars": 427,
    "preview": "model_params:\n  name: 'LogCoshVAE'\n  in_channels: 3\n  latent_dim: 128\n  alpha: 10.0\n  beta: 1.0\n\ndata_params:\n  data_pat"
  },
  {
    "path": "configs/lvae.yaml",
    "chars": 439,
    "preview": "model_params:\n  name: 'LVAE'\n  in_channels: 3\n  latent_dims: [4,8,16,32,128]\n  hidden_dims: [32, 64,128, 256, 512]\n\ndata"
  },
  {
    "path": "configs/miwae.yaml",
    "chars": 427,
    "preview": "model_params:\n  name: 'MIWAE'\n  in_channels: 3\n  latent_dim: 128\n  num_samples: 5\n  num_estimates: 3\n\ndata_params:\n  dat"
  },
  {
    "path": "configs/mssim_vae.yaml",
    "chars": 396,
    "preview": "model_params:\n  name: 'MSSIMVAE'\n  in_channels: 3\n  latent_dim: 128\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_siz"
  },
  {
    "path": "configs/swae.yaml",
    "chars": 495,
    "preview": "model_params:\n  name: 'SWAE'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 100\n  wasserstein_deg: 2.0\n  num_projectio"
  },
  {
    "path": "configs/vae.yaml",
    "chars": 405,
    "preview": "model_params:\n  name: 'VanillaVAE'\n  in_channels: 3\n  latent_dim: 128\n\n\ndata_params:\n  data_path: \"Data/\"\n  train_batch_"
  },
  {
    "path": "configs/vampvae.yaml",
    "chars": 393,
    "preview": "model_params:\n  name: 'VampVAE'\n  in_channels: 3\n  latent_dim: 128\n\nexp_params:\n  dataset: celeba\n  data_path: \"../../sh"
  },
  {
    "path": "configs/vq_vae.yaml",
    "chars": 441,
    "preview": "model_params:\n  name: 'VQVAE'\n  in_channels: 3\n  embedding_dim: 64\n  num_embeddings: 512\n  img_size: 64\n  beta: 0.25\n\nda"
  },
  {
    "path": "configs/wae_mmd_imq.yaml",
    "chars": 449,
    "preview": "model_params:\n  name: 'WAE_MMD'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 100\n  kernel_type: 'imq'\n\ndata_params:\n"
  },
  {
    "path": "configs/wae_mmd_rbf.yaml",
    "chars": 450,
    "preview": "model_params:\n  name: 'WAE_MMD'\n  in_channels: 3\n  latent_dim: 128\n  reg_weight: 5000\n  kernel_type: 'rbf'\n\ndata_params:"
  },
  {
    "path": "dataset.py",
    "chars": 6315,
    "preview": "import os\nimport torch\nfrom torch import Tensor\nfrom pathlib import Path\nfrom typing import List, Optional, Sequence, Un"
  },
  {
    "path": "experiment.py",
    "chars": 4997,
    "preview": "import os\nimport math\nimport torch\nfrom torch import optim\nfrom models import BaseVAE\nfrom models.types_ import *\nfrom u"
  },
  {
    "path": "models/__init__.py",
    "chars": 1356,
    "preview": "from .base import *\nfrom .vanilla_vae import *\nfrom .gamma_vae import *\nfrom .beta_vae import *\nfrom .wae_mmd import *\nf"
  },
  {
    "path": "models/base.py",
    "chars": 733,
    "preview": "from .types_ import *\nfrom torch import nn\nfrom abc import abstractmethod\n\nclass BaseVAE(nn.Module):\n    \n    def __init"
  },
  {
    "path": "models/beta_vae.py",
    "chars": 6242,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/betatc_vae.py",
    "chars": 8558,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/cat_vae.py",
    "chars": 7531,
    "preview": "import torch\nimport numpy as np\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfro"
  },
  {
    "path": "models/cvae.py",
    "chars": 6079,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/dfcvae.py",
    "chars": 7315,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torchvision.models import vgg19_bn\nfrom torch.nn impor"
  },
  {
    "path": "models/dip_vae.py",
    "chars": 6597,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/fvae.py",
    "chars": 8251,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/gamma_vae.py",
    "chars": 8650,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.distributions import Gamma\nfrom torch.nn import "
  },
  {
    "path": "models/hvae.py",
    "chars": 9396,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/info_vae.py",
    "chars": 8538,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/iwae.py",
    "chars": 6694,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/joint_vae.py",
    "chars": 9837,
    "preview": "import torch\nimport numpy as np\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfro"
  },
  {
    "path": "models/logcosh_vae.py",
    "chars": 6292,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom models import BaseVAE\nfrom torch import nn\nfrom .types_ import *\n\n\ncla"
  },
  {
    "path": "models/lvae.py",
    "chars": 9666,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/miwae.py",
    "chars": 6969,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/mssim_vae.py",
    "chars": 9644,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/swae.py",
    "chars": 7340,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch import dist"
  },
  {
    "path": "models/twostage_vae.py",
    "chars": 6867,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/types_.py",
    "chars": 133,
    "preview": "from typing import List, Callable, Union, Any, TypeVar, Tuple\n# from torch import tensor as Tensor\n\nTensor = TypeVar('to"
  },
  {
    "path": "models/vampvae.py",
    "chars": 6760,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/vanilla_vae.py",
    "chars": 5757,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/vq_vae.py",
    "chars": 7576,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "models/wae_mmd.py",
    "chars": 7427,
    "preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
  },
  {
    "path": "requirements.txt",
    "chars": 108,
    "preview": "pytorch-lightning==1.5.6\nPyYAML==6.0\ntensorboard>=2.2.0\ntorch>=1.6.1\ntorchsummary==1.5.1\ntorchvision>=0.10.1"
  },
  {
    "path": "run.py",
    "chars": 2216,
    "preview": "import os\nimport yaml\nimport argparse\nimport numpy as np\nfrom pathlib import Path\nfrom models import *\nfrom experiment i"
  },
  {
    "path": "tests/bvae.py",
    "chars": 846,
    "preview": "import torch\nimport unittest\nfrom models import BetaVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestCa"
  },
  {
    "path": "tests/test_betatcvae.py",
    "chars": 1173,
    "preview": "import torch\nimport unittest\nfrom models import BetaTCVAE\nfrom torchsummary import summary\n\n\nclass TestBetaTCVAE(unittes"
  },
  {
    "path": "tests/test_cat_vae.py",
    "chars": 951,
    "preview": "import torch\nimport unittest\nfrom models import GumbelVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.Test"
  },
  {
    "path": "tests/test_dfc.py",
    "chars": 914,
    "preview": "import torch\nimport unittest\nfrom models import DFCVAE\nfrom torchsummary import summary\n\n\nclass TestDFCVAE(unittest.Test"
  },
  {
    "path": "tests/test_dipvae.py",
    "chars": 1145,
    "preview": "import torch\nimport unittest\nfrom models import DIPVAE\nfrom torchsummary import summary\n\n\nclass TestDIPVAE(unittest.Test"
  },
  {
    "path": "tests/test_fvae.py",
    "chars": 1368,
    "preview": "import torch\nimport unittest\nfrom models import FactorVAE\nfrom torchsummary import summary\n\n\nclass TestFAE(unittest.Test"
  },
  {
    "path": "tests/test_gvae.py",
    "chars": 920,
    "preview": "import torch\nimport unittest\nfrom models import GammaVAE\nfrom torchsummary import summary\n\n\nclass TestGammaVAE(unittest."
  },
  {
    "path": "tests/test_hvae.py",
    "chars": 840,
    "preview": "import torch\nimport unittest\nfrom models import HVAE\nfrom torchsummary import summary\n\n\nclass TestHVAE(unittest.TestCase"
  },
  {
    "path": "tests/test_iwae.py",
    "chars": 904,
    "preview": "import torch\nimport unittest\nfrom models import IWAE\nfrom torchsummary import summary\n\n\nclass TestIWAE(unittest.TestCase"
  },
  {
    "path": "tests/test_joint_Vae.py",
    "chars": 958,
    "preview": "import torch\nimport unittest\nfrom models import JointVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestC"
  },
  {
    "path": "tests/test_logcosh.py",
    "chars": 832,
    "preview": "import torch\nimport unittest\nfrom models import LogCoshVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.Tes"
  },
  {
    "path": "tests/test_lvae.py",
    "chars": 977,
    "preview": "import torch\nimport unittest\nfrom models import LVAE\nfrom torchsummary import summary\n\n\nclass TestLVAE(unittest.TestCase"
  },
  {
    "path": "tests/test_miwae.py",
    "chars": 1057,
    "preview": "import torch\nimport unittest\nfrom models import MIWAE\nfrom torchsummary import summary\n\n\nclass TestMIWAE(unittest.TestCa"
  },
  {
    "path": "tests/test_mssimvae.py",
    "chars": 920,
    "preview": "import torch\nimport unittest\nfrom models import MSSIMVAE\nfrom torchsummary import summary\n\n\nclass TestMSSIMVAE(unittest."
  },
  {
    "path": "tests/test_swae.py",
    "chars": 782,
    "preview": "import torch\nimport unittest\nfrom models import SWAE\nfrom torchsummary import summary\n\n\nclass TestSWAE(unittest.TestCase"
  },
  {
    "path": "tests/test_vae.py",
    "chars": 823,
    "preview": "import torch\nimport unittest\nfrom models import VanillaVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.Tes"
  },
  {
    "path": "tests/test_vq_vae.py",
    "chars": 1147,
    "preview": "import torch\nimport unittest\nfrom models import VQVAE\nfrom torchsummary import summary\n\n\nclass TestVQVAE(unittest.TestCa"
  },
  {
    "path": "tests/test_wae.py",
    "chars": 787,
    "preview": "import torch\nimport unittest\nfrom models import WAE_MMD\nfrom torchsummary import summary\n\n\nclass TestWAE(unittest.TestCa"
  },
  {
    "path": "tests/text_cvae.py",
    "chars": 705,
    "preview": "import torch\nimport unittest\nfrom models import CVAE\n\n\nclass TestCVAE(unittest.TestCase):\n\n    def setUp(self) -> None:\n"
  },
  {
    "path": "tests/text_vamp.py",
    "chars": 844,
    "preview": "import torch\nimport unittest\nfrom models import VampVAE\nfrom torchsummary import summary\n\n\nclass TestVVAE(unittest.TestC"
  },
  {
    "path": "utils.py",
    "chars": 622,
    "preview": "import pytorch_lightning as pl\n\n\n## Utils to handle newer PyTorch Lightning changes from version 0.6\n## ================"
  }
]

About this extraction

This page contains the full source code of the AntixK/PyTorch-VAE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 82 files (236.1 KB), approximately 61.2k tokens, and a symbol index with 386 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!