Showing preview only (254K chars total). Download the full file or copy to clipboard to get everything.
Repository: AntixK/PyTorch-VAE
Branch: master
Commit: a6896b944c91
Files: 82
Total size: 236.1 KB
Directory structure:
gitextract_vk1fb46d/
├── .gitignore
├── .idea/
│ ├── .gitignore
│ ├── PyTorch-VAE.iml
│ ├── inspectionProfiles/
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ └── vcs.xml
├── LICENSE.md
├── README.md
├── configs/
│ ├── bbvae.yaml
│ ├── betatc_vae.yaml
│ ├── bhvae.yaml
│ ├── cat_vae.yaml
│ ├── cvae.yaml
│ ├── dfc_vae.yaml
│ ├── dip_vae.yaml
│ ├── factorvae.yaml
│ ├── gammavae.yaml
│ ├── hvae.yaml
│ ├── infovae.yaml
│ ├── iwae.yaml
│ ├── joint_vae.yaml
│ ├── logcosh_vae.yaml
│ ├── lvae.yaml
│ ├── miwae.yaml
│ ├── mssim_vae.yaml
│ ├── swae.yaml
│ ├── vae.yaml
│ ├── vampvae.yaml
│ ├── vq_vae.yaml
│ ├── wae_mmd_imq.yaml
│ └── wae_mmd_rbf.yaml
├── dataset.py
├── experiment.py
├── models/
│ ├── __init__.py
│ ├── base.py
│ ├── beta_vae.py
│ ├── betatc_vae.py
│ ├── cat_vae.py
│ ├── cvae.py
│ ├── dfcvae.py
│ ├── dip_vae.py
│ ├── fvae.py
│ ├── gamma_vae.py
│ ├── hvae.py
│ ├── info_vae.py
│ ├── iwae.py
│ ├── joint_vae.py
│ ├── logcosh_vae.py
│ ├── lvae.py
│ ├── miwae.py
│ ├── mssim_vae.py
│ ├── swae.py
│ ├── twostage_vae.py
│ ├── types_.py
│ ├── vampvae.py
│ ├── vanilla_vae.py
│ ├── vq_vae.py
│ └── wae_mmd.py
├── requirements.txt
├── run.py
├── tests/
│ ├── bvae.py
│ ├── test_betatcvae.py
│ ├── test_cat_vae.py
│ ├── test_dfc.py
│ ├── test_dipvae.py
│ ├── test_fvae.py
│ ├── test_gvae.py
│ ├── test_hvae.py
│ ├── test_iwae.py
│ ├── test_joint_Vae.py
│ ├── test_logcosh.py
│ ├── test_lvae.py
│ ├── test_miwae.py
│ ├── test_mssimvae.py
│ ├── test_swae.py
│ ├── test_vae.py
│ ├── test_vq_vae.py
│ ├── test_wae.py
│ ├── text_cvae.py
│ └── text_vamp.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
Data/
logs/
VanillaVAE/version_0/
__pycache__/
.ipynb_checkpoints/
Run.ipynb
================================================
FILE: .idea/.gitignore
================================================
# Default ignored files
/workspace.xml
================================================
FILE: .idea/PyTorch-VAE.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.7 (main)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="module" module-name="somic_research" />
</component>
<component name="ReSTService">
<option name="DOC_DIR" value="$MODULE_DIR$/../Project_S/somic_research/docs" />
</component>
</module>
================================================
FILE: .idea/inspectionProfiles/profiles_settings.xml
================================================
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
================================================
FILE: .idea/misc.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (main)" project-jdk-type="Python SDK" />
<component name="PyCharmProfessionalAdvertiser">
<option name="shown" value="true" />
</component>
</project>
================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/PyTorch-VAE.iml" filepath="$PROJECT_DIR$/.idea/PyTorch-VAE.iml" />
<module fileurl="file://$PROJECT_DIR$/../Project_S/somic_research/.idea/somic_research.iml" filepath="$PROJECT_DIR$/../Project_S/somic_research/.idea/somic_research.iml" />
</modules>
</component>
</project>
================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/../Project_S/somic_research" vcs="Git" />
</component>
</project>
================================================
FILE: LICENSE.md
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Copyright Anand Krishnamoorthy Subramanian 2020
anandkrish894@gmail.com
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<h1 align="center">
<b>PyTorch VAE</b><br>
</h1>
<p align="center">
<a href="https://www.python.org/">
<img src="https://img.shields.io/badge/Python-3.5-ff69b4.svg" /></a>
<a href= "https://pytorch.org/">
<img src="https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg" /></a>
<a href= "https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md">
<img src="https://img.shields.io/badge/license-Apache2.0-blue.svg" /></a>
<a href= "https://twitter.com/intent/tweet?text=PyTorch-VAE:%20Collection%20of%20VAE%20models%20in%20PyTorch.&url=https://github.com/AntixK/PyTorch-VAE">
<img src="https://img.shields.io/twitter/url/https/shields.io.svg?style=social" /></a>
</p>
**Update 22/12/2021:** Added support for PyTorch Lightning 1.5.6 version and cleaned up the code.
A collection of Variational AutoEncoders (VAEs) implemented in pytorch with focus on reproducibility. The aim of this project is to provide
a quick and simple working example for many of the cool VAE models out there. All the models are trained on the [CelebA dataset](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates
a radically different architecture (Ex. VQ VAE uses Residual layers and no Batch-Norm, unlike other models).
Here are the [results](https://github.com/AntixK/PyTorch-VAE/blob/master/README.md#--results) of each model.
### Requirements
- Python >= 3.5
- PyTorch >= 1.3
- Pytorch Lightning >= 0.6.0 ([GitHub Repo](https://github.com/PyTorchLightning/pytorch-lightning/tree/deb1581e26b7547baf876b7a94361e60bb200d32))
- CUDA enabled computing device
### Installation
```
$ git clone https://github.com/AntixK/PyTorch-VAE
$ cd PyTorch-VAE
$ pip install -r requirements.txt
```
### Usage
```
$ cd PyTorch-VAE
$ python run.py -c configs/<config-file-name.yaml>
```
**Config file template**
```yaml
model_params:
name: "<name of VAE model>"
in_channels: 3
latent_dim:
. # Other parameters required by the model
.
.
data_params:
data_path: "<path to the celebA dataset>"
train_batch_size: 64 # Better to have a square number
val_batch_size: 64
patch_size: 64 # Models are designed to work for this size
num_workers: 4
exp_params:
manual_seed: 1265
LR: 0.005
weight_decay:
. # Other arguments required for training, like scheduler etc.
.
.
trainer_params:
gpus: 1
max_epochs: 100
gradient_clip_val: 1.5
.
.
.
logging_params:
save_dir: "logs/"
name: "<experiment name>"
```
**View TensorBoard Logs**
```
$ cd logs/<experiment name>/version_<the version you want>
$ tensorboard --logdir .
```
**Note:** The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the [file](https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing) from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference.
----
<h2 align="center">
<b>Results</b><br>
</h2>
| Model | Paper |Reconstruction | Samples |
|------------------------------------------------------------------------|--------------------------------------------------|---------------|---------|
| VAE ([Code][vae_code], [Config][vae_config]) |[Link](https://arxiv.org/abs/1312.6114) | ![][2] | ![][1] |
| Conditional VAE ([Code][cvae_code], [Config][cvae_config]) |[Link](https://openreview.net/forum?id=rJWXGDWd-H)| ![][16] | ![][15] |
| WAE - MMD (RBF Kernel) ([Code][wae_code], [Config][wae_rbf_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][4] | ![][3] |
| WAE - MMD (IMQ Kernel) ([Code][wae_code], [Config][wae_imq_config]) |[Link](https://arxiv.org/abs/1711.01558) | ![][6] | ![][5] |
| Beta-VAE ([Code][bvae_code], [Config][bbvae_config]) |[Link](https://openreview.net/forum?id=Sy2fzU9gl) | ![][8] | ![][7] |
| Disentangled Beta-VAE ([Code][bvae_code], [Config][bhvae_config]) |[Link](https://arxiv.org/abs/1804.03599) | ![][22] | ![][21] |
| Beta-TC-VAE ([Code][btcvae_code], [Config][btcvae_config]) |[Link](https://arxiv.org/abs/1802.04942) | ![][34] | ![][33] |
| IWAE (*K = 5*) ([Code][iwae_code], [Config][iwae_config]) |[Link](https://arxiv.org/abs/1509.00519) | ![][10] | ![][9] |
| MIWAE (*K = 5, M = 3*) ([Code][miwae_code], [Config][miwae_config]) |[Link](https://arxiv.org/abs/1802.04537) | ![][30] | ![][29] |
| DFCVAE ([Code][dfcvae_code], [Config][dfcvae_config]) |[Link](https://arxiv.org/abs/1610.00291) | ![][12] | ![][11] |
| MSSIM VAE ([Code][mssimvae_code], [Config][mssimvae_config]) |[Link](https://arxiv.org/abs/1511.06409) | ![][14] | ![][13] |
| Categorical VAE ([Code][catvae_code], [Config][catvae_config]) |[Link](https://arxiv.org/abs/1611.01144) | ![][18] | ![][17] |
| Joint VAE ([Code][jointvae_code], [Config][jointvae_config]) |[Link](https://arxiv.org/abs/1804.00104) | ![][20] | ![][19] |
| Info VAE ([Code][infovae_code], [Config][infovae_config]) |[Link](https://arxiv.org/abs/1706.02262) | ![][24] | ![][23] |
| LogCosh VAE ([Code][logcoshvae_code], [Config][logcoshvae_config]) |[Link](https://openreview.net/forum?id=rkglvsC9Ym)| ![][26] | ![][25] |
| SWAE (200 Projections) ([Code][swae_code], [Config][swae_config]) |[Link](https://arxiv.org/abs/1804.01947) | ![][28] | ![][27] |
| VQ-VAE (*K = 512, D = 64*) ([Code][vqvae_code], [Config][vqvae_config])|[Link](https://arxiv.org/abs/1711.00937) | ![][31] | **N/A** |
| DIP VAE ([Code][dipvae_code], [Config][dipvae_config]) |[Link](https://arxiv.org/abs/1711.00848) | ![][36] | ![][35] |
<!-- | Gamma VAE |[Link](https://arxiv.org/abs/1610.05683) | ![][16] | ![][15] |-->
<!--
### TODO
- [x] VanillaVAE
- [x] Beta VAE
- [x] DFC VAE
- [x] MSSIM VAE
- [x] IWAE
- [x] MIWAE
- [x] WAE-MMD
- [x] Conditional VAE- [ ] PixelVAE
- [x] Categorical VAE (Gumbel-Softmax VAE)
- [x] Joint VAE
- [x] Disentangled beta-VAE
- [x] InfoVAE
- [x] LogCosh VAE
- [x] SWAE
- [x] VQVAE
- [x] Beta TC-VAE
- [x] DIP VAE
- [ ] Ladder VAE (Doesn't work well)
- [ ] Gamma VAE (Doesn't work well)
- [ ] Vamp VAE (Doesn't work well)
-->
### Contributing
If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file,
I would be happy to include your result (along with your config file) in this repo, citing your name 😊.
Additionally, if you would like to contribute some models, please submit a PR.
### License
**Apache License 2.0**
| Permissions | Limitations | Conditions |
|------------------|-------------------|----------------------------------|
| ✔️ Commercial use | ❌ Trademark use | ⓘ License and copyright notice |
| ✔️ Modification | ❌ Liability | ⓘ State changes |
| ✔️ Distribution | ❌ Warranty | |
| ✔️ Patent use | | |
| ✔️ Private use | | |
### Citation
```
@misc{Subramanian2020,
author = {Subramanian, A.K},
title = {PyTorch-VAE},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/AntixK/PyTorch-VAE}}
}
```
-----------
[vae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
[cvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cvae.py
[bvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
[btcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/betatc_vae.py
[wae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/wae_mmd.py
[iwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/iwae.py
[miwae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/miwae.py
[swae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/swae.py
[jointvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/joint_vae.py
[dfcvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dfcvae.py
[mssimvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/mssim_vae.py
[logcoshvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/logcosh_vae.py
[catvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/cat_vae.py
[infovae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/info_vae.py
[vqvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
[dipvae_code]: https://github.com/AntixK/PyTorch-VAE/blob/master/models/dip_vae.py
[vae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vae.yaml
[cvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cvae.yaml
[bbvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bbvae.yaml
[bhvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/bhvae.yaml
[btcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/betatc_vae.yaml
[wae_rbf_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_rbf.yaml
[wae_imq_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/wae_mmd_imq.yaml
[iwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/iwae.yaml
[miwae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/miwae.yaml
[swae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/swae.yaml
[jointvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/joint_vae.yaml
[dfcvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dfc_vae.yaml
[mssimvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/mssim_vae.yaml
[logcoshvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/logcosh_vae.yaml
[catvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/cat_vae.yaml
[infovae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/infovae.yaml
[vqvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/vq_vae.yaml
[dipvae_config]: https://github.com/AntixK/PyTorch-VAE/blob/master/configs/dip_vae.yaml
[1]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/Vanilla%20VAE_25.png
[2]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_Vanilla%20VAE_25.png
[3]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_RBF_18.png
[4]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_RBF_19.png
[5]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/WAE_IMQ_15.png
[6]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_WAE_IMQ_15.png
[7]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_H_20.png
[8]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_H_20.png
[9]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/IWAE_19.png
[10]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_IWAE_19.png
[11]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DFCVAE_49.png
[12]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DFCVAE_49.png
[13]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MSSIMVAE_29.png
[14]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MSSIMVAE_29.png
[15]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/ConditionalVAE_20.png
[16]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_ConditionalVAE_20.png
[17]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/CategoricalVAE_49.png
[18]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_CategoricalVAE_49.png
[19]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/JointVAE_49.png
[20]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_JointVAE_49.png
[21]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaVAE_B_35.png
[22]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaVAE_B_35.png
[23]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/InfoVAE_31.png
[24]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_InfoVAE_31.png
[25]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/LogCoshVAE_49.png
[26]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_LogCoshVAE_49.png
[27]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/SWAE_49.png
[28]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_SWAE_49.png
[29]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/MIWAE_29.png
[30]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_MIWAE_29.png
[31]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_VQVAE_29.png
[33]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/BetaTCVAE_49.png
[34]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_BetaTCVAE_49.png
[35]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/DIPVAE_83.png
[36]: https://github.com/AntixK/PyTorch-VAE/blob/master/assets/recons_DIPVAE_83.png
[python-image]: https://img.shields.io/badge/Python-3.5-ff69b4.svg
[python-url]: https://www.python.org/
[pytorch-image]: https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg
[pytorch-url]: https://pytorch.org/
[twitter-image]:https://img.shields.io/twitter/url/https/shields.io.svg?style=social
[twitter-url]:https://twitter.com/intent/tweet?text=Neural%20Blocks-Easy%20to%20use%20neural%20net%20blocks%20for%20fast%20prototyping.&url=https://github.com/AntixK/NeuralBlocks
[license-image]:https://img.shields.io/badge/license-Apache2.0-blue.svg
[license-url]:https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md
================================================
FILE: configs/bbvae.yaml
================================================
model_params:
name: 'BetaVAE'
in_channels: 3
latent_dim: 128
loss_type: 'B'
gamma: 10.0
max_capacity: 25
Capacity_max_iter: 10000
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
manual_seed: 1265
name: 'BetaVAE'
================================================
FILE: configs/betatc_vae.yaml
================================================
model_params:
name: 'BetaTCVAE'
in_channels: 3
latent_dim: 10
anneal_steps: 10000
alpha: 1.
beta: 6.
gamma: 1.
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: 'BetaTCVAE'
================================================
FILE: configs/bhvae.yaml
================================================
model_params:
name: 'BetaVAE'
in_channels: 3
latent_dim: 128
loss_type: 'H'
beta: 10.
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: 'BetaVAE'
================================================
FILE: configs/cat_vae.yaml
================================================
model_params:
name: 'CategoricalVAE'
in_channels: 3
latent_dim: 512
categorical_dim: 40
temperature: 0.5
anneal_rate: 0.00003
anneal_interval: 100
alpha: 1.0
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "CategoricalVAE"
================================================
FILE: configs/cvae.yaml
================================================
model_params:
name: 'ConditionalVAE'
in_channels: 3
num_classes: 40
latent_dim: 128
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "ConditionalVAE"
================================================
FILE: configs/dfc_vae.yaml
================================================
model_params:
name: 'DFCVAE'
in_channels: 3
latent_dim: 128
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "DFCVAE"
================================================
FILE: configs/dip_vae.yaml
================================================
model_params:
name: 'DIPVAE'
in_channels: 3
latent_dim: 128
lambda_diag: 0.05
lambda_offdiag: 0.1
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.001
weight_decay: 0.0
scheduler_gamma: 0.97
kld_weight: 1
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "DIPVAE"
manual_seed: 1265
================================================
FILE: configs/factorvae.yaml
================================================
model_params:
name: 'FactorVAE'
in_channels: 3
latent_dim: 128
gamma: 6.4
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
submodel: 'discriminator'
retain_first_backpass: True
LR: 0.005
weight_decay: 0.0
LR_2: 0.005
scheduler_gamma_2: 0.95
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "FactorVAE"
================================================
FILE: configs/gammavae.yaml
================================================
model_params:
name: 'GammaVAE'
in_channels: 3
latent_dim: 128
gamma_shape: 8.
prior_shape: 2.
prior_rate: 1.
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.003
weight_decay: 0.00005
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
gradient_clip_val: 0.8
logging_params:
save_dir: "logs/"
name: "GammaVAE"
================================================
FILE: configs/hvae.yaml
================================================
model_params:
name: 'HVAE'
in_channels: 3
latent1_dim: 64
latent2_dim: 64
pseudo_input_size: 128
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "VampVAE"
================================================
FILE: configs/infovae.yaml
================================================
model_params:
name: 'InfoVAE'
in_channels: 3
latent_dim: 128
reg_weight: 110 # MMD weight
kernel_type: 'imq'
alpha: -9.0 # KLD weight
beta: 10.5 # Reconstruction weight
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
gradient_clip_val: 0.8
logging_params:
save_dir: "logs/"
name: "InfoVAE"
manual_seed: 1265
================================================
FILE: configs/iwae.yaml
================================================
model_params:
name: 'IWAE'
in_channels: 3
latent_dim: 128
num_samples: 5
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.007
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "IWAE"
================================================
FILE: configs/joint_vae.yaml
================================================
model_params:
name: 'JointVAE'
in_channels: 3
latent_dim: 512
categorical_dim: 40
latent_min_capacity: 0.0
latent_max_capacity: 20.0
latent_gamma: 10.
latent_num_iter: 25000
categorical_min_capacity: 0.0
categorical_max_capacity: 20.0
categorical_gamma: 10.
categorical_num_iter: 25000
temperature: 0.5
anneal_rate: 0.00003
anneal_interval: 100
alpha: 10.0
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "JointVAE"
================================================
FILE: configs/logcosh_vae.yaml
================================================
model_params:
name: 'LogCoshVAE'
in_channels: 3
latent_dim: 128
alpha: 10.0
beta: 1.0
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.97
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "LogCoshVAE"
================================================
FILE: configs/lvae.yaml
================================================
model_params:
name: 'LVAE'
in_channels: 3
latent_dims: [4,8,16,32,128]
hidden_dims: [32, 64,128, 256, 512]
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "LVAE"
================================================
FILE: configs/miwae.yaml
================================================
model_params:
name: 'MIWAE'
in_channels: 3
latent_dim: 128
num_samples: 5
num_estimates: 3
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "MIWAE"
================================================
FILE: configs/mssim_vae.yaml
================================================
model_params:
name: 'MSSIMVAE'
in_channels: 3
latent_dim: 128
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "MSSIMVAE"
================================================
FILE: configs/swae.yaml
================================================
model_params:
name: 'SWAE'
in_channels: 3
latent_dim: 128
reg_weight: 100
wasserstein_deg: 2.0
num_projections: 200
projection_dist: "normal" #"cauchy"
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "SWAE"
================================================
FILE: configs/vae.yaml
================================================
model_params:
name: 'VanillaVAE'
in_channels: 3
latent_dim: 128
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 100
logging_params:
save_dir: "logs/"
name: "VanillaVAE"
================================================
FILE: configs/vampvae.yaml
================================================
model_params:
name: 'VampVAE'
in_channels: 3
latent_dim: 128
exp_params:
dataset: celeba
data_path: "../../shared/Data/"
img_size: 64
batch_size: 144 # Better to have a square number
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
trainer_params:
gpus: 1
max_nb_epochs: 50
max_epochs: 50
logging_params:
save_dir: "logs/"
name: "VampVAE"
manual_seed: 1265
================================================
FILE: configs/vq_vae.yaml
================================================
model_params:
name: 'VQVAE'
in_channels: 3
embedding_dim: 64
num_embeddings: 512
img_size: 64
beta: 0.25
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.0
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: 'VQVAE'
================================================
FILE: configs/wae_mmd_imq.yaml
================================================
model_params:
name: 'WAE_MMD'
in_channels: 3
latent_dim: 128
reg_weight: 100
kernel_type: 'imq'
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "WassersteinVAE_IMQ"
================================================
FILE: configs/wae_mmd_rbf.yaml
================================================
model_params:
name: 'WAE_MMD'
in_channels: 3
latent_dim: 128
reg_weight: 5000
kernel_type: 'rbf'
data_params:
data_path: "Data/"
train_batch_size: 64
val_batch_size: 64
patch_size: 64
num_workers: 4
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 10
logging_params:
save_dir: "logs/"
name: "WassersteinVAE_RBF"
================================================
FILE: dataset.py
================================================
import os
import torch
from torch import Tensor
from pathlib import Path
from typing import List, Optional, Sequence, Union, Any, Callable
from torchvision.datasets.folder import default_loader
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import CelebA
import zipfile
# Add your custom dataset class here
class MyDataset(Dataset):
def __init__(self):
pass
def __len__(self):
pass
def __getitem__(self, idx):
pass
class MyCelebA(CelebA):
"""
A work-around to address issues with pytorch's celebA dataset class.
Download and Extract
URL : https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing
"""
def _check_integrity(self) -> bool:
return True
class OxfordPets(Dataset):
"""
URL = https://www.robots.ox.ac.uk/~vgg/data/pets/
"""
def __init__(self,
data_path: str,
split: str,
transform: Callable,
**kwargs):
self.data_dir = Path(data_path) / "OxfordPets"
self.transforms = transform
imgs = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.jpg'])
self.imgs = imgs[:int(len(imgs) * 0.75)] if split == "train" else imgs[int(len(imgs) * 0.75):]
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img = default_loader(self.imgs[idx])
if self.transforms is not None:
img = self.transforms(img)
return img, 0.0 # dummy datat to prevent breaking
class VAEDataset(LightningDataModule):
"""
PyTorch Lightning data module
Args:
data_dir: root directory of your dataset.
train_batch_size: the batch size to use during training.
val_batch_size: the batch size to use during validation.
patch_size: the size of the crop to take from the original images.
num_workers: the number of parallel workers to create to load data
items (see PyTorch's Dataloader documentation for more details).
pin_memory: whether prepared items should be loaded into pinned memory
or not. This can improve performance on GPUs.
"""
def __init__(
self,
data_path: str,
train_batch_size: int = 8,
val_batch_size: int = 8,
patch_size: Union[int, Sequence[int]] = (256, 256),
num_workers: int = 0,
pin_memory: bool = False,
**kwargs,
):
super().__init__()
self.data_dir = data_path
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.patch_size = patch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
def setup(self, stage: Optional[str] = None) -> None:
# ========================= OxfordPets Dataset =========================
# train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
# transforms.CenterCrop(self.patch_size),
# # transforms.Resize(self.patch_size),
# transforms.ToTensor(),
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
# val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
# transforms.CenterCrop(self.patch_size),
# # transforms.Resize(self.patch_size),
# transforms.ToTensor(),
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
# self.train_dataset = OxfordPets(
# self.data_dir,
# split='train',
# transform=train_transforms,
# )
# self.val_dataset = OxfordPets(
# self.data_dir,
# split='val',
# transform=val_transforms,
# )
# ========================= CelebA Dataset =========================
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.CenterCrop(148),
transforms.Resize(self.patch_size),
transforms.ToTensor(),])
val_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.CenterCrop(148),
transforms.Resize(self.patch_size),
transforms.ToTensor(),])
self.train_dataset = MyCelebA(
self.data_dir,
split='train',
transform=train_transforms,
download=False,
)
# Replace CelebA with your dataset
self.val_dataset = MyCelebA(
self.data_dir,
split='test',
transform=val_transforms,
download=False,
)
# ===============================================================
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
shuffle=True,
pin_memory=self.pin_memory,
)
def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
self.val_dataset,
batch_size=self.val_batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=self.pin_memory,
)
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
self.val_dataset,
batch_size=144,
num_workers=self.num_workers,
shuffle=True,
pin_memory=self.pin_memory,
)
================================================
FILE: experiment.py
================================================
import os
import math
import torch
from torch import optim
from models import BaseVAE
from models.types_ import *
from utils import data_loader
import pytorch_lightning as pl
from torchvision import transforms
import torchvision.utils as vutils
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
class VAEXperiment(pl.LightningModule):
def __init__(self,
vae_model: BaseVAE,
params: dict) -> None:
super(VAEXperiment, self).__init__()
self.model = vae_model
self.params = params
self.curr_device = None
self.hold_graph = False
try:
self.hold_graph = self.params['retain_first_backpass']
except:
pass
def forward(self, input: Tensor, **kwargs) -> Tensor:
return self.model(input, **kwargs)
def training_step(self, batch, batch_idx, optimizer_idx = 0):
real_img, labels = batch
self.curr_device = real_img.device
results = self.forward(real_img, labels = labels)
train_loss = self.model.loss_function(*results,
M_N = self.params['kld_weight'], #al_img.shape[0]/ self.num_train_imgs,
optimizer_idx=optimizer_idx,
batch_idx = batch_idx)
self.log_dict({key: val.item() for key, val in train_loss.items()}, sync_dist=True)
return train_loss['loss']
def validation_step(self, batch, batch_idx, optimizer_idx = 0):
real_img, labels = batch
self.curr_device = real_img.device
results = self.forward(real_img, labels = labels)
val_loss = self.model.loss_function(*results,
M_N = 1.0, #real_img.shape[0]/ self.num_val_imgs,
optimizer_idx = optimizer_idx,
batch_idx = batch_idx)
self.log_dict({f"val_{key}": val.item() for key, val in val_loss.items()}, sync_dist=True)
def on_validation_end(self) -> None:
self.sample_images()
def sample_images(self):
# Get sample reconstruction image
test_input, test_label = next(iter(self.trainer.datamodule.test_dataloader()))
test_input = test_input.to(self.curr_device)
test_label = test_label.to(self.curr_device)
# test_input, test_label = batch
recons = self.model.generate(test_input, labels = test_label)
vutils.save_image(recons.data,
os.path.join(self.logger.log_dir ,
"Reconstructions",
f"recons_{self.logger.name}_Epoch_{self.current_epoch}.png"),
normalize=True,
nrow=12)
try:
samples = self.model.sample(144,
self.curr_device,
labels = test_label)
vutils.save_image(samples.cpu().data,
os.path.join(self.logger.log_dir ,
"Samples",
f"{self.logger.name}_Epoch_{self.current_epoch}.png"),
normalize=True,
nrow=12)
except Warning:
pass
def configure_optimizers(self):
optims = []
scheds = []
optimizer = optim.Adam(self.model.parameters(),
lr=self.params['LR'],
weight_decay=self.params['weight_decay'])
optims.append(optimizer)
# Check if more than 1 optimizer is required (Used for adversarial training)
try:
if self.params['LR_2'] is not None:
optimizer2 = optim.Adam(getattr(self.model,self.params['submodel']).parameters(),
lr=self.params['LR_2'])
optims.append(optimizer2)
except:
pass
try:
if self.params['scheduler_gamma'] is not None:
scheduler = optim.lr_scheduler.ExponentialLR(optims[0],
gamma = self.params['scheduler_gamma'])
scheds.append(scheduler)
# Check if another scheduler is required for the second optimizer
try:
if self.params['scheduler_gamma_2'] is not None:
scheduler2 = optim.lr_scheduler.ExponentialLR(optims[1],
gamma = self.params['scheduler_gamma_2'])
scheds.append(scheduler2)
except:
pass
return optims, scheds
except:
return optims
================================================
FILE: models/__init__.py
================================================
from .base import *
from .vanilla_vae import *
from .gamma_vae import *
from .beta_vae import *
from .wae_mmd import *
from .cvae import *
from .hvae import *
from .vampvae import *
from .iwae import *
from .dfcvae import *
from .mssim_vae import MSSIMVAE
from .fvae import *
from .cat_vae import *
from .joint_vae import *
from .info_vae import *
# from .twostage_vae import *
from .lvae import LVAE
from .logcosh_vae import *
from .swae import *
from .miwae import *
from .vq_vae import *
from .betatc_vae import *
from .dip_vae import *
# Aliases
VAE = VanillaVAE
GaussianVAE = VanillaVAE
CVAE = ConditionalVAE
GumbelVAE = CategoricalVAE
vae_models = {'HVAE':HVAE,
'LVAE':LVAE,
'IWAE':IWAE,
'SWAE':SWAE,
'MIWAE':MIWAE,
'VQVAE':VQVAE,
'DFCVAE':DFCVAE,
'DIPVAE':DIPVAE,
'BetaVAE':BetaVAE,
'InfoVAE':InfoVAE,
'WAE_MMD':WAE_MMD,
'VampVAE': VampVAE,
'GammaVAE':GammaVAE,
'MSSIMVAE':MSSIMVAE,
'JointVAE':JointVAE,
'BetaTCVAE':BetaTCVAE,
'FactorVAE':FactorVAE,
'LogCoshVAE':LogCoshVAE,
'VanillaVAE':VanillaVAE,
'ConditionalVAE':ConditionalVAE,
'CategoricalVAE':CategoricalVAE}
================================================
FILE: models/base.py
================================================
from .types_ import *
from torch import nn
from abc import abstractmethod
class BaseVAE(nn.Module):
def __init__(self) -> None:
super(BaseVAE, self).__init__()
def encode(self, input: Tensor) -> List[Tensor]:
raise NotImplementedError
def decode(self, input: Tensor) -> Any:
raise NotImplementedError
def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
raise NotImplementedError
def generate(self, x: Tensor, **kwargs) -> Tensor:
raise NotImplementedError
@abstractmethod
def forward(self, *inputs: Tensor) -> Tensor:
pass
@abstractmethod
def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
pass
================================================
FILE: models/beta_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class BetaVAE(BaseVAE):
num_iter = 0 # Global static variable to keep track of iterations
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
beta: int = 4,
gamma:float = 1000.,
max_capacity: int = 25,
Capacity_max_iter: int = 1e5,
loss_type:str = 'B',
**kwargs) -> None:
super(BetaVAE, self).__init__()
self.latent_dim = latent_dim
self.beta = beta
self.gamma = gamma
self.loss_type = loss_type
self.C_max = torch.Tensor([max_capacity])
self.C_stop_iter = Capacity_max_iter
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> Tensor:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
self.num_iter += 1
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
loss = recons_loss + self.beta * kld_weight * kld_loss
elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
self.C_max = self.C_max.to(input.device)
C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
else:
raise ValueError('Undefined loss type.')
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/betatc_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
import math
class BetaTCVAE(BaseVAE):
num_iter = 0 # Global static variable to keep track of iterations
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
anneal_steps: int = 200,
alpha: float = 1.,
beta: float = 6.,
gamma: float = 1.,
**kwargs) -> None:
super(BetaTCVAE, self).__init__()
self.latent_dim = latent_dim
self.anneal_steps = anneal_steps
self.alpha = alpha
self.beta = beta
self.gamma = gamma
modules = []
if hidden_dims is None:
hidden_dims = [32, 32, 32, 32]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 4, stride= 2, padding = 1),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc = nn.Linear(hidden_dims[-1]*16, 256)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_var = nn.Linear(256, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, 256 * 2)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
result = self.fc(result)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 32, 4, 4)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var, z]
def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
"""
Computes the log pdf of the Gaussian with parameters mu and logvar at x
:param x: (Tensor) Point at whichGaussian PDF is to be evaluated
:param mu: (Tensor) Mean of the Gaussian distribution
:param logvar: (Tensor) Log variance of the Gaussian distribution
:return:
"""
norm = - 0.5 * (math.log(2 * math.pi) + logvar)
log_density = norm - 0.5 * ((x - mu) ** 2 * torch.exp(-logvar))
return log_density
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]
weight = 1 #kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input, reduction='sum')
log_q_zx = self.log_density_gaussian(z, mu, log_var).sum(dim = 1)
zeros = torch.zeros_like(z)
log_p_z = self.log_density_gaussian(z, zeros, zeros).sum(dim = 1)
batch_size, latent_dim = z.shape
mat_log_q_z = self.log_density_gaussian(z.view(batch_size, 1, latent_dim),
mu.view(1, batch_size, latent_dim),
log_var.view(1, batch_size, latent_dim))
# Reference
# [1] https://github.com/YannDubs/disentangling-vae/blob/535bbd2e9aeb5a200663a4f82f1d34e084c4ba8d/disvae/utils/math.py#L54
dataset_size = (1 / kwargs['M_N']) * batch_size # dataset size
strat_weight = (dataset_size - batch_size + 1) / (dataset_size * (batch_size - 1))
importance_weights = torch.Tensor(batch_size, batch_size).fill_(1 / (batch_size -1)).to(input.device)
importance_weights.view(-1)[::batch_size] = 1 / dataset_size
importance_weights.view(-1)[1::batch_size] = strat_weight
importance_weights[batch_size - 2, 0] = strat_weight
log_importance_weights = importance_weights.log()
mat_log_q_z += log_importance_weights.view(batch_size, batch_size, 1)
log_q_z = torch.logsumexp(mat_log_q_z.sum(2), dim=1, keepdim=False)
log_prod_q_z = torch.logsumexp(mat_log_q_z, dim=1, keepdim=False).sum(1)
mi_loss = (log_q_zx - log_q_z).mean()
tc_loss = (log_q_z - log_prod_q_z).mean()
kld_loss = (log_prod_q_z - log_p_z).mean()
# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
if self.training:
self.num_iter += 1
anneal_rate = min(0 + 1 * self.num_iter / self.anneal_steps, 1)
else:
anneal_rate = 1.
loss = recons_loss/batch_size + \
self.alpha * mi_loss + \
weight * (self.beta * tc_loss +
anneal_rate * self.gamma * kld_loss)
return {'loss': loss,
'Reconstruction_Loss':recons_loss,
'KLD':kld_loss,
'TC_Loss':tc_loss,
'MI_Loss':mi_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/cat_vae.py
================================================
import torch
import numpy as np
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class CategoricalVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
categorical_dim: int = 40, # Num classes
hidden_dims: List = None,
temperature: float = 0.5,
anneal_rate: float = 3e-5,
anneal_interval: int = 100, # every 100 batches
alpha: float = 30.,
**kwargs) -> None:
super(CategoricalVAE, self).__init__()
self.latent_dim = latent_dim
self.categorical_dim = categorical_dim
self.temp = temperature
self.min_temp = temperature
self.anneal_rate = anneal_rate
self.anneal_interval = anneal_interval
self.alpha = alpha
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_z = nn.Linear(hidden_dims[-1]*4,
self.latent_dim * self.categorical_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(self.latent_dim * self.categorical_dim
, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [B x C x H x W]
:return: (Tensor) Latent code [B x D x Q]
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
z = self.fc_z(result)
z = z.view(-1, self.latent_dim, self.categorical_dim)
return [z]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D x Q]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:
"""
Gumbel-softmax trick to sample from Categorical Distribution
:param z: (Tensor) Latent Codes [B x D x Q]
:return: (Tensor) [B x D]
"""
# Sample from Gumbel
u = torch.rand_like(z)
g = - torch.log(- torch.log(u + eps) + eps)
# Gumbel-Softmax sample
s = F.softmax((z + g) / self.temp, dim=-1)
s = s.view(-1, self.latent_dim * self.categorical_dim)
return s
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
q = self.encode(input)[0]
z = self.reparameterize(q)
return [self.decode(z), input, q]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
q = args[2]
q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
batch_idx = kwargs['batch_idx']
# Anneal the temperature at regular intervals
if batch_idx % self.anneal_interval == 0 and self.training:
self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
self.min_temp)
recons_loss =F.mse_loss(recons, input, reduction='mean')
# KL divergence between gumbel-softmax distribution
eps = 1e-7
# Entropy of the logits
h1 = q_p * torch.log(q_p + eps)
# Cross entropy with the categorical distribution
h2 = q_p * np.log(1. / self.categorical_dim + eps)
kld_loss = torch.mean(torch.sum(h1 - h2, dim =(1,2)), dim=0)
# kld_weight = 1.2
loss = self.alpha * recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
# [S x D x Q]
M = num_samples * self.latent_dim
np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)
np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1
np_y = np.reshape(np_y, [M // self.latent_dim, self.latent_dim, self.categorical_dim])
z = torch.from_numpy(np_y)
# z = self.sampling_dist.sample((num_samples * self.latent_dim, ))
z = z.view(num_samples, self.latent_dim * self.categorical_dim).to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/cvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class ConditionalVAE(BaseVAE):
def __init__(self,
in_channels: int,
num_classes: int,
latent_dim: int,
hidden_dims: List = None,
img_size:int = 64,
**kwargs) -> None:
super(ConditionalVAE, self).__init__()
self.latent_dim = latent_dim
self.img_size = img_size
self.embed_class = nn.Linear(num_classes, img_size * img_size)
self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
in_channels += 1 # To account for the extra label channel
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
y = kwargs['labels'].float()
embedded_class = self.embed_class(y)
embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
embedded_input = self.embed_data(input)
x = torch.cat([embedded_input, embedded_class], dim = 1)
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
z = torch.cat([z, y], dim = 1)
return [self.decode(z), input, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
def sample(self,
num_samples:int,
current_device: int,
**kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
y = kwargs['labels'].float()
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
z = torch.cat([z, y], dim=1)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x, **kwargs)[0]
================================================
FILE: models/dfcvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torchvision.models import vgg19_bn
from torch.nn import functional as F
from .types_ import *
class DFCVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
alpha:float = 1,
beta:float = 0.5,
**kwargs) -> None:
super(DFCVAE, self).__init__()
self.latent_dim = latent_dim
self.alpha = alpha
self.beta = beta
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
self.feature_network = vgg19_bn(pretrained=True)
# Freeze the pretrained feature network
for param in self.feature_network.parameters():
param.requires_grad = False
self.feature_network.eval()
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
recons = self.decode(z)
recons_features = self.extract_features(recons)
input_features = self.extract_features(input)
return [recons, input, recons_features, input_features, mu, log_var]
def extract_features(self,
input: Tensor,
feature_layers: List = None) -> List[Tensor]:
"""
Extracts the features from the pretrained model
at the layers indicated by feature_layers.
:param input: (Tensor) [B x C x H x W]
:param feature_layers: List of string of IDs
:return: List of the extracted features
"""
if feature_layers is None:
feature_layers = ['14', '24', '34', '43']
features = []
result = input
for (key, module) in self.feature_network.features._modules.items():
result = module(result)
if(key in feature_layers):
features.append(result)
return features
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
recons_features = args[2]
input_features = args[3]
mu = args[4]
log_var = args[5]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
feature_loss = 0.0
for (r, i) in zip(recons_features, input_features):
feature_loss += F.mse_loss(r, i)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = self.beta * (recons_loss + feature_loss) + self.alpha * kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/dip_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class DIPVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
lambda_diag: float = 10.,
lambda_offdiag: float = 5.,
**kwargs) -> None:
super(DIPVAE, self).__init__()
self.latent_dim = latent_dim
self.lambda_diag = lambda_diag
self.lambda_offdiag = lambda_offdiag
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input, reduction='sum')
kld_loss = torch.sum(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
# DIP Loss
centered_mu = mu - mu.mean(dim=1, keepdim = True) # [B x D]
cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D]
# Add Variance for DIP Loss II
cov_z = cov_mu + torch.mean(torch.diagonal((2. * log_var).exp(), dim1 = 0), dim = 0) # [D x D]
# For DIp Loss I
# cov_z = cov_mu
cov_diag = torch.diag(cov_z) # [D]
cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D]
dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \
self.lambda_diag * torch.sum((cov_diag - 1) ** 2)
loss = recons_loss + kld_weight * kld_loss + dip_loss
return {'loss': loss,
'Reconstruction_Loss':recons_loss,
'KLD':-kld_loss,
'DIP_Loss':dip_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/fvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class FactorVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
gamma: float = 40.,
**kwargs) -> None:
super(FactorVAE, self).__init__()
self.latent_dim = latent_dim
self.gamma = gamma
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
# Discriminator network for the Total Correlation (TC) loss
self.discriminator = nn.Sequential(nn.Linear(self.latent_dim, 1000),
nn.BatchNorm1d(1000),
nn.LeakyReLU(0.2),
nn.Linear(1000, 1000),
nn.BatchNorm1d(1000),
nn.LeakyReLU(0.2),
nn.Linear(1000, 1000),
nn.BatchNorm1d(1000),
nn.LeakyReLU(0.2),
nn.Linear(1000, 2))
self.D_z_reserve = None
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var, z]
def permute_latent(self, z: Tensor) -> Tensor:
"""
Permutes each of the latent codes in the batch
:param z: [B x D]
:return: [B x D]
"""
B, D = z.size()
# Returns a shuffled inds for each latent code in the batch
inds = torch.cat([(D *i) + torch.randperm(D) for i in range(B)])
return z.view(-1)[inds].view(B, D)
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
optimizer_idx = kwargs['optimizer_idx']
# Update the VAE
if optimizer_idx == 0:
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
self.D_z_reserve = self.discriminator(z)
vae_tc_loss = (self.D_z_reserve[:, 0] - self.D_z_reserve[:, 1]).mean()
loss = recons_loss + kld_weight * kld_loss + self.gamma * vae_tc_loss
# print(f' recons: {recons_loss}, kld: {kld_loss}, VAE_TC_loss: {vae_tc_loss}')
return {'loss': loss,
'Reconstruction_Loss':recons_loss,
'KLD':-kld_loss,
'VAE_TC_Loss': vae_tc_loss}
# Update the Discriminator
elif optimizer_idx == 1:
device = input.device
true_labels = torch.ones(input.size(0), dtype= torch.long,
requires_grad=False).to(device)
false_labels = torch.zeros(input.size(0), dtype= torch.long,
requires_grad=False).to(device)
z = z.detach() # Detach so that VAE is not trained again
z_perm = self.permute_latent(z)
D_z_perm = self.discriminator(z_perm)
D_tc_loss = 0.5 * (F.cross_entropy(self.D_z_reserve, false_labels) +
F.cross_entropy(D_z_perm, true_labels))
# print(f'D_TC: {D_tc_loss}')
return {'loss': D_tc_loss,
'D_TC_Loss':D_tc_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/gamma_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.distributions import Gamma
from torch.nn import functional as F
from .types_ import *
import torch.nn.init as init
class GammaVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
gamma_shape: float = 8.,
prior_shape: float = 2.0,
prior_rate: float = 1.,
**kwargs) -> None:
super(GammaVAE, self).__init__()
self.latent_dim = latent_dim
self.B = gamma_shape
self.prior_alpha = torch.tensor([prior_shape])
self.prior_beta = torch.tensor([prior_rate])
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),
nn.Softmax())
self.fc_var = nn.Sequential(nn.Linear(hidden_dims[-1] * 4, latent_dim),
nn.Softmax())
# Build Decoder
modules = []
self.decoder_input = nn.Sequential(nn.Linear(latent_dim, hidden_dims[-1] * 4))
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels=3,
kernel_size=3, padding=1),
nn.Sigmoid())
self.weight_init()
def weight_init(self):
# print(self._modules)
for block in self._modules:
for m in self._modules[block]:
init_(m)
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
alpha = self.fc_mu(result)
beta = self.fc_var(result)
return [alpha, beta]
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor:
"""
Reparameterize the Gamma distribution by the shape augmentation trick.
Reference:
[1] https://arxiv.org/pdf/1610.05683.pdf
:param alpha: (Tensor) Shape parameter of the latent Gamma
:param beta: (Tensor) Rate parameter of the latent Gamma
:return:
"""
# Sample from Gamma to guarantee acceptance
alpha_ = alpha.clone().detach()
z_hat = Gamma(alpha_ + self.B, torch.ones_like(alpha_)).sample()
# Compute the eps ~ N(0,1) that produces z_hat
eps = self.inv_h_func(alpha + self.B , z_hat)
z = self.h_func(alpha + self.B, eps)
# When beta != 1, scale by beta
return z / beta
def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor:
"""
Reparameterize a sample eps ~ N(0, 1) so that h(z) ~ Gamma(alpha, 1)
:param alpha: (Tensor) Shape parameter
:param eps: (Tensor) Random sample to reparameterize
:return: (Tensor)
"""
z = (alpha - 1./3.) * (1 + eps / torch.sqrt(9. * alpha - 3.))**3
return z
def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor:
"""
Inverse reparameterize the given z into eps.
:param alpha: (Tensor)
:param z: (Tensor)
:return: (Tensor)
"""
eps = torch.sqrt(9. * alpha - 3.) * ((z / (alpha - 1./3.))**(1. / 3.) - 1.)
return eps
def forward(self, input: Tensor, **kwargs) -> Tensor:
alpha, beta = self.encode(input)
z = self.reparameterize(alpha, beta)
return [self.decode(z), input, alpha, beta]
# def I_function(self, alpha_p, beta_p, alpha_q, beta_q):
# return - (alpha_q * beta_q) / alpha_p - \
# beta_p * torch.log(alpha_p) - torch.lgamma(beta_p) + \
# (beta_p - 1) * torch.digamma(beta_q) + \
# (beta_p - 1) * torch.log(alpha_q)
def I_function(self, a, b, c, d):
return - c * d / a - b * torch.log(a) - torch.lgamma(b) + (b - 1) * (torch.digamma(d) + torch.log(c))
def vae_gamma_kl_loss(self, a, b, c, d):
"""
https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions
b and d are Gamma shape parameters and
a and c are scale parameters.
(All, therefore, must be positive.)
"""
a = 1 / a
c = 1 / c
losses = self.I_function(c, d, c, d) - self.I_function(a, b, c, d)
return torch.sum(losses, dim=1)
def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
alpha = args[2]
beta = args[3]
curr_device = input.device
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss = torch.mean(F.mse_loss(recons, input, reduction = 'none'), dim = (1,2,3))
# https://stats.stackexchange.com/questions/11646/kullback-leibler-divergence-between-two-gamma-distributions
# alpha = 1./ alpha
self.prior_alpha = self.prior_alpha.to(curr_device)
self.prior_beta = self.prior_beta.to(curr_device)
# kld_loss = - self.I_function(alpha, beta, self.prior_alpha, self.prior_beta)
kld_loss = self.vae_gamma_kl_loss(alpha, beta, self.prior_alpha, self.prior_beta)
# kld_loss = torch.sum(kld_loss, dim=1)
loss = recons_loss + kld_loss
loss = torch.mean(loss, dim = 0)
# print(loss, recons_loss, kld_loss)
return {'loss': loss} #, 'Reconstruction_Loss': recons_loss, 'KLD': -kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the modelSay
:return: (Tensor)
"""
z = Gamma(self.prior_alpha, self.prior_beta).sample((num_samples, self.latent_dim))
z = z.squeeze().to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
def init_(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
init.orthogonal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.fill_(0)
================================================
FILE: models/hvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class HVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent1_dim: int,
latent2_dim: int,
hidden_dims: List = None,
img_size:int = 64,
pseudo_input_size: int = 128,
**kwargs) -> None:
super(HVAE, self).__init__()
self.latent1_dim = latent1_dim
self.latent2_dim = latent2_dim
self.img_size = img_size
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
channels = in_channels
# Build z2 Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
channels = h_dim
self.encoder_z2_layers = nn.Sequential(*modules)
self.fc_z2_mu = nn.Linear(hidden_dims[-1]*4, latent2_dim)
self.fc_z2_var = nn.Linear(hidden_dims[-1]*4, latent2_dim)
# ========================================================================#
# Build z1 Encoder
self.embed_z2_code = nn.Linear(latent2_dim, img_size * img_size)
self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)
modules = []
channels = in_channels + 1 # One more channel for the latent code
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
channels = h_dim
self.encoder_z1_layers = nn.Sequential(*modules)
self.fc_z1_mu = nn.Linear(hidden_dims[-1]*4, latent1_dim)
self.fc_z1_var = nn.Linear(hidden_dims[-1]*4, latent1_dim)
#========================================================================#
# Build z2 Decoder
self.recons_z1_mu = nn.Linear(latent2_dim, latent1_dim)
self.recons_z1_log_var = nn.Linear(latent2_dim, latent1_dim)
# ========================================================================#
# Build z1 Decoder
self.debed_z1_code = nn.Linear(latent1_dim, 1024)
self.debed_z2_code = nn.Linear(latent2_dim, 1024)
modules = []
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
# ========================================================================#
# Pesudo Input for the Vamp-Prior
# self.pseudo_input = torch.eye(pseudo_input_size,
# requires_grad=False).view(1, 1, pseudo_input_size, -1)
#
#
# self.pseudo_layer = nn.Conv2d(1, out_channels=in_channels,
# kernel_size=3, stride=2, padding=1)
def encode_z2(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder_z2_layers(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
z2_mu = self.fc_z2_mu(result)
z2_log_var = self.fc_z2_var(result)
return [z2_mu, z2_log_var]
def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]:
x = self.embed_data(input)
z2 = self.embed_z2_code(z2)
z2 = z2.view(-1, self.img_size, self.img_size).unsqueeze(1)
result = torch.cat([x, z2], dim=1)
result = self.encoder_z1_layers(result)
result = torch.flatten(result, start_dim=1)
z1_mu = self.fc_z1_mu(result)
z1_log_var = self.fc_z1_var(result)
return [z1_mu, z1_log_var]
def encode(self, input: Tensor) -> List[Tensor]:
z2_mu, z2_log_var = self.encode_z2(input)
z2 = self.reparameterize(z2_mu, z2_log_var)
# z1 ~ q(z1|x, z2)
z1_mu, z1_log_var = self.encode_z1(input, z2)
return [z1_mu, z1_log_var, z2_mu, z2_log_var, z2]
def decode(self, input: Tensor) -> Tensor:
result = self.decoder(input)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
# Encode the input into the latent codes z1 and z2
# z2 ~q(z2 | x)
# z1 ~ q(z1|x, z2)
z1_mu, z1_log_var, z2_mu, z2_log_var, z2 = self.encode(input)
z1 = self.reparameterize(z1_mu, z1_log_var)
# Reconstruct the image using both the latent codes
# x ~ p(x|z1, z2)
debedded_z1 = self.debed_z1_code(z1)
debedded_z2 = self.debed_z2_code(z2)
result = torch.cat([debedded_z1, debedded_z2], dim=1)
result = result.view(-1, 512, 2, 2)
recons = self.decode(result)
return [recons,
input,
z1_mu, z1_log_var,
z2_mu, z2_log_var,
z1, z2]
def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
z1_mu = args[2]
z1_log_var = args[3]
z2_mu = args[4]
z2_log_var = args[5]
z1= args[6]
z2 = args[7]
# Reconstruct (decode) z2 into z1
# z1 ~ p(z1|z2) [This for the loss calculation]
z1_p_mu = self.recons_z1_mu(z2)
z1_p_log_var = self.recons_z1_log_var(z2)
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
z1_kld = torch.mean(-0.5 * torch.sum(1 + z1_log_var - z1_mu ** 2 - z1_log_var.exp(), dim = 1),
dim = 0)
z2_kld = torch.mean(-0.5 * torch.sum(1 + z2_log_var - z2_mu ** 2 - z2_log_var.exp(), dim = 1),
dim = 0)
z1_p_kld = torch.mean(-0.5 * torch.sum(1 + z1_p_log_var - (z1 - z1_p_mu) ** 2 - z1_p_log_var.exp(),
dim = 1),
dim = 0)
z2_p_kld = torch.mean(-0.5*(z2**2), dim = 0)
kld_loss = -(z1_p_kld - z1_kld - z2_kld)
loss = recons_loss + kld_weight * kld_loss
# print(z2_p_kld)
return {'loss': loss, 'Reconstruction Loss':recons_loss, 'KLD':-kld_loss}
def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
z2 = torch.randn(batch_size,
self.latent2_dim)
z2 = z2.cuda(current_device)
z1_mu = self.recons_z1_mu(z2)
z1_log_var = self.recons_z1_log_var(z2)
z1 = self.reparameterize(z1_mu, z1_log_var)
debedded_z1 = self.debed_z1_code(z1)
debedded_z2 = self.debed_z2_code(z2)
result = torch.cat([debedded_z1, debedded_z2], dim=1)
result = result.view(-1, 512, 2, 2)
samples = self.decode(result)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/info_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class InfoVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
alpha: float = -0.5,
beta: float = 5.0,
reg_weight: int = 100,
kernel_type: str = 'imq',
latent_var: float = 2.,
**kwargs) -> None:
super(InfoVAE, self).__init__()
self.latent_dim = latent_dim
self.reg_weight = reg_weight
self.kernel_type = kernel_type
self.z_var = latent_var
assert alpha <= 0, 'alpha must be negative or zero.'
self.alpha = alpha
self.beta = beta
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, z, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
z = args[2]
mu = args[3]
log_var = args[4]
batch_size = input.size(0)
bias_corr = batch_size * (batch_size - 1)
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
mmd_loss = self.compute_mmd(z)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
loss = self.beta * recons_loss + \
(1. - self.alpha) * kld_weight * kld_loss + \
(self.alpha + self.reg_weight - 1.)/bias_corr * mmd_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'MMD': mmd_loss, 'KLD':-kld_loss}
def compute_kernel(self,
x1: Tensor,
x2: Tensor) -> Tensor:
# Convert the tensors into row and column vectors
D = x1.size(1)
N = x1.size(0)
x1 = x1.unsqueeze(-2) # Make it into a column tensor
x2 = x2.unsqueeze(-3) # Make it into a row tensor
"""
Usually the below lines are not required, especially in our case,
but this is useful when x1 and x2 have different sizes
along the 0th dimension.
"""
x1 = x1.expand(N, N, D)
x2 = x2.expand(N, N, D)
if self.kernel_type == 'rbf':
result = self.compute_rbf(x1, x2)
elif self.kernel_type == 'imq':
result = self.compute_inv_mult_quad(x1, x2)
else:
raise ValueError('Undefined kernel type.')
return result
def compute_rbf(self,
x1: Tensor,
x2: Tensor,
eps: float = 1e-7) -> Tensor:
"""
Computes the RBF Kernel between x1 and x2.
:param x1: (Tensor)
:param x2: (Tensor)
:param eps: (Float)
:return:
"""
z_dim = x2.size(-1)
sigma = 2. * z_dim * self.z_var
result = torch.exp(-((x1 - x2).pow(2).mean(-1) / sigma))
return result
def compute_inv_mult_quad(self,
x1: Tensor,
x2: Tensor,
eps: float = 1e-7) -> Tensor:
"""
Computes the Inverse Multi-Quadratics Kernel between x1 and x2,
given by
k(x_1, x_2) = \sum \frac{C}{C + \|x_1 - x_2 \|^2}
:param x1: (Tensor)
:param x2: (Tensor)
:param eps: (Float)
:return:
"""
z_dim = x2.size(-1)
C = 2 * z_dim * self.z_var
kernel = C / (eps + C + (x1 - x2).pow(2).sum(dim = -1))
# Exclude diagonal elements
result = kernel.sum() - kernel.diag().sum()
return result
def compute_mmd(self, z: Tensor) -> Tensor:
# Sample from prior (Gaussian) distribution
prior_z = torch.randn_like(z)
prior_z__kernel = self.compute_kernel(prior_z, prior_z)
z__kernel = self.compute_kernel(z, z)
priorz_z__kernel = self.compute_kernel(prior_z, z)
mmd = prior_z__kernel.mean() + \
z__kernel.mean() - \
2 * priorz_z__kernel.mean()
return mmd
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/iwae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class IWAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
num_samples: int = 5,
**kwargs) -> None:
super(IWAE, self).__init__()
self.latent_dim = latent_dim
self.num_samples = num_samples
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes of S samples
onto the image space.
:param z: (Tensor) [B x S x D]
:return: (Tensor) [B x S x C x H x W]
"""
B, _, _ = z.size()
z = z.view(-1, self.latent_dim) #[BS x D]
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result) #[BS x C x H x W ]
result = result.view([B, -1, result.size(1), result.size(2), result.size(3)]) #[B x S x C x H x W]
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
mu = mu.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
log_var = log_var.repeat(self.num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
z= self.reparameterize(mu, log_var) # [B x S x D]
eps = (z - mu) / log_var # Prior samples
return [self.decode(z), input, mu, log_var, z, eps]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]
eps = args[5]
input = input.repeat(self.num_samples, 1, 1, 1, 1).permute(1, 0, 2, 3, 4) #[B x S x C x H x W]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
log_p_x_z = ((recons - input) ** 2).flatten(2).mean(-1) # Reconstruction Loss [B x S]
kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=2) ## [B x S]
# Get importance weights
log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data
# Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
weight = F.softmax(log_weight, dim = -1)
# kld_loss = torch.mean(kld_loss, dim = 0)
loss = torch.mean(torch.sum(weight * log_weight, dim=-1), dim = 0)
return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples, 1,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z).squeeze()
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image.
Returns only the first reconstructed sample
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0][:, 0, :]
================================================
FILE: models/joint_vae.py
================================================
import torch
import numpy as np
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class JointVAE(BaseVAE):
num_iter = 1
def __init__(self,
in_channels: int,
latent_dim: int,
categorical_dim: int,
latent_min_capacity: float =0.,
latent_max_capacity: float = 25.,
latent_gamma: float = 30.,
latent_num_iter: int = 25000,
categorical_min_capacity: float =0.,
categorical_max_capacity: float = 25.,
categorical_gamma: float = 30.,
categorical_num_iter: int = 25000,
hidden_dims: List = None,
temperature: float = 0.5,
anneal_rate: float = 3e-5,
anneal_interval: int = 100, # every 100 batches
alpha: float = 30.,
**kwargs) -> None:
super(JointVAE, self).__init__()
self.latent_dim = latent_dim
self.categorical_dim = categorical_dim
self.temp = temperature
self.min_temp = temperature
self.anneal_rate = anneal_rate
self.anneal_interval = anneal_interval
self.alpha = alpha
self.cont_min = latent_min_capacity
self.cont_max = latent_max_capacity
self.disc_min = categorical_min_capacity
self.disc_max = categorical_max_capacity
self.cont_gamma = latent_gamma
self.disc_gamma = categorical_gamma
self.cont_iter = latent_num_iter
self.disc_iter = categorical_num_iter
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, self.latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, self.latent_dim)
self.fc_z = nn.Linear(hidden_dims[-1]*4, self.categorical_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(self.latent_dim + self.categorical_dim,
hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
self.sampling_dist = torch.distributions.OneHotCategorical(1. / categorical_dim * torch.ones((self.categorical_dim, 1)))
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [B x C x H x W]
:return: (Tensor) Latent code [B x D x Q]
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
z = self.fc_z(result)
z = z.view(-1, self.categorical_dim)
return [mu, log_var, z]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D x Q]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self,
mu: Tensor,
log_var: Tensor,
q: Tensor,
eps:float = 1e-7) -> Tensor:
"""
Gumbel-softmax trick to sample from Categorical Distribution
:param mu: (Tensor) mean of the latent Gaussian [B x D]
:param log_var: (Tensor) Log variance of the latent Gaussian [B x D]
:param q: (Tensor) Categorical latent Codes [B x Q]
:return: (Tensor) [B x (D + Q)]
"""
std = torch.exp(0.5 * log_var)
e = torch.randn_like(std)
z = e * std + mu
# Sample from Gumbel
u = torch.rand_like(q)
g = - torch.log(- torch.log(u + eps) + eps)
# Gumbel-Softmax sample
s = F.softmax((q + g) / self.temp, dim=-1)
s = s.view(-1, self.categorical_dim)
return torch.cat([z, s], dim=1)
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var, q = self.encode(input)
z = self.reparameterize(mu, log_var, q)
return [self.decode(z), input, q, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
q = args[2]
mu = args[3]
log_var = args[4]
q_p = F.softmax(q, dim=-1) # Convert the categorical codes into probabilities
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
batch_idx = kwargs['batch_idx']
# Anneal the temperature at regular intervals
if batch_idx % self.anneal_interval == 0 and self.training:
self.temp = np.maximum(self.temp * np.exp(- self.anneal_rate * batch_idx),
self.min_temp)
recons_loss =F.mse_loss(recons, input, reduction='mean')
# Adaptively increase the discrinimator capacity
disc_curr = (self.disc_max - self.disc_min) * \
self.num_iter/ float(self.disc_iter) + self.disc_min
disc_curr = min(disc_curr, np.log(self.categorical_dim))
# KL divergence between gumbel-softmax distribution
eps = 1e-7
# Entropy of the logits
h1 = q_p * torch.log(q_p + eps)
# Cross entropy with the categorical distribution
h2 = q_p * np.log(1. / self.categorical_dim + eps)
kld_disc_loss = torch.mean(torch.sum(h1 - h2, dim =1), dim=0)
# Compute Continuous loss
# Adaptively increase the continuous capacity
cont_curr = (self.cont_max - self.cont_min) * \
self.num_iter/ float(self.cont_iter) + self.cont_min
cont_curr = min(cont_curr, self.cont_max)
kld_cont_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(),
dim=1),
dim=0)
capacity_loss = self.disc_gamma * torch.abs(disc_curr - kld_disc_loss) + \
self.cont_gamma * torch.abs(cont_curr - kld_cont_loss)
# kld_weight = 1.2
loss = self.alpha * recons_loss + kld_weight * capacity_loss
if self.training:
self.num_iter += 1
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'Capacity_Loss':capacity_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
# [S x D]
z = torch.randn(num_samples,
self.latent_dim)
M = num_samples
np_y = np.zeros((M, self.categorical_dim), dtype=np.float32)
np_y[range(M), np.random.choice(self.categorical_dim, M)] = 1
np_y = np.reshape(np_y, [M , self.categorical_dim])
q = torch.from_numpy(np_y)
# z = self.sampling_dist.sample((num_samples * self.latent_dim, ))
z = torch.cat([z, q], dim = 1).to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/logcosh_vae.py
================================================
import torch
import torch.nn.functional as F
from models import BaseVAE
from torch import nn
from .types_ import *
class LogCoshVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
alpha: float = 100.,
beta: float = 10.,
**kwargs) -> None:
super(LogCoshVAE, self).__init__()
self.latent_dim = latent_dim
self.alpha = alpha
self.beta = beta
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
t = recons - input
# recons_loss = F.mse_loss(recons, input)
# cosh = torch.cosh(self.alpha * t)
# recons_loss = (1./self.alpha * torch.log(cosh)).mean()
recons_loss = self.alpha * t + \
torch.log(1. + torch.exp(- 2 * self.alpha * t)) - \
torch.log(torch.tensor(2.0))
# print(self.alpha* t.max(), self.alpha*t.min())
recons_loss = (1. / self.alpha) * recons_loss.mean()
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + self.beta * kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/lvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
from math import floor, pi, log
def conv_out_shape(img_size):
return floor((img_size + 2 - 3) / 2.) + 1
class EncoderBlock(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
latent_dim: int,
img_size: int):
super(EncoderBlock, self).__init__()
# Build Encoder
self.encoder = nn.Sequential(
nn.Conv2d(in_channels,
out_channels,
kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU())
out_size = conv_out_shape(img_size)
self.encoder_mu = nn.Linear(out_channels * out_size ** 2 , latent_dim)
self.encoder_var = nn.Linear(out_channels * out_size ** 2, latent_dim)
def forward(self, input: Tensor) -> Tensor:
result = self.encoder(input)
h = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.encoder_mu(h)
log_var = self.encoder_var(h)
return [result, mu, log_var]
class LadderBlock(nn.Module):
def __init__(self,
in_channels: int,
latent_dim: int):
super(LadderBlock, self).__init__()
# Build Decoder
self.decode = nn.Sequential(nn.Linear(in_channels, latent_dim),
nn.BatchNorm1d(latent_dim))
self.fc_mu = nn.Linear(latent_dim, latent_dim)
self.fc_var = nn.Linear(latent_dim, latent_dim)
def forward(self, z: Tensor) -> Tensor:
z = self.decode(z)
mu = self.fc_mu(z)
log_var = self.fc_var(z)
return [mu, log_var]
class LVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dims: List,
hidden_dims: List,
**kwargs) -> None:
super(LVAE, self).__init__()
self.latent_dims = latent_dims
self.hidden_dims = hidden_dims
self.num_rungs = len(latent_dims)
assert len(latent_dims) == len(hidden_dims), "Length of the latent" \
"and hidden dims must be the same"
# Build Encoder
modules = []
img_size = 64
for i, h_dim in enumerate(hidden_dims):
modules.append(EncoderBlock(in_channels,
h_dim,
latent_dims[i],
img_size))
img_size = conv_out_shape(img_size)
in_channels = h_dim
self.encoders = nn.Sequential(*modules)
# ====================================================================== #
# Build Decoder
modules = []
for i in range(self.num_rungs -1, 0, -1):
modules.append(LadderBlock(latent_dims[i],
latent_dims[i-1]))
self.ladders = nn.Sequential(*modules)
self.decoder_input = nn.Linear(latent_dims[0], hidden_dims[-1] * 4)
hidden_dims.reverse()
modules = []
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
hidden_dims.reverse()
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
h = input
# Posterior Parameters
post_params = []
for encoder_block in self.encoders:
h, mu, log_var = encoder_block(h)
post_params.append((mu, log_var))
return post_params
def decode(self, z: Tensor, post_params: List) -> Tuple:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
kl_div = 0
post_params.reverse()
for i, ladder_block in enumerate(self.ladders):
mu_e, log_var_e = post_params[i]
mu_t, log_var_t = ladder_block(z)
mu, log_var = self.merge_gauss(mu_e, mu_t,
log_var_e, log_var_t)
z = self.reparameterize(mu, log_var)
kl_div += self.compute_kl_divergence(z, (mu, log_var), (mu_e, log_var_e))
result = self.decoder_input(z)
result = result.view(-1, self.hidden_dims[-1], 2, 2)
result = self.decoder(result)
return self.final_layer(result), kl_div
def merge_gauss(self,
mu_1: Tensor,
mu_2: Tensor,
log_var_1: Tensor,
log_var_2: Tensor) -> List:
p_1 = 1. / (log_var_1.exp() + 1e-7)
p_2 = 1. / (log_var_2.exp() + 1e-7)
mu = (mu_1 * p_1 + mu_2 * p_2)/(p_1 + p_2)
log_var = torch.log(1./(p_1 + p_2))
return [mu, log_var]
def compute_kl_divergence(self, z: Tensor, q_params: Tuple, p_params: Tuple):
mu_q, log_var_q = q_params
mu_p, log_var_p = p_params
#
# qz = -0.5 * torch.sum(1 + log_var_q + (z - mu_q) ** 2 / (2 * log_var_q.exp() + 1e-8), dim=1)
# pz = -0.5 * torch.sum(1 + log_var_p + (z - mu_p) ** 2 / (2 * log_var_p.exp() + 1e-8), dim=1)
kl = (log_var_p - log_var_q) + (log_var_q.exp() + (mu_q - mu_p)**2)/(2 * log_var_p.exp()) - 0.5
kl = torch.sum(kl, dim = -1)
return kl
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
post_params = self.encode(input)
mu, log_var = post_params.pop()
z = self.reparameterize(mu, log_var)
recons, kl_div = self.decode(z, post_params)
#kl_div += -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1)
return [recons, input, kl_div]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
kl_div = args[2]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(kl_div, dim = 0)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss }
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dims[-1])
z = z.to(current_device)
for ladder_block in self.ladders:
mu, log_var = ladder_block(z)
z = self.reparameterize(mu, log_var)
result = self.decoder_input(z)
result = result.view(-1, self.hidden_dims[-1], 2, 2)
result = self.decoder(result)
samples = self.final_layer(result)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/miwae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
from torch.distributions import Normal
class MIWAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
num_samples: int = 5,
num_estimates: int = 5,
**kwargs) -> None:
super(MIWAE, self).__init__()
self.latent_dim = latent_dim
self.num_samples = num_samples # K
self.num_estimates = num_estimates # M
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes of S samples
onto the image space.
:param z: (Tensor) [B x S x D]
:return: (Tensor) [B x S x C x H x W]
"""
B, M,S, D = z.size()
z = z.contiguous().view(-1, self.latent_dim) #[BMS x D]
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result) #[BMS x C x H x W ]
result = result.view([B, M, S,result.size(-3), result.size(-2), result.size(-1)]) #[B x M x S x C x H x W]
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
mu = mu.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
log_var = log_var.repeat(self.num_estimates, self.num_samples, 1, 1).permute(2, 0, 1, 3) # [B x M x S x D]
z = self.reparameterize(mu, log_var) # [B x M x S x D]
eps = (z - mu) / log_var # Prior samples
return [self.decode(z), input, mu, log_var, z, eps]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
z = args[4]
eps = args[5]
input = input.repeat(self.num_estimates,
self.num_samples, 1, 1, 1, 1).permute(2, 0, 1, 3, 4, 5) #[B x M x S x C x H x W]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
log_p_x_z = ((recons - input) ** 2).flatten(3).mean(-1) # Reconstruction Loss # [B x M x S]
kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=3) # [B x M x S]
# Get importance weights
log_weight = (log_p_x_z + kld_weight * kld_loss) #.detach().data
# Rescale the weights (along the sample dim) to lie in [0, 1] and sum to 1
weight = F.softmax(log_weight, dim = -1) # [B x M x S]
loss = torch.mean(torch.mean(torch.sum(weight * log_weight, dim=-1), dim = -2), dim = 0)
return {'loss': loss, 'Reconstruction_Loss':log_p_x_z.mean(), 'KLD':-kld_loss.mean()}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples, 1, 1,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z).squeeze()
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image.
Returns only the first reconstructed sample
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0][:, 0, 0, :]
================================================
FILE: models/mssim_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
from math import exp
class MSSIMVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
window_size: int = 11,
size_average: bool = True,
**kwargs) -> None:
super(MSSIMVAE, self).__init__()
self.latent_dim = latent_dim
self.in_channels = in_channels
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
self.mssim_loss = MSSIM(self.in_channels,
window_size,
size_average)
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
def loss_function(self,
*args: Any,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss = self.mssim_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.cuda(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
class MSSIM(nn.Module):
def __init__(self,
in_channels: int = 3,
window_size: int=11,
size_average:bool = True) -> None:
"""
Computes the differentiable MS-SSIM loss
Reference:
[1] https://github.com/jorge-pessoa/pytorch-msssim/blob/dev/pytorch_msssim/__init__.py
(MIT License)
:param in_channels: (Int)
:param window_size: (Int)
:param size_average: (Bool)
"""
super(MSSIM, self).__init__()
self.in_channels = in_channels
self.window_size = window_size
self.size_average = size_average
def gaussian_window(self, window_size:int, sigma: float) -> Tensor:
kernel = torch.tensor([exp((x - window_size // 2)**2/(2 * sigma ** 2))
for x in range(window_size)])
return kernel/kernel.sum()
def create_window(self, window_size, in_channels):
_1D_window = self.gaussian_window(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(in_channels, 1, window_size, window_size).contiguous()
return window
def ssim(self,
img1: Tensor,
img2: Tensor,
window_size: int,
in_channel: int,
size_average: bool) -> Tensor:
device = img1.device
window = self.create_window(window_size, in_channel).to(device)
mu1 = F.conv2d(img1, window, padding= window_size//2, groups=in_channel)
mu2 = F.conv2d(img2, window, padding= window_size//2, groups=in_channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size//2, groups=in_channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size//2, groups=in_channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding = window_size//2, groups=in_channel) - mu1_mu2
img_range = 1.0 #img1.max() - img1.min() # Dynamic range
C1 = (0.01 * img_range) ** 2
C2 = (0.03 * img_range) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
return ret, cs
def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
device = img1.device
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
mssim = []
mcs = []
for _ in range(levels):
sim, cs = self.ssim(img1, img2,
self.window_size,
self.in_channels,
self.size_average)
mssim.append(sim)
mcs.append(cs)
img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))
mssim = torch.stack(mssim)
mcs = torch.stack(mcs)
# # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
# if normalize:
# mssim = (mssim + 1) / 2
# mcs = (mcs + 1) / 2
pow1 = mcs ** weights
pow2 = mssim ** weights
output = torch.prod(pow1[:-1] * pow2[-1])
return 1 - output
================================================
FILE: models/swae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from torch import distributions as dist
from .types_ import *
class SWAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
reg_weight: int = 100,
wasserstein_deg: float= 2.,
num_projections: int = 50,
projection_dist: str = 'normal',
**kwargs) -> None:
super(SWAE, self).__init__()
self.latent_dim = latent_dim
self.reg_weight = reg_weight
self.p = wasserstein_deg
self.num_projections = num_projections
self.proj_dist = projection_dist
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> Tensor:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
z = self.fc_z(result)
return z
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
z = self.encode(input)
return [self.decode(z), input, z]
def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
z = args[2]
batch_size = input.size(0)
bias_corr = batch_size * (batch_size - 1)
reg_weight = self.reg_weight / bias_corr
recons_loss_l2 = F.mse_loss(recons, input)
recons_loss_l1 = F.l1_loss(recons, input)
swd_loss = self.compute_swd(z, self.p, reg_weight)
loss = recons_loss_l2 + recons_loss_l1 + swd_loss
return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss}
def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:
"""
Returns random samples from latent distribution's (Gaussian)
unit sphere for projecting the encoded samples and the
distribution samples.
:param latent_dim: (Int) Dimensionality of the latent space (D)
:param num_samples: (Int) Number of samples required (S)
:return: Random projections from the latent unit sphere
"""
if self.proj_dist == 'normal':
rand_samples = torch.randn(num_samples, latent_dim)
elif self.proj_dist == 'cauchy':
rand_samples = dist.Cauchy(torch.tensor([0.0]),
torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()
else:
raise ValueError('Unknown projection distribution.')
rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)
return rand_proj # [S x D]
def compute_swd(self,
z: Tensor,
p: float,
reg_weight: float) -> Tensor:
"""
Computes the Sliced Wasserstein Distance (SWD) - which consists of
randomly projecting the encoded and prior vectors and computing
their Wasserstein distance along those projections.
:param z: Latent samples # [N x D]
:param p: Value for the p^th Wasserstein distance
:param reg_weight:
:return:
"""
prior_z = torch.randn_like(z) # [N x D]
device = z.device
proj_matrix = self.get_random_projections(self.latent_dim,
num_samples=self.num_projections).transpose(0,1).to(device)
latent_projections = z.matmul(proj_matrix) # [N x S]
prior_projections = prior_z.matmul(proj_matrix) # [N x S]
# The Wasserstein distance is computed by sorting the two projections
# across the batches and computing their element-wise l2 distance
w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \
torch.sort(prior_projections.t(), dim=1)[0]
w_dist = w_dist.pow(p)
return reg_weight * w_dist.mean()
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/twostage_vae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class TwoStageVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
hidden_dims2: List = None,
**kwargs) -> None:
super(TwoStageVAE, self).__init__()
self.latent_dim = latent_dim
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
if hidden_dims2 is None:
hidden_dims2 = [1024, 1024]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
#---------------------- Second VAE ---------------------------#
encoder2 = []
in_channels = self.latent_dim
for h_dim in hidden_dims2:
encoder2.append(nn.Sequential(
nn.Linear(in_channels, h_dim),
nn.BatchNorm1d(h_dim),
nn.LeakyReLU()))
in_channels = h_dim
self.encoder2 = nn.Sequential(*encoder2)
self.fc_mu2 = nn.Linear(hidden_dims2[-1], self.latent_dim)
self.fc_var2 = nn.Linear(hidden_dims2[-1], self.latent_dim)
decoder2 = []
hidden_dims2.reverse()
in_channels = self.latent_dim
for h_dim in hidden_dims2:
decoder2.append(nn.Sequential(
nn.Linear(in_channels, h_dim),
nn.BatchNorm1d(h_dim),
nn.LeakyReLU()))
in_channels = h_dim
self.decoder2 = nn.Sequential(*decoder2)
def encode(self, input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
"""
Maps the given latent codes
onto the image space.
:param z: (Tensor) [B x D]
:return: (Tensor) [B x C x H x W]
"""
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
return [self.decode(z), input, mu, log_var]
def loss_function(self,
*args,
**kwargs) -> dict:
"""
Computes the VAE loss function.
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
:param args:
:param kwargs:
:return:
"""
recons = args[0]
input = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
recons_loss =F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
loss = recons_loss + kld_weight * kld_loss
return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':-kld_loss}
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
================================================
FILE: models/types_.py
================================================
from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor
Tensor = TypeVar('torch.tensor')
================================================
FILE: models/vampvae.py
================================================
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from .types_ import *
class VampVAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
num_components: int = 50,
**kwargs) -> None:
super(VampVAE, self).__init__()
self.latent_dim = latent_dim
self.num_components = num_components
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_
gitextract_vk1fb46d/ ├── .gitignore ├── .idea/ │ ├── .gitignore │ ├── PyTorch-VAE.iml │ ├── inspectionProfiles/ │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ └── vcs.xml ├── LICENSE.md ├── README.md ├── configs/ │ ├── bbvae.yaml │ ├── betatc_vae.yaml │ ├── bhvae.yaml │ ├── cat_vae.yaml │ ├── cvae.yaml │ ├── dfc_vae.yaml │ ├── dip_vae.yaml │ ├── factorvae.yaml │ ├── gammavae.yaml │ ├── hvae.yaml │ ├── infovae.yaml │ ├── iwae.yaml │ ├── joint_vae.yaml │ ├── logcosh_vae.yaml │ ├── lvae.yaml │ ├── miwae.yaml │ ├── mssim_vae.yaml │ ├── swae.yaml │ ├── vae.yaml │ ├── vampvae.yaml │ ├── vq_vae.yaml │ ├── wae_mmd_imq.yaml │ └── wae_mmd_rbf.yaml ├── dataset.py ├── experiment.py ├── models/ │ ├── __init__.py │ ├── base.py │ ├── beta_vae.py │ ├── betatc_vae.py │ ├── cat_vae.py │ ├── cvae.py │ ├── dfcvae.py │ ├── dip_vae.py │ ├── fvae.py │ ├── gamma_vae.py │ ├── hvae.py │ ├── info_vae.py │ ├── iwae.py │ ├── joint_vae.py │ ├── logcosh_vae.py │ ├── lvae.py │ ├── miwae.py │ ├── mssim_vae.py │ ├── swae.py │ ├── twostage_vae.py │ ├── types_.py │ ├── vampvae.py │ ├── vanilla_vae.py │ ├── vq_vae.py │ └── wae_mmd.py ├── requirements.txt ├── run.py ├── tests/ │ ├── bvae.py │ ├── test_betatcvae.py │ ├── test_cat_vae.py │ ├── test_dfc.py │ ├── test_dipvae.py │ ├── test_fvae.py │ ├── test_gvae.py │ ├── test_hvae.py │ ├── test_iwae.py │ ├── test_joint_Vae.py │ ├── test_logcosh.py │ ├── test_lvae.py │ ├── test_miwae.py │ ├── test_mssimvae.py │ ├── test_swae.py │ ├── test_vae.py │ ├── test_vq_vae.py │ ├── test_wae.py │ ├── text_cvae.py │ └── text_vamp.py └── utils.py
SYMBOL INDEX (386 symbols across 46 files)
FILE: dataset.py
class MyDataset (line 15) | class MyDataset(Dataset):
method __init__ (line 16) | def __init__(self):
method __len__ (line 20) | def __len__(self):
method __getitem__ (line 23) | def __getitem__(self, idx):
class MyCelebA (line 27) | class MyCelebA(CelebA):
method _check_integrity (line 35) | def _check_integrity(self) -> bool:
class OxfordPets (line 40) | class OxfordPets(Dataset):
method __init__ (line 44) | def __init__(self,
method __len__ (line 55) | def __len__(self):
method __getitem__ (line 58) | def __getitem__(self, idx):
class VAEDataset (line 66) | class VAEDataset(LightningDataModule):
method __init__ (line 81) | def __init__(
method setup (line 100) | def setup(self, stage: Optional[str] = None) -> None:
method train_dataloader (line 155) | def train_dataloader(self) -> DataLoader:
method val_dataloader (line 164) | def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
method test_dataloader (line 173) | def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
FILE: experiment.py
class VAEXperiment (line 15) | class VAEXperiment(pl.LightningModule):
method __init__ (line 17) | def __init__(self,
method forward (line 31) | def forward(self, input: Tensor, **kwargs) -> Tensor:
method training_step (line 34) | def training_step(self, batch, batch_idx, optimizer_idx = 0):
method validation_step (line 48) | def validation_step(self, batch, batch_idx, optimizer_idx = 0):
method on_validation_end (line 61) | def on_validation_end(self) -> None:
method sample_images (line 64) | def sample_images(self):
method configure_optimizers (line 92) | def configure_optimizers(self):
FILE: models/base.py
class BaseVAE (line 5) | class BaseVAE(nn.Module):
method __init__ (line 7) | def __init__(self) -> None:
method encode (line 10) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 13) | def decode(self, input: Tensor) -> Any:
method sample (line 16) | def sample(self, batch_size:int, current_device: int, **kwargs) -> Ten...
method generate (line 19) | def generate(self, x: Tensor, **kwargs) -> Tensor:
method forward (line 23) | def forward(self, *inputs: Tensor) -> Tensor:
method loss_function (line 27) | def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
FILE: models/beta_vae.py
class BetaVAE (line 8) | class BetaVAE(BaseVAE):
method __init__ (line 12) | def __init__(self,
method encode (line 88) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 105) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 112) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 124) | def forward(self, input: Tensor, **kwargs) -> Tensor:
method loss_function (line 129) | def loss_function(self,
method sample (line 154) | def sample(self,
method generate (line 172) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/betatc_vae.py
class BetaTCVAE (line 9) | class BetaTCVAE(BaseVAE):
method __init__ (line 12) | def __init__(self,
method encode (line 84) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 102) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 115) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 127) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method log_density_gaussian (line 132) | def log_density_gaussian(self, x: Tensor, mu: Tensor, logvar: Tensor):
method loss_function (line 144) | def loss_function(self,
method sample (line 213) | def sample(self,
method generate (line 231) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/cat_vae.py
class CategoricalVAE (line 9) | class CategoricalVAE(BaseVAE):
method __init__ (line 11) | def __init__(self,
method encode (line 89) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 105) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 118) | def reparameterize(self, z: Tensor, eps:float = 1e-7) -> Tensor:
method forward (line 134) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 139) | def loss_function(self,
method sample (line 179) | def sample(self,
method generate (line 202) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/cvae.py
class ConditionalVAE (line 8) | class ConditionalVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 83) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 100) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 107) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 119) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 133) | def loss_function(self,
method sample (line 149) | def sample(self,
method generate (line 170) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/dfcvae.py
class DFCVAE (line 9) | class DFCVAE(BaseVAE):
method __init__ (line 11) | def __init__(self,
method encode (line 90) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 107) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 120) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 132) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method extract_features (line 142) | def extract_features(self,
method loss_function (line 163) | def loss_function(self,
method sample (line 192) | def sample(self,
method generate (line 210) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/dip_vae.py
class DIPVAE (line 8) | class DIPVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 78) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 95) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 108) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 120) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 125) | def loss_function(self,
method sample (line 166) | def sample(self,
method generate (line 184) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/fvae.py
class FactorVAE (line 8) | class FactorVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 92) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 109) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 122) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 134) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method permute_latent (line 139) | def permute_latent(self, z: Tensor) -> Tensor:
method loss_function (line 151) | def loss_function(self,
method sample (line 203) | def sample(self,
method generate (line 221) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/gamma_vae.py
class GammaVAE (line 10) | class GammaVAE(BaseVAE):
method __init__ (line 12) | def __init__(self,
method weight_init (line 85) | def weight_init(self):
method encode (line 92) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 109) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 116) | def reparameterize(self, alpha: Tensor, beta: Tensor) -> Tensor:
method h_func (line 137) | def h_func(self, alpha: Tensor, eps: Tensor) -> Tensor:
method inv_h_func (line 148) | def inv_h_func(self, alpha: Tensor, z: Tensor) -> Tensor:
method forward (line 158) | def forward(self, input: Tensor, **kwargs) -> Tensor:
method I_function (line 168) | def I_function(self, a, b, c, d):
method vae_gamma_kl_loss (line 171) | def vae_gamma_kl_loss(self, a, b, c, d):
method loss_function (line 184) | def loss_function(self,
method sample (line 214) | def sample(self,
method generate (line 230) | def generate(self, x: Tensor, **kwargs) -> Tensor:
function init_ (line 239) | def init_(m):
FILE: models/hvae.py
class HVAE (line 8) | class HVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode_z2 (line 115) | def encode_z2(self, input: Tensor) -> List[Tensor]:
method encode_z1 (line 132) | def encode_z1(self, input: Tensor, z2: Tensor) -> List[Tensor]:
method encode (line 145) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 153) | def decode(self, input: Tensor) -> Tensor:
method reparameterize (line 158) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 170) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 192) | def loss_function(self,
method sample (line 233) | def sample(self, batch_size:int, current_device: int, **kwargs) -> Ten...
method generate (line 252) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/info_vae.py
class InfoVAE (line 8) | class InfoVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 88) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 104) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 111) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 123) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 128) | def loss_function(self,
method compute_kernel (line 150) | def compute_kernel(self,
method compute_rbf (line 178) | def compute_rbf(self,
method compute_inv_mult_quad (line 195) | def compute_inv_mult_quad(self,
method compute_mmd (line 218) | def compute_mmd(self, z: Tensor) -> Tensor:
method sample (line 231) | def sample(self,
method generate (line 249) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/iwae.py
class IWAE (line 8) | class IWAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 78) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 95) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 111) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 121) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 129) | def loss_function(self,
method sample (line 162) | def sample(self,
method generate (line 180) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/joint_vae.py
class JointVAE (line 9) | class JointVAE(BaseVAE):
method __init__ (line 12) | def __init__(self,
method encode (line 111) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 129) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 142) | def reparameterize(self,
method forward (line 170) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 175) | def loss_function(self,
method sample (line 236) | def sample(self,
method generate (line 261) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/logcosh_vae.py
class LogCoshVAE (line 8) | class LogCoshVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 78) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 95) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 108) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 120) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 125) | def loss_function(self,
method sample (line 157) | def sample(self,
method generate (line 175) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/lvae.py
function conv_out_shape (line 9) | def conv_out_shape(img_size):
class EncoderBlock (line 12) | class EncoderBlock(nn.Module):
method __init__ (line 13) | def __init__(self,
method forward (line 32) | def forward(self, input: Tensor) -> Tensor:
class LadderBlock (line 43) | class LadderBlock(nn.Module):
method __init__ (line 44) | def __init__(self,
method forward (line 55) | def forward(self, z: Tensor) -> Tensor:
class LVAE (line 62) | class LVAE(BaseVAE):
method __init__ (line 64) | def __init__(self,
method encode (line 134) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 151) | def decode(self, z: Tensor, post_params: List) -> Tuple:
method merge_gauss (line 173) | def merge_gauss(self,
method compute_kl_divergence (line 186) | def compute_kl_divergence(self, z: Tensor, q_params: Tuple, p_params: ...
method reparameterize (line 197) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 209) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 218) | def loss_function(self,
method sample (line 239) | def sample(self,
method generate (line 264) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/miwae.py
class MIWAE (line 9) | class MIWAE(BaseVAE):
method __init__ (line 11) | def __init__(self,
method encode (line 81) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 98) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 114) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 124) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 132) | def loss_function(self,
method sample (line 166) | def sample(self,
method generate (line 184) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/mssim_vae.py
class MSSIMVAE (line 9) | class MSSIMVAE(BaseVAE):
method __init__ (line 11) | def __init__(self,
method encode (line 84) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 101) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 114) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 126) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 131) | def loss_function(self,
method sample (line 155) | def sample(self,
method generate (line 173) | def generate(self, x: Tensor, **kwargs) -> Tensor:
class MSSIM (line 182) | class MSSIM(nn.Module):
method __init__ (line 184) | def __init__(self,
method gaussian_window (line 203) | def gaussian_window(self, window_size:int, sigma: float) -> Tensor:
method create_window (line 208) | def create_window(self, window_size, in_channels):
method ssim (line 214) | def ssim(self,
method forward (line 250) | def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
FILE: models/swae.py
class SWAE (line 9) | class SWAE(BaseVAE):
method __init__ (line 11) | def __init__(self,
method encode (line 84) | def encode(self, input: Tensor) -> Tensor:
method decode (line 99) | def decode(self, z: Tensor) -> Tensor:
method forward (line 106) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 110) | def loss_function(self,
method get_random_projections (line 129) | def get_random_projections(self, latent_dim: int, num_samples: int) ->...
method compute_swd (line 151) | def compute_swd(self,
method sample (line 181) | def sample(self,
method generate (line 199) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/twostage_vae.py
class TwoStageVAE (line 8) | class TwoStageVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 100) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 117) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 130) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 142) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 148) | def loss_function(self,
method sample (line 172) | def sample(self,
method generate (line 190) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/vampvae.py
class VampVAE (line 8) | class VampVAE(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 82) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 99) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 106) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 118) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 123) | def loss_function(self,
method sample (line 170) | def sample(self,
method generate (line 188) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/vanilla_vae.py
class VanillaVAE (line 8) | class VanillaVAE(BaseVAE):
method __init__ (line 11) | def __init__(self,
method encode (line 77) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 94) | def decode(self, z: Tensor) -> Tensor:
method reparameterize (line 107) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
method forward (line 119) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 124) | def loss_function(self,
method sample (line 148) | def sample(self,
method generate (line 166) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/vq_vae.py
class VectorQuantizer (line 7) | class VectorQuantizer(nn.Module):
method __init__ (line 12) | def __init__(self,
method forward (line 24) | def forward(self, latents: Tensor) -> Tensor:
class ResidualLayer (line 57) | class ResidualLayer(nn.Module):
method __init__ (line 59) | def __init__(self,
method forward (line 69) | def forward(self, input: Tensor) -> Tensor:
class VQVAE (line 73) | class VQVAE(BaseVAE):
method __init__ (line 75) | def __init__(self,
method encode (line 168) | def encode(self, input: Tensor) -> List[Tensor]:
method decode (line 178) | def decode(self, z: Tensor) -> Tensor:
method forward (line 189) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 194) | def loss_function(self,
method sample (line 213) | def sample(self,
method generate (line 218) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: models/wae_mmd.py
class WAE_MMD (line 8) | class WAE_MMD(BaseVAE):
method __init__ (line 10) | def __init__(self,
method encode (line 81) | def encode(self, input: Tensor) -> Tensor:
method decode (line 96) | def decode(self, z: Tensor) -> Tensor:
method forward (line 103) | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
method loss_function (line 107) | def loss_function(self,
method compute_kernel (line 125) | def compute_kernel(self,
method compute_rbf (line 153) | def compute_rbf(self,
method compute_inv_mult_quad (line 170) | def compute_inv_mult_quad(self,
method compute_mmd (line 193) | def compute_mmd(self, z: Tensor, reg_weight: float) -> Tensor:
method sample (line 206) | def sample(self,
method generate (line 224) | def generate(self, x: Tensor, **kwargs) -> Tensor:
FILE: tests/bvae.py
class TestVAE (line 7) | class TestVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
FILE: tests/test_betatcvae.py
class TestBetaTCVAE (line 7) | class TestBetaTCVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 24) | def test_loss(self):
method test_sample (line 31) | def test_sample(self):
method test_generate (line 36) | def test_generate(self):
FILE: tests/test_cat_vae.py
class TestVAE (line 7) | class TestVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
method test_sample (line 31) | def test_sample(self):
FILE: tests/test_dfc.py
class TestDFCVAE (line 7) | class TestDFCVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 18) | def test_forward(self):
method test_loss (line 25) | def test_loss(self):
method test_sample (line 32) | def test_sample(self):
FILE: tests/test_dipvae.py
class TestDIPVAE (line 7) | class TestDIPVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 24) | def test_loss(self):
method test_sample (line 31) | def test_sample(self):
method test_generate (line 36) | def test_generate(self):
FILE: tests/test_fvae.py
class TestFAE (line 7) | class TestFAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 20) | def test_forward(self):
method test_loss (line 27) | def test_loss(self):
method test_optim (line 36) | def test_optim(self):
method test_sample (line 40) | def test_sample(self):
FILE: tests/test_gvae.py
class TestGammaVAE (line 7) | class TestGammaVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 18) | def test_forward(self):
method test_loss (line 25) | def test_loss(self):
method test_sample (line 32) | def test_sample(self):
FILE: tests/test_hvae.py
class TestHVAE (line 7) | class TestHVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
FILE: tests/test_iwae.py
class TestIWAE (line 7) | class TestIWAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
method test_sample (line 30) | def test_sample(self):
FILE: tests/test_joint_Vae.py
class TestVAE (line 7) | class TestVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
method test_sample (line 31) | def test_sample(self):
FILE: tests/test_logcosh.py
class TestVAE (line 7) | class TestVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
FILE: tests/test_lvae.py
class TestLVAE (line 7) | class TestLVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 24) | def test_loss(self):
method test_sample (line 31) | def test_sample(self):
FILE: tests/test_miwae.py
class TestMIWAE (line 7) | class TestMIWAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
method test_sample (line 30) | def test_sample(self):
method test_generate (line 35) | def test_generate(self):
FILE: tests/test_mssimvae.py
class TestMSSIMVAE (line 7) | class TestMSSIMVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 18) | def test_forward(self):
method test_loss (line 25) | def test_loss(self):
method test_sample (line 32) | def test_sample(self):
FILE: tests/test_swae.py
class TestSWAE (line 7) | class TestSWAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 12) | def test_summary(self):
method test_forward (line 16) | def test_forward(self):
method test_loss (line 22) | def test_loss(self):
FILE: tests/test_vae.py
class TestVAE (line 7) | class TestVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
FILE: tests/test_vq_vae.py
class TestVQVAE (line 7) | class TestVQVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 24) | def test_loss(self):
method test_sample (line 31) | def test_sample(self):
method test_generate (line 36) | def test_generate(self):
FILE: tests/test_wae.py
class TestWAE (line 7) | class TestWAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 12) | def test_summary(self):
method test_forward (line 16) | def test_forward(self):
method test_loss (line 22) | def test_loss(self):
FILE: tests/text_cvae.py
class TestCVAE (line 6) | class TestCVAE(unittest.TestCase):
method setUp (line 8) | def setUp(self) -> None:
method test_forward (line 12) | def test_forward(self):
method test_loss (line 19) | def test_loss(self):
FILE: tests/text_vamp.py
class TestVVAE (line 7) | class TestVVAE(unittest.TestCase):
method setUp (line 9) | def setUp(self) -> None:
method test_summary (line 13) | def test_summary(self):
method test_forward (line 17) | def test_forward(self):
method test_loss (line 23) | def test_loss(self):
FILE: utils.py
function data_loader (line 8) | def data_loader(fn):
Condensed preview — 82 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (254K chars).
[
{
"path": ".gitignore",
"chars": 81,
"preview": "\nData/\nlogs/\n\nVanillaVAE/version_0/\n\n__pycache__/\n.ipynb_checkpoints/\n\nRun.ipynb\n"
},
{
"path": ".idea/.gitignore",
"chars": 39,
"preview": "# Default ignored files\n/workspace.xml\n"
},
{
"path": ".idea/PyTorch-VAE.iml",
"chars": 518,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n <component name=\"NewModuleRootManager"
},
{
"path": ".idea/inspectionProfiles/profiles_settings.xml",
"chars": 174,
"preview": "<component name=\"InspectionProjectProfileManager\">\n <settings>\n <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n"
},
{
"path": ".idea/misc.xml",
"chars": 299,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ProjectRootManager\" version=\"2\" project-"
},
{
"path": ".idea/modules.xml",
"chars": 453,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ProjectModuleManager\">\n <modules>\n "
},
{
"path": ".idea/vcs.xml",
"chars": 247,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"VcsDirectoryMappings\">\n <mapping dire"
},
{
"path": "LICENSE.md",
"chars": 11455,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 14526,
"preview": "<h1 align=\"center\">\n <b>PyTorch VAE</b><br>\n</h1>\n\n<p align=\"center\">\n <a href=\"https://www.python.org/\">\n "
},
{
"path": "configs/bbvae.yaml",
"chars": 492,
"preview": "model_params:\n name: 'BetaVAE'\n in_channels: 3\n latent_dim: 128\n loss_type: 'B'\n gamma: 10.0\n max_capacity: 25\n C"
},
{
"path": "configs/betatc_vae.yaml",
"chars": 455,
"preview": "model_params:\n name: 'BetaTCVAE'\n in_channels: 3\n latent_dim: 10\n anneal_steps: 10000\n alpha: 1.\n beta: 6.\n gamm"
},
{
"path": "configs/bhvae.yaml",
"chars": 423,
"preview": "model_params:\n name: 'BetaVAE'\n in_channels: 3\n latent_dim: 128\n loss_type: 'H'\n beta: 10.\n\ndata_params:\n data_pat"
},
{
"path": "configs/cat_vae.yaml",
"chars": 508,
"preview": "model_params:\n name: 'CategoricalVAE'\n in_channels: 3\n latent_dim: 512\n categorical_dim: 40\n temperature: 0.5\n ann"
},
{
"path": "configs/cvae.yaml",
"chars": 425,
"preview": "model_params:\n name: 'ConditionalVAE'\n in_channels: 3\n num_classes: 40\n latent_dim: 128\n\ndata_params:\n data_path: \""
},
{
"path": "configs/dfc_vae.yaml",
"chars": 392,
"preview": "model_params:\n name: 'DFCVAE'\n in_channels: 3\n latent_dim: 128\n\ndata_params:\n data_path: \"Data/\"\n train_batch_size:"
},
{
"path": "configs/dip_vae.yaml",
"chars": 449,
"preview": "model_params:\n name: 'DIPVAE'\n in_channels: 3\n latent_dim: 128\n lambda_diag: 0.05\n lambda_offdiag: 0.1\n\n\ndata_param"
},
{
"path": "configs/factorvae.yaml",
"chars": 515,
"preview": "model_params:\n name: 'FactorVAE'\n in_channels: 3\n latent_dim: 128\n gamma: 6.4\n\ndata_params:\n data_path: \"Data/\"\n t"
},
{
"path": "configs/gammavae.yaml",
"chars": 479,
"preview": "model_params:\n name: 'GammaVAE'\n in_channels: 3\n latent_dim: 128\n gamma_shape: 8.\n prior_shape: 2.\n prior_rate: 1."
},
{
"path": "configs/hvae.yaml",
"chars": 434,
"preview": "model_params:\n name: 'HVAE'\n in_channels: 3\n latent1_dim: 64\n latent2_dim: 64\n pseudo_input_size: 128\n\ndata_params:"
},
{
"path": "configs/infovae.yaml",
"chars": 569,
"preview": "model_params:\n name: 'InfoVAE'\n in_channels: 3\n latent_dim: 128\n reg_weight: 110 # MMD weight\n kernel_type: 'imq'\n"
},
{
"path": "configs/iwae.yaml",
"chars": 405,
"preview": "model_params:\n name: 'IWAE'\n in_channels: 3\n latent_dim: 128\n num_samples: 5\n\ndata_params:\n data_path: \"Data/\"\n tr"
},
{
"path": "configs/joint_vae.yaml",
"chars": 718,
"preview": "model_params:\n name: 'JointVAE'\n in_channels: 3\n latent_dim: 512\n categorical_dim: 40\n latent_min_capacity: 0.0\n l"
},
{
"path": "configs/logcosh_vae.yaml",
"chars": 427,
"preview": "model_params:\n name: 'LogCoshVAE'\n in_channels: 3\n latent_dim: 128\n alpha: 10.0\n beta: 1.0\n\ndata_params:\n data_pat"
},
{
"path": "configs/lvae.yaml",
"chars": 439,
"preview": "model_params:\n name: 'LVAE'\n in_channels: 3\n latent_dims: [4,8,16,32,128]\n hidden_dims: [32, 64,128, 256, 512]\n\ndata"
},
{
"path": "configs/miwae.yaml",
"chars": 427,
"preview": "model_params:\n name: 'MIWAE'\n in_channels: 3\n latent_dim: 128\n num_samples: 5\n num_estimates: 3\n\ndata_params:\n dat"
},
{
"path": "configs/mssim_vae.yaml",
"chars": 396,
"preview": "model_params:\n name: 'MSSIMVAE'\n in_channels: 3\n latent_dim: 128\n\ndata_params:\n data_path: \"Data/\"\n train_batch_siz"
},
{
"path": "configs/swae.yaml",
"chars": 495,
"preview": "model_params:\n name: 'SWAE'\n in_channels: 3\n latent_dim: 128\n reg_weight: 100\n wasserstein_deg: 2.0\n num_projectio"
},
{
"path": "configs/vae.yaml",
"chars": 405,
"preview": "model_params:\n name: 'VanillaVAE'\n in_channels: 3\n latent_dim: 128\n\n\ndata_params:\n data_path: \"Data/\"\n train_batch_"
},
{
"path": "configs/vampvae.yaml",
"chars": 393,
"preview": "model_params:\n name: 'VampVAE'\n in_channels: 3\n latent_dim: 128\n\nexp_params:\n dataset: celeba\n data_path: \"../../sh"
},
{
"path": "configs/vq_vae.yaml",
"chars": 441,
"preview": "model_params:\n name: 'VQVAE'\n in_channels: 3\n embedding_dim: 64\n num_embeddings: 512\n img_size: 64\n beta: 0.25\n\nda"
},
{
"path": "configs/wae_mmd_imq.yaml",
"chars": 449,
"preview": "model_params:\n name: 'WAE_MMD'\n in_channels: 3\n latent_dim: 128\n reg_weight: 100\n kernel_type: 'imq'\n\ndata_params:\n"
},
{
"path": "configs/wae_mmd_rbf.yaml",
"chars": 450,
"preview": "model_params:\n name: 'WAE_MMD'\n in_channels: 3\n latent_dim: 128\n reg_weight: 5000\n kernel_type: 'rbf'\n\ndata_params:"
},
{
"path": "dataset.py",
"chars": 6315,
"preview": "import os\nimport torch\nfrom torch import Tensor\nfrom pathlib import Path\nfrom typing import List, Optional, Sequence, Un"
},
{
"path": "experiment.py",
"chars": 4997,
"preview": "import os\nimport math\nimport torch\nfrom torch import optim\nfrom models import BaseVAE\nfrom models.types_ import *\nfrom u"
},
{
"path": "models/__init__.py",
"chars": 1356,
"preview": "from .base import *\nfrom .vanilla_vae import *\nfrom .gamma_vae import *\nfrom .beta_vae import *\nfrom .wae_mmd import *\nf"
},
{
"path": "models/base.py",
"chars": 733,
"preview": "from .types_ import *\nfrom torch import nn\nfrom abc import abstractmethod\n\nclass BaseVAE(nn.Module):\n \n def __init"
},
{
"path": "models/beta_vae.py",
"chars": 6242,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/betatc_vae.py",
"chars": 8558,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/cat_vae.py",
"chars": 7531,
"preview": "import torch\nimport numpy as np\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfro"
},
{
"path": "models/cvae.py",
"chars": 6079,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/dfcvae.py",
"chars": 7315,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torchvision.models import vgg19_bn\nfrom torch.nn impor"
},
{
"path": "models/dip_vae.py",
"chars": 6597,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/fvae.py",
"chars": 8251,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/gamma_vae.py",
"chars": 8650,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.distributions import Gamma\nfrom torch.nn import "
},
{
"path": "models/hvae.py",
"chars": 9396,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/info_vae.py",
"chars": 8538,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/iwae.py",
"chars": 6694,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/joint_vae.py",
"chars": 9837,
"preview": "import torch\nimport numpy as np\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfro"
},
{
"path": "models/logcosh_vae.py",
"chars": 6292,
"preview": "import torch\nimport torch.nn.functional as F\nfrom models import BaseVAE\nfrom torch import nn\nfrom .types_ import *\n\n\ncla"
},
{
"path": "models/lvae.py",
"chars": 9666,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/miwae.py",
"chars": 6969,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/mssim_vae.py",
"chars": 9644,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/swae.py",
"chars": 7340,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch import dist"
},
{
"path": "models/twostage_vae.py",
"chars": 6867,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/types_.py",
"chars": 133,
"preview": "from typing import List, Callable, Union, Any, TypeVar, Tuple\n# from torch import tensor as Tensor\n\nTensor = TypeVar('to"
},
{
"path": "models/vampvae.py",
"chars": 6760,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/vanilla_vae.py",
"chars": 5757,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/vq_vae.py",
"chars": 7576,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "models/wae_mmd.py",
"chars": 7427,
"preview": "import torch\nfrom models import BaseVAE\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom .types_ import *\n"
},
{
"path": "requirements.txt",
"chars": 108,
"preview": "pytorch-lightning==1.5.6\nPyYAML==6.0\ntensorboard>=2.2.0\ntorch>=1.6.1\ntorchsummary==1.5.1\ntorchvision>=0.10.1"
},
{
"path": "run.py",
"chars": 2216,
"preview": "import os\nimport yaml\nimport argparse\nimport numpy as np\nfrom pathlib import Path\nfrom models import *\nfrom experiment i"
},
{
"path": "tests/bvae.py",
"chars": 846,
"preview": "import torch\nimport unittest\nfrom models import BetaVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestCa"
},
{
"path": "tests/test_betatcvae.py",
"chars": 1173,
"preview": "import torch\nimport unittest\nfrom models import BetaTCVAE\nfrom torchsummary import summary\n\n\nclass TestBetaTCVAE(unittes"
},
{
"path": "tests/test_cat_vae.py",
"chars": 951,
"preview": "import torch\nimport unittest\nfrom models import GumbelVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.Test"
},
{
"path": "tests/test_dfc.py",
"chars": 914,
"preview": "import torch\nimport unittest\nfrom models import DFCVAE\nfrom torchsummary import summary\n\n\nclass TestDFCVAE(unittest.Test"
},
{
"path": "tests/test_dipvae.py",
"chars": 1145,
"preview": "import torch\nimport unittest\nfrom models import DIPVAE\nfrom torchsummary import summary\n\n\nclass TestDIPVAE(unittest.Test"
},
{
"path": "tests/test_fvae.py",
"chars": 1368,
"preview": "import torch\nimport unittest\nfrom models import FactorVAE\nfrom torchsummary import summary\n\n\nclass TestFAE(unittest.Test"
},
{
"path": "tests/test_gvae.py",
"chars": 920,
"preview": "import torch\nimport unittest\nfrom models import GammaVAE\nfrom torchsummary import summary\n\n\nclass TestGammaVAE(unittest."
},
{
"path": "tests/test_hvae.py",
"chars": 840,
"preview": "import torch\nimport unittest\nfrom models import HVAE\nfrom torchsummary import summary\n\n\nclass TestHVAE(unittest.TestCase"
},
{
"path": "tests/test_iwae.py",
"chars": 904,
"preview": "import torch\nimport unittest\nfrom models import IWAE\nfrom torchsummary import summary\n\n\nclass TestIWAE(unittest.TestCase"
},
{
"path": "tests/test_joint_Vae.py",
"chars": 958,
"preview": "import torch\nimport unittest\nfrom models import JointVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.TestC"
},
{
"path": "tests/test_logcosh.py",
"chars": 832,
"preview": "import torch\nimport unittest\nfrom models import LogCoshVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.Tes"
},
{
"path": "tests/test_lvae.py",
"chars": 977,
"preview": "import torch\nimport unittest\nfrom models import LVAE\nfrom torchsummary import summary\n\n\nclass TestLVAE(unittest.TestCase"
},
{
"path": "tests/test_miwae.py",
"chars": 1057,
"preview": "import torch\nimport unittest\nfrom models import MIWAE\nfrom torchsummary import summary\n\n\nclass TestMIWAE(unittest.TestCa"
},
{
"path": "tests/test_mssimvae.py",
"chars": 920,
"preview": "import torch\nimport unittest\nfrom models import MSSIMVAE\nfrom torchsummary import summary\n\n\nclass TestMSSIMVAE(unittest."
},
{
"path": "tests/test_swae.py",
"chars": 782,
"preview": "import torch\nimport unittest\nfrom models import SWAE\nfrom torchsummary import summary\n\n\nclass TestSWAE(unittest.TestCase"
},
{
"path": "tests/test_vae.py",
"chars": 823,
"preview": "import torch\nimport unittest\nfrom models import VanillaVAE\nfrom torchsummary import summary\n\n\nclass TestVAE(unittest.Tes"
},
{
"path": "tests/test_vq_vae.py",
"chars": 1147,
"preview": "import torch\nimport unittest\nfrom models import VQVAE\nfrom torchsummary import summary\n\n\nclass TestVQVAE(unittest.TestCa"
},
{
"path": "tests/test_wae.py",
"chars": 787,
"preview": "import torch\nimport unittest\nfrom models import WAE_MMD\nfrom torchsummary import summary\n\n\nclass TestWAE(unittest.TestCa"
},
{
"path": "tests/text_cvae.py",
"chars": 705,
"preview": "import torch\nimport unittest\nfrom models import CVAE\n\n\nclass TestCVAE(unittest.TestCase):\n\n def setUp(self) -> None:\n"
},
{
"path": "tests/text_vamp.py",
"chars": 844,
"preview": "import torch\nimport unittest\nfrom models import VampVAE\nfrom torchsummary import summary\n\n\nclass TestVVAE(unittest.TestC"
},
{
"path": "utils.py",
"chars": 622,
"preview": "import pytorch_lightning as pl\n\n\n## Utils to handle newer PyTorch Lightning changes from version 0.6\n## ================"
}
]
About this extraction
This page contains the full source code of the AntixK/PyTorch-VAE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 82 files (236.1 KB), approximately 61.2k tokens, and a symbol index with 386 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.