Repository: HJ-harry/DiffusionMBIR Branch: main Commit: bdfe460582ba Files: 97 Total size: 380.5 KB Directory structure: gitextract_en_fo40u/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── default_celeba_configs.py │ ├── default_cifar10_configs.py │ ├── default_complex_configs.py │ ├── default_lsun_configs.py │ ├── subvp/ │ │ ├── cifar10_ddpm_continuous.py │ │ ├── cifar10_ddpmpp_continuous.py │ │ ├── cifar10_ddpmpp_deep_continuous.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ └── cifar10_ncsnpp_deep_continuous.py │ ├── ve/ │ │ ├── AAPM_128_ncsnpp_continuous.py │ │ ├── AAPM_256_ncsnpp_continuous.py │ │ ├── Object5_fast.py │ │ ├── Object5_ncsnpp_continuous.py │ │ ├── bedroom_ncsnpp_continuous.py │ │ ├── celeba_ncsnpp.py │ │ ├── celebahq_256_ncsnpp_continuous.py │ │ ├── celebahq_ncsnpp_continuous.py │ │ ├── church_ncsnpp_continuous.py │ │ ├── cifar10_ddpm.py │ │ ├── cifar10_ncsnpp.py │ │ ├── cifar10_ncsnpp_continuous.py │ │ ├── cifar10_ncsnpp_deep_continuous.py │ │ ├── fastmri_knee_128_ncsnpp_continuous.py │ │ ├── fastmri_knee_256_ncsnpp_continuous.py │ │ ├── fastmri_knee_320_ncsnpp_continuous.py │ │ ├── fastmri_knee_320_ncsnpp_continuous_complex.py │ │ ├── fastmri_knee_320_ncsnpp_continuous_complex_magpha.py │ │ ├── fastmri_knee_320_ncsnpp_continuous_multi.py │ │ ├── ffhq_256_ncsnpp_continuous.py │ │ ├── ffhq_ncsnpp_continuous.py │ │ ├── ncsn/ │ │ │ ├── celeba.py │ │ │ ├── celeba_124.py │ │ │ ├── celeba_1245.py │ │ │ ├── celeba_5.py │ │ │ ├── cifar10.py │ │ │ ├── cifar10_124.py │ │ │ ├── cifar10_1245.py │ │ │ └── cifar10_5.py │ │ └── ncsnv2/ │ │ ├── bedroom.py │ │ ├── celeba.py │ │ └── cifar10.py │ └── vp/ │ ├── cifar10_ddpmpp.py │ ├── cifar10_ddpmpp_continuous.py │ ├── cifar10_ddpmpp_deep_continuous.py │ ├── cifar10_ncsnpp.py │ ├── cifar10_ncsnpp_continuous.py │ ├── cifar10_ncsnpp_deep_continuous.py │ └── ddpm/ │ ├── bedroom.py │ ├── celebahq.py │ ├── church.py │ ├── cifar10.py │ ├── cifar10_continuous.py │ └── cifar10_unconditional.py ├── controllable_generation_TV.py ├── datasets.py ├── environment.yml ├── evaluation.py ├── fastmri_utils.py ├── inverse_problem_solver_AAPM_3d_total.py ├── inverse_problem_solver_BRATS_MRI_3d_total.py ├── likelihood.py ├── losses.py ├── main.py ├── models/ │ ├── __init__.py │ ├── ddpm.py │ ├── ema.py │ ├── layers.py │ ├── layerspp.py │ ├── ncsnpp.py │ ├── ncsnv2.py │ ├── normalization.py │ ├── unet.py │ ├── up_or_down_sampling.py │ └── utils.py ├── op/ │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── physics/ │ ├── ct.py │ ├── inpainting.py │ └── radon/ │ ├── __init__.py │ ├── filters.py │ ├── radon.py │ ├── stackgram.py │ └── utils.py ├── run_lib.py ├── sampling.py ├── sde_lib.py ├── test/ │ └── test_TV.py ├── train_AAPM256.sh └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Compiled source # ################### *.o *.so *.pyc # Logs and temporaries # ######################## *.log *~ .coverage # Folders # ########### build/ dist/ *.egg-info/ __pycache__/ .eggs/ data/ exp/ results/ results_AAPM/ results_AAPM_tv/ workdir/ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models (CVPR 2023) Official PyTorch implementation of **DiffusionMBIR**, the CVPR 2023 paper "[Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models](https://arxiv.org/abs/2211.10655)". Code modified from [score_sde_pytorch](https://github.com/yang-song/score_sde_pytorch). ✅ If you would like to use an updated, faster version of DiffusionMBIR, you might want to use [DDS](https://github.com/hyungjin-chung/DDS) [![arXiv](https://img.shields.io/badge/arXiv-2211.10655-green)](https://arxiv.org/abs/2211.10655) [![arXiv](https://img.shields.io/badge/paper-CVPR2023-blue)](https://arxiv.org/abs/2211.10655) ![concept](./figs/forward_model.jpg) ![concept](./figs/cover_result.jpg) ## Getting started ### Download pre-trained model weights * **CT** experiments: [weights](https://drive.google.com/file/d/1-TaLbg3-4gLwKH2-Qf5VBFCBLG3RjY9j/view) ### Download the data * **CT** experiments (in-distribution) ```bash DATA_DIR=./data/CT/ind/256_sorted mkdir -p "$DATA_DIR" wget -O "$DATA_DIR"/256_sorted.zip https://www.dropbox.com/sh/ibjpgo5seksjera/AADlhYqCWq5C4K0uWSrCL_JUa?dl=1 unzip -d "$DATA_DIR"/ "$DATA_DIR"/256_sorted.zip ``` * **CT** experiments (out-of-distribution) ```bash DATA_DIR=./data/CT/ood/256_sorted mkdir -p "$DATA_DIR" wget -O "$DATA_DIR"/slice.zip https://www.dropbox.com/s/h3drrlx0pvutyoi/slice.zip?dl=0 unzip -d "$DATA_DIR"/ "$DATA_DIR"/slice.zip ``` * Make a conda environment and install dependencies ```bash conda env create --file environment.yml ``` ## DiffusionMBIR (fast) reconstruction Once you have the pre-trained weights and the test data set up properly, you may run the following scripts. Modify the parameters in the python scripts directly to change experimental settings. ```bash conda activate diffusion-mbir python inverse_problem_solver_AAPM_3d_total.py python inverse_problem_solver_BRATS_MRI_3d_total.py ``` ## Training You may train the diffusion model with your own data by using e.g. ```bash bash train_AAPM256.sh ``` You can modify the training config with the ```--config``` flag. ## Citation If you find our work interesting, please consider citing ``` @InProceedings{chung2023solving, title={Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models}, author={Chung, Hyungjin and Ryu, Dohoon and McCann, Michael T and Klasky, Marc L and Ye, Jong Chul}, journal={IEEE/CVF Conference on Computer Vision and Pattern Recognition}, year={2023} } ``` ================================================ FILE: configs/default_celeba_configs.py ================================================ import ml_collections import torch def get_default_configs(): config = ml_collections.ConfigDict() # training config.training = training = ml_collections.ConfigDict() # config.training.batch_size = 128 config.training.batch_size = 64 training.n_iters = 1300001 training.snapshot_freq = 50000 training.log_freq = 50 training.eval_freq = 100 ## store additional checkpoints for preemption in cloud computing environments training.snapshot_freq_for_preemption = 10000 ## produce samples at each snapshot. training.snapshot_sampling = True training.likelihood_weighting = False training.continuous = True training.reduce_mean = False # sampling config.sampling = sampling = ml_collections.ConfigDict() sampling.n_steps_each = 1 sampling.noise_removal = True sampling.probability_flow = False sampling.snr = 0.17 # evaluation config.eval = evaluate = ml_collections.ConfigDict() evaluate.begin_ckpt = 1 evaluate.end_ckpt = 26 evaluate.batch_size = 1024 evaluate.enable_sampling = True evaluate.num_samples = 50000 evaluate.enable_loss = True evaluate.enable_bpd = False evaluate.bpd_dataset = 'test' # data config.data = data = ml_collections.ConfigDict() data.dataset = 'CELEBA' data.image_size = 64 data.random_flip = True data.uniform_dequantization = False data.centered = False data.num_channels = 3 # model config.model = model = ml_collections.ConfigDict() model.sigma_max = 90. model.sigma_min = 0.01 model.num_scales = 1000 model.beta_min = 0.1 model.beta_max = 20. model.dropout = 0.1 model.embedding_type = 'fourier' # optimization config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 2e-4 optim.beta1 = 0.9 optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: configs/default_cifar10_configs.py ================================================ import ml_collections import torch def get_default_configs(): config = ml_collections.ConfigDict() # training config.training = training = ml_collections.ConfigDict() # config.training.batch_size = 128 config.training.batch_size = 4 training.n_iters = 1300001 training.snapshot_freq = 50000 training.log_freq = 50 training.eval_freq = 100 ## store additional checkpoints for preemption in cloud computing environments training.snapshot_freq_for_preemption = 10000 ## produce samples at each snapshot. training.snapshot_sampling = True training.likelihood_weighting = False training.continuous = True training.reduce_mean = False # sampling config.sampling = sampling = ml_collections.ConfigDict() sampling.n_steps_each = 1 sampling.noise_removal = True sampling.probability_flow = False sampling.snr = 0.16 # evaluation config.eval = evaluate = ml_collections.ConfigDict() evaluate.begin_ckpt = 9 evaluate.end_ckpt = 26 evaluate.batch_size = 1024 evaluate.enable_sampling = False evaluate.num_samples = 50000 evaluate.enable_loss = True evaluate.enable_bpd = False evaluate.bpd_dataset = 'test' # data config.data = data = ml_collections.ConfigDict() data.dataset = 'CIFAR10' data.image_size = 32 data.random_flip = True data.centered = False data.uniform_dequantization = False data.num_channels = 3 # data.num_channels = 1 # model config.model = model = ml_collections.ConfigDict() model.sigma_min = 0.01 model.sigma_max = 50 model.num_scales = 1000 model.beta_min = 0.1 model.beta_max = 20. model.dropout = 0.1 model.embedding_type = 'fourier' # optimization config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 2e-4 optim.beta1 = 0.9 optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: configs/default_complex_configs.py ================================================ import ml_collections import torch def get_default_configs(): config = ml_collections.ConfigDict() # training config.training = training = ml_collections.ConfigDict() # config.training.batch_size = 64 # config.training.batch_size = 2 # seriously? config.training.batch_size = 1 # When using single GPU # training.n_iters = 2400001 training.epochs = 100 training.snapshot_freq = 50000 # training.log_freq = 50 training.log_freq = 25 training.eval_freq = 100 ## store additional checkpoints for preemption in cloud computing environments training.snapshot_freq_for_preemption = 5000 ## produce samples at each snapshot. training.snapshot_sampling = True training.likelihood_weighting = False training.continuous = True training.reduce_mean = False # sampling config.sampling = sampling = ml_collections.ConfigDict() sampling.n_steps_each = 1 sampling.noise_removal = True sampling.probability_flow = False sampling.snr = 0.075 # evaluation config.eval = evaluate = ml_collections.ConfigDict() evaluate.begin_ckpt = 50 evaluate.end_ckpt = 96 # evaluate.batch_size = 512 evaluate.batch_size = 8 evaluate.enable_sampling = True evaluate.num_samples = 50000 evaluate.enable_loss = True evaluate.enable_bpd = False evaluate.bpd_dataset = 'test' # data config.data = data = ml_collections.ConfigDict() # data.dataset = 'LSUN' data.image_size = 320 data.random_flip = True data.uniform_dequantization = False data.centered = False data.num_channels = 2 # model config.model = model = ml_collections.ConfigDict() model.sigma_max = 378 model.sigma_min = 0.01 model.num_scales = 2000 model.beta_min = 0.1 model.beta_max = 20. model.dropout = 0. model.embedding_type = 'fourier' # optimization config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 2e-4 optim.beta1 = 0.9 optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: configs/default_lsun_configs.py ================================================ import ml_collections import torch def get_default_configs(): config = ml_collections.ConfigDict() # training config.training = training = ml_collections.ConfigDict() # config.training.batch_size = 64 # config.training.batch_size = 2 # seriously? config.training.batch_size = 1 # When using single GPU # training.n_iters = 2400001 training.epochs = 1000 training.snapshot_freq = 50000 # training.log_freq = 50 training.log_freq = 25 training.eval_freq = 100 ## store additional checkpoints for preemption in cloud computing environments training.snapshot_freq_for_preemption = 5000 ## produce samples at each snapshot. training.snapshot_sampling = True training.likelihood_weighting = False training.continuous = True training.reduce_mean = False # sampling config.sampling = sampling = ml_collections.ConfigDict() sampling.n_steps_each = 1 sampling.noise_removal = True sampling.probability_flow = False sampling.snr = 0.075 # evaluation config.eval = evaluate = ml_collections.ConfigDict() evaluate.begin_ckpt = 50 evaluate.end_ckpt = 96 # evaluate.batch_size = 512 evaluate.batch_size = 8 evaluate.enable_sampling = True evaluate.num_samples = 50000 evaluate.enable_loss = True evaluate.enable_bpd = False evaluate.bpd_dataset = 'test' # data config.data = data = ml_collections.ConfigDict() data.dataset = 'LSUN' data.image_size = 256 data.random_flip = True data.uniform_dequantization = False data.centered = False # data.num_channels = 3 data.num_channels = 1 # model config.model = model = ml_collections.ConfigDict() model.sigma_max = 378 model.sigma_min = 0.01 model.num_scales = 2000 model.beta_min = 0.1 model.beta_max = 20. model.dropout = 0. model.embedding_type = 'fourier' # optimization config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 2e-4 optim.beta1 = 0.9 optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: configs/subvp/cifar10_ddpm_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training DDPM with sub-VP SDE.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'subvpsde' training.continuous = True training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ddpm' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True return config ================================================ FILE: configs/subvp/cifar10_ddpmpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'subvpsde' training.continuous = True training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = False model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'none' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.embedding_type = 'positional' model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/subvp/cifar10_ddpmpp_deep_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'subvpsde' training.continuous = True training.reduce_mean = True training.n_iters = 950001 # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 8 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = False model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'none' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.embedding_type = 'positional' model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/subvp/cifar10_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10 with sub-VP SDE.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'subvpsde' training.continuous = True training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.embedding_type = 'positional' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/subvp/cifar10_ncsnpp_deep_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'subvpsde' training.continuous = True training.n_iters = 950001 training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.fourier_scale = 16 model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 8 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.embedding_type = 'positional' model.init_scale = 0.0 model.conv_size = 3 return config ================================================ FILE: configs/ve/AAPM_128_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on fastmri knee with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'aapm' data.root = '/media/harry/tomo/AAPM_data/128' data.is_complex = False data.is_multi = False data.image_size = 128 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/AAPM_256_ncsnpp_continuous.py ================================================ from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'AAPM' data.root = '/media/harry/tomo/AAPM_data/256' data.is_complex = False data.is_multi = False data.image_size = 256 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/Object5_fast.py ================================================ from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True training.epochs = 3 # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'Object5Fast' data.root = './data/Object5/' data.is_complex = False data.is_multi = False data.image_size = 256 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 model.num_scales = 3 # number of sampling steps return config ================================================ FILE: configs/ve/Object5_ncsnpp_continuous.py ================================================ from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'Object5' data.root = './data/Object5/' data.is_complex = False data.is_multi = False data.image_size = 256 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/bedroom_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on bedroom with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.category = 'bedroom' # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 1, 2, 2, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'output_skip' model.progressive_input = 'input_skip' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/celeba_ncsnpp.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CelebA with SMLD.""" from configs.default_celeba_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.sigma_begin = 90 model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0.0 model.conv_size = 3 model.embedding_type = 'positional' return config ================================================ FILE: configs/ve/celebahq_256_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on Church with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'CelebAHQ' data.image_size = 256 data.tfrecords_path = '/home/yangsong/ncsc/celebahq/r08.tfrecords' # model model = config.model model.name = 'ncsnpp' model.sigma_max = 348 model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 1, 2, 2, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'output_skip' model.progressive_input = 'input_skip' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/celebahq_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CelebAHQ with VE SDE.""" import ml_collections import torch def get_config(): config = ml_collections.ConfigDict() # training config.training = training = ml_collections.ConfigDict() training.batch_size = 8 training.n_iters = 2400001 training.snapshot_freq = 50000 training.log_freq = 50 training.eval_freq = 100 training.snapshot_freq_for_preemption = 5000 training.snapshot_sampling = True training.sde = 'vesde' training.continuous = True training.likelihood_weighting = False training.reduce_mean = False # sampling config.sampling = sampling = ml_collections.ConfigDict() sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' sampling.probability_flow = False sampling.snr = 0.15 sampling.n_steps_each = 1 sampling.noise_removal = True # eval config.eval = evaluate = ml_collections.ConfigDict() evaluate.batch_size = 1024 evaluate.num_samples = 50000 evaluate.begin_ckpt = 1 evaluate.end_ckpt = 96 # data config.data = data = ml_collections.ConfigDict() data.dataset = 'CelebAHQ' data.image_size = 1024 data.centered = False data.random_flip = True data.uniform_dequantization = False data.num_channels = 3 data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords' # model config.model = model = ml_collections.ConfigDict() model.name = 'ncsnpp' model.scale_by_sigma = True model.sigma_max = 1348 model.num_scales = 2000 model.ema_rate = 0.9999 model.sigma_min = 0.01 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 16 model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) model.num_res_blocks = 1 model.attn_resolutions = (16,) model.dropout = 0. model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'output_skip' model.progressive_input = 'input_skip' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 model.embedding_type = 'fourier' # optim config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 2e-4 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: configs/ve/church_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on Church with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.category = 'church_outdoor' # model model = config.model model.name = 'ncsnpp' model.sigma_max = 380 model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 1, 2, 2, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'output_skip' model.progressive_input = 'input_skip' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/cifar10_ddpm.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Train the original DDPM model with SMLD.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # model model = config.model model.name = 'ddpm' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.conv_size = 3 return config ================================================ FILE: configs/ve/cifar10_ncsnpp.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10 with SMLD.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0.0 model.embedding_type = 'positional' model.conv_size = 3 return config ================================================ FILE: configs/ve/cifar10_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10 with VE SDE.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/cifar10_ncsnpp_deep_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10 with VE SDE.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True training.n_iters = 950001 # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # model model = config.model model.name = 'ncsnpp' model.fourier_scale = 16 model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 8 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0.0 model.conv_size = 3 return config ================================================ FILE: configs/ve/fastmri_knee_128_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on fastmri knee with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # training (regression) training.mask_type = 'gaussian2d' training.acc_factor = [8, 15] # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'fastmri_knee' data.root = '/media/harry/tomo/fastmri' data.is_complex = False data.is_multi = False data.image_size = 128 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/fastmri_knee_256_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on fastmri knee with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'fastmri_knee' data.root = '/media/harry/tomo/fastmri' data.image_size = 256 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on fastmri knee with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'fastmri_knee' data.root = '/media/harry/tomo/fastmri' data.image_size = 320 data.is_multi = False data.is_complex = False # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_complex.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on fastmri knee with VE SDE.""" from configs.default_complex_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'fastmri_knee' data.is_multi = False data.is_complex = True data.root = '/media/harry/tomo/fastmri' data.image_size = 320 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_complex_magpha.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on fastmri knee with VE SDE.""" from configs.default_complex_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'fastmri_knee' data.is_multi = False data.is_complex = True data.magpha = True data.root = '/media/harry/tomo/fastmri' data.image_size = 320 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_multi.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on fastmri knee with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'fastmri_knee' data.is_complex = False data.is_multi = True data.root = '/media/harry/tomo/fastmri' data.image_size = 320 # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/ffhq_256_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on FFHQ with VE SDE.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' # data data = config.data data.dataset = 'FFHQ' data.image_size = 256 data.tfrecords_path = '/media/harry/ExtDrive/PycharmProjects/score_sde_pytorch/dataset/FFHQ/ffhq-r08.tfrecords' # model model = config.model model.name = 'ncsnpp' model.sigma_max = 348 model.scale_by_sigma = True model.ema_rate = 0.999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 1, 2, 2, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'output_skip' model.progressive_input = 'input_skip' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/ve/ffhq_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on FFHQ with VE SDEs.""" import ml_collections import torch def get_config(): config = ml_collections.ConfigDict() # training config.training = training = ml_collections.ConfigDict() training.batch_size = 8 training.n_iters = 2400001 training.snapshot_freq = 50000 training.log_freq = 50 training.eval_freq = 100 training.snapshot_freq_for_preemption = 5000 training.snapshot_sampling = True training.sde = 'vesde' training.continuous = True training.likelihood_weighting = False training.reduce_mean = True # sampling config.sampling = sampling = ml_collections.ConfigDict() sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'langevin' sampling.probability_flow = False sampling.snr = 0.15 sampling.n_steps_each = 1 sampling.noise_removal = True # eval config.eval = evaluate = ml_collections.ConfigDict() evaluate.batch_size = 1024 evaluate.num_samples = 50000 evaluate.begin_ckpt = 1 evaluate.end_ckpt = 96 # data config.data = data = ml_collections.ConfigDict() data.dataset = 'FFHQ' data.image_size = 1024 data.centered = False data.random_flip = True data.uniform_dequantization = False data.num_channels = 3 # Plug in your own path to the tfrecords file. data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords' # model config.model = model = ml_collections.ConfigDict() model.name = 'ncsnpp' model.scale_by_sigma = True model.sigma_max = 1348 model.num_scales = 2000 model.ema_rate = 0.9999 model.sigma_min = 0.01 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 16 model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32) model.num_res_blocks = 1 model.attn_resolutions = (16,) model.dropout = 0. model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'output_skip' model.progressive_input = 'input_skip' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 model.embedding_type = 'fourier' # optim config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 2e-4 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: configs/ve/ncsn/celeba.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for reproducing NCSNv1 on CelebA.""" from configs.default_celeba_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.loss = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 100 sampling.snr = 0.316 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.sigma_max = 1 model.num_scales = 10 model.ema_rate = 0. model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsn/celeba_124.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSN with technique 1,2,4 only.""" from configs.default_celeba_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 5 sampling.snr = 0.128 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.num_scales = 500 model.ema_rate = 0. model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsn/celeba_1245.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSN with technique 1245 only.""" from configs.default_celeba_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 5 sampling.snr = 0.128 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.num_scales = 500 model.ema_rate = 0.999 model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsn/celeba_5.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSNv1 model with technique 5 only.""" from configs.default_celeba_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 100 sampling.snr = 0.316 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.sigma_max = 1. model.num_scales = 10 model.ema_rate = 0.999 model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsn/cifar10.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for reproducing NCSNv1 on CIFAR-10.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 100 sampling.snr = 0.316 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.sigma_max = 1 model.num_scales = 10 model.ema_rate = 0. model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsn/cifar10_124.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSN with technique 1,2,4 only.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 5 sampling.snr = 0.176 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.num_scales = 232 model.ema_rate = 0. model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsn/cifar10_1245.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSN with technique 1,2,4,5 only.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # shared configs for sample generation step_size = 0.0000062 n_steps_each = 5 ckpt_id = 300000 final_only = True noise_removal = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 5 sampling.snr = 0.176 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.num_scales = 232 model.ema_rate = 0.999 model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsn/cifar10_5.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSN with technique 5 only.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.snr = 0.316 sampling.n_steps_each = 100 # model model = config.model model.name = 'ncsn' model.scale_by_sigma = False model.sigma_max = 1 model.num_scales = 10 model.ema_rate = 0.999 model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-3 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsnv2/bedroom.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSNv2 on bedroom.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.batch_size = 128 training.sde = 'vesde' training.continuouse = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 3 sampling.snr = 0.095 # data data = config.data data.category = 'bedroom' data.image_size = 128 # model model = config.model model.name = 'ncsnv2_128' model.scale_by_sigma = True model.sigma_max = 190 model.num_scales = 1086 model.ema_rate = 0.9999 model.sigma_min = 0.01 model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-4 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1 return config ================================================ FILE: configs/ve/ncsnv2/celeba.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSNv2 on CelebA.""" from configs.default_celeba_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # shared configs for sample generation step_size = 0.0000033 n_steps_each = 5 ckpt_id = 210000 final_only = True noise_removal = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 5 sampling.snr = 0.128 # model model = config.model model.name = 'ncsnv2_64' model.scale_by_sigma = True model.num_scales = 500 model.ema_rate = 0.999 model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-4 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/ve/ncsnv2/cifar10.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for training NCSNv2 on CIFAR-10.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vesde' training.continuous = False # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'none' sampling.corrector = 'ald' sampling.n_steps_each = 5 sampling.snr = 0.176 # model model = config.model model.name = 'ncsnv2_64' model.scale_by_sigma = True model.num_scales = 232 model.ema_rate = 0.999 model.normalization = 'InstanceNorm++' model.nonlinearity = 'elu' model.nf = 128 model.interpolation = 'bilinear' # optim optim = config.optim optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 1e-4 optim.beta1 = 0.9 optim.amsgrad = False optim.eps = 1e-8 optim.warmup = 0 optim.grad_clip = -1. return config ================================================ FILE: configs/vp/cifar10_ddpmpp.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = False training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = False model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'none' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.embedding_type = 'positional' model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/vp/cifar10_ddpmpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = True training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = False model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'none' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.embedding_type = 'positional' model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/vp/cifar10_ddpmpp_deep_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = True training.reduce_mean = True training.n_iters = 950001 # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 8 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = False model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'none' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0. model.embedding_type = 'positional' model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/vp/cifar10_ncsnpp.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10 with DDPM.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = False training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'reverse_diffusion' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.init_scale = 0.0 model.embedding_type = 'positional' model.conv_size = 3 return config ================================================ FILE: configs/vp/cifar10_ncsnpp_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10 with VP SDE.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = True training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 4 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.embedding_type = 'positional' model.init_scale = 0. model.fourier_scale = 16 model.conv_size = 3 return config ================================================ FILE: configs/vp/cifar10_ncsnpp_deep_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training NCSN++ on CIFAR-10.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = True training.n_iters = 950001 training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ncsnpp' model.fourier_scale = 16 model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 8 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True model.fir = True model.fir_kernel = [1, 3, 3, 1] model.skip_rescale = True model.resblock_type = 'biggan' model.progressive = 'none' model.progressive_input = 'residual' model.progressive_combine = 'sum' model.attention_type = 'ddpm' model.embedding_type = 'positional' model.init_scale = 0.0 model.conv_size = 3 return config ================================================ FILE: configs/vp/ddpm/bedroom.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for reproducing the results of DDPM on bedrooms.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = False training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' # data data = config.data data.category = 'bedroom' data.centered = True # model model = config.model model.name = 'ddpm' model.scale_by_sigma = False model.num_scales = 1000 model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 1, 2, 2, 4, 4) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True # optim optim = config.optim optim.lr = 2e-5 return config ================================================ FILE: configs/vp/ddpm/celebahq.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for reproducing the results of DDPM on bedrooms.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = False training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' # data data = config.data data.dataset = 'CelebAHQ' data.centered = True data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords' data.image_size = 256 # model model = config.model model.name = 'ddpm' model.scale_by_sigma = False model.num_scales = 1000 model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 1, 2, 2, 4, 4) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True # optim optim = config.optim optim.lr = 2e-5 return config ================================================ FILE: configs/vp/ddpm/church.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for reproducing the results of DDPM on church_outdoor.""" from configs.default_lsun_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = False training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' # data data = config.data data.category = 'church_outdoor' data.centered = True # model model = config.model model.name = 'ddpm' model.scale_by_sigma = False model.num_scales = 1000 model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 1, 2, 2, 4, 4) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True # optim optim = config.optim optim.lr = 2e-5 return config ================================================ FILE: configs/vp/ddpm/cifar10.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Config file for reproducing the results of DDPM on cifar-10.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = False training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ddpm' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True return config ================================================ FILE: configs/vp/ddpm/cifar10_continuous.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training DDPM with VP SDE.""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = True training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'euler_maruyama' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ddpm' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = True return config ================================================ FILE: configs/vp/ddpm/cifar10_unconditional.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Training DDPM on CIFAR-10 without explicitly conditioning on time steps. (NCSNv2 technique 3)""" from configs.default_cifar10_configs import get_default_configs def get_config(): config = get_default_configs() # training training = config.training training.sde = 'vpsde' training.continuous = False training.reduce_mean = True # sampling sampling = config.sampling sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' # data data = config.data data.centered = True # model model = config.model model.name = 'ddpm' model.scale_by_sigma = False model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.nonlinearity = 'swish' model.nf = 128 model.ch_mult = (1, 2, 2, 2) model.num_res_blocks = 2 model.attn_resolutions = (16,) model.resamp_with_conv = True model.conditional = False return config ================================================ FILE: controllable_generation_TV.py ================================================ import functools import time import torch from numpy.testing._private.utils import measure import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm from models import utils as mutils from sampling import NoneCorrector, NonePredictor, shared_corrector_update_fn, shared_predictor_update_fn from utils import fft2, ifft2, fft2_m, ifft2_m from physics.ct import * from utils import show_samples, show_samples_gray, clear, clear_color, batchfy class lambda_schedule: def __init__(self, total=2000): self.total = total def get_current_lambda(self, i): pass class lambda_schedule_linear(lambda_schedule): def __init__(self, start_lamb=1.0, end_lamb=0.0): super().__init__() self.start_lamb = start_lamb self.end_lamb = end_lamb def get_current_lambda(self, i): return self.start_lamb + (self.end_lamb - self.start_lamb) * (i / self.total) class lambda_schedule_const(lambda_schedule): def __init__(self, lamb=1.0): super().__init__() self.lamb = lamb def get_current_lambda(self, i): return self.lamb def _Dz(x): # Batch direction y = torch.zeros_like(x) y[:-1] = x[1:] y[-1] = x[0] return y - x def _DzT(x): # Batch direction y = torch.zeros_like(x) y[:-1] = x[1:] y[-1] = x[0] tempt = -(y-x) difft = tempt[:-1] y[1:] = difft y[0] = x[-1] - x[0] return y def _Dx(x): # Batch direction y = torch.zeros_like(x) y[:, :, :-1, :] = x[:, :, 1:, :] y[:, :, -1, :] = x[:, :, 0, :] return y - x def _DxT(x): # Batch direction y = torch.zeros_like(x) y[:, :, :-1, :] = x[:, :, 1:, :] y[:, :, -1, :] = x[:, :, 0, :] tempt = -(y - x) difft = tempt[:, :, :-1, :] y[:, :, 1:, :] = difft y[:, :, 0, :] = x[:, :, -1, :] - x[:, :, 0, :] return y def _Dy(x): # Batch direction y = torch.zeros_like(x) y[:, :, :, :-1] = x[:, :, :, 1:] y[:, :, :, -1] = x[:, :, :, 0] return y - x def _DyT(x): # Batch direction y = torch.zeros_like(x) y[:, :, :, :-1] = x[:, :, :, 1:] y[:, :, :, -1] = x[:, :, :, 0] tempt = -(y - x) difft = tempt[:, :, :, :-1] y[:, :, :, 1:] = difft y[:, :, :, 0] = x[:, :, :, -1] - x[:, :, :, 0] return y def get_pc_radon_ADMM_TV(sde, predictor, corrector, inverse_scaler, snr, n_steps=1, probability_flow=False, continuous=False, denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None, final_consistency=False, img_cache=None, img_shape=None, lamb_1=5, rho=10): """ Sparse application of measurement consistency """ # Define predictor & corrector predictor_update_fn = functools.partial(shared_predictor_update_fn, sde=sde, predictor=predictor, probability_flow=probability_flow, continuous=continuous) corrector_update_fn = functools.partial(shared_corrector_update_fn, sde=sde, corrector=corrector, continuous=continuous, snr=snr, n_steps=n_steps) if img_cache != None : img_shape[0] += 1 del_z = torch.zeros(img_shape) udel_z = torch.zeros(img_shape) eps = 1e-10 def _A(x): return radon.A(x) def _AT(sinogram): return radon.AT(sinogram) def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None, norm_const=None): x = x + lamb * _AT(measurement - _A(x))/norm_const x_mean = x return x, x_mean def A_cg(x): return _AT(_A(x)) + rho * _DzT(_Dz(x)) def CG(A_fn,b_cg,x,n_inner=10): r = b_cg - A_fn(x) p = r rs_old = torch.matmul(r.view(1,-1),r.view(1,-1).T) for i in range(n_inner): Ap = A_fn(p) a = rs_old/torch.matmul(p.view(1,-1),Ap.view(1,-1).T) x += a * p r -= a * Ap rs_new = torch.matmul(r.view(1,-1),r.view(1,-1).T) if torch.sqrt(rs_new) < eps : break p = r + (rs_new/rs_old) * p rs_old = rs_new return x def CS_routine(x,ATy, niter=20): if img_cache != None : x = torch.cat([img_cache,x],dim=0) idx = list(range(len(x),0,-1)) x = x[idx] nonlocal del_z, udel_z if del_z.device != x.device : del_z = del_z.to(x.device) udel_z = del_z.to(x.device) for i in range(niter): b_cg = ATy + rho * (_DzT(del_z)-_DzT(udel_z)) x = CG(A_cg, b_cg, x, n_inner=1) del_z = shrink(_Dz(x) + udel_z, lamb_1/rho) udel_z = _Dz(x) - del_z + udel_z if img_cache != None : x = x[idx] x = x[1:] del_z[-1] = 0 udel_z[-1] = 0 x_mean = x return x, x_mean def get_update_fn(update_fn): def radon_update_fn(model, data, x, t): with torch.no_grad(): vec_t = torch.ones(data.shape[0], device=data.device) * t x, x_mean = update_fn(x, vec_t, model=model) return x, x_mean return radon_update_fn def get_corrector_update_fn(update_fn): def radon_update_fn(model, data, x, t, measurement=None): with torch.no_grad(): vec_t = torch.ones(data.shape[0], device=data.device) * t x, x_mean = update_fn(x, vec_t, model=model) ATy = _AT(measurement) x, x_mean = CS_routine(x, ATy, niter=1) return x, x_mean return radon_update_fn predictor_denoise_update_fn = get_update_fn(predictor_update_fn) corrector_radon_update_fn = get_corrector_update_fn(corrector_update_fn) def pc_radon(model, data, measurement=None): with torch.no_grad(): x = sde.prior_sampling(data.shape).to(data.device) ones = torch.ones_like(x).to(data.device) norm_const = _AT(_A(ones)) timesteps = torch.linspace(sde.T, eps, sde.N) for i in tqdm(range(sde.N)): t = timesteps[i] x, x_mean = predictor_denoise_update_fn(model, data, x, t) x, x_mean = corrector_radon_update_fn(model, data, x, t, measurement=measurement) if save_progress: if (i % 50) == 0: print(f'iter: {i}/{sde.N}') plt.imsave(save_root / 'recon' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray') # Final step which coerces the data fidelity error term to be zero, # and thereby satisfying Ax = y if final_consistency: x, x_mean = kaczmarz(x, x_mean, measurement, lamb=1.0, norm_const=norm_const) return inverse_scaler(x_mean if denoise else x) return pc_radon def get_pc_radon_ADMM_TV_vol(sde, predictor, corrector, inverse_scaler, snr, n_steps=1, probability_flow=False, continuous=False, denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None, final_consistency=False, img_shape=None, lamb_1=5, rho=10): """ Sparse application of measurement consistency """ # Define predictor & corrector predictor_update_fn = functools.partial(shared_predictor_update_fn, sde=sde, predictor=predictor, probability_flow=probability_flow, continuous=continuous) corrector_update_fn = functools.partial(shared_corrector_update_fn, sde=sde, corrector=corrector, continuous=continuous, snr=snr, n_steps=n_steps) del_z = torch.zeros(img_shape) udel_z = torch.zeros(img_shape) eps = 1e-10 def _A(x): return radon.A(x) def _AT(sinogram): return radon.AT(sinogram) def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None, norm_const=None): x = x + lamb * _AT(measurement - _A(x)) / norm_const x_mean = x return x, x_mean def A_cg(x): return _AT(_A(x)) + rho * _DzT(_Dz(x)) def CG(A_fn, b_cg, x, n_inner=10): r = b_cg - A_fn(x) p = r rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T) for i in range(n_inner): Ap = A_fn(p) a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T) x += a * p r -= a * Ap rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T) if torch.sqrt(rs_new) < eps: break p = r + (rs_new / rs_old) * p rs_old = rs_new return x def CS_routine(x, ATy, niter=20): nonlocal del_z, udel_z if del_z.device != x.device: del_z = del_z.to(x.device) udel_z = del_z.to(x.device) for i in range(niter): b_cg = ATy + rho * (_DzT(del_z) - _DzT(udel_z)) x = CG(A_cg, b_cg, x, n_inner=1) del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho) udel_z = _Dz(x) - del_z + udel_z x_mean = x return x, x_mean def get_update_fn(update_fn): def radon_update_fn(model, data, x, t): with torch.no_grad(): vec_t = torch.ones(x.shape[0], device=x.device) * t x, x_mean = update_fn(x, vec_t, model=model) return x, x_mean return radon_update_fn def get_ADMM_TV_fn(): def ADMM_TV_fn(x, measurement=None): with torch.no_grad(): ATy = _AT(measurement) x, x_mean = CS_routine(x, ATy, niter=1) return x, x_mean return ADMM_TV_fn predictor_denoise_update_fn = get_update_fn(predictor_update_fn) corrector_denoise_update_fn = get_update_fn(corrector_update_fn) mc_update_fn = get_ADMM_TV_fn() def pc_radon(model, data, measurement=None): with torch.no_grad(): x = sde.prior_sampling(data.shape).to(data.device) ones = torch.ones_like(x).to(data.device) norm_const = _AT(_A(ones)) timesteps = torch.linspace(sde.T, eps, sde.N) for i in tqdm(range(sde.N)): t = timesteps[i] # 1. batchify into sizes that fit into the GPU x_batch = batchfy(x, 12) # 2. Run PC step for each batch x_agg = list() for idx, x_batch_sing in enumerate(x_batch): x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t) x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t) x_agg.append(x_batch_sing) # 3. Aggregate to run ADMM TV x = torch.cat(x_agg, dim=0) # 4. Run ADMM TV x, x_mean = mc_update_fn(x, measurement=measurement) if save_progress: if (i % 50) == 0: print(f'iter: {i}/{sde.N}') plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray') # Final step which coerces the data fidelity error term to be zero, # and thereby satisfying Ax = y if final_consistency: x, x_mean = kaczmarz(x, x, measurement, lamb=1.0, norm_const=norm_const) return inverse_scaler(x_mean if denoise else x) return pc_radon def get_pc_radon_ADMM_TV_all_vol(sde, predictor, corrector, inverse_scaler, snr, n_steps=1, probability_flow=False, continuous=False, denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None, final_consistency=False, img_shape=None, lamb_1=5, rho=10): """ Sparse application of measurement consistency """ # Define predictor & corrector predictor_update_fn = functools.partial(shared_predictor_update_fn, sde=sde, predictor=predictor, probability_flow=probability_flow, continuous=continuous) corrector_update_fn = functools.partial(shared_corrector_update_fn, sde=sde, corrector=corrector, continuous=continuous, snr=snr, n_steps=n_steps) del_x = torch.zeros(img_shape) del_y = torch.zeros(img_shape) del_z = torch.zeros(img_shape) udel_x = torch.zeros(img_shape) udel_y = torch.zeros(img_shape) udel_z = torch.zeros(img_shape) eps = 1e-10 def _A(x): return radon.A(x) def _AT(sinogram): return radon.AT(sinogram) def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None, norm_const=None): x = x + lamb * _AT(measurement - _A(x)) / norm_const x_mean = x return x, x_mean def A_cg(x): return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x))) def CG(A_fn, b_cg, x, n_inner=10): r = b_cg - A_fn(x) p = r rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T) for i in range(n_inner): Ap = A_fn(p) a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T) x += a * p r -= a * Ap rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T) if torch.sqrt(rs_new) < eps: break p = r + (rs_new / rs_old) * p rs_old = rs_new return x def CS_routine(x, ATy, niter=20): nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z if del_z.device != x.device: del_x = del_x.to(x.device) del_y = del_y.to(x.device) del_z = del_z.to(x.device) udel_x = udel_x.to(x.device) udel_y = udel_y.to(x.device) udel_z = udel_z.to(x.device) for i in range(niter): b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x)) + (_DyT(del_y) - _DyT(udel_y)) + (_DzT(del_z) - _DzT(udel_z))) x = CG(A_cg, b_cg, x, n_inner=1) del_x = shrink(_Dx(x) + udel_x, lamb_1 / rho) del_y = shrink(_Dy(x) + udel_y, lamb_1 / rho) del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho) udel_x = _Dx(x) - del_x + udel_x udel_y = _Dy(x) - del_y + udel_y udel_z = _Dz(x) - del_z + udel_z x_mean = x return x, x_mean def get_update_fn(update_fn): def radon_update_fn(model, data, x, t): with torch.no_grad(): vec_t = torch.ones(x.shape[0], device=x.device) * t x, x_mean = update_fn(x, vec_t, model=model) return x, x_mean return radon_update_fn def get_ADMM_TV_fn(): def ADMM_TV_fn(x, measurement=None): with torch.no_grad(): ATy = _AT(measurement) x, x_mean = CS_routine(x, ATy, niter=1) return x, x_mean return ADMM_TV_fn predictor_denoise_update_fn = get_update_fn(predictor_update_fn) corrector_denoise_update_fn = get_update_fn(corrector_update_fn) mc_update_fn = get_ADMM_TV_fn() def pc_radon(model, data, measurement=None): with torch.no_grad(): x = sde.prior_sampling(data.shape).to(data.device) ones = torch.ones_like(x).to(data.device) norm_const = _AT(_A(ones)) timesteps = torch.linspace(sde.T, eps, sde.N) for i in tqdm(range(sde.N)): t = timesteps[i] # 1. batchify into sizes that fit into the GPU x_batch = batchfy(x, 12) # 2. Run PC step for each batch x_agg = list() for idx, x_batch_sing in enumerate(x_batch): x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t) x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t) x_agg.append(x_batch_sing) # 3. Aggregate to run ADMM TV x = torch.cat(x_agg, dim=0) # 4. Run ADMM TV x, x_mean = mc_update_fn(x, measurement=measurement) if save_progress: if (i % 50) == 0: print(f'iter: {i}/{sde.N}') plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray') # Final step which coerces the data fidelity error term to be zero, # and thereby satisfying Ax = y if final_consistency: x, x_mean = kaczmarz(x, x, measurement, lamb=1.0, norm_const=norm_const) return inverse_scaler(x_mean if denoise else x) return pc_radon def get_ADMM_TV(eps=1e-5, radon=None, save_progress=False, save_root=None, img_shape=None, lamb_1=5, rho=10, outer_iter=30, inner_iter=20): del_x = torch.zeros(img_shape) del_y = torch.zeros(img_shape) del_z = torch.zeros(img_shape) udel_x = torch.zeros(img_shape) udel_y = torch.zeros(img_shape) udel_z = torch.zeros(img_shape) eps = 1e-10 def _A(x): return radon.A(x) def _AT(sinogram): return radon.AT(sinogram) def A_cg(x): return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x))) def CG(A_fn, b_cg, x, n_inner=20): r = b_cg - A_fn(x) p = r rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T) for i in range(n_inner): Ap = A_fn(p) a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T) x += a * p r -= a * Ap rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T) if torch.sqrt(rs_new) < eps: break p = r + (rs_new / rs_old) * p rs_old = rs_new return x def CS_routine(x, ATy, niter=30): nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z if del_z.device != x.device: del_x = del_x.to(x.device) del_y = del_y.to(x.device) del_z = del_z.to(x.device) udel_x = udel_x.to(x.device) udel_y = udel_y.to(x.device) udel_z = udel_z.to(x.device) for i in tqdm(range(niter)): b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x)) + (_DyT(del_y) - _DyT(udel_y)) + (_DzT(del_z) - _DzT(udel_z))) x = CG(A_cg, b_cg, x, n_inner=inner_iter) if save_progress: plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x[0:1]), cmap='gray') del_x = shrink(_Dx(x) + udel_x, lamb_1 / rho) del_y = shrink(_Dy(x) + udel_y, lamb_1 / rho) del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho) udel_x = _Dx(x) - del_x + udel_x udel_y = _Dy(x) - del_y + udel_y udel_z = _Dz(x) - del_z + udel_z return x def get_ADMM_TV_fn(): def ADMM_TV_fn(x, measurement=None): with torch.no_grad(): ATy = _AT(measurement) x, x_mean = CS_routine(x, ATy, niter=outer_iter) return x, x_mean return ADMM_TV_fn mc_update_fn = get_ADMM_TV_fn() def ADMM_TV(data, measurement=None): with torch.no_grad(): x = torch.zeros(data.shape).to(data.device) x = mc_update_fn(x, measurement=measurement) return x return ADMM_TV def get_ADMM_TV_isotropic(eps=1e-5, radon=None, save_progress=False, save_root=None, img_shape=None, lamb_1=5, rho=10, outer_iter=30, inner_iter=20): """ (get_ADMM_TV): implements anisotropic TV-ADMM In contrast, this function implements isotropic TV, which regularizes with |TV|_{1,2} """ del_x = torch.zeros(img_shape) del_y = torch.zeros(img_shape) del_z = torch.zeros(img_shape) udel_x = torch.zeros(img_shape) udel_y = torch.zeros(img_shape) udel_z = torch.zeros(img_shape) eps = 1e-10 def _A(x): return radon.A(x) def _AT(sinogram): return radon.AT(sinogram) def A_cg(x): return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x))) def CG(A_fn, b_cg, x, n_inner=20): r = b_cg - A_fn(x) p = r rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T) for i in range(n_inner): Ap = A_fn(p) a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T) x += a * p r -= a * Ap rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T) if torch.sqrt(rs_new) < eps: break p = r + (rs_new / rs_old) * p rs_old = rs_new return x def CS_routine(x, ATy, niter=30): nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z if del_z.device != x.device: del_x = del_x.to(x.device) del_y = del_y.to(x.device) del_z = del_z.to(x.device) udel_x = udel_x.to(x.device) udel_y = udel_y.to(x.device) udel_z = udel_z.to(x.device) for i in tqdm(range(niter)): b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x)) + (_DyT(del_y) - _DyT(udel_y)) + (_DzT(del_z) - _DzT(udel_z))) x = CG(A_cg, b_cg, x, n_inner=inner_iter) if save_progress: plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x[0:1]), cmap='gray') # Each of shape [448, 1, 256, 256] _Dxx = _Dx(x) _Dyx = _Dy(x) _Dzx = _Dz(x) # shape [448, 3, 256, 256]. dim=1 gradient dimension _Dxa = torch.cat((_Dxx, _Dyx, _Dzx), dim=1) udel_a = torch.cat((udel_x, udel_y, udel_z), dim=1) # prox del_a = prox_l21(_Dxa + udel_a, lamb_1 / rho, dim=1) # split del_x, del_y, del_z = torch.split(del_a, 1, dim=1) # del_x = prox_l21(_Dxx + udel_x, lamb_1 / rho, -2) # del_y = prox_l21(_Dyx + udel_y, lamb_1 / rho, -1) # del_z = prox_l21(_Dzx + udel_z, lamb_1 / rho, 0) udel_x = _Dxx - del_x + udel_x udel_y = _Dyx - del_y + udel_y udel_z = _Dzx - del_z + udel_z return x def get_ADMM_TV_fn(): def ADMM_TV_fn(x, measurement=None): with torch.no_grad(): ATy = _AT(measurement) x = CS_routine(x, ATy, niter=outer_iter) return x return ADMM_TV_fn mc_update_fn = get_ADMM_TV_fn() def ADMM_TV(data, measurement=None): with torch.no_grad(): x = torch.zeros(data.shape).to(data.device) x = mc_update_fn(x, measurement=measurement) return x return ADMM_TV def prox_l21(src, lamb, dim): """ src.shape = [448(z), 1, 256(x), 256(y)] """ weight_src = torch.linalg.norm(src, dim=dim, keepdim=True) weight_src_shrink = shrink(weight_src, lamb) weight = weight_src_shrink / weight_src return src * weight def shrink(weight_src, lamb): return torch.sign(weight_src) * torch.max(torch.abs(weight_src) - lamb, torch.zeros_like(weight_src)) def get_pc_radon_ADMM_TV_mri(sde, predictor, corrector, inverse_scaler, snr, mask=None, n_steps=1, probability_flow=False, continuous=False, denoise=True, eps=1e-5, save_progress=False, save_root=None, img_shape=None, lamb_1=5, rho=10): predictor_update_fn = functools.partial(shared_predictor_update_fn, sde=sde, predictor=predictor, probability_flow=probability_flow, continuous=continuous) corrector_update_fn = functools.partial(shared_corrector_update_fn, sde=sde, corrector=corrector, continuous=continuous, snr=snr, n_steps=n_steps) del_z = torch.zeros(img_shape) udel_z = torch.zeros(img_shape) eps = 1e-10 def _A(x): return fft2(x) * mask def _AT(kspace): return torch.real(ifft2(kspace)) def _Dz(x): # Batch direction y = torch.zeros_like(x) y[:-1] = x[1:] y[-1] = x[0] return y - x def _DzT(x): # Batch direction y = torch.zeros_like(x) y[:-1] = x[1:] y[-1] = x[0] tempt = -(y - x) difft = tempt[:-1] y[1:] = difft y[0] = x[-1] - x[0] return y def A_cg(x): return _AT(_A(x)) + rho * _DzT(_Dz(x)) def shrink(src, lamb): return torch.sign(src) * torch.max(torch.abs(src) - lamb, torch.zeros_like(src)) def CG(A_fn, b_cg, x, n_inner=10): r = b_cg - A_fn(x) p = r rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T) for i in range(n_inner): Ap = A_fn(p) a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T) x += a * p r -= a * Ap rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T) if torch.sqrt(rs_new) < eps: break p = r + (rs_new / rs_old) * p rs_old = rs_new return x def CS_routine(x, ATy, niter=20): nonlocal del_z, udel_z if del_z.device != x.device: del_z = del_z.to(x.device) udel_z = del_z.to(x.device) for i in range(niter): b_cg = ATy + rho * (_DzT(del_z) - _DzT(udel_z)) x = CG(A_cg, b_cg, x, n_inner=1) del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho) udel_z = _Dz(x) - del_z + udel_z x_mean = x return x, x_mean def get_update_fn(update_fn): def radon_update_fn(model, data, x, t): with torch.no_grad(): vec_t = torch.ones(x.shape[0], device=x.device) * t x, x_mean = update_fn(x, vec_t, model=model) return x, x_mean return radon_update_fn def get_ADMM_TV_fn(): def ADMM_TV_fn(x, measurement=None): with torch.no_grad(): ATy = _AT(measurement) x, x_mean = CS_routine(x, ATy, niter=1) return x, x_mean return ADMM_TV_fn predictor_denoise_update_fn = get_update_fn(predictor_update_fn) corrector_denoise_update_fn = get_update_fn(corrector_update_fn) mc_update_fn = get_ADMM_TV_fn() def pc_radon(model, data, measurement=None): with torch.no_grad(): x = sde.prior_sampling(data.shape).to(data.device) timesteps = torch.linspace(sde.T, eps, sde.N) for i in tqdm(range(sde.N)): t = timesteps[i] # 1. batchify into sizes that fit into the GPU x_batch = batchfy(x, 20) # 2. Run PC step for each batch x_agg = list() for idx, x_batch_sing in enumerate(x_batch): x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t) x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t) x_agg.append(x_batch_sing) # 3. Aggregate to run ADMM TV x = torch.cat(x_agg, dim=0) # 4. Run ADMM TV x, x_mean = mc_update_fn(x, measurement=measurement) if save_progress: if (i % 50) == 0: print(f'iter: {i}/{sde.N}') plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray') return inverse_scaler(x_mean if denoise else x) return pc_radon ================================================ FILE: datasets.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """Return training and evaluation/test datasets from config files.""" from torch.utils.data import Dataset, DataLoader import numpy as np def get_data_scaler(config): """Data normalizer. Assume data are always in [0, 1].""" if config.data.centered: # Rescale to [-1, 1] return lambda x: x * 2. - 1. else: return lambda x: x def get_data_inverse_scaler(config): """Inverse data normalizer.""" if config.data.centered: # Rescale [-1, 1] to [0, 1] return lambda x: (x + 1.) / 2. else: return lambda x: x def crop_resize(image, resolution): """Crop and resize an image to the given resolution.""" crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) h, w = tf.shape(image)[0], tf.shape(image)[1] image = image[(h - crop) // 2:(h + crop) // 2, (w - crop) // 2:(w + crop) // 2] image = tf.image.resize( image, size=(resolution, resolution), antialias=True, method=tf.image.ResizeMethod.BICUBIC) return tf.cast(image, tf.uint8) def resize_small(image, resolution): """Shrink an image to the given resolution.""" h, w = image.shape[0], image.shape[1] ratio = resolution / min(h, w) h = tf.round(h * ratio, tf.int32) w = tf.round(w * ratio, tf.int32) return tf.image.resize(image, [h, w], antialias=True) def central_crop(image, size): """Crop the center of an image to the given size.""" top = (image.shape[0] - size) // 2 left = (image.shape[1] - size) // 2 return tf.image.crop_to_bounding_box(image, top, left, size, size) def get_dataset(config, uniform_dequantization=False, evaluation=False): """Create data loaders for training and evaluation. Args: config: A ml_collection.ConfigDict parsed from config files. uniform_dequantization: If `True`, add uniform dequantization to images. evaluation: If `True`, fix number of epochs to 1. Returns: train_ds, eval_ds, dataset_builder. """ # Compute batch size for this worker. batch_size = config.training.batch_size if not evaluation else config.eval.batch_size if batch_size % jax.device_count() != 0: raise ValueError(f'Batch sizes ({batch_size} must be divided by' f'the number of devices ({jax.device_count()})') # Reduce this when image resolution is too large and data pointer is stored shuffle_buffer_size = 10000 prefetch_size = tf.data.experimental.AUTOTUNE num_epochs = None if not evaluation else 1 # Create dataset builders for each dataset. if config.data.dataset == 'CIFAR10': dataset_builder = tfds.builder('cifar10') train_split_name = 'train' eval_split_name = 'test' def resize_op(img): img = tf.image.convert_image_dtype(img, tf.float32) # Added to train grayscale models # img = tf.image.rgb_to_grayscale(img) return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) elif config.data.dataset == 'SVHN': dataset_builder = tfds.builder('svhn_cropped') train_split_name = 'train' eval_split_name = 'test' def resize_op(img): img = tf.image.convert_image_dtype(img, tf.float32) return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) elif config.data.dataset == 'CELEBA': dataset_builder = tfds.builder('celeb_a') train_split_name = 'train' eval_split_name = 'validation' def resize_op(img): img = tf.image.convert_image_dtype(img, tf.float32) img = central_crop(img, 140) img = resize_small(img, config.data.image_size) return img elif config.data.dataset == 'LSUN': dataset_builder = tfds.builder(f'lsun/{config.data.category}') train_split_name = 'train' eval_split_name = 'validation' if config.data.image_size == 128: def resize_op(img): img = tf.image.convert_image_dtype(img, tf.float32) img = resize_small(img, config.data.image_size) img = central_crop(img, config.data.image_size) return img else: def resize_op(img): img = crop_resize(img, config.data.image_size) img = tf.image.convert_image_dtype(img, tf.float32) return img elif config.data.dataset in ['FFHQ', 'CelebAHQ']: dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path) train_split_name = eval_split_name = 'train' else: raise NotImplementedError( f'Dataset {config.data.dataset} not yet supported.') # Customize preprocess functions for each dataset. if config.data.dataset in ['FFHQ', 'CelebAHQ']: def preprocess_fn(d): sample = tf.io.parse_single_example(d, features={ 'shape': tf.io.FixedLenFeature([3], tf.int64), 'data': tf.io.FixedLenFeature([], tf.string)}) data = tf.io.decode_raw(sample['data'], tf.uint8) data = tf.reshape(data, sample['shape']) data = tf.transpose(data, (1, 2, 0)) img = tf.image.convert_image_dtype(data, tf.float32) if config.data.random_flip and not evaluation: img = tf.image.random_flip_left_right(img) if uniform_dequantization: img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. return dict(image=img, label=None) else: def preprocess_fn(d): """Basic preprocessing function scales data to [0, 1) and randomly flips.""" img = resize_op(d['image']) if config.data.random_flip and not evaluation: img = tf.image.random_flip_left_right(img) if uniform_dequantization: img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. return dict(image=img, label=d.get('label', None)) def create_dataset(dataset_builder, split): dataset_options = tf.data.Options() dataset_options.experimental_optimization.map_parallelization = True dataset_options.experimental_threading.private_threadpool_size = 48 dataset_options.experimental_threading.max_intra_op_parallelism = 1 read_config = tfds.ReadConfig(options=dataset_options) if isinstance(dataset_builder, tfds.core.DatasetBuilder): dataset_builder.download_and_prepare() ds = dataset_builder.as_dataset( split=split, shuffle_files=True, read_config=read_config) else: ds = dataset_builder.with_options(dataset_options) ds = ds.repeat(count=num_epochs) ds = ds.shuffle(shuffle_buffer_size) ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) ds = ds.batch(batch_size, drop_remainder=True) return ds.prefetch(prefetch_size) train_ds = create_dataset(dataset_builder, train_split_name) eval_ds = create_dataset(dataset_builder, eval_split_name) return train_ds, eval_ds, dataset_builder from pathlib import Path class fastmri_knee(Dataset): """ Simple pytorch dataset for fastmri knee singlecoil dataset """ def __init__(self, root, is_complex=False): self.root = root self.data_list = list(root.glob('*/*.npy')) self.is_complex = is_complex def __len__(self): return len(self.data_list) def __getitem__(self, idx): fname = self.data_list[idx] if not self.is_complex: data = np.load(fname) else: data = np.load(fname).astype(np.complex64) data = np.expand_dims(data, axis=0) return data class AAPM(Dataset): def __init__(self, root, sort): self.root = root self.data_list = list(root.glob('full_dose/*.npy')) self.sort = sort if sort: self.data_list = sorted(self.data_list) def __len__(self): return len(self.data_list) def __getitem__(self, idx): fname = self.data_list[idx] data = np.load(fname) data = np.expand_dims(data, axis=0) return data class Object5(Dataset): def __init__(self, root, slice, fast=False): """ slice - range of the 2000 _volumes_ that you want, but the dataset will return images, so will be 256 times longer fast - set to true to get a tiny version of the dataset """ if fast: self.NUM_SLICES = 10 else: self.NUM_SLICES = 256 self.root = root self.data_list = list(root.glob('*.npz')) if len(self.data_list) == 0: raise ValueError(f"No npz files found in {root}") self.data_list = sorted(self.data_list)[slice] def __len__(self): return len(self.data_list) * self.NUM_SLICES def __getitem__(self, idx): vol_index = idx // self.NUM_SLICES slice_index = idx % self.NUM_SLICES fname = self.data_list[vol_index] data = np.load(fname)['x'][slice_index] data = np.expand_dims(data, axis=0) return data class fastmri_knee_infer(Dataset): """ Simple pytorch dataset for fastmri knee singlecoil dataset """ def __init__(self, root, sort=True, is_complex=False): self.root = root self.data_list = list(root.glob('*/*.npy')) self.is_complex = is_complex if sort: self.data_list = sorted(self.data_list) def __len__(self): return len(self.data_list) def __getitem__(self, idx): fname = self.data_list[idx] if not self.is_complex: data = np.load(fname) else: data = np.load(fname).astype(np.complex64) data = np.expand_dims(data, axis=0) return data, str(fname) class fastmri_knee_magpha(Dataset): """ Simple pytorch dataset for fastmri knee singlecoil dataset """ def __init__(self, root): self.root = root self.data_list = list(root.glob('*/*.npy')) def __len__(self): return len(self.data_list) def __getitem__(self, idx): fname = self.data_list[idx] data = np.load(fname).astype(np.float32) return data class fastmri_knee_magpha_infer(Dataset): """ Simple pytorch dataset for fastmri knee singlecoil dataset """ def __init__(self, root, sort=True): self.root = root self.data_list = list(root.glob('*/*.npy')) if sort: self.data_list = sorted(self.data_list) def __len__(self): return len(self.data_list) def __getitem__(self, idx): fname = self.data_list[idx] data = np.load(fname).astype(np.float32) return data, str(fname) def create_dataloader(configs, evaluation=False, sort=True): shuffle = True if not evaluation else False if configs.data.dataset == 'Object5': train_dataset = Object5(Path(configs.data.root), slice(None,1800)) val_dataset = Object5(Path(configs.data.root), slice(1800,None)) elif configs.data.dataset == 'Object5Fast': train_dataset = Object5(Path(configs.data.root), slice(None,1), fast=True) val_dataset = Object5(Path(configs.data.root), slice(1,2), fast=True) elif configs.data.dataset == 'AAPM': train_dataset = AAPM(Path(configs.data.root) / f'train', sort=False) val_dataset = AAPM(Path(configs.data.root) / f'test', sort=True) elif configs.data.is_multi: train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_multicoil_{configs.data.image_size}_train') val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_{configs.data.image_size}_val', sort=sort) elif configs.data.is_complex: if configs.data.magpha: train_dataset = fastmri_knee_magpha(Path(configs.data.root) / f'knee_complex_magpha_{configs.data.image_size}_train') val_dataset = fastmri_knee_magpha_infer(Path(configs.data.root) / f'knee_complex_magpha_{configs.data.image_size}_val') else: train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_complex_{configs.data.image_size}_train', is_complex=True) val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_complex_{configs.data.image_size}_val', is_complex=True) elif configs.data.dataset == 'fastmri_knee': train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_{configs.data.image_size}_train') val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_{configs.data.image_size}_val', sort=sort) else: raise ValueError(f'Dataset {configs.data.dataset} not recognized.') train_loader = DataLoader( dataset=train_dataset, batch_size=configs.training.batch_size, shuffle=shuffle, drop_last=True ) val_loader = DataLoader( dataset=val_dataset, batch_size=configs.training.batch_size, # shuffle=False, shuffle=True, drop_last=True ) return train_loader, val_loader def create_dataloader_regression(configs, evaluation=False): shuffle = True if not evaluation else False train_dataset = fastmri_knee(Path(configs.root) / f'knee_{configs.image_size}_train') val_dataset = fastmri_knee_infer(Path(configs.root) / f'knee_{configs.image_size}_val') train_loader = DataLoader( dataset=train_dataset, batch_size=configs.batch_size, shuffle=shuffle, drop_last=True ) val_loader = DataLoader( dataset=val_dataset, batch_size=configs.batch_size, shuffle=False, drop_last=True ) return train_loader, val_loader ================================================ FILE: environment.yml ================================================ name: diffusion-mbir channels: - conda-forge - defaults dependencies: - python=3.8 - numpy - matplotlib - scikit-image - sporco - tqdm - ninja - pytorch::pytorch - pytorch::torchvision - tensorboard - pip - pip: - ml_collections - ninja ================================================ FILE: evaluation.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for computing FID/Inception scores.""" import numpy as np import six INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1' INCEPTION_OUTPUT = 'logits' INCEPTION_FINAL_POOL = 'pool_3' _DEFAULT_DTYPES = { INCEPTION_OUTPUT: tf.float32, INCEPTION_FINAL_POOL: tf.float32 } INCEPTION_DEFAULT_IMAGE_SIZE = 299 def get_inception_model(inceptionv3=False): if inceptionv3: return tfhub.load( 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4') else: return tfhub.load(INCEPTION_TFHUB) def load_dataset_stats(config): """Load the pre-computed dataset statistics.""" if config.data.dataset == 'CIFAR10': filename = 'assets/stats/cifar10_stats.npz' elif config.data.dataset == 'CELEBA': filename = 'assets/stats/celeba_stats.npz' elif config.data.dataset == 'LSUN': filename = f'assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz' else: raise ValueError(f'Dataset {config.data.dataset} stats not found.') with tf.io.gfile.GFile(filename, 'rb') as fin: stats = np.load(fin) return stats def classifier_fn_from_tfhub(output_fields, inception_model, return_tensor=False): """Returns a function that can be as a classifier function. Copied from tfgan but avoid loading the model each time calling _classifier_fn Args: output_fields: A string, list, or `None`. If present, assume the module outputs a dictionary, and select this field. inception_model: A model loaded from TFHub. return_tensor: If `True`, return a single tensor instead of a dictionary. Returns: A one-argument function that takes an image Tensor and returns outputs. """ if isinstance(output_fields, six.string_types): output_fields = [output_fields] def _classifier_fn(images): output = inception_model(images) if output_fields is not None: output = {x: output[x] for x in output_fields} if return_tensor: assert len(output) == 1 output = list(output.values())[0] return tf.nest.map_structure(tf.compat.v1.layers.flatten, output) return _classifier_fn @tf.function def run_inception_jit(inputs, inception_model, num_batches=1, inceptionv3=False): """Running the inception network. Assuming input is within [0, 255].""" if not inceptionv3: inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5 else: inputs = tf.cast(inputs, tf.float32) / 255. return tfgan.eval.run_classifier_fn( inputs, num_batches=num_batches, classifier_fn=classifier_fn_from_tfhub(None, inception_model), dtypes=_DEFAULT_DTYPES) @tf.function def run_inception_distributed(input_tensor, inception_model, num_batches=1, inceptionv3=False): """Distribute the inception network computation to all available TPUs. Args: input_tensor: The input images. Assumed to be within [0, 255]. inception_model: The inception network model obtained from `tfhub`. num_batches: The number of batches used for dividing the input. inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1. Returns: A dictionary with key `pool_3` and `logits`, representing the pool_3 and logits of the inception network respectively. """ num_tpus = jax.local_device_count() input_tensors = tf.split(input_tensor, num_tpus, axis=0) pool3 = [] logits = [] if not inceptionv3 else None device_format = '/TPU:{}' if 'TPU' in str(jax.devices()[0]) else '/GPU:{}' for i, tensor in enumerate(input_tensors): with tf.device(device_format.format(i)): tensor_on_device = tf.identity(tensor) res = run_inception_jit( tensor_on_device, inception_model, num_batches=num_batches, inceptionv3=inceptionv3) if not inceptionv3: pool3.append(res['pool_3']) logits.append(res['logits']) # pytype: disable=attribute-error else: pool3.append(res) with tf.device('/CPU'): return { 'pool_3': tf.concat(pool3, axis=0), 'logits': tf.concat(logits, axis=0) if not inceptionv3 else None } ================================================ FILE: fastmri_utils.py ================================================ """ Copyright (c) Facebook, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ from typing import List, Optional import torch from packaging import version if version.parse(torch.__version__) >= version.parse("1.7.0"): import torch.fft # type: ignore def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: """ Apply centered 2 dimensional Fast Fourier Transform. Args: data: Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. norm: Whether to include normalization. Must be one of ``"backward"`` or ``"ortho"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details. Returns: The FFT of the input. """ if not data.shape[-1] == 2: raise ValueError("Tensor does not have separate complex dim.") if norm not in ("ortho", "backward"): raise ValueError("norm must be 'ortho' or 'backward'.") normalized = True if norm == "ortho" else False data = ifftshift(data, dim=[-3, -2]) data = torch.fft(data, 2, normalized=normalized) data = fftshift(data, dim=[-3, -2]) return data def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data: Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. norm: Whether to include normalization. Must be one of ``"backward"`` or ``"ortho"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for details. Returns: The IFFT of the input. """ if not data.shape[-1] == 2: raise ValueError("Tensor does not have separate complex dim.") if norm not in ("ortho", "backward"): raise ValueError("norm must be 'ortho' or 'backward'.") normalized = True if norm == "ortho" else False data = ifftshift(data, dim=[-3, -2]) data = torch.ifft(data, 2, normalized=normalized) data = fftshift(data, dim=[-3, -2]) return data def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: """ Apply centered 2 dimensional Fast Fourier Transform. Args: data: Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. norm: Normalization mode. See ``torch.fft.fft``. Returns: The FFT of the input. """ if not data.shape[-1] == 2: raise ValueError("Tensor does not have separate complex dim.") data = ifftshift(data, dim=[-3, -2]) data = torch.view_as_real( torch.fft.fftn( # type: ignore torch.view_as_complex(data), dim=(-2, -1), norm=norm ) ) data = fftshift(data, dim=[-3, -2]) return data def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor: """ Apply centered 2-dimensional Inverse Fast Fourier Transform. Args: data: Complex valued input data containing at least 3 dimensions: dimensions -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are assumed to be batch dimensions. norm: Normalization mode. See ``torch.fft.ifft``. Returns: The IFFT of the input. """ if not data.shape[-1] == 2: raise ValueError("Tensor does not have separate complex dim.") data = ifftshift(data, dim=[-3, -2]) data = torch.view_as_real( torch.fft.ifftn( # type: ignore torch.view_as_complex(data), dim=(-2, -1), norm=norm ) ) data = fftshift(data, dim=[-3, -2]) return data # Helper functions def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor: """ Similar to roll but for only one dim. Args: x: A PyTorch tensor. shift: Amount to roll. dim: Which dimension to roll. Returns: Rolled version of x. """ shift = shift % x.size(dim) if shift == 0: return x left = x.narrow(dim, 0, x.size(dim) - shift) right = x.narrow(dim, x.size(dim) - shift, shift) return torch.cat((right, left), dim=dim) def roll( x: torch.Tensor, shift: List[int], dim: List[int], ) -> torch.Tensor: """ Similar to np.roll but applies to PyTorch Tensors. Args: x: A PyTorch tensor. shift: Amount to roll. dim: Which dimension to roll. Returns: Rolled version of x. """ if len(shift) != len(dim): raise ValueError("len(shift) must match len(dim)") for (s, d) in zip(shift, dim): x = roll_one_dim(x, s, d) return x def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: """ Similar to np.fft.fftshift but applies to PyTorch Tensors Args: x: A PyTorch tensor. dim: Which dimension to fftshift. Returns: fftshifted version of x. """ if dim is None: # this weird code is necessary for toch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i # also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = x.shape[dim_num] // 2 return roll(x, shift, dim) def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor: """ Similar to np.fft.ifftshift but applies to PyTorch Tensors Args: x: A PyTorch tensor. dim: Which dimension to ifftshift. Returns: ifftshifted version of x. """ if dim is None: # this weird code is necessary for toch.jit.script typing dim = [0] * (x.dim()) for i in range(1, x.dim()): dim[i] = i # also necessary for torch.jit.script shift = [0] * len(dim) for i, dim_num in enumerate(dim): shift[i] = (x.shape[dim_num] + 1) // 2 return roll(x, shift, dim) ================================================ FILE: inverse_problem_solver_AAPM_3d_total.py ================================================ import torch from torch._C import device from losses import get_optimizer from models.ema import ExponentialMovingAverage import numpy as np import controllable_generation_TV from utils import restore_checkpoint, clear, batchfy, patient_wise_min_max, img_wise_min_max from pathlib import Path from models import utils as mutils from models import ncsnpp from sde_lib import VESDE from sampling import (ReverseDiffusionPredictor, LangevinCorrector) import datasets import time # for radon from physics.ct import CT import matplotlib.pyplot as plt import os from tqdm import tqdm ############################################### # Configurations ############################################### problem = 'sparseview_CT_ADMM_TV_total' config_name = 'AAPM_256_ncsnpp_continuous' sde = 'VESDE' num_scales = 2000 ckpt_num = 185 N = num_scales vol_name = 'L067' root = Path(f'./data/CT/ind/256_sorted/{vol_name}') # Parameters for the inverse problem Nview = 8 det_spacing = 1.0 size = 256 det_count = int((size * (2 * torch.ones(1)).sqrt()).ceil()) lamb = 0.04 rho = 10 freq = 1 if sde.lower() == 'vesde': from configs.ve import AAPM_256_ncsnpp_continuous as configs ckpt_filename = f"exp/ve/{config_name}/checkpoint_{ckpt_num}.pth" config = configs.get_config() config.model.num_scales = N sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sde.N = N sampling_eps = 1e-5 predictor = ReverseDiffusionPredictor corrector = LangevinCorrector probability_flow = False snr = 0.16 n_steps = 1 batch_size = 12 config.training.batch_size = batch_size config.eval.batch_size = batch_size random_seed = 0 sigmas = mutils.get_sigmas(config) scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) score_model = mutils.create_model(config) ## model optimizer = get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(step=0, optimizer=optimizer, model=score_model, ema=ema) state = restore_checkpoint(ckpt_filename, state, config.device, skip_sigma=True, skip_optimizer=True) ema.copy_to(score_model.parameters()) # Specify save directory for saving generated samples save_root = Path(f'./results/{config_name}/{problem}/m{Nview}/rho{rho}/lambda{lamb}') save_root.mkdir(parents=True, exist_ok=True) irl_types = ['input', 'recon', 'label', 'BP', 'sinogram'] for t in irl_types: if t == 'recon': save_root_f = save_root / t / 'progress' save_root_f.mkdir(exist_ok=True, parents=True) else: save_root_f = save_root / t save_root_f.mkdir(parents=True, exist_ok=True) # read all data fname_list = os.listdir(root) fname_list = sorted(fname_list, key=lambda x: float(x.split(".")[0])) print(fname_list) all_img = [] print("Loading all data") for fname in tqdm(fname_list): just_name = fname.split('.')[0] img = torch.from_numpy(np.load(os.path.join(root, fname), allow_pickle=True)) h, w = img.shape img = img.view(1, 1, h, w) all_img.append(img) plt.imsave(os.path.join(save_root, 'label', f'{just_name}.png'), clear(img), cmap='gray') all_img = torch.cat(all_img, dim=0) print(f"Data loaded shape : {all_img.shape}") # full angles = np.linspace(0, np.pi, 180, endpoint=False) radon = CT(img_width=h, radon_view=Nview, circle=False, device=config.device) predicted_sinogram = [] label_sinogram = [] img_cache = None img = all_img.to(config.device) pc_radon = controllable_generation_TV.get_pc_radon_ADMM_TV_vol(sde, predictor, corrector, inverse_scaler, snr=snr, n_steps=n_steps, probability_flow=probability_flow, continuous=config.training.continuous, denoise=True, radon=radon, save_progress=True, save_root=save_root, final_consistency=True, img_shape=img.shape, lamb_1=lamb, rho=rho) # Sparse by masking sinogram = radon.A(img) # A_dagger bp = radon.AT(sinogram) # Recon Image x = pc_radon(score_model, scaler(img), measurement=sinogram) img_cahce = x[-1].unsqueeze(0) count = 0 for i, recon_img in enumerate(x): plt.imsave(save_root / 'BP' / f'{count}.png', clear(bp[i]), cmap='gray') plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray') plt.imsave(save_root / 'recon' / f'{count}.png', clear(recon_img), cmap='gray') count += 1 # Recon and Save Sinogram label_sinogram.append(radon.A_all(img)) predicted_sinogram.append(radon.A_all(x)) original_sinogram = torch.cat(label_sinogram, dim=0).detach().cpu().numpy() recon_sinogram = torch.cat(predicted_sinogram, dim=0).detach().cpu().numpy() np.save(str(save_root / 'sinogram' / f'original_{count}.npy'), original_sinogram) np.save(str(save_root / 'sinogram' / f'recon_{count}.npy'), recon_sinogram) ================================================ FILE: inverse_problem_solver_BRATS_MRI_3d_total.py ================================================ from pathlib import Path from models import utils as mutils import sampling from sde_lib import VESDE from sampling import (ReverseDiffusionPredictor, LangevinCorrector, LangevinCorrectorCS) from models import ncsnpp from itertools import islice from losses import get_optimizer import datasets import time import controllable_generation_TV from utils import restore_checkpoint, fft2, ifft2, show_samples_gray, get_mask, clear import torch import torch.nn as nn import numpy as np from models.ema import ExponentialMovingAverage from scipy.io import savemat, loadmat from tqdm import tqdm import matplotlib.pyplot as plt import importlib ############################################### # Configurations ############################################### problem = 'Fourier_CS_3d_admm_tv' config_name = 'fastmri_knee_320_ncsnpp_continuous' sde = 'VESDE' num_scales = 2000 ckpt_num = 95 N = num_scales root = './data/MRI/BRATS' vol = 'Brats18_CBICA_AAM_1' if sde.lower() == 'vesde': # from configs.ve import fastmri_knee_320_ncsnpp_continuous as configs configs = importlib.import_module(f"configs.ve.{config_name}") if config_name == 'fastmri_knee_320_ncsnpp_continuous': ckpt_filename = f"./exp/ve/{config_name}/checkpoint_{ckpt_num}.pth" elif config_name == 'ffhq_256_ncsnpp_continuous': ckpt_filename = f"exp/ve/{config_name}/checkpoint_48.pth" config = configs.get_config() config.model.num_scales = num_scales sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sde.N = N sampling_eps = 1e-5 img_size = 240 batch_size = 1 config.training.batch_size = batch_size predictor = ReverseDiffusionPredictor corrector = LangevinCorrector probability_flow = False snr = 0.16 n_steps = 1 # parameters for Fourier CS recon mask_type = 'uniform1d' use_measurement_noise = False acc_factor = 2.0 center_fraction = 0.15 # ADMM TV parameters lamb_list = [0.005] rho_list = [0.01] random_seed = 0 sigmas = mutils.get_sigmas(config) scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) score_model = mutils.create_model(config) optimizer = get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(step=0, optimizer=optimizer, model=score_model, ema=ema) state = restore_checkpoint(ckpt_filename, state, config.device, skip_sigma=True) ema.copy_to(score_model.parameters()) fname_list = sorted(list((Path(root) / vol).glob('*.npy'))) all_img = [] for fname in tqdm(fname_list): img = np.load(fname) img = torch.from_numpy(img) h, w = img.shape img = img.view(1, 1, h, w) all_img.append(img) all_img = torch.cat(all_img, dim=0) # normalize the volume to be in proper range vmax = all_img.max() all_img /= (vmax + 1e-5) img = all_img.to(config.device) b = img.shape[0] for lamb in lamb_list: for rho in rho_list: print(f'lambda: {lamb}') print(f'rho: {rho}') # Specify save directory for saving generated samples save_root = Path(f'./results/{config_name}/{problem}/{mask_type}/acc{acc_factor}/lamb{lamb}/rho{rho}/{vol}') save_root.mkdir(parents=True, exist_ok=True) irl_types = ['input', 'recon', 'label'] for t in irl_types: save_root_f = save_root / t save_root_f.mkdir(parents=True, exist_ok=True) ############################################### # Inference ############################################### # forward model kspace = fft2(img) # generate mask mask = get_mask(torch.zeros(1, 1, h, w), img_size, batch_size, type=mask_type, acc_factor=acc_factor, center_fraction=center_fraction) mask = mask.to(img.device) mask = mask.repeat(b, 1, 1, 1) pc_fouriercs = controllable_generation_TV.get_pc_radon_ADMM_TV_mri(sde, predictor, corrector, inverse_scaler, mask=mask, lamb_1=lamb, rho=rho, img_shape=img.shape, snr=snr, n_steps=n_steps, probability_flow=probability_flow, continuous=config.training.continuous) # undersampling under_kspace = kspace * mask under_img = torch.real(ifft2(under_kspace)) count = 0 for i, recon_img in enumerate(under_img): plt.imsave(save_root / 'input' / f'{count}.png', clear(under_img[i]), cmap='gray') plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray') count += 1 x = pc_fouriercs(score_model, scaler(under_img), measurement=under_kspace) count = 0 for i, recon_img in enumerate(x): plt.imsave(save_root / 'input' / f'{count}.png', clear(under_img[i]), cmap='gray') plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray') plt.imsave(save_root / 'recon' / f'{count}.png', clear(recon_img), cmap='gray') np.save(str(save_root / 'input' / f'{count}.npy'), clear(under_img[i], normalize=False)) np.save(str(save_root / 'recon' / f'{count}.npy'), clear(x[i], normalize=False)) np.save(str(save_root / 'label' / f'{count}.npy'), clear(img[i], normalize=False)) count += 1 ================================================ FILE: likelihood.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file # pytype: skip-file """Various sampling methods.""" import torch import numpy as np from scipy import integrate from models import utils as mutils def get_div_fn(fn): """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" def div_fn(x, t, eps): with torch.enable_grad(): x.requires_grad_(True) fn_eps = torch.sum(fn(x, t) * eps) grad_fn_eps = torch.autograd.grad(fn_eps, x)[0] x.requires_grad_(False) return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape)))) return div_fn def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher', rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5): """Create a function to compute the unbiased log-likelihood estimate of a given data point. Args: sde: A `sde_lib.SDE` object that represents the forward SDE. inverse_scaler: The inverse data normalizer. hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator. rtol: A `float` number. The relative tolerance level of the black-box ODE solver. atol: A `float` number. The absolute tolerance level of the black-box ODE solver. method: A `str`. The algorithm for the black-box ODE solver. See documentation for `scipy.integrate.solve_ivp`. eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability. Returns: A function that a batch of data points and returns the log-likelihoods in bits/dim, the latent code, and the number of function evaluations cost by computation. """ def drift_fn(model, x, t): """The drift function of the reverse-time SDE.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True) # Probability flow ODE is a special case of Reverse SDE rsde = sde.reverse(score_fn, probability_flow=True) return rsde.sde(x, t)[0] def div_fn(model, x, t, noise): return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise) def likelihood_fn(model, data): """Compute an unbiased estimate to the log-likelihood in bits/dim. Args: model: A score model. data: A PyTorch tensor. Returns: bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim. z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the probability flow ODE. nfe: An integer. The number of function evaluations used for running the black-box ODE solver. """ with torch.no_grad(): shape = data.shape if hutchinson_type == 'Gaussian': epsilon = torch.randn_like(data) elif hutchinson_type == 'Rademacher': epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1. else: raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") def ode_func(t, x): sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32) vec_t = torch.ones(sample.shape[0], device=sample.device) * t drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t)) logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon)) return np.concatenate([drift, logp_grad], axis=0) init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0) solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method) nfe = solution.nfev zp = solution.y[:, -1] z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32) delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32) prior_logp = sde.prior_logp(z) bpd = -(prior_logp + delta_logp) / np.log(2) N = np.prod(shape[1:]) bpd = bpd / N # A hack to convert log-likelihoods to bits/dim offset = 7. - inverse_scaler(-1.) bpd = bpd + offset return bpd, z, nfe return likelihood_fn ================================================ FILE: losses.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """All functions related to loss computation and optimization. """ import torch import torch.optim as optim import numpy as np from models import utils as mutils from sde_lib import VESDE, VPSDE from utils import fft2, ifft2, get_mask import numpy as np def get_optimizer(config, params): """Returns a flax optimizer object based on `config`.""" if config.optim.optimizer == 'Adam': optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, weight_decay=config.optim.weight_decay) else: raise NotImplementedError( f'Optimizer {config.optim.optimizer} not supported yet!') return optimizer def optimization_manager(config): """Returns an optimize_fn based on `config`.""" def optimize_fn(optimizer, params, step, lr=config.optim.lr, warmup=config.optim.warmup, grad_clip=config.optim.grad_clip): """Optimizes with warmup and gradient clipping (disabled if negative).""" if warmup > 0: for g in optimizer.param_groups: g['lr'] = lr * np.minimum(step / warmup, 1.0) if grad_clip >= 0: torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) optimizer.step() return optimize_fn def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5): """Create a loss function for training with arbirary SDEs. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. train: `True` for training loss and `False` for evaluation loss. reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires ad-hoc interpolation to take continuous time steps. likelihood_weighting: If `True`, weight the mixture of score matching losses according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper. eps: A `float` number. The smallest time step to sample from. Returns: A loss function. """ reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) def loss_fn(model, batch): """Compute the loss function. Args: model: A score model. batch: A mini-batch of training data. Returns: loss: A scalar that represents the average loss value across the mini-batch. """ score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous) t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps z = torch.randn_like(batch) mean, std = sde.marginal_prob(batch, t) perturbed_data = mean + std[:, None, None, None] * z score = score_fn(perturbed_data, t) if not likelihood_weighting: losses = torch.square(score * std[:, None, None, None] + z) losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) else: g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2 losses = torch.square(score + z / std[:, None, None, None]) losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2 loss = torch.mean(losses) return loss return loss_fn def get_smld_loss_fn(vesde, train, reduce_mean=False): """Legacy code to reproduce previous results on SMLD(NCSN). Not recommended for new work.""" assert isinstance(vesde, VESDE), "SMLD training only works for VESDEs." # Previous SMLD models assume descending sigmas smld_sigma_array = torch.flip(vesde.discrete_sigmas, dims=(0,)) reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) def loss_fn(model, batch): model_fn = mutils.get_model_fn(model, train=train) labels = torch.randint(0, vesde.N, (batch.shape[0],), device=batch.device) sigmas = smld_sigma_array.to(batch.device)[labels] noise = torch.randn_like(batch) * sigmas[:, None, None, None] perturbed_data = noise + batch score = model_fn(perturbed_data, labels) target = -noise / (sigmas ** 2)[:, None, None, None] losses = torch.square(score - target) losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * sigmas ** 2 loss = torch.mean(losses) return loss return loss_fn def get_ddpm_loss_fn(vpsde, train, reduce_mean=True): """Legacy code to reproduce previous results on DDPM. Not recommended for new work.""" assert isinstance(vpsde, VPSDE), "DDPM training only works for VPSDEs." reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) def loss_fn(model, batch): model_fn = mutils.get_model_fn(model, train=train) labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device) sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device) sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device) noise = torch.randn_like(batch) perturbed_data = sqrt_alphas_cumprod[labels, None, None, None] * batch + \ sqrt_1m_alphas_cumprod[labels, None, None, None] * noise score = model_fn(perturbed_data, labels) losses = torch.square(score - noise) losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) loss = torch.mean(losses) return loss return loss_fn def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False): """Create a one-step training/evaluation function. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. optimize_fn: An optimization function. reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. continuous: `True` indicates that the model is defined to take continuous time steps. likelihood_weighting: If `True`, weight the mixture of score matching losses according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper. Returns: A one-step function for training or evaluation. """ if continuous: loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean, continuous=True, likelihood_weighting=likelihood_weighting) else: assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training." if isinstance(sde, VESDE): loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean) elif isinstance(sde, VPSDE): loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean) else: raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") def step_fn(state, batch): """Running one step of training or evaluation. This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together for faster execution. Args: state: A dictionary of training information, containing the score model, optimizer, EMA status, and number of optimization steps. batch: A mini-batch of training/evaluation data. Returns: loss: The average loss value of this state. """ model = state['model'] if train: optimizer = state['optimizer'] optimizer.zero_grad() loss = loss_fn(model, batch) loss.backward() optimize_fn(optimizer, model.parameters(), step=state['step']) state['step'] += 1 state['ema'].update(model.parameters()) else: with torch.no_grad(): ema = state['ema'] ema.store(model.parameters()) ema.copy_to(model.parameters()) loss = loss_fn(model, batch) ema.restore(model.parameters()) return loss return step_fn def get_step_fn_regression(train, config, mask=None, loss_fn=None, optimize_fn=None): def step_fn(state, batch): model = state['model'] if train: optimizer = state['optimizer'] optimizer.zero_grad() # fft kspace = fft2(batch) # sample mask acc_factor = np.random.choice(config.training.acc_factor) mask = get_mask(batch, config.data.image_size, config.training.batch_size, type=config.training.mask_type, acc_factor=acc_factor, fix=True) # undersampling under_kspace = kspace * mask under_img = torch.abs(ifft2(under_kspace)) est_img = model(under_img) loss = loss_fn(est_img, batch) loss.backward() optimize_fn(optimizer, model.parameters(), step=state['step']) state['step'] += 1 state['ema'].update(model.parameters()) return loss else: with torch.no_grad(): ema = state['ema'] ema.store(model.parameters()) ema.copy_to(model.parameters()) # fft kspace = fft2(batch) # sample mask mask = get_mask(batch, config.data.image_size, config.traiing.batch_size, type=config.training.mask_type, acc_factor=config.training.acc_factor) # undersampling under_kspace = kspace * mask under_img = torch.real(ifft2(under_kspace)) est_img = model(under_img) ema.restore(model.parameters()) return est_img return step_fn ================================================ FILE: main.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Training and evaluation""" import os from pathlib import Path from absl import app from absl import flags from ml_collections.config_flags import config_flags import logging import run_lib FLAGS = flags.FLAGS config_flags.DEFINE_config_file( "config", None, "Training configuration.", lock_config=True) flags.DEFINE_string("workdir", None, "Work directory.") flags.DEFINE_enum("mode", None, ["train", "train_regression", "eval"], "Running mode: train, train_regression, or eval") flags.DEFINE_string("eval_folder", "eval", "The folder name for storing evaluation results") flags.mark_flags_as_required(["workdir", "config", "mode"]) def main(argv): print(FLAGS.config) if FLAGS.mode == "train" or FLAGS.mode == "train_regression": # Create the working directory Path(FLAGS.workdir).mkdir(parents=True, exist_ok=True) # Set logger so that it outputs to both console and file # Make logging work for both disk and Google Cloud Storage gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'w') handler = logging.StreamHandler(gfile_stream) formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') handler.setFormatter(formatter) logger = logging.getLogger() logger.addHandler(handler) logger.setLevel('INFO') # Run the training pipeline if FLAGS.mode == "train": run_lib.train(FLAGS.config, FLAGS.workdir) elif FLAGS.mode == "eval": # Run the evaluation pipeline run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder) else: raise ValueError(f"Mode {FLAGS.mode} not recognized.") if __name__ == "__main__": app.run(main) ================================================ FILE: models/__init__.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: models/ddpm.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """DDPM model. This code is the pytorch equivalent of: https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py """ import torch import torch.nn as nn import functools from . import utils, layers, normalization RefineBlock = layers.RefineBlock ResidualBlock = layers.ResidualBlock ResnetBlockDDPM = layers.ResnetBlockDDPM Upsample = layers.Upsample Downsample = layers.Downsample conv3x3 = layers.ddpm_conv3x3 get_act = layers.get_act get_normalization = normalization.get_normalization default_initializer = layers.default_init @utils.register_model(name='ddpm') class DDPM(nn.Module): def __init__(self, config): super().__init__() self.act = act = get_act(config) self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) self.nf = nf = config.model.nf ch_mult = config.model.ch_mult self.num_res_blocks = num_res_blocks = config.model.num_res_blocks self.attn_resolutions = attn_resolutions = config.model.attn_resolutions dropout = config.model.dropout resamp_with_conv = config.model.resamp_with_conv self.num_resolutions = num_resolutions = len(ch_mult) self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] AttnBlock = functools.partial(layers.AttnBlock) self.conditional = conditional = config.model.conditional ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) if conditional: # Condition on noise levels. modules = [nn.Linear(nf, nf * 4)] modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) nn.init.zeros_(modules[0].bias) modules.append(nn.Linear(nf * 4, nf * 4)) modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) nn.init.zeros_(modules[1].bias) self.centered = config.data.centered channels = config.data.num_channels # Downsampling block modules.append(conv3x3(channels, nf)) hs_c = [nf] in_ch = nf for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: modules.append(AttnBlock(channels=in_ch)) hs_c.append(in_ch) if i_level != num_resolutions - 1: modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) hs_c.append(in_ch) in_ch = hs_c[-1] modules.append(ResnetBlock(in_ch=in_ch)) modules.append(AttnBlock(channels=in_ch)) modules.append(ResnetBlock(in_ch=in_ch)) # Upsampling block for i_level in reversed(range(num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: modules.append(AttnBlock(channels=in_ch)) if i_level != 0: modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) assert not hs_c modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) modules.append(conv3x3(in_ch, channels, init_scale=0.)) self.all_modules = nn.ModuleList(modules) self.scale_by_sigma = config.model.scale_by_sigma def forward(self, x, labels): modules = self.all_modules m_idx = 0 if self.conditional: # timestep/scale embedding timesteps = labels temb = layers.get_timestep_embedding(timesteps, self.nf) temb = modules[m_idx](temb) m_idx += 1 temb = modules[m_idx](self.act(temb)) m_idx += 1 else: temb = None if self.centered: # Input is in [-1, 1] h = x else: # Input is in [0, 1] h = 2 * x - 1. # Downsampling block hs = [modules[m_idx](h)] m_idx += 1 for i_level in range(self.num_resolutions): # Residual blocks for this resolution for i_block in range(self.num_res_blocks): h = modules[m_idx](hs[-1], temb) m_idx += 1 if h.shape[-1] in self.attn_resolutions: h = modules[m_idx](h) m_idx += 1 hs.append(h) if i_level != self.num_resolutions - 1: hs.append(modules[m_idx](hs[-1])) m_idx += 1 h = hs[-1] h = modules[m_idx](h, temb) m_idx += 1 h = modules[m_idx](h) m_idx += 1 h = modules[m_idx](h, temb) m_idx += 1 # Upsampling block for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) m_idx += 1 if h.shape[-1] in self.attn_resolutions: h = modules[m_idx](h) m_idx += 1 if i_level != 0: h = modules[m_idx](h) m_idx += 1 assert not hs h = self.act(modules[m_idx](h)) m_idx += 1 h = modules[m_idx](h) m_idx += 1 assert m_idx == len(modules) if self.scale_by_sigma: # Divide the output by sigmas. Useful for training with the NCSN loss. # The DDPM loss scales the network output by sigma in the loss function, # so no need of doing it here. used_sigmas = self.sigmas[labels, None, None, None] h = h / used_sigmas return h ================================================ FILE: models/ema.py ================================================ # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py from __future__ import division from __future__ import unicode_literals import torch # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. """ def __init__(self, parameters, decay, use_num_updates=True): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`. decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] def update(self, parameters): """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) def copy_to(self, parameters): """ Copy current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) def state_dict(self): return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params) def load_state_dict(self, state_dict): self.decay = state_dict['decay'] self.num_updates = state_dict['num_updates'] self.shadow_params = state_dict['shadow_params'] ================================================ FILE: models/layers.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """Common layers for defining score networks. """ import math import string from functools import partial import torch.nn as nn import torch import torch.nn.functional as F import numpy as np from .normalization import ConditionalInstanceNorm2dPlus class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) def get_act(config): """Get activation functions from the config file.""" if config.model.nonlinearity.lower() == 'elu': return nn.ELU() elif config.model.nonlinearity.lower() == 'relu': return nn.ReLU() elif config.model.nonlinearity.lower() == 'lrelu': return nn.LeakyReLU(negative_slope=0.2) elif config.model.nonlinearity.lower() == 'swish': return nn.SiLU() else: raise NotImplementedError('activation function does not exist!') def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0): """1x1 convolution. Same as NCSNv1/v2.""" conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation, padding=padding) init_scale = 1e-10 if init_scale == 0 else init_scale conv.weight.data *= init_scale conv.bias.data *= init_scale return conv def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device='cpu'): """Ported from JAX. """ def _compute_fans(shape, in_axis=1, out_axis=0): receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] fan_in = shape[in_axis] * receptive_field_size fan_out = shape[out_axis] * receptive_field_size return fan_in, fan_out def init(shape, dtype=dtype, device=device): fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 else: raise ValueError( "invalid mode for variance scaling initializer: {}".format(mode)) variance = scale / denominator if distribution == "normal": return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) elif distribution == "uniform": return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) else: raise ValueError("invalid distribution for variance scaling initializer") return init def default_init(scale=1.): """The same initialization used in DDPM.""" scale = 1e-10 if scale == 0 else scale return variance_scaling(scale, 'fan_avg', 'uniform') class Dense(nn.Module): """Linear layer with `default_init`.""" def __init__(self): super().__init__() def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0): """1x1 convolution with DDPM initialization.""" conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" init_scale = 1e-10 if init_scale == 0 else init_scale conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias, dilation=dilation, padding=padding, kernel_size=3) conv.weight.data *= init_scale conv.bias.data *= init_scale return conv def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): """3x3 convolution with DDPM initialization.""" conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv ########################################################################### # Functions below are ported over from the NCSNv1/NCSNv2 codebase: # https://github.com/ermongroup/ncsn # https://github.com/ermongroup/ncsnv2 ########################################################################### class CRPBlock(nn.Module): def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True): super().__init__() self.convs = nn.ModuleList() for i in range(n_stages): self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) self.n_stages = n_stages if maxpool: self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) else: self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) self.act = act def forward(self, x): x = self.act(x) path = x for i in range(self.n_stages): path = self.pool(path) path = self.convs[i](path) x = path + x return x class CondCRPBlock(nn.Module): def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()): super().__init__() self.convs = nn.ModuleList() self.norms = nn.ModuleList() self.normalizer = normalizer for i in range(n_stages): self.norms.append(normalizer(features, num_classes, bias=True)) self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False)) self.n_stages = n_stages self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2) self.act = act def forward(self, x, y): x = self.act(x) path = x for i in range(self.n_stages): path = self.norms[i](path, y) path = self.pool(path) path = self.convs[i](path) x = path + x return x class RCUBlock(nn.Module): def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()): super().__init__() for i in range(n_blocks): for j in range(n_stages): setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) self.stride = 1 self.n_blocks = n_blocks self.n_stages = n_stages self.act = act def forward(self, x): for i in range(self.n_blocks): residual = x for j in range(self.n_stages): x = self.act(x) x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) x += residual return x class CondRCUBlock(nn.Module): def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()): super().__init__() for i in range(n_blocks): for j in range(n_stages): setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True)) setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False)) self.stride = 1 self.n_blocks = n_blocks self.n_stages = n_stages self.act = act self.normalizer = normalizer def forward(self, x, y): for i in range(self.n_blocks): residual = x for j in range(self.n_stages): x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y) x = self.act(x) x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x) x += residual return x class MSFBlock(nn.Module): def __init__(self, in_planes, features): super().__init__() assert isinstance(in_planes, list) or isinstance(in_planes, tuple) self.convs = nn.ModuleList() self.features = features for i in range(len(in_planes)): self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) def forward(self, xs, shape): sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) for i in range(len(self.convs)): h = self.convs[i](xs[i]) h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) sums += h return sums class CondMSFBlock(nn.Module): def __init__(self, in_planes, features, num_classes, normalizer): super().__init__() assert isinstance(in_planes, list) or isinstance(in_planes, tuple) self.convs = nn.ModuleList() self.norms = nn.ModuleList() self.features = features self.normalizer = normalizer for i in range(len(in_planes)): self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True)) self.norms.append(normalizer(in_planes[i], num_classes, bias=True)) def forward(self, xs, y, shape): sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device) for i in range(len(self.convs)): h = self.norms[i](xs[i], y) h = self.convs[i](h) h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True) sums += h return sums class RefineBlock(nn.Module): def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True): super().__init__() assert isinstance(in_planes, tuple) or isinstance(in_planes, list) self.n_blocks = n_blocks = len(in_planes) self.adapt_convs = nn.ModuleList() for i in range(n_blocks): self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act)) self.output_convs = RCUBlock(features, 3 if end else 1, 2, act) if not start: self.msf = MSFBlock(in_planes, features) self.crp = CRPBlock(features, 2, act, maxpool=maxpool) def forward(self, xs, output_shape): assert isinstance(xs, tuple) or isinstance(xs, list) hs = [] for i in range(len(xs)): h = self.adapt_convs[i](xs[i]) hs.append(h) if self.n_blocks > 1: h = self.msf(hs, output_shape) else: h = hs[0] h = self.crp(h) h = self.output_convs(h) return h class CondRefineBlock(nn.Module): def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False): super().__init__() assert isinstance(in_planes, tuple) or isinstance(in_planes, list) self.n_blocks = n_blocks = len(in_planes) self.adapt_convs = nn.ModuleList() for i in range(n_blocks): self.adapt_convs.append( CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act) ) self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act) if not start: self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer) self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act) def forward(self, xs, y, output_shape): assert isinstance(xs, tuple) or isinstance(xs, list) hs = [] for i in range(len(xs)): h = self.adapt_convs[i](xs[i], y) hs.append(h) if self.n_blocks > 1: h = self.msf(hs, y, output_shape) else: h = hs[0] h = self.crp(h, y) h = self.output_convs(h, y) return h class ConvMeanPool(nn.Module): def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False): super().__init__() if not adjust_padding: conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) self.conv = conv else: conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) self.conv = nn.Sequential( nn.ZeroPad2d((1, 0, 1, 0)), conv ) def forward(self, inputs): output = self.conv(inputs) output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. return output class MeanPoolConv(nn.Module): def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): super().__init__() self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) def forward(self, inputs): output = inputs output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2], output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4. return self.conv(output) class UpsampleConv(nn.Module): def __init__(self, input_dim, output_dim, kernel_size=3, biases=True): super().__init__() self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases) self.pixelshuffle = nn.PixelShuffle(upscale_factor=2) def forward(self, inputs): output = inputs output = torch.cat([output, output, output, output], dim=1) output = self.pixelshuffle(output) return self.conv(output) class ConditionalResidualBlock(nn.Module): def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(), normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None): super().__init__() self.non_linearity = act self.input_dim = input_dim self.output_dim = output_dim self.resample = resample self.normalization = normalization if resample == 'down': if dilation > 1: self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) self.normalize2 = normalization(input_dim, num_classes) self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) else: self.conv1 = ncsn_conv3x3(input_dim, input_dim) self.normalize2 = normalization(input_dim, num_classes) self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) elif resample is None: if dilation > 1: conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) self.normalize2 = normalization(output_dim, num_classes) self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) else: conv_shortcut = nn.Conv2d self.conv1 = ncsn_conv3x3(input_dim, output_dim) self.normalize2 = normalization(output_dim, num_classes) self.conv2 = ncsn_conv3x3(output_dim, output_dim) else: raise Exception('invalid resample value') if output_dim != input_dim or resample is not None: self.shortcut = conv_shortcut(input_dim, output_dim) self.normalize1 = normalization(input_dim, num_classes) def forward(self, x, y): output = self.normalize1(x, y) output = self.non_linearity(output) output = self.conv1(output) output = self.normalize2(output, y) output = self.non_linearity(output) output = self.conv2(output) if self.output_dim == self.input_dim and self.resample is None: shortcut = x else: shortcut = self.shortcut(x) return shortcut + output class ResidualBlock(nn.Module): def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(), normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1): super().__init__() self.non_linearity = act self.input_dim = input_dim self.output_dim = output_dim self.resample = resample self.normalization = normalization if resample == 'down': if dilation > 1: self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation) self.normalize2 = normalization(input_dim) self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) else: self.conv1 = ncsn_conv3x3(input_dim, input_dim) self.normalize2 = normalization(input_dim) self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding) conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding) elif resample is None: if dilation > 1: conv_shortcut = partial(ncsn_conv3x3, dilation=dilation) self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation) self.normalize2 = normalization(output_dim) self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation) else: # conv_shortcut = nn.Conv2d ### Something wierd here. conv_shortcut = partial(ncsn_conv1x1) self.conv1 = ncsn_conv3x3(input_dim, output_dim) self.normalize2 = normalization(output_dim) self.conv2 = ncsn_conv3x3(output_dim, output_dim) else: raise Exception('invalid resample value') if output_dim != input_dim or resample is not None: self.shortcut = conv_shortcut(input_dim, output_dim) self.normalize1 = normalization(input_dim) def forward(self, x): output = self.normalize1(x) output = self.non_linearity(output) output = self.conv1(output) output = self.normalize2(output) output = self.non_linearity(output) output = self.conv2(output) if self.output_dim == self.input_dim and self.resample is None: shortcut = x else: shortcut = self.shortcut(x) return shortcut + output ########################################################################### # Functions below are ported over from the DDPM codebase: # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py ########################################################################### def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 half_dim = embedding_dim // 2 # magic number 10000 is from transformers emb = math.log(max_positions) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1), mode='constant') assert emb.shape == (timesteps.shape[0], embedding_dim) return emb def _einsum(a, b, c, x, y): einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) return torch.einsum(einsum_str, x, y) def contract_inner(x, y): """tensordot(x, y, 1).""" x_chars = list(string.ascii_lowercase[:len(x.shape)]) y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)]) y_chars[0] = x_chars[-1] # first axis of y and last of x get summed out_chars = x_chars[:-1] + y_chars[1:] return _einsum(x_chars, y_chars, out_chars, x, y) class NIN(nn.Module): def __init__(self, in_dim, num_units, init_scale=0.1): super().__init__() self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) def forward(self, x): x = x.permute(0, 2, 3, 1) y = contract_inner(x, self.W) + self.b return y.permute(0, 3, 1, 2) class AttnBlock(nn.Module): """Channel-wise self-attention block.""" def __init__(self, channels): super().__init__() self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) self.NIN_0 = NIN(channels, channels) self.NIN_1 = NIN(channels, channels) self.NIN_2 = NIN(channels, channels) self.NIN_3 = NIN(channels, channels, init_scale=0.) def forward(self, x): B, C, H, W = x.shape h = self.GroupNorm_0(x) q = self.NIN_0(h) k = self.NIN_1(h) v = self.NIN_2(h) w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) w = torch.reshape(w, (B, H, W, H * W)) w = F.softmax(w, dim=-1) w = torch.reshape(w, (B, H, W, H, W)) h = torch.einsum('bhwij,bcij->bchw', w, v) h = self.NIN_3(h) return x + h class Upsample(nn.Module): def __init__(self, channels, with_conv=False): super().__init__() if with_conv: self.Conv_0 = ddpm_conv3x3(channels, channels) self.with_conv = with_conv def forward(self, x): B, C, H, W = x.shape h = F.interpolate(x, (H * 2, W * 2), mode='nearest') if self.with_conv: h = self.Conv_0(h) return h class Downsample(nn.Module): def __init__(self, channels, with_conv=False): super().__init__() if with_conv: self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0) self.with_conv = with_conv def forward(self, x): B, C, H, W = x.shape # Emulate 'SAME' padding if self.with_conv: x = F.pad(x, (0, 1, 0, 1)) x = self.Conv_0(x) else: x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0) assert x.shape == (B, C, H // 2, W // 2) return x class ResnetBlockDDPM(nn.Module): """The ResNet Blocks used in DDPM.""" def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1): super().__init__() if out_ch is None: out_ch = in_ch self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6) self.act = act self.Conv_0 = ddpm_conv3x3(in_ch, out_ch) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) nn.init.zeros_(self.Dense_0.bias) self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6) self.Dropout_0 = nn.Dropout(dropout) self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.) if in_ch != out_ch: if conv_shortcut: self.Conv_2 = ddpm_conv3x3(in_ch, out_ch) else: self.NIN_0 = NIN(in_ch, out_ch) self.out_ch = out_ch self.in_ch = in_ch self.conv_shortcut = conv_shortcut def forward(self, x, temb=None): B, C, H, W = x.shape assert C == self.in_ch out_ch = self.out_ch if self.out_ch else self.in_ch h = self.act(self.GroupNorm_0(x)) h = self.Conv_0(h) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += self.Dense_0(self.act(temb))[:, :, None, None] h = self.act(self.GroupNorm_1(h)) h = self.Dropout_0(h) h = self.Conv_1(h) if C != out_ch: if self.conv_shortcut: x = self.Conv_2(x) else: x = self.NIN_0(x) return x + h ================================================ FILE: models/layerspp.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """Layers for defining NCSN++. """ from . import layers from . import up_or_down_sampling import torch.nn as nn import torch import torch.nn.functional as F import numpy as np conv1x1 = layers.ddpm_conv1x1 conv3x3 = layers.ddpm_conv3x3 NIN = layers.NIN default_init = layers.default_init class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__(self, embedding_size=256, scale=1.0): super().__init__() self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) def forward(self, x): x_proj = x[:, None] * self.W[None, :] * 2 * np.pi return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) class Combine(nn.Module): """Combine information from skip connections.""" def __init__(self, dim1, dim2, method='cat'): super().__init__() self.Conv_0 = conv1x1(dim1, dim2) self.method = method def forward(self, x, y): h = self.Conv_0(x) if self.method == 'cat': return torch.cat([h, y], dim=1) elif self.method == 'sum': return h + y else: raise ValueError(f'Method {self.method} not recognized.') class AttnBlockpp(nn.Module): """Channel-wise self-attention block. Modified from DDPM.""" def __init__(self, channels, skip_rescale=False, init_scale=0.): super().__init__() self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) self.NIN_0 = NIN(channels, channels) self.NIN_1 = NIN(channels, channels) self.NIN_2 = NIN(channels, channels) self.NIN_3 = NIN(channels, channels, init_scale=init_scale) self.skip_rescale = skip_rescale def forward(self, x): B, C, H, W = x.shape h = self.GroupNorm_0(x) q = self.NIN_0(h) k = self.NIN_1(h) v = self.NIN_2(h) w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) w = torch.reshape(w, (B, H, W, H * W)) w = F.softmax(w, dim=-1) w = torch.reshape(w, (B, H, W, H, W)) h = torch.einsum('bhwij,bcij->bchw', w, v) h = self.NIN_3(h) if not self.skip_rescale: return x + h else: return (x + h) / np.sqrt(2.) class Upsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_ch = out_ch if out_ch else in_ch if not fir: if with_conv: self.Conv_0 = conv3x3(in_ch, out_ch) else: if with_conv: self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, kernel=3, up=True, resample_kernel=fir_kernel, use_bias=True, kernel_init=default_init()) self.fir = fir self.with_conv = with_conv self.fir_kernel = fir_kernel self.out_ch = out_ch def forward(self, x): B, C, H, W = x.shape if not self.fir: h = F.interpolate(x, (H * 2, W * 2), 'nearest') if self.with_conv: h = self.Conv_0(h) else: if not self.with_conv: h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) else: h = self.Conv2d_0(x) return h class Downsample(nn.Module): def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_ch = out_ch if out_ch else in_ch if not fir: if with_conv: self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) else: if with_conv: self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, kernel=3, down=True, resample_kernel=fir_kernel, use_bias=True, kernel_init=default_init()) self.fir = fir self.fir_kernel = fir_kernel self.with_conv = with_conv self.out_ch = out_ch def forward(self, x): B, C, H, W = x.shape if not self.fir: if self.with_conv: x = F.pad(x, (0, 1, 0, 1)) x = self.Conv_0(x) else: x = F.avg_pool2d(x, 2, stride=2) else: if not self.with_conv: x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) else: x = self.Conv2d_0(x) return x class ResnetBlockDDPMpp(nn.Module): """ResBlock adapted from DDPM.""" def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, skip_rescale=False, init_scale=0.): super().__init__() out_ch = out_ch if out_ch else in_ch self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.Conv_0 = conv3x3(in_ch, out_ch) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) nn.init.zeros_(self.Dense_0.bias) self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) self.Dropout_0 = nn.Dropout(dropout) self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) if in_ch != out_ch: if conv_shortcut: self.Conv_2 = conv3x3(in_ch, out_ch) else: self.NIN_0 = NIN(in_ch, out_ch) self.skip_rescale = skip_rescale self.act = act self.out_ch = out_ch self.conv_shortcut = conv_shortcut def forward(self, x, temb=None): h = self.act(self.GroupNorm_0(x)) h = self.Conv_0(h) if temb is not None: h += self.Dense_0(self.act(temb))[:, :, None, None] h = self.act(self.GroupNorm_1(h)) h = self.Dropout_0(h) h = self.Conv_1(h) if x.shape[1] != self.out_ch: if self.conv_shortcut: x = self.Conv_2(x) else: x = self.NIN_0(x) if not self.skip_rescale: return x + h else: return (x + h) / np.sqrt(2.) class ResnetBlockBigGANpp(nn.Module): def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), skip_rescale=True, init_scale=0.): super().__init__() out_ch = out_ch if out_ch else in_ch self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) self.up = up self.down = down self.fir = fir self.fir_kernel = fir_kernel self.Conv_0 = conv3x3(in_ch, out_ch) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) nn.init.zeros_(self.Dense_0.bias) self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) self.Dropout_0 = nn.Dropout(dropout) self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) if in_ch != out_ch or up or down: self.Conv_2 = conv1x1(in_ch, out_ch) self.skip_rescale = skip_rescale self.act = act self.in_ch = in_ch self.out_ch = out_ch def forward(self, x, temb=None): h = self.act(self.GroupNorm_0(x)) if self.up: if self.fir: h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) else: h = up_or_down_sampling.naive_upsample_2d(h, factor=2) x = up_or_down_sampling.naive_upsample_2d(x, factor=2) elif self.down: if self.fir: h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) else: h = up_or_down_sampling.naive_downsample_2d(h, factor=2) x = up_or_down_sampling.naive_downsample_2d(x, factor=2) h = self.Conv_0(h) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += self.Dense_0(self.act(temb))[:, :, None, None] h = self.act(self.GroupNorm_1(h)) h = self.Dropout_0(h) h = self.Conv_1(h) if self.in_ch != self.out_ch or self.up or self.down: x = self.Conv_2(x) if not self.skip_rescale: return x + h else: return (x + h) / np.sqrt(2.) ================================================ FILE: models/ncsnpp.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file from . import utils, layers, layerspp, normalization import torch.nn as nn import functools import torch import numpy as np ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp Combine = layerspp.Combine conv3x3 = layerspp.conv3x3 conv1x1 = layerspp.conv1x1 get_act = layers.get_act get_normalization = normalization.get_normalization default_initializer = layers.default_init @utils.register_model(name='ncsnpp') class NCSNpp(nn.Module): """NCSN++ model""" def __init__(self, config): super().__init__() self.config = config self.act = act = get_act(config) self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config))) self.nf = nf = config.model.nf ch_mult = config.model.ch_mult self.num_res_blocks = num_res_blocks = config.model.num_res_blocks self.attn_resolutions = attn_resolutions = config.model.attn_resolutions dropout = config.model.dropout resamp_with_conv = config.model.resamp_with_conv self.num_resolutions = num_resolutions = len(ch_mult) self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)] self.conditional = conditional = config.model.conditional # noise-conditional fir = config.model.fir fir_kernel = config.model.fir_kernel self.skip_rescale = skip_rescale = config.model.skip_rescale self.resblock_type = resblock_type = config.model.resblock_type.lower() self.progressive = progressive = config.model.progressive.lower() self.progressive_input = progressive_input = config.model.progressive_input.lower() self.embedding_type = embedding_type = config.model.embedding_type.lower() init_scale = config.model.init_scale assert progressive in ['none', 'output_skip', 'residual'] assert progressive_input in ['none', 'input_skip', 'residual'] assert embedding_type in ['fourier', 'positional'] combine_method = config.model.progressive_combine.lower() combiner = functools.partial(Combine, method=combine_method) modules = [] # timestep/noise_level embedding; only for continuous training if embedding_type == 'fourier': # Gaussian Fourier features embeddings. assert config.training.continuous, "Fourier features are only used for continuous training." modules.append(layerspp.GaussianFourierProjection( embedding_size=nf, scale=config.model.fourier_scale )) embed_dim = 2 * nf elif embedding_type == 'positional': embed_dim = nf else: raise ValueError(f'embedding type {embedding_type} unknown.') if conditional: modules.append(nn.Linear(embed_dim, nf * 4)) modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) modules.append(nn.Linear(nf * 4, nf * 4)) modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) nn.init.zeros_(modules[-1].bias) AttnBlock = functools.partial(layerspp.AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) Upsample = functools.partial(layerspp.Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive == 'output_skip': self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive == 'residual': pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir, fir_kernel=fir_kernel, with_conv=True) Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) if progressive_input == 'input_skip': self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) elif progressive_input == 'residual': pyramid_downsample = functools.partial(layerspp.Downsample, fir=fir, fir_kernel=fir_kernel, with_conv=True) if resblock_type == 'ddpm': ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, dropout=dropout, init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4) elif resblock_type == 'biggan': ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act, dropout=dropout, fir=fir, fir_kernel=fir_kernel, init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4) else: raise ValueError(f'resblock type {resblock_type} unrecognized.') # Downsampling block channels = config.data.num_channels if progressive_input != 'none': input_pyramid_ch = channels modules.append(conv3x3(channels, nf)) hs_c = [nf] in_ch = nf for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: modules.append(AttnBlock(channels=in_ch)) hs_c.append(in_ch) if i_level != num_resolutions - 1: if resblock_type == 'ddpm': modules.append(Downsample(in_ch=in_ch)) else: modules.append(ResnetBlock(down=True, in_ch=in_ch)) if progressive_input == 'input_skip': modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) if combine_method == 'cat': in_ch *= 2 elif progressive_input == 'residual': modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) input_pyramid_ch = in_ch hs_c.append(in_ch) in_ch = hs_c[-1] modules.append(ResnetBlock(in_ch=in_ch)) modules.append(AttnBlock(channels=in_ch)) modules.append(ResnetBlock(in_ch=in_ch)) pyramid_ch = 0 # Upsampling block for i_level in reversed(range(num_resolutions)): for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[i_level] modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) in_ch = out_ch if all_resolutions[i_level] in attn_resolutions: modules.append(AttnBlock(channels=in_ch)) if progressive != 'none': if i_level == num_resolutions - 1: if progressive == 'output_skip': modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) pyramid_ch = channels elif progressive == 'residual': modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(conv3x3(in_ch, in_ch, bias=True)) pyramid_ch = in_ch else: raise ValueError(f'{progressive} is not a valid name.') else: if progressive == 'output_skip': modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) pyramid_ch = channels elif progressive == 'residual': modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) pyramid_ch = in_ch else: raise ValueError(f'{progressive} is not a valid name') if i_level != 0: if resblock_type == 'ddpm': modules.append(Upsample(in_ch=in_ch)) else: modules.append(ResnetBlock(in_ch=in_ch, up=True)) assert not hs_c if progressive != 'output_skip': modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) self.all_modules = nn.ModuleList(modules) def forward(self, x, time_cond): # timestep/noise_level embedding; only for continuous training modules = self.all_modules m_idx = 0 if self.embedding_type == 'fourier': # Gaussian Fourier features embeddings. used_sigmas = time_cond temb = modules[m_idx](torch.log(used_sigmas)) m_idx += 1 elif self.embedding_type == 'positional': # Sinusoidal positional embeddings. timesteps = time_cond used_sigmas = self.sigmas[time_cond.long()] temb = layers.get_timestep_embedding(timesteps, self.nf) else: raise ValueError(f'embedding type {self.embedding_type} unknown.') if self.conditional: temb = modules[m_idx](temb) m_idx += 1 temb = modules[m_idx](self.act(temb)) m_idx += 1 else: temb = None if not self.config.data.centered: # If input data is in [0, 1] x = 2 * x - 1. # Downsampling block input_pyramid = None if self.progressive_input != 'none': input_pyramid = x hs = [modules[m_idx](x)] m_idx += 1 for i_level in range(self.num_resolutions): # Residual blocks for this resolution for i_block in range(self.num_res_blocks): h = modules[m_idx](hs[-1], temb) m_idx += 1 if h.shape[-1] in self.attn_resolutions: h = modules[m_idx](h) m_idx += 1 hs.append(h) # debug # print(f'lv/block : {i_level}/{i_block} shape: {h.shape}') if i_level != self.num_resolutions - 1: if self.resblock_type == 'ddpm': h = modules[m_idx](hs[-1]) m_idx += 1 else: h = modules[m_idx](hs[-1], temb) m_idx += 1 # debug if self.progressive_input == 'input_skip': input_pyramid = self.pyramid_downsample(input_pyramid) h = modules[m_idx](input_pyramid, h) m_idx += 1 elif self.progressive_input == 'residual': input_pyramid = modules[m_idx](input_pyramid) m_idx += 1 if self.skip_rescale: input_pyramid = (input_pyramid + h) / np.sqrt(2.) else: input_pyramid = input_pyramid + h h = input_pyramid hs.append(h) h = hs[-1] h = modules[m_idx](h, temb) m_idx += 1 h = modules[m_idx](h) m_idx += 1 h = modules[m_idx](h, temb) m_idx += 1 pyramid = None # Upsampling block for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): tmp = hs.pop() h = modules[m_idx](torch.cat([h, tmp], dim=1), temb) m_idx += 1 # debug # print(f'lv/block : {i_level}/{i_block} shape: {h.shape}') if h.shape[-1] in self.attn_resolutions: h = modules[m_idx](h) m_idx += 1 # debug # print(f'(ATTN) lv/block : {i_level}/{i_block} shape: {h.shape}') if self.progressive != 'none': if i_level == self.num_resolutions - 1: if self.progressive == 'output_skip': pyramid = self.act(modules[m_idx](h)) m_idx += 1 pyramid = modules[m_idx](pyramid) m_idx += 1 elif self.progressive == 'residual': pyramid = self.act(modules[m_idx](h)) m_idx += 1 pyramid = modules[m_idx](pyramid) m_idx += 1 else: raise ValueError(f'{self.progressive} is not a valid name.') else: if self.progressive == 'output_skip': pyramid = self.pyramid_upsample(pyramid) pyramid_h = self.act(modules[m_idx](h)) m_idx += 1 pyramid_h = modules[m_idx](pyramid_h) m_idx += 1 pyramid = pyramid + pyramid_h elif self.progressive == 'residual': pyramid = modules[m_idx](pyramid) m_idx += 1 if self.skip_rescale: pyramid = (pyramid + h) / np.sqrt(2.) else: pyramid = pyramid + h h = pyramid else: raise ValueError(f'{self.progressive} is not a valid name') if i_level != 0: if self.resblock_type == 'ddpm': h = modules[m_idx](h) m_idx += 1 else: h = modules[m_idx](h, temb) m_idx += 1 # debug assert not hs if self.progressive == 'output_skip': h = pyramid else: h = self.act(modules[m_idx](h)) m_idx += 1 h = modules[m_idx](h) m_idx += 1 # debug # print(f'module : {modules[m_idx-1]} shape: {h.shape}') assert m_idx == len(modules) if self.config.model.scale_by_sigma: used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) # debug # print(f'used_sigmas: {used_sigmas.shape}') h = h / used_sigmas return h ================================================ FILE: models/ncsnv2.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """The NCSNv2 model.""" import torch import torch.nn as nn import functools from .utils import get_sigmas, register_model from .layers import (CondRefineBlock, RefineBlock, ResidualBlock, ncsn_conv3x3, ConditionalResidualBlock, get_act) from .normalization import get_normalization CondResidualBlock = ConditionalResidualBlock conv3x3 = ncsn_conv3x3 def get_network(config): if config.data.image_size < 96: return functools.partial(NCSNv2, config=config) elif 96 <= config.data.image_size <= 128: return functools.partial(NCSNv2_128, config=config) elif 128 < config.data.image_size <= 256: return functools.partial(NCSNv2_256, config=config) else: raise NotImplementedError( f'No network suitable for {config.data.image_size}px implemented yet.') @register_model(name='ncsnv2_64') class NCSNv2(nn.Module): def __init__(self, config): super().__init__() self.centered = config.data.centered self.norm = get_normalization(config) self.nf = nf = config.model.nf self.act = act = get_act(config) self.register_buffer('sigmas', torch.tensor(get_sigmas(config))) self.config = config self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) self.normalizer = self.norm(nf, config.model.num_scales) self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) self.res1 = nn.ModuleList([ ResidualBlock(self.nf, self.nf, resample=None, act=act, normalization=self.norm), ResidualBlock(self.nf, self.nf, resample=None, act=act, normalization=self.norm)] ) self.res2 = nn.ModuleList([ ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm)] ) self.res3 = nn.ModuleList([ ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm, dilation=2), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm, dilation=2)] ) if config.data.image_size == 28: self.res4 = nn.ModuleList([ ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm, adjust_padding=True, dilation=4), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm, dilation=4)] ) else: self.res4 = nn.ModuleList([ ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm, adjust_padding=False, dilation=4), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm, dilation=4)] ) self.refine1 = RefineBlock([2 * self.nf], 2 * self.nf, act=act, start=True) self.refine2 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act) self.refine4 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True) def _compute_cond_module(self, module, x): for m in module: x = m(x) return x def forward(self, x, y): if not self.centered: h = 2 * x - 1. else: h = x output = self.begin_conv(h) layer1 = self._compute_cond_module(self.res1, output) layer2 = self._compute_cond_module(self.res2, layer1) layer3 = self._compute_cond_module(self.res3, layer2) layer4 = self._compute_cond_module(self.res4, layer3) ref1 = self.refine1([layer4], layer4.shape[2:]) ref2 = self.refine2([layer3, ref1], layer3.shape[2:]) ref3 = self.refine3([layer2, ref2], layer2.shape[2:]) output = self.refine4([layer1, ref3], layer1.shape[2:]) output = self.normalizer(output) output = self.act(output) output = self.end_conv(output) used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) output = output / used_sigmas return output @register_model(name='ncsn') class NCSN(nn.Module): def __init__(self, config): super().__init__() self.centered = config.data.centered self.norm = get_normalization(config) self.nf = nf = config.model.nf self.act = act = get_act(config) self.config = config self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) self.normalizer = self.norm(nf, config.model.num_scales) self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) self.res1 = nn.ModuleList([ ConditionalResidualBlock(self.nf, self.nf, config.model.num_scales, resample=None, act=act, normalization=self.norm), ConditionalResidualBlock(self.nf, self.nf, config.model.num_scales, resample=None, act=act, normalization=self.norm)] ) self.res2 = nn.ModuleList([ ConditionalResidualBlock(self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, normalization=self.norm), ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, normalization=self.norm)] ) self.res3 = nn.ModuleList([ ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, normalization=self.norm, dilation=2), ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, normalization=self.norm, dilation=2)] ) if config.data.image_size == 28: self.res4 = nn.ModuleList([ ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, normalization=self.norm, adjust_padding=True, dilation=4), ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, normalization=self.norm, dilation=4)] ) else: self.res4 = nn.ModuleList([ ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act, normalization=self.norm, adjust_padding=False, dilation=4), ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act, normalization=self.norm, dilation=4)] ) self.refine1 = CondRefineBlock([2 * self.nf], 2 * self.nf, config.model.num_scales, self.norm, act=act, start=True) self.refine2 = CondRefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, config.model.num_scales, self.norm, act=act) self.refine3 = CondRefineBlock([2 * self.nf, 2 * self.nf], self.nf, config.model.num_scales, self.norm, act=act) self.refine4 = CondRefineBlock([self.nf, self.nf], self.nf, config.model.num_scales, self.norm, act=act, end=True) def _compute_cond_module(self, module, x, y): for m in module: x = m(x, y) return x def forward(self, x, y): if not self.centered: h = 2 * x - 1. else: h = x output = self.begin_conv(h) layer1 = self._compute_cond_module(self.res1, output, y) layer2 = self._compute_cond_module(self.res2, layer1, y) layer3 = self._compute_cond_module(self.res3, layer2, y) layer4 = self._compute_cond_module(self.res4, layer3, y) ref1 = self.refine1([layer4], y, layer4.shape[2:]) ref2 = self.refine2([layer3, ref1], y, layer3.shape[2:]) ref3 = self.refine3([layer2, ref2], y, layer2.shape[2:]) output = self.refine4([layer1, ref3], y, layer1.shape[2:]) output = self.normalizer(output, y) output = self.act(output) output = self.end_conv(output) return output @register_model(name='ncsnv2_128') class NCSNv2_128(nn.Module): """NCSNv2 model architecture for 128px images.""" def __init__(self, config): super().__init__() self.centered = config.data.centered self.norm = get_normalization(config) self.nf = nf = config.model.nf self.act = act = get_act(config) self.register_buffer('sigmas', torch.tensor(get_sigmas(config))) self.config = config self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) self.normalizer = self.norm(nf, config.model.num_scales) self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) self.res1 = nn.ModuleList([ ResidualBlock(self.nf, self.nf, resample=None, act=act, normalization=self.norm), ResidualBlock(self.nf, self.nf, resample=None, act=act, normalization=self.norm)] ) self.res2 = nn.ModuleList([ ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm)] ) self.res3 = nn.ModuleList([ ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm)] ) self.res4 = nn.ModuleList([ ResidualBlock(2 * self.nf, 4 * self.nf, resample='down', act=act, normalization=self.norm, dilation=2), ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, normalization=self.norm, dilation=2)] ) self.res5 = nn.ModuleList([ ResidualBlock(4 * self.nf, 4 * self.nf, resample='down', act=act, normalization=self.norm, dilation=4), ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, normalization=self.norm, dilation=4)] ) self.refine1 = RefineBlock([4 * self.nf], 4 * self.nf, act=act, start=True) self.refine2 = RefineBlock([4 * self.nf, 4 * self.nf], 2 * self.nf, act=act) self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) self.refine4 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act) self.refine5 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True) def _compute_cond_module(self, module, x): for m in module: x = m(x) return x def forward(self, x, y): if not self.centered: h = 2 * x - 1. else: h = x output = self.begin_conv(h) layer1 = self._compute_cond_module(self.res1, output) layer2 = self._compute_cond_module(self.res2, layer1) layer3 = self._compute_cond_module(self.res3, layer2) layer4 = self._compute_cond_module(self.res4, layer3) layer5 = self._compute_cond_module(self.res5, layer4) ref1 = self.refine1([layer5], layer5.shape[2:]) ref2 = self.refine2([layer4, ref1], layer4.shape[2:]) ref3 = self.refine3([layer3, ref2], layer3.shape[2:]) ref4 = self.refine4([layer2, ref3], layer2.shape[2:]) output = self.refine5([layer1, ref4], layer1.shape[2:]) output = self.normalizer(output) output = self.act(output) output = self.end_conv(output) used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) output = output / used_sigmas return output @register_model(name='ncsnv2_256') class NCSNv2_256(nn.Module): """NCSNv2 model architecture for 256px images.""" def __init__(self, config): super().__init__() self.centered = config.data.centered self.norm = get_normalization(config) self.nf = nf = config.model.nf self.act = act = get_act(config) self.register_buffer('sigmas', torch.tensor(get_sigmas(config))) self.config = config self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1) self.normalizer = self.norm(nf, config.model.num_scales) self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1) self.res1 = nn.ModuleList([ ResidualBlock(self.nf, self.nf, resample=None, act=act, normalization=self.norm), ResidualBlock(self.nf, self.nf, resample=None, act=act, normalization=self.norm)] ) self.res2 = nn.ModuleList([ ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm)] ) self.res3 = nn.ModuleList([ ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm)] ) self.res31 = nn.ModuleList([ ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act, normalization=self.norm), ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act, normalization=self.norm)] ) self.res4 = nn.ModuleList([ ResidualBlock(2 * self.nf, 4 * self.nf, resample='down', act=act, normalization=self.norm, dilation=2), ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, normalization=self.norm, dilation=2)] ) self.res5 = nn.ModuleList([ ResidualBlock(4 * self.nf, 4 * self.nf, resample='down', act=act, normalization=self.norm, dilation=4), ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act, normalization=self.norm, dilation=4)] ) self.refine1 = RefineBlock([4 * self.nf], 4 * self.nf, act=act, start=True) self.refine2 = RefineBlock([4 * self.nf, 4 * self.nf], 2 * self.nf, act=act) self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) self.refine31 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act) self.refine4 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act) self.refine5 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True) def _compute_cond_module(self, module, x): for m in module: x = m(x) return x def forward(self, x, y): if not self.centered: h = 2 * x - 1. else: h = x output = self.begin_conv(h) layer1 = self._compute_cond_module(self.res1, output) layer2 = self._compute_cond_module(self.res2, layer1) layer3 = self._compute_cond_module(self.res3, layer2) layer31 = self._compute_cond_module(self.res31, layer3) layer4 = self._compute_cond_module(self.res4, layer31) layer5 = self._compute_cond_module(self.res5, layer4) ref1 = self.refine1([layer5], layer5.shape[2:]) ref2 = self.refine2([layer4, ref1], layer4.shape[2:]) ref31 = self.refine31([layer31, ref2], layer31.shape[2:]) ref3 = self.refine3([layer3, ref31], layer3.shape[2:]) ref4 = self.refine4([layer2, ref3], layer2.shape[2:]) output = self.refine5([layer1, ref4], layer1.shape[2:]) output = self.normalizer(output) output = self.act(output) output = self.end_conv(output) used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:]))) output = output / used_sigmas return output ================================================ FILE: models/normalization.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Normalization layers.""" import torch.nn as nn import torch import functools def get_normalization(config, conditional=False): """Obtain normalization modules from the config file.""" norm = config.model.normalization if conditional: if norm == 'InstanceNorm++': return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) else: raise NotImplementedError(f'{norm} not implemented yet.') else: if norm == 'InstanceNorm': return nn.InstanceNorm2d elif norm == 'InstanceNorm++': return InstanceNorm2dPlus elif norm == 'VarianceNorm': return VarianceNorm2d elif norm == 'GroupNorm': return nn.GroupNorm else: raise ValueError('Unknown normalization: %s' % norm) class ConditionalBatchNorm2d(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.bn = nn.BatchNorm2d(num_features, affine=False) if self.bias: self.embed = nn.Embedding(num_classes, num_features * 2) self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.uniform_() def forward(self, x, y): out = self.bn(x) if self.bias: gamma, beta = self.embed(y).chunk(2, dim=1) out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) else: gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1) * out return out class ConditionalInstanceNorm2d(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) if bias: self.embed = nn.Embedding(num_classes, num_features * 2) self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.uniform_() def forward(self, x, y): h = self.instance_norm(x) if self.bias: gamma, beta = self.embed(y).chunk(2, dim=-1) out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) else: gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1) * h return out class ConditionalVarianceNorm2d(nn.Module): def __init__(self, num_features, num_classes, bias=False): super().__init__() self.num_features = num_features self.bias = bias self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.normal_(1, 0.02) def forward(self, x, y): vars = torch.var(x, dim=(2, 3), keepdim=True) h = x / torch.sqrt(vars + 1e-5) gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1) * h return out class VarianceNorm2d(nn.Module): def __init__(self, num_features, bias=False): super().__init__() self.num_features = num_features self.bias = bias self.alpha = nn.Parameter(torch.zeros(num_features)) self.alpha.data.normal_(1, 0.02) def forward(self, x): vars = torch.var(x, dim=(2, 3), keepdim=True) h = x / torch.sqrt(vars + 1e-5) out = self.alpha.view(-1, self.num_features, 1, 1) * h return out class ConditionalNoneNorm2d(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias if bias: self.embed = nn.Embedding(num_classes, num_features * 2) self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.uniform_() def forward(self, x, y): if self.bias: gamma, beta = self.embed(y).chunk(2, dim=-1) out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) else: gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1) * x return out class NoneNorm2d(nn.Module): def __init__(self, num_features, bias=True): super().__init__() def forward(self, x): return x class InstanceNorm2dPlus(nn.Module): def __init__(self, num_features, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) self.alpha = nn.Parameter(torch.zeros(num_features)) self.gamma = nn.Parameter(torch.zeros(num_features)) self.alpha.data.normal_(1, 0.02) self.gamma.data.normal_(1, 0.02) if bias: self.beta = nn.Parameter(torch.zeros(num_features)) def forward(self, x): means = torch.mean(x, dim=(2, 3)) m = torch.mean(means, dim=-1, keepdim=True) v = torch.var(means, dim=-1, keepdim=True) means = (means - m) / (torch.sqrt(v + 1e-5)) h = self.instance_norm(x) if self.bias: h = h + means[..., None, None] * self.alpha[..., None, None] out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) else: h = h + means[..., None, None] * self.alpha[..., None, None] out = self.gamma.view(-1, self.num_features, 1, 1) * h return out class ConditionalInstanceNorm2dPlus(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) if bias: self.embed = nn.Embedding(num_classes, num_features * 3) self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, 2 * num_features) self.embed.weight.data.normal_(1, 0.02) def forward(self, x, y): means = torch.mean(x, dim=(2, 3)) m = torch.mean(means, dim=-1, keepdim=True) v = torch.var(means, dim=-1, keepdim=True) means = (means - m) / (torch.sqrt(v + 1e-5)) h = self.instance_norm(x) if self.bias: gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) h = h + means[..., None, None] * alpha[..., None, None] out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) else: gamma, alpha = self.embed(y).chunk(2, dim=-1) h = h + means[..., None, None] * alpha[..., None, None] out = gamma.view(-1, self.num_features, 1, 1) * h return out ================================================ FILE: models/unet.py ================================================ from . import utils import torch from torch import nn from torch.nn import functional as F class ConvBlock(nn.Module): """ A Convolutional Block that consists of two convolution layers each followed by instance normalization, relu activation and dropout. """ def __init__(self, in_chans, out_chans, stride=2): """ Args: in_chans (int): Number of channels in the input. out_chans (int): Number of channels in the output. drop_prob (float): Dropout probability. """ super().__init__() self.in_chans = in_chans self.out_chans = out_chans self.layers = nn.Sequential( nn.Conv2d(in_chans, out_chans, kernel_size=3, stride=stride, padding=1), nn.GroupNorm(num_groups=8, num_channels=out_chans), nn.LeakyReLU(), nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1), nn.GroupNorm(num_groups=8, num_channels=out_chans), nn.LeakyReLU(), ) def forward(self, tensor): """ Args: tensor (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] Returns: (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] """ return self.layers(tensor) def __repr__(self): return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans})' @utils.register_model(name='unet') class Unet(nn.Module): def __init__(self, in_chans=1, out_chans=1, chans=64, num_pool_layers=4, use_residual=True): super().__init__() # self.config = config # self.in_chans = config.model.in_chans # self.out_chans = config.model.out_chans # self.chans = config.model.chans # self.num_pool_layers = config.model.num_pool_layers # self.use_residual = config.model.use_residual self.in_chans = in_chans self.out_chans = out_chans self.chans = chans self.num_pool_layers = num_pool_layers self.use_residual = use_residual ch = self.chans self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, stride=1)]) for i in range(self.num_pool_layers - 1): self.down_sample_layers += [ConvBlock(ch, ch * 2, stride=2)] ch *= 2 # Size reduction happens at the beginning of a block, hence the need for stride here self.conv = ConvBlock(ch, ch, stride=2) self.up_sample_layers = nn.ModuleList() for i in range(self.num_pool_layers - 1): self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, stride=1)] ch //= 2 self.up_sample_layers += [ConvBlock(ch * 2, ch, stride=1)] self.conv2 = nn.Conv2d(ch, self.out_chans, kernel_size=1) def forward(self, tensor): """ Args: tensor (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] Returns: (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] """ stack = list() output = tensor # Apply down-sampling layers for layer in self.down_sample_layers: output = layer(output) stack.append(output) # output = F.avg_pool2d(output, kernel_size=2) output = self.conv(output) # Apply up-sampling layers for layer in self.up_sample_layers: output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False) output = torch.cat((output, stack.pop()), dim=1) output = layer(output) output = self.conv2(output) if self.use_residual: output = output + tensor return output ================================================ FILE: models/up_or_down_sampling.py ================================================ """Layers used for up-sampling or down-sampling images. Many functions are ported from https://github.com/NVlabs/stylegan2. """ import torch.nn as nn import torch import torch.nn.functional as F import numpy as np from op import upfirdn2d # Function ported from StyleGAN2 def get_weight(module, shape, weight_var='weight', kernel_init=None): """Get/create weight tensor for a convolution or fully-connected layer.""" return module.param(weight_var, kernel_init, shape) class Conv2d(nn.Module): """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" def __init__(self, in_ch, out_ch, kernel, up=False, down=False, resample_kernel=(1, 3, 3, 1), use_bias=True, kernel_init=None): super().__init__() assert not (up and down) assert kernel >= 1 and kernel % 2 == 1 self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) if kernel_init is not None: self.weight.data = kernel_init(self.weight.data.shape) if use_bias: self.bias = nn.Parameter(torch.zeros(out_ch)) self.up = up self.down = down self.resample_kernel = resample_kernel self.kernel = kernel self.use_bias = use_bias def forward(self, x): if self.up: x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) elif self.down: x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) else: x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) if self.use_bias: x = x + self.bias.reshape(1, -1, 1, 1) return x def naive_upsample_2d(x, factor=2): _N, C, H, W = x.shape x = torch.reshape(x, (-1, C, H, 1, W, 1)) x = x.repeat(1, 1, 1, factor, 1, factor) return torch.reshape(x, (-1, C, H * factor, W * factor)) def naive_downsample_2d(x, factor=2): _N, C, H, W = x.shape x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) return torch.mean(x, dim=(3, 5)) def upsample_conv_2d(x, w, k=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 # Check weight shape. assert len(w.shape) == 4 convH = w.shape[2] convW = w.shape[3] inC = w.shape[1] outC = w.shape[0] assert convW == convH # Setup filter kernel. if k is None: k = [1] * factor k = _setup_kernel(k) * (gain * (factor ** 2)) p = (k.shape[0] - factor) - (convW - 1) stride = (factor, factor) # Determine data dimensions. stride = [1, 1, factor, factor] output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) assert output_padding[0] >= 0 and output_padding[1] >= 0 num_groups = _shape(x, 1) // inC # Transpose weights. w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) ## Original TF code. # x = tf.nn.conv2d_transpose( # x, # w, # output_shape=output_shape, # strides=stride, # padding='VALID', # data_format=data_format) ## JAX equivalent return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) def conv_downsample_2d(x, w, k=None, factor=2, gain=1): """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 _outC, _inC, convH, convW = w.shape assert convW == convH if k is None: k = [1] * factor k = _setup_kernel(k) * gain p = (k.shape[0] - factor) + (convW - 1) s = [factor, factor] x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) return F.conv2d(x, w, stride=s, padding=0) def _setup_kernel(k): k = np.asarray(k, dtype=np.float32) if k.ndim == 1: k = np.outer(k, k) k /= np.sum(k) assert k.ndim == 2 assert k.shape[0] == k.shape[1] return k def _shape(x, dim): return x.shape[dim] def upsample_2d(x, k=None, factor=2, gain=1): r"""Upsample a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the upsampling factor. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H * factor, W * factor]` """ assert isinstance(factor, int) and factor >= 1 if k is None: k = [1] * factor k = _setup_kernel(k) * (gain * (factor ** 2)) p = k.shape[0] - factor return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) def downsample_2d(x, k=None, factor=2, gain=1): r"""Downsample a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. Args: x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 if k is None: k = [1] * factor k = _setup_kernel(k) * gain p = k.shape[0] - factor return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) ================================================ FILE: models/utils.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """All functions and modules related to model definition. """ import torch import sde_lib import numpy as np _MODELS = {} def register_model(cls=None, *, name=None): """A decorator for registering model classes.""" def _register(cls): if name is None: local_name = cls.__name__ else: local_name = name if local_name in _MODELS: raise ValueError(f'Already registered model with name: {local_name}') _MODELS[local_name] = cls return cls if cls is None: return _register else: return _register(cls) def get_model(name): return _MODELS[name] def get_sigmas(config): """Get sigmas --- the set of noise levels for SMLD from config files. Args: config: A ConfigDict object parsed from the config file Returns: sigmas: a jax numpy arrary of noise levels """ sigmas = np.exp( np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) return sigmas def get_ddpm_params(config): """Get betas and alphas --- parameters used in the original DDPM paper.""" num_diffusion_timesteps = 1000 # parameters need to be adapted if number of time steps differs from 1000 beta_start = config.model.beta_min / config.model.num_scales beta_end = config.model.beta_max / config.model.num_scales betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) return { 'betas': betas, 'alphas': alphas, 'alphas_cumprod': alphas_cumprod, 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 'beta_min': beta_start * (num_diffusion_timesteps - 1), 'beta_max': beta_end * (num_diffusion_timesteps - 1), 'num_diffusion_timesteps': num_diffusion_timesteps } def create_model(config): """Create the score model.""" model_name = config.model.name score_model = get_model(model_name)(config) score_model = score_model.to(config.device) score_model = torch.nn.DataParallel(score_model) return score_model def get_model_fn(model, train=False): """Create a function to give the output of the score-based model. Args: model: The score model. train: `True` for training and `False` for evaluation. Returns: A model function. """ def model_fn(x, labels): """Compute the output of the score-based model. Args: x: A mini-batch of input data. labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently for different models. Returns: A tuple of (model output, new mutable states) """ if not train: model.eval() return model(x, labels) else: model.train() return model(x, labels) return model_fn def get_score_fn(sde, model, train=False, continuous=False): """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. model: A score model. train: `True` for training and `False` for evaluation. continuous: If `True`, the score-based model is expected to directly take continuous time steps. Returns: A score function. """ model_fn = get_model_fn(model, train=train) if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): def score_fn(x, t): # Scale neural network output by standard deviation and flip sign if continuous or isinstance(sde, sde_lib.subVPSDE): # For VP-trained models, t=0 corresponds to the lowest noise level # The maximum value of time embedding is assumed to 999 for # continuously-trained models. labels = t * 999 score = model_fn(x, labels) std = sde.marginal_prob(torch.zeros_like(x), t)[1] else: # For VP-trained models, t=0 corresponds to the lowest noise level labels = t * (sde.N - 1) score = model_fn(x, labels) std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] score = -score / std[:, None, None, None] return score elif isinstance(sde, sde_lib.VESDE): def score_fn(x, t): if continuous: labels = sde.marginal_prob(torch.zeros_like(x), t)[1] else: # For VE-trained models, t=0 corresponds to the highest noise level labels = sde.T - t labels *= sde.N - 1 labels = torch.round(labels).long() score = model_fn(x, labels) return score else: raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") return score_fn def to_flattened_numpy(x): """Flatten a torch tensor `x` and convert it to numpy.""" return x.detach().cpu().numpy().reshape((-1,)) def from_flattened_numpy(x, shape): """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" return torch.from_numpy(x.reshape(shape)) ================================================ FILE: op/__init__.py ================================================ from .fused_act import FusedLeakyReLU, fused_leaky_relu from .upfirdn2d import upfirdn2d ================================================ FILE: op/fused_act.py ================================================ import os import torch from torch import nn from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load module_path = os.path.dirname(__file__) fused = load( "fused", sources=[ os.path.join(module_path, "fused_bias_act.cpp"), os.path.join(module_path, "fused_bias_act_kernel.cu"), ], ) class FusedLeakyReLUFunctionBackward(Function): @staticmethod def forward(ctx, grad_output, out, negative_slope, scale): ctx.save_for_backward(out) ctx.negative_slope = negative_slope ctx.scale = scale empty = grad_output.new_empty(0) grad_input = fused.fused_bias_act( grad_output, empty, out, 3, 1, negative_slope, scale ) dim = [0] if grad_input.ndim > 2: dim += list(range(2, grad_input.ndim)) grad_bias = grad_input.sum(dim).detach() return grad_input, grad_bias @staticmethod def backward(ctx, gradgrad_input, gradgrad_bias): out, = ctx.saved_tensors gradgrad_out = fused.fused_bias_act( gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale ) return gradgrad_out, None, None, None class FusedLeakyReLUFunction(Function): @staticmethod def forward(ctx, input, bias, negative_slope, scale): empty = input.new_empty(0) out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) ctx.save_for_backward(out) ctx.negative_slope = negative_slope ctx.scale = scale return out @staticmethod def backward(ctx, grad_output): out, = ctx.saved_tensors grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( grad_output, out, ctx.negative_slope, ctx.scale ) return grad_input, grad_bias, None, None class FusedLeakyReLU(nn.Module): def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): super().__init__() self.bias = nn.Parameter(torch.zeros(channel)) self.negative_slope = negative_slope self.scale = scale def forward(self, input): return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): if input.device.type == "cpu": rest_dim = [1] * (input.ndim - bias.ndim - 1) return ( F.leaky_relu( input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 ) * scale ) else: return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) ================================================ FILE: op/fused_bias_act.cpp ================================================ #include torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { CHECK_CUDA(input); CHECK_CUDA(bias); return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); } ================================================ FILE: op/fused_bias_act_kernel.cu ================================================ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include template static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; scalar_t zero = 0.0; for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { scalar_t x = p_x[xi]; if (use_bias) { x += p_b[(xi / step_b) % size_b]; } scalar_t ref = use_ref ? p_ref[xi] : zero; scalar_t y; switch (act * 10 + grad) { default: case 10: y = x; break; case 11: y = x; break; case 12: y = 0.0; break; case 30: y = (x > 0.0) ? x : x * alpha; break; case 31: y = (ref > 0.0) ? x : x * alpha; break; case 32: y = 0.0; break; } out[xi] = y * scale; } } torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); auto x = input.contiguous(); auto b = bias.contiguous(); auto ref = refer.contiguous(); int use_bias = b.numel() ? 1 : 0; int use_ref = ref.numel() ? 1 : 0; int size_x = x.numel(); int size_b = b.numel(); int step_b = 1; for (int i = 1 + 1; i < x.dim(); i++) { step_b *= x.size(i); } int loop_x = 4; int block_size = 4 * 32; int grid_size = (size_x - 1) / (loop_x * block_size) + 1; auto y = torch::empty_like(x); AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { fused_bias_act_kernel<<>>( y.data_ptr(), x.data_ptr(), b.data_ptr(), ref.data_ptr(), act, grad, alpha, scale, loop_x, size_x, step_b, size_b, use_bias, use_ref ); }); return y; } ================================================ FILE: op/upfirdn2d.cpp ================================================ #include torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { CHECK_CUDA(input); CHECK_CUDA(kernel); return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); } ================================================ FILE: op/upfirdn2d.py ================================================ import os import torch from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load module_path = os.path.dirname(__file__) upfirdn2d_op = load( "upfirdn2d", sources=[ os.path.join(module_path, "upfirdn2d.cpp"), os.path.join(module_path, "upfirdn2d_kernel.cu"), ], ) class UpFirDn2dBackward(Function): @staticmethod def forward( ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size ): up_x, up_y = up down_x, down_y = down g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) grad_input = upfirdn2d_op.upfirdn2d( grad_output, grad_kernel, down_x, down_y, up_x, up_y, g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1, ) grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) ctx.save_for_backward(kernel) pad_x0, pad_x1, pad_y0, pad_y1 = pad ctx.up_x = up_x ctx.up_y = up_y ctx.down_x = down_x ctx.down_y = down_y ctx.pad_x0 = pad_x0 ctx.pad_x1 = pad_x1 ctx.pad_y0 = pad_y0 ctx.pad_y1 = pad_y1 ctx.in_size = in_size ctx.out_size = out_size return grad_input @staticmethod def backward(ctx, gradgrad_input): kernel, = ctx.saved_tensors gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) gradgrad_out = upfirdn2d_op.upfirdn2d( gradgrad_input, kernel, ctx.up_x, ctx.up_y, ctx.down_x, ctx.down_y, ctx.pad_x0, ctx.pad_x1, ctx.pad_y0, ctx.pad_y1, ) # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) gradgrad_out = gradgrad_out.view( ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] ) return gradgrad_out, None, None, None, None, None, None, None, None class UpFirDn2d(Function): @staticmethod def forward(ctx, input, kernel, up, down, pad): up_x, up_y = up down_x, down_y = down pad_x0, pad_x1, pad_y0, pad_y1 = pad kernel_h, kernel_w = kernel.shape batch, channel, in_h, in_w = input.shape ctx.in_size = input.shape input = input.reshape(-1, in_h, in_w, 1) ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 ctx.out_size = (out_h, out_w) ctx.up = (up_x, up_y) ctx.down = (down_x, down_y) ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) g_pad_x0 = kernel_w - pad_x0 - 1 g_pad_y0 = kernel_h - pad_y0 - 1 g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) out = upfirdn2d_op.upfirdn2d( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ) # out = out.view(major, out_h, out_w, minor) out = out.view(-1, channel, out_h, out_w) return out @staticmethod def backward(ctx, grad_output): kernel, grad_kernel = ctx.saved_tensors grad_input = UpFirDn2dBackward.apply( grad_output, kernel, grad_kernel, ctx.up, ctx.down, ctx.pad, ctx.g_pad, ctx.in_size, ctx.out_size, ) return grad_input, None, None, None, None def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): if input.device.type == "cpu": out = upfirdn2d_native( input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] ) else: out = UpFirDn2d.apply( input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) ) return out def upfirdn2d_native( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad( out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) ================================================ FILE: op/upfirdn2d_kernel.cu ================================================ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include static __host__ __device__ __forceinline__ int floor_div(int a, int b) { int c = a / b; if (c * b > a) { c--; } return c; } struct UpFirDn2DKernelParams { int up_x; int up_y; int down_x; int down_y; int pad_x0; int pad_x1; int pad_y0; int pad_y1; int major_dim; int in_h; int in_w; int minor_dim; int kernel_h; int kernel_w; int out_h; int out_w; int loop_major; int loop_x; }; template __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; int out_y = minor_idx / p.minor_dim; minor_idx -= out_y * p.minor_dim; int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; int major_idx_base = blockIdx.z * p.loop_major; if (out_x_base >= p.out_w || out_y >= p.out_h || major_idx_base >= p.major_dim) { return; } int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major && major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, out_x = out_x_base; loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; const scalar_t *x_p = &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; int x_px = p.minor_dim; int k_px = -p.up_x; int x_py = p.in_w * p.minor_dim; int k_py = -p.up_y * p.kernel_w; scalar_t v = 0.0f; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += static_cast(*x_p) * static_cast(*k_p); x_p += x_px; k_p += k_px; } x_p += x_py - w * x_px; k_p += k_py - w * k_px; } out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } template __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; __shared__ volatile float sk[kernel_h][kernel_w]; __shared__ volatile float sx[tile_in_h][tile_in_w]; int minor_idx = blockIdx.x; int tile_out_y = minor_idx / p.minor_dim; minor_idx -= tile_out_y * p.minor_dim; tile_out_y *= tile_out_h; int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; int major_idx_base = blockIdx.z * p.loop_major; if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { return; } for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { int ky = tap_idx / kernel_w; int kx = tap_idx - ky * kernel_w; scalar_t v = 0.0; if (kx < p.kernel_w & ky < p.kernel_h) { v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; } sk[ky][kx] = v; } for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; int tile_in_x = floor_div(tile_mid_x, up_x); int tile_in_y = floor_div(tile_mid_y, up_y); __syncthreads(); for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { int rel_in_y = in_idx / tile_in_w; int rel_in_x = in_idx - rel_in_y * tile_in_w; int in_x = rel_in_x + tile_in_x; int in_y = rel_in_y + tile_in_y; scalar_t v = 0.0; if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; } sx[rel_in_y][rel_in_x] = v; } __syncthreads(); for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { int rel_out_y = out_idx / tile_out_w; int rel_out_x = out_idx - rel_out_y * tile_out_w; int out_x = rel_out_x + tile_out_x; int out_y = rel_out_y + tile_out_y; int mid_x = tile_mid_x + rel_out_x * down_x; int mid_y = tile_mid_y + rel_out_y * down_y; int in_x = floor_div(mid_x, up_x); int in_y = floor_div(mid_y, up_y); int rel_in_x = in_x - tile_in_x; int rel_in_y = in_y - tile_in_y; int kernel_x = (in_x + 1) * up_x - mid_x - 1; int kernel_y = (in_y + 1) * up_y - mid_y - 1; scalar_t v = 0.0; #pragma unroll for (int y = 0; y < kernel_h / up_y; y++) #pragma unroll for (int x = 0; x < kernel_w / up_x; x++) v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; if (out_x < p.out_w & out_y < p.out_h) { out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } } } torch::Tensor upfirdn2d_op(const torch::Tensor &input, const torch::Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); UpFirDn2DKernelParams p; auto x = input.contiguous(); auto k = kernel.contiguous(); p.major_dim = x.size(0); p.in_h = x.size(1); p.in_w = x.size(2); p.minor_dim = x.size(3); p.kernel_h = k.size(0); p.kernel_w = k.size(1); p.up_x = up_x; p.up_y = up_y; p.down_x = down_x; p.down_y = down_y; p.pad_x0 = pad_x0; p.pad_x1 = pad_x1; p.pad_y0 = pad_y0; p.pad_y1 = pad_y1; p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); int mode = -1; int tile_out_h = -1; int tile_out_w = -1; if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 1; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { mode = 2; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 3; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 4; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 5; tile_out_h = 8; tile_out_w = 32; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 6; tile_out_h = 8; tile_out_w = 32; } dim3 block_size; dim3 grid_size; if (tile_out_h > 0 && tile_out_w > 0) { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 1; block_size = dim3(32 * 8, 1, 1); grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, (p.major_dim - 1) / p.loop_major + 1); } else { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 4; block_size = dim3(4, 32, 1); grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, (p.out_w - 1) / (p.loop_x * block_size.y) + 1, (p.major_dim - 1) / p.loop_major + 1); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { switch (mode) { case 1: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 2: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 3: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 4: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 5: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 6: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; default: upfirdn2d_kernel_large<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p); } }); return out; } ================================================ FILE: physics/ct.py ================================================ import torch import numpy as np from .radon import Radon, IRadon class CT(): def __init__(self, img_width, radon_view, uniform=True, circle=False, device='cuda:0'): if uniform: theta = np.linspace(0, 180, radon_view, endpoint=False) theta_all = np.linspace(0, 180, 180, endpoint=False) else: theta = torch.arange(radon_view) theta_all = torch.arange(radon_view) self.radon = Radon(img_width, theta, circle).to(device) self.radon_all = Radon(img_width, theta_all, circle).to(device) self.iradon_all = IRadon(img_width, theta_all, circle).to(device) self.iradon = IRadon(img_width, theta, circle).to(device) self.radont = IRadon(img_width, theta, circle, use_filter=None).to(device) def A(self, x): return self.radon(x) def A_all(self, x): return self.radon_all(x) def A_all_dagger(self, x): return self.iradon_all(x) def A_dagger(self, y): return self.iradon(y) def AT(self, y): return self.radont(y) class CT_LA(): """ Limited Angle tomography """ def __init__(self, img_width, radon_view, uniform=True, circle=False, device='cuda:0'): if uniform: theta = np.linspace(0, 180, radon_view, endpoint=False) else: theta = torch.arange(radon_view) self.radon = Radon(img_width, theta, circle).to(device) self.iradon = IRadon(img_width, theta, circle).to(device) self.radont = IRadon(img_width, theta, circle, use_filter=None).to(device) def A(self, x): return self.radon(x) def A_dagger(self, y): return self.iradon(y) def AT(self, y): return self.radont(y) ================================================ FILE: physics/inpainting.py ================================================ import os import torch class Inpainting(): def __init__(self, img_heigth=512, img_width=512, mode='random', mask_rate=0.3, resize=False, device='cuda:0'): mask_path = './physics/mask_random{}.pt'.format(mask_rate) if os.path.exists(mask_path): self.mask = torch.load(mask_path).to(device) else: self.mask = torch.ones(img_heigth, img_width, device=device) self.mask[torch.rand_like(self.mask) > 1 - mask_rate] = 0 torch.save(self.mask, mask_path) def A(self, x): return torch.einsum('kl,ijkl->ijkl', self.mask, x) def A_dagger(self, x): return torch.einsum('kl,ijkl->ijkl', self.mask, x) ================================================ FILE: physics/radon/__init__.py ================================================ from .radon import Radon, IRadon from .stackgram import Stackgram, IStackgram ================================================ FILE: physics/radon/filters.py ================================================ import torch from torch import nn import torch.nn.functional as F from .utils import PI, fftfreq '''source: https://github.com/matteo-ronchetti/torch-radon''' class AbstractFilter(nn.Module): def __init__(self): super(AbstractFilter, self).__init__() def forward(self, x): input_size = x.shape[2] projection_size_padded = \ max(64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil())) pad_width = projection_size_padded - input_size padded_tensor = F.pad(x, (0,0,0,pad_width)) f = self._get_fourier_filter(padded_tensor.shape[2]).to(x.device) fourier_filter = self.create_filter(f)[..., None] projection = torch.fft.fft(padded_tensor, dim=2) * fourier_filter return torch.real(torch.fft.ifft(projection, dim=2)[:,:,:input_size,:]) def _get_fourier_filter(self, size): n = torch.cat([ torch.arange(1, size / 2 + 1, 2), torch.arange(size / 2 - 1, 0, -2) ]) f = torch.zeros(size) f[0] = 0.25 f[1::2] = -1 / (PI * n) ** 2 fourier_filter = torch.fft.fft(f) return 2*fourier_filter def create_filter(self, f): raise NotImplementedError class RampFilter(AbstractFilter): def __init__(self): super(RampFilter, self).__init__() def create_filter(self, f): return f class HannFilter(AbstractFilter): def __init__(self): super(HannFilter, self).__init__() def create_filter(self, f): n = torch.arange(0, f.shape[0]) hann = 0.5 - 0.5*(2.0*PI*n/(f.shape[0]-1)).cos() return f*hann.roll(hann.shape[0]//2,0).unsqueeze(-1) class LearnableFilter(AbstractFilter): def __init__(self, filter_size): super(LearnableFilter, self).__init__() self.filter = nn.Parameter(2*fftfreq(filter_size).abs().view(-1, 1)) def forward(self, x): fourier_filter = self.filter.unsqueeze(-1).repeat(1,1,2).to(x.device) projection = torch.rfft(x.transpose(2,3), 1, onesided=False).transpose(2,3) * fourier_filter return torch.irfft(projection.transpose(2,3), 1, onesided=False).transpose(2,3) # projection = torch.fft.rfft(x.transpose(2, 3), 1).transpose(2, 3) * fourier_filter # return torch.fft.irfft(projection.transpose(2, 3), 1).transpose(2, 3) ================================================ FILE: physics/radon/radon.py ================================================ import torch from torch import nn import torch.nn.functional as F from physics.radon.filters import RampFilter from physics.radon.utils import PI, SQRT2, deg2rad, affine_grid, grid_sample '''source: https://github.com/matteo-ronchetti/torch-radon''' class Radon(nn.Module): def __init__(self, in_size=None, theta=None, circle=True, dtype=torch.float): super(Radon, self).__init__() self.circle = circle self.theta = theta if theta is None: self.theta = torch.arange(180) self.dtype = dtype self.all_grids = None if in_size is not None: self.all_grids = self._create_grids(self.theta, in_size, circle) def forward(self, x): N, C, W, H = x.shape assert (W == H) if self.all_grids is None: self.all_grids = self._create_grids(self.theta, W, self.circle) if not self.circle: diagonal = SQRT2 * W pad = int((diagonal - W).ceil()) new_center = (W + pad) // 2 old_center = W // 2 pad_before = new_center - old_center pad_width = (pad_before, pad - pad_before) x = F.pad(x, (pad_width[0], pad_width[1], pad_width[0], pad_width[1])) N, C, W, _ = x.shape out = torch.zeros(N, C, W, len(self.theta), device=x.device, dtype=self.dtype) for i in range(len(self.theta)): rotated = grid_sample(x, self.all_grids[i].repeat(N, 1, 1, 1).to(x.device)) out[..., i] = rotated.sum(2) return out def _create_grids(self, angles, grid_size, circle): if not circle: grid_size = int((SQRT2 * grid_size).ceil()) all_grids = [] for theta in angles: theta = deg2rad(theta) R = torch.tensor([[ [theta.cos(), theta.sin(), 0], [-theta.sin(), theta.cos(), 0], ]], dtype=self.dtype) all_grids.append(affine_grid(R, torch.Size([1, 1, grid_size, grid_size]))) return all_grids class IRadon(nn.Module): def __init__(self, in_size=None, theta=None, circle=True, use_filter=RampFilter(), out_size=None, dtype=torch.float): super(IRadon, self).__init__() self.circle = circle self.theta = theta if theta is not None else torch.arange(180) self.out_size = out_size self.in_size = in_size self.dtype = dtype self.ygrid, self.xgrid, self.all_grids = None, None, None if in_size is not None: self.ygrid, self.xgrid = self._create_yxgrid(in_size, circle) self.all_grids = self._create_grids(self.theta, in_size, circle) self.filter = use_filter if use_filter is not None else lambda x: x def forward(self, x): it_size = x.shape[2] ch_size = x.shape[1] if self.in_size is None: self.in_size = int((it_size / SQRT2).floor()) if not self.circle else it_size # if None in [self.ygrid, self.xgrid, self.all_grids]: if self.ygrid is None or self.xgrid is None or self.all_grids is None : self.ygrid, self.xgrid = self._create_yxgrid(self.in_size, self.circle) self.all_grids = self._create_grids(self.theta, self.in_size, self.circle) # sinogram x = self.filter(x) reco = torch.zeros(x.shape[0], ch_size, it_size, it_size, device=x.device, dtype=self.dtype) for i_theta in range(len(self.theta)): reco += grid_sample(x, self.all_grids[i_theta].repeat(reco.shape[0], 1, 1, 1).to(x.device)) if not self.circle: W = self.in_size diagonal = it_size pad = int(torch.tensor(diagonal - W, dtype=torch.float).ceil()) new_center = (W + pad) // 2 old_center = W // 2 pad_before = new_center - old_center pad_width = (pad_before, pad - pad_before) reco = F.pad(reco, (-pad_width[0], -pad_width[1], -pad_width[0], -pad_width[1])) if self.circle: reconstruction_circle = (self.xgrid ** 2 + self.ygrid ** 2) <= 1 reconstruction_circle = reconstruction_circle.repeat(x.shape[0], ch_size, 1, 1) reco[~reconstruction_circle] = 0. reco = reco * PI.item() / (2 * len(self.theta)) if self.out_size is not None: pad = (self.out_size - self.in_size) // 2 reco = F.pad(reco, (pad, pad, pad, pad)) return reco def _create_yxgrid(self, in_size, circle): if not circle: in_size = int((SQRT2 * in_size).ceil()) unitrange = torch.linspace(-1, 1, in_size, dtype=self.dtype) return torch.meshgrid(unitrange, unitrange) def _XYtoT(self, theta): T = self.xgrid * (deg2rad(theta)).cos() - self.ygrid * (deg2rad(theta)).sin() return T def _create_grids(self, angles, grid_size, circle): if not circle: grid_size = int((SQRT2 * grid_size).ceil()) all_grids = [] for i_theta in range(len(angles)): X = torch.ones(grid_size, dtype=self.dtype).view(-1, 1).repeat(1, grid_size) * i_theta * 2. / ( len(angles) - 1) - 1. Y = self._XYtoT(angles[i_theta]) all_grids.append(torch.cat((X.unsqueeze(-1), Y.unsqueeze(-1)), dim=-1).unsqueeze(0)) return all_grids if __name__ == '__main__': img_width = 2 num_proj = 180 device = 'cuda:0' radon = Radon(in_size=img_width, theta=torch.arange(num_proj), circle=False).to(device) iradon = IRadon(in_size=img_width, theta=torch.arange(num_proj), circle=False).to(device) img = torch.randn([1, 1, 2, 2]).to(device) sinogram = radon(img) b_img = iradon(sinogram) ================================================ FILE: physics/radon/stackgram.py ================================================ import torch from torch import nn import torch.nn.functional as F from .utils import SQRT2, deg2rad, affine_grid, grid_sample '''source: https://github.com/matteo-ronchetti/torch-radon''' class Stackgram(nn.Module): def __init__(self, out_size, theta=None, circle=True, mode='nearest', dtype=torch.float): super(Stackgram, self).__init__() self.circle = circle self.theta = theta if theta is None: self.theta = torch.arange(180) self.out_size = out_size self.in_size = in_size = out_size if circle else int((SQRT2*out_size).ceil()) self.dtype = dtype self.all_grids = self._create_grids(self.theta, in_size) self.mode = mode def forward(self, x): stackgram = torch.zeros(x.shape[0], len(self.theta), self.in_size, self.in_size, device=x.device, dtype=self.dtype) for i_theta in range(len(self.theta)): repline = x[...,i_theta] repline = repline.unsqueeze(-1).repeat(1,1,1,repline.shape[2]) linogram = grid_sample(repline, self.all_grids[i_theta].repeat(x.shape[0],1,1,1).to(x.device), mode=self.mode) stackgram[:,i_theta] = linogram return stackgram def _create_grids(self, angles, grid_size): all_grids = [] for i_theta in range(len(angles)): t = deg2rad(angles[i_theta]) R = torch.tensor([[t.sin(), t.cos(), 0.],[t.cos(), -t.sin(), 0.]], dtype=self.dtype).unsqueeze(0) all_grids.append(affine_grid(R, torch.Size([1,1,grid_size,grid_size]))) return all_grids class IStackgram(nn.Module): def __init__(self, out_size, theta=None, circle=True, mode='bilinear', dtype=torch.float): super(IStackgram, self).__init__() self.circle = circle self.theta = theta if theta is None: self.theta = torch.arange(180) self.out_size = out_size self.in_size = in_size = out_size if circle else int((SQRT2*out_size).ceil()) self.dtype = dtype self.all_grids = self._create_grids(self.theta, in_size) self.mode = mode def forward(self, x): sinogram = torch.zeros(x.shape[0], 1, self.in_size, len(self.theta), device=x.device, dtype=self.dtype) for i_theta in range(len(self.theta)): linogram = x[:,i_theta].unsqueeze(1) repline = grid_sample(linogram, self.all_grids[i_theta].repeat(x.shape[0],1,1,1).to(x.device), mode=self.mode) repline = repline[...,repline.shape[-1]//2] sinogram[...,i_theta] = repline return sinogram def _create_grids(self, angles, grid_size): all_grids = [] for i_theta in range(len(angles)): t = deg2rad(angles[i_theta]) R = torch.tensor([[t.sin(), t.cos(), 0.],[t.cos(), -t.sin(), 0.]], dtype=self.dtype).unsqueeze(0) all_grids.append(affine_grid(R, torch.Size([1,1,grid_size,grid_size]))) return all_grids ================================================ FILE: physics/radon/utils.py ================================================ import torch import torch.nn.functional as F '''source: https://github.com/matteo-ronchetti/torch-radon''' if torch.__version__>'1.2.0': affine_grid = lambda theta, size: F.affine_grid(theta, size, align_corners=True) grid_sample = lambda input, grid, mode='bilinear': F.grid_sample(input, grid, align_corners=True, mode=mode) else: affine_grid = F.affine_grid grid_sample = F.grid_sample # constants PI = 4*torch.ones(1).atan() SQRT2 = (2*torch.ones(1)).sqrt() def fftfreq(n): val = 1.0/n results = torch.zeros(n) N = (n-1)//2 + 1 p1 = torch.arange(0, N) results[:N] = p1 p2 = torch.arange(-(n//2), 0) results[N:] = p2 return results*val def deg2rad(x): return x*PI/180 ================================================ FILE: run_lib.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """Training and evaluation for score-based generative models. """ import gc import io import os import time from pathlib import Path import numpy as np import logging # Keep the import below for registering all model definitions from models import ddpm, ncsnv2, ncsnpp, unet import losses import sampling from models import utils as mutils from models.ema import ExponentialMovingAverage import datasets #import evaluation import likelihood import sde_lib from absl import flags import torch from torch import nn from torch.utils import tensorboard from torchvision.utils import make_grid, save_image from utils import save_checkpoint, restore_checkpoint, get_mask, kspace_to_nchw, root_sum_of_squares FLAGS = flags.FLAGS def train(config, workdir): """Runs the training pipeline. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ # Create directories for experimental logs sample_dir = os.path.join(workdir, "samples") Path(sample_dir).mkdir(parents=True, exist_ok=True) tb_dir = os.path.join(workdir, "tensorboard") Path(tb_dir).mkdir(parents=True, exist_ok=True) writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. score_model = mutils.create_model(config) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) optimizer = losses.get_optimizer(config, score_model.parameters()) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) # Create checkpoints directory checkpoint_dir = os.path.join(workdir, "checkpoints") checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta") Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) Path(checkpoint_meta_dir).mkdir(parents=True, exist_ok=True) # Resume training when intermediate checkpoints are detected state = restore_checkpoint(checkpoint_meta_dir, state, config.device) initial_step = int(state['step']) # Build pytorch dataloader for training train_dl, eval_dl = datasets.create_dataloader(config) num_data = len(train_dl.dataset) # Create data normalizer and its inverse scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) # Setup SDEs if config.training.sde.lower() == 'vpsde': sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'subvpsde': sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'vesde': sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 else: raise NotImplementedError(f"SDE {config.training.sde} unknown.") # Build one-step training and evaluation functions optimize_fn = losses.optimization_manager(config) continuous = config.training.continuous reduce_mean = config.training.reduce_mean likelihood_weighting = config.training.likelihood_weighting train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting) eval_step_fn = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting) # Building sampling functions if config.training.snapshot_sampling: sampling_shape = (config.training.batch_size, config.data.num_channels, config.data.image_size, config.data.image_size) sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps) # In case there are multiple hosts (e.g., TPU pods), only log to host 0 logging.info("Starting training loop at step %d." % (initial_step,)) for epoch in range(1, config.training.epochs): print('=================================================') print(f'Epoch: {epoch}') print('=================================================') for step, batch in enumerate(train_dl, start=1): batch = scaler(batch.to(config.device)) # (b, 1, 320, 320, 2) --> (b, 2, 320, 320) # batch = kspace_to_nchw(torch.view_as_real(batch)) # Execute one training step loss = train_step_fn(state, batch) if step % config.training.log_freq == 0: logging.info("step: %d, training_loss: %.5e" % (step, loss.item())) global_step = num_data * epoch + step writer.add_scalar("training_loss", scalar_value=loss, global_step=global_step) if step != 0 and step % config.training.snapshot_freq_for_preemption == 0: save_checkpoint(checkpoint_meta_dir, state) # Report the loss on an evaluation dataset periodically # if step % config.training.eval_freq == 0: # eval_batch = scaler(next(iter(eval_dl)).to(config.device)) # eval_loss = eval_step_fn(state, eval_batch) # logging.info("step: %d, eval_loss: %.5e" % (step, eval_loss.item())) # global_step = num_data * epoch + step # writer.add_scalar("eval_loss", scalar_value=eval_loss.item(), global_step=global_step) # Save a checkpoint for every epoch save_checkpoint(checkpoint_dir, state, name=f'checkpoint_{epoch}.pth') # Generate and save samples for every epoch if config.training.snapshot_sampling: print('sampling') ema.store(score_model.parameters()) ema.copy_to(score_model.parameters()) sample, n = sampling_fn(score_model) if config.data.is_complex: sample = root_sum_of_squares(sample, dim=1).unsqueeze(dim=0) ema.restore(score_model.parameters()) this_sample_dir = os.path.join(sample_dir, "iter_{}".format(epoch)) Path(this_sample_dir).mkdir(parents=True, exist_ok=True) nrow = int(np.sqrt(sample.shape[0])) image_grid = make_grid(sample, nrow, padding=2) sample = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) np.save(os.path.join(this_sample_dir, "sample"), sample) save_image(image_grid, os.path.join(this_sample_dir, "sample.png")) def evaluate(config, workdir, eval_folder="eval"): """Evaluate trained models. Args: config: Configuration to use. workdir: Working directory for checkpoints. eval_folder: The subfolder for storing evaluation results. Default to "eval". """ # Create directory to eval_folder eval_dir = os.path.join(workdir, eval_folder) Path(eval_dir).mkdir(parents=True, exist_ok=True) # Build pytorch dataloader for training train_dl, eval_dl = datasets.create_dataloader(config) num_data = len(train_dl.dataset) # Create data normalizer and its inverse scaler = datasets.get_data_scaler(config) inverse_scaler = datasets.get_data_inverse_scaler(config) # Initialize model score_model = mutils.create_model(config) optimizer = losses.get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) checkpoint_dir = os.path.join(workdir, "checkpoints") # Setup SDEs if config.training.sde.lower() == 'vpsde': sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'subvpsde': sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 elif config.training.sde.lower() == 'vesde': sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) sampling_eps = 1e-5 else: raise NotImplementedError(f"SDE {config.training.sde} unknown.") # Create the one-step evaluation function when loss computation is enabled if config.eval.enable_loss: optimize_fn = losses.optimization_manager(config) continuous = config.training.continuous likelihood_weighting = config.training.likelihood_weighting reduce_mean = config.training.reduce_mean eval_step = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting) # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset(config, uniform_dequantization=True, evaluation=True) if config.eval.bpd_dataset.lower() == 'train': ds_bpd = train_ds_bpd bpd_num_repeats = 1 elif config.eval.bpd_dataset.lower() == 'test': # Go over the dataset 5 times when computing likelihood on the test dataset ds_bpd = eval_ds_bpd bpd_num_repeats = 5 else: raise ValueError(f"No bpd dataset {config.eval.bpd_dataset} recognized.") # Build the likelihood computation function when likelihood is enabled if config.eval.enable_bpd: likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler) # Build the sampling function when sampling is enabled if config.eval.enable_sampling: sampling_shape = (config.eval.batch_size, config.data.num_channels, config.data.image_size, config.data.image_size) sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps) # Use inceptionV3 for images with resolution higher than 256. inceptionv3 = config.data.image_size >= 256 inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3) begin_ckpt = config.eval.begin_ckpt logging.info("begin checkpoint: %d" % (begin_ckpt,)) for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1): # Wait if the target checkpoint doesn't exist yet waiting_message_printed = False ckpt_filename = os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(ckpt)) while not tf.io.gfile.exists(ckpt_filename): if not waiting_message_printed: logging.warning("Waiting for the arrival of checkpoint_%d" % (ckpt,)) waiting_message_printed = True time.sleep(60) # Wait for 2 additional mins in case the file exists but is not ready for reading ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{ckpt}.pth') try: state = restore_checkpoint(ckpt_path, state, device=config.device) except: time.sleep(60) try: state = restore_checkpoint(ckpt_path, state, device=config.device) except: time.sleep(120) state = restore_checkpoint(ckpt_path, state, device=config.device) ema.copy_to(score_model.parameters()) # Compute the loss function on the full evaluation dataset if loss computation is enabled if config.eval.enable_loss: all_losses = [] eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types for i, batch in enumerate(eval_iter): eval_batch = torch.from_numpy(batch['image']._numpy()).to(config.device).float() eval_batch = eval_batch.permute(0, 3, 1, 2) eval_batch = scaler(eval_batch) eval_loss = eval_step(state, eval_batch) all_losses.append(eval_loss.item()) if (i + 1) % 1000 == 0: logging.info("Finished %dth step loss evaluation" % (i + 1)) # Save loss values to disk or Google Cloud Storage all_losses = np.asarray(all_losses) with tf.io.gfile.GFile(os.path.join(eval_dir, f"ckpt_{ckpt}_loss.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, all_losses=all_losses, mean_loss=all_losses.mean()) fout.write(io_buffer.getvalue()) # Compute log-likelihoods (bits/dim) if enabled if config.eval.enable_bpd: bpds = [] for repeat in range(bpd_num_repeats): bpd_iter = iter(ds_bpd) # pytype: disable=wrong-arg-types for batch_id in range(len(ds_bpd)): batch = next(bpd_iter) eval_batch = torch.from_numpy(batch['image']._numpy()).to(config.device).float() eval_batch = eval_batch.permute(0, 3, 1, 2) eval_batch = scaler(eval_batch) bpd = likelihood_fn(score_model, eval_batch)[0] bpd = bpd.detach().cpu().numpy().reshape(-1) bpds.extend(bpd) logging.info( "ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f" % (ckpt, repeat, batch_id, np.mean(np.asarray(bpds)))) bpd_round_id = batch_id + len(ds_bpd) * repeat # Save bits/dim to disk or Google Cloud Storage with tf.io.gfile.GFile(os.path.join(eval_dir, f"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, bpd) fout.write(io_buffer.getvalue()) # Generate samples and compute IS/FID/KID when enabled if config.eval.enable_sampling: num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1 for r in range(num_sampling_rounds): logging.info("sampling -- ckpt: %d, round: %d" % (ckpt, r)) # Directory to save samples. Different for each host to avoid writing conflicts this_sample_dir = os.path.join( eval_dir, f"ckpt_{ckpt}") tf.io.gfile.makedirs(this_sample_dir) samples, n = sampling_fn(score_model) samples = np.clip(samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8) samples = samples.reshape( (-1, config.data.image_size, config.data.image_size, config.data.num_channels)) # Write samples to disk or Google Cloud Storage with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, samples=samples) fout.write(io_buffer.getvalue()) # Force garbage collection before calling TensorFlow code for Inception network gc.collect() latents = evaluation.run_inception_distributed(samples, inception_model, inceptionv3=inceptionv3) # Force garbage collection again before returning to JAX code gc.collect() # Save latent represents of the Inception network to disk or Google Cloud Storage with tf.io.gfile.GFile( os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout: io_buffer = io.BytesIO() np.savez_compressed( io_buffer, pool_3=latents["pool_3"], logits=latents["logits"]) fout.write(io_buffer.getvalue()) # Compute inception scores, FIDs and KIDs. # Load all statistics that have been previously computed and saved for each host all_logits = [] all_pools = [] this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}") stats = tf.io.gfile.glob(os.path.join(this_sample_dir, "statistics_*.npz")) for stat_file in stats: with tf.io.gfile.GFile(stat_file, "rb") as fin: stat = np.load(fin) if not inceptionv3: all_logits.append(stat["logits"]) all_pools.append(stat["pool_3"]) if not inceptionv3: all_logits = np.concatenate(all_logits, axis=0)[:config.eval.num_samples] all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples] # Load pre-computed dataset statistics. data_stats = evaluation.load_dataset_stats(config) data_pools = data_stats["pool_3"] # Compute FID/KID/IS on all samples together. if not inceptionv3: inception_score = tfgan.eval.classifier_score_from_logits(all_logits) else: inception_score = -1 fid = tfgan.eval.frechet_classifier_distance_from_activations( data_pools, all_pools) # Hack to get tfgan KID work for eager execution. tf_data_pools = tf.convert_to_tensor(data_pools) tf_all_pools = tf.convert_to_tensor(all_pools) kid = tfgan.eval.kernel_classifier_distance_from_activations( tf_data_pools, tf_all_pools).numpy() del tf_data_pools, tf_all_pools logging.info( "ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e" % ( ckpt, inception_score, fid, kid)) with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_{ckpt}.npz"), "wb") as f: io_buffer = io.BytesIO() np.savez_compressed(io_buffer, IS=inception_score, fid=fid, kid=kid) f.write(io_buffer.getvalue()) ================================================ FILE: sampling.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file # pytype: skip-file """Various sampling methods.""" import functools import time import torch import numpy as np import abc from models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn from scipy import integrate import sde_lib from models import utils as mutils _CORRECTORS = {} _PREDICTORS = {} def register_predictor(cls=None, *, name=None): """A decorator for registering predictor classes.""" def _register(cls): if name is None: local_name = cls.__name__ else: local_name = name if local_name in _PREDICTORS: raise ValueError(f'Already registered model with name: {local_name}') _PREDICTORS[local_name] = cls return cls if cls is None: return _register else: return _register(cls) def register_corrector(cls=None, *, name=None): """A decorator for registering corrector classes.""" def _register(cls): if name is None: local_name = cls.__name__ else: local_name = name if local_name in _CORRECTORS: raise ValueError(f'Already registered model with name: {local_name}') _CORRECTORS[local_name] = cls return cls if cls is None: return _register else: return _register(cls) def get_predictor(name): return _PREDICTORS[name] def get_corrector(name): return _CORRECTORS[name] def get_sampling_fn(config, sde, shape, inverse_scaler, eps): """Create a sampling function. Args: config: A `ml_collections.ConfigDict` object that contains all configuration information. sde: A `sde_lib.SDE` object that represents the forward SDE. shape: A sequence of integers representing the expected shape of a single sample. inverse_scaler: The inverse data normalizer function. eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability. Returns: A function that takes random states and a replicated training state and outputs samples with the trailing dimensions matching `shape`. """ sampler_name = config.sampling.method # Probability flow ODE sampling with black-box ODE solvers if sampler_name.lower() == 'ode': sampling_fn = get_ode_sampler(sde=sde, shape=shape, inverse_scaler=inverse_scaler, denoise=config.sampling.noise_removal, eps=eps, device=config.device) # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases. elif sampler_name.lower() == 'pc': predictor = get_predictor(config.sampling.predictor.lower()) corrector = get_corrector(config.sampling.corrector.lower()) sampling_fn = get_pc_sampler(sde=sde, shape=shape, predictor=predictor, corrector=corrector, inverse_scaler=inverse_scaler, snr=config.sampling.snr, n_steps=config.sampling.n_steps_each, probability_flow=config.sampling.probability_flow, continuous=config.training.continuous, denoise=config.sampling.noise_removal, eps=eps, device=config.device) else: raise ValueError(f"Sampler name {sampler_name} unknown.") return sampling_fn class Predictor(abc.ABC): """The abstract class for a predictor algorithm.""" def __init__(self, sde, score_fn, probability_flow=False): super().__init__() self.sde = sde # Compute the reverse SDE/ODE self.rsde = sde.reverse(score_fn, probability_flow) self.score_fn = score_fn @abc.abstractmethod def update_fn(self, x, t): """One update of the predictor. Args: x: A PyTorch tensor representing the current state t: A Pytorch tensor representing the current time step. Returns: x: A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. """ pass class Corrector(abc.ABC): """The abstract class for a corrector algorithm.""" def __init__(self, sde, score_fn, snr, n_steps): super().__init__() self.sde = sde self.score_fn = score_fn self.snr = snr self.n_steps = n_steps @abc.abstractmethod def update_fn(self, x, t): """One update of the corrector. Args: x: A PyTorch tensor representing the current state t: A PyTorch tensor representing the current time step. Returns: x: A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. """ pass @register_predictor(name='euler_maruyama') class EulerMaruyamaPredictor(Predictor): def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow) def update_fn(self, x, t): dt = -1. / self.rsde.N z = torch.randn_like(x) drift, diffusion = self.rsde.sde(x, t) x_mean = x + drift * dt x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z return x, x_mean @register_predictor(name='reverse_diffusion') class ReverseDiffusionPredictor(Predictor): def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow) def update_fn(self, x, t): f, G = self.rsde.discretize(x, t) z = torch.randn_like(x) x_mean = x - f x = x_mean + G[:, None, None, None] * z return x, x_mean @register_predictor(name='ancestral_sampling') class AncestralSamplingPredictor(Predictor): """The ancestral sampling predictor. Currently only supports VE/VP SDEs.""" def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow) if not isinstance(sde, sde_lib.VPSDE) and not isinstance(sde, sde_lib.VESDE): raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") assert not probability_flow, "Probability flow not supported by ancestral sampling" def vesde_update_fn(self, x, t): sde = self.sde timestep = (t * (sde.N - 1) / sde.T).long() sigma = sde.discrete_sigmas[timestep] adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), sde.discrete_sigmas.to(t.device)[timestep - 1]) score = self.score_fn(x, t) x_mean = x + score * (sigma ** 2 - adjacent_sigma ** 2)[:, None, None, None] std = torch.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2)) noise = torch.randn_like(x) x = x_mean + std[:, None, None, None] * noise return x, x_mean def vpsde_update_fn(self, x, t): sde = self.sde timestep = (t * (sde.N - 1) / sde.T).long() beta = sde.discrete_betas.to(t.device)[timestep] score = self.score_fn(x, t) x_mean = (x + beta[:, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None] noise = torch.randn_like(x) x = x_mean + torch.sqrt(beta)[:, None, None, None] * noise return x, x_mean def update_fn(self, x, t): if isinstance(self.sde, sde_lib.VESDE): return self.vesde_update_fn(x, t) elif isinstance(self.sde, sde_lib.VPSDE): return self.vpsde_update_fn(x, t) @register_predictor(name='none') class NonePredictor(Predictor): """An empty predictor that does nothing.""" def __init__(self, sde, score_fn, probability_flow=False): pass def update_fn(self, x, t): return x, x @register_corrector(name='langevin') class LangevinCorrector(Corrector): def __init__(self, sde, score_fn, snr, n_steps): super().__init__(sde, score_fn, snr, n_steps) if not isinstance(sde, sde_lib.VPSDE) \ and not isinstance(sde, sde_lib.VESDE) \ and not isinstance(sde, sde_lib.subVPSDE): raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") def update_fn(self, x, t): sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): timestep = (t * (sde.N - 1) / sde.T).long() alpha = sde.alphas.to(t.device)[timestep] else: alpha = torch.ones_like(t) for i in range(n_steps): grad = score_fn(x, t) noise = torch.randn_like(x) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None, None] * grad x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise return x, x_mean class LangevinCorrectorCS(Corrector): """ Modified Langevin Corrector to solve for p(x|y) """ def __init__(self, sde, score_fn, snr, n_steps, sigma_min, sigma_max, N): super().__init__(sde, score_fn, snr, n_steps) self.N = N self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), N)) if not isinstance(sde, sde_lib.VESDE): raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") def update_fn(self, x, t, y, discrete_sigmas): """ Args: x: current estimate x_i t: current time step y: measurement in the image domain discrete_sigmas: list of values of \sigma that are indexable with t """ sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): timestep = (t * (sde.N - 1) / sde.T).long() alpha = sde.alphas.to(t.device)[timestep] else: alpha = torch.ones_like(t) for i in range(n_steps): timestep = (t * (self.N - 1) / 1).long() sigma = self.discrete_sigmas.to(t.device)[timestep] grad = score_fn(x, t) grad_likelihood = (x - y) / (sigma[0] ** 2) noise = torch.randn_like(x) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None, None] * (grad + grad_likelihood) x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise return x, x_mean @register_corrector(name='ald') class AnnealedLangevinDynamics(Corrector): """The original annealed Langevin dynamics predictor in NCSN/NCSNv2. We include this corrector only for completeness. It was not directly used in our paper. """ def __init__(self, sde, score_fn, snr, n_steps): super().__init__(sde, score_fn, snr, n_steps) if not isinstance(sde, sde_lib.VPSDE) \ and not isinstance(sde, sde_lib.VESDE) \ and not isinstance(sde, sde_lib.subVPSDE): raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") def update_fn(self, x, t): sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): timestep = (t * (sde.N - 1) / sde.T).long() alpha = sde.alphas.to(t.device)[timestep] else: alpha = torch.ones_like(t) std = self.sde.marginal_prob(x, t)[1] for i in range(n_steps): grad = score_fn(x, t) noise = torch.randn_like(x) step_size = (target_snr * std) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None, None] * grad x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None] return x, x_mean @register_corrector(name='none') class NoneCorrector(Corrector): """An empty corrector that does nothing.""" def __init__(self, sde, score_fn, snr, n_steps): pass def update_fn(self, x, t): return x, x def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous): """A wrapper that configures and returns the update function of predictors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) if predictor is None: # Corrector-only sampler predictor_obj = NonePredictor(sde, score_fn, probability_flow) else: predictor_obj = predictor(sde, score_fn, probability_flow) return predictor_obj.update_fn(x, t) def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps, cs=False, sigma_min=None, sigma_max=None, N=None, y=None, discrete_sigmas=None): """A wrapper tha configures and returns the update function of correctors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous) if corrector is None: # Predictor-only sampler corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) fn = corrector_obj.update_fn(x, t) else: if cs: corrector_obj = corrector(sde, score_fn, snr, n_steps, sigma_min, sigma_max, N) fn = corrector_obj.update_fn(x, t, y, discrete_sigmas) else: corrector_obj = corrector(sde, score_fn, snr, n_steps) fn = corrector_obj.update_fn(x, t) return fn def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr, n_steps=1, probability_flow=False, continuous=False, denoise=True, eps=1e-3, device='cuda'): """Create a Predictor-Corrector (PC) sampler. Args: sde: An `sde_lib.SDE` object representing the forward SDE. shape: A sequence of integers. The expected shape of a single sample. predictor: A subclass of `sampling.Predictor` representing the predictor algorithm. corrector: A subclass of `sampling.Corrector` representing the corrector algorithm. inverse_scaler: The inverse data normalizer. snr: A `float` number. The signal-to-noise ratio for configuring correctors. n_steps: An integer. The number of corrector steps per predictor update. probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor. continuous: `True` indicates that the score model was continuously trained. denoise: If `True`, add one-step denoising to the final samples. eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues. device: PyTorch device. Returns: A sampling function that returns samples and the number of function evaluations during sampling. """ # Create predictor & corrector update functions predictor_update_fn = functools.partial(shared_predictor_update_fn, sde=sde, predictor=predictor, probability_flow=probability_flow, continuous=continuous) corrector_update_fn = functools.partial(shared_corrector_update_fn, sde=sde, corrector=corrector, continuous=continuous, snr=snr, n_steps=n_steps) def pc_sampler(model): """ The PC sampler funciton. Args: model: A score model. Returns: Samples, number of function evaluations. """ with torch.no_grad(): # Initial sample x = sde.prior_sampling(shape).to(device) timesteps = torch.linspace(sde.T, eps, sde.N, device=device) time_corrector_tot = 0 time_predictor_tot = 0 for i in range(sde.N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t tic_corrector = time.time() x, x_mean = corrector_update_fn(x, vec_t, model=model) time_corrector_tot += time.time() - tic_corrector tic_predictor = time.time() x, x_mean = predictor_update_fn(x, vec_t, model=model) time_predictor_tot += time.time() - tic_predictor print(f'Average time for corrector step: {time_corrector_tot / sde.N} sec.') print(f'Average time for predictor step: {time_predictor_tot / sde.N} sec.') return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) return pc_sampler def get_ode_sampler(sde, shape, inverse_scaler, denoise=False, rtol=1e-5, atol=1e-5, method='RK45', eps=1e-3, device='cuda'): """Probability flow ODE sampler with the black-box ODE solver. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. shape: A sequence of integers. The expected shape of a single sample. inverse_scaler: The inverse data normalizer. denoise: If `True`, add one-step denoising to final samples. rtol: A `float` number. The relative tolerance level of the ODE solver. atol: A `float` number. The absolute tolerance level of the ODE solver. method: A `str`. The algorithm used for the black-box ODE solver. See the documentation of `scipy.integrate.solve_ivp`. eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability. device: PyTorch device. Returns: A sampling function that returns samples and the number of function evaluations during sampling. """ def denoise_update_fn(model, x): score_fn = get_score_fn(sde, model, train=False, continuous=True) # Reverse diffusion predictor for denoising predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False) vec_eps = torch.ones(x.shape[0], device=x.device) * eps _, x = predictor_obj.update_fn(x, vec_eps) return x def drift_fn(model, x, t): """Get the drift function of the reverse-time SDE.""" score_fn = get_score_fn(sde, model, train=False, continuous=True) rsde = sde.reverse(score_fn, probability_flow=True) return rsde.sde(x, t)[0] # returns only the drift term because diffusion = 0 for probability_flow def ode_sampler(model, z=None): """The probability flow ODE sampler with black-box ODE solver. Args: model: A score model. z: If present, generate samples from latent code `z`. Returns: samples, number of function evaluations. """ with torch.no_grad(): # Initial sample if z is None: # If not represent, sample the latent code from the prior distibution of the SDE. x = sde.prior_sampling(shape).to(device) else: x = z def ode_func(t, x): x = from_flattened_numpy(x, shape).to(device).type(torch.float32) vec_t = torch.ones(shape[0], device=x.device) * t drift = drift_fn(model, x, vec_t) return to_flattened_numpy(drift) # Black-box ODE solver for the probability flow ODE solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x), rtol=rtol, atol=atol, method=method) nfe = solution.nfev x = torch.tensor(solution.y[:, -1]).reshape(shape).to(device).type(torch.float32) # Denoising is equivalent to running one predictor step without adding noise if denoise: x = denoise_update_fn(model, x) x = inverse_scaler(x) return x, nfe return ode_sampler ================================================ FILE: sde_lib.py ================================================ """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" import abc import torch import numpy as np class SDE(abc.ABC): """SDE abstract class. Functions are designed for a mini-batch of inputs.""" def __init__(self, N): """Construct an SDE. Args: N: number of discretization time steps. """ super().__init__() self.N = N @property @abc.abstractmethod def T(self): """End time of the SDE.""" pass @abc.abstractmethod def sde(self, x, t): pass @abc.abstractmethod def marginal_prob(self, x, t): """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" pass @abc.abstractmethod def prior_sampling(self, shape): """Generate one sample from the prior distribution, $p_T(x)$.""" pass @abc.abstractmethod def prior_logp(self, z): """Compute log-density of the prior distribution. Useful for computing the log-likelihood via probability flow ODE. Args: z: latent code Returns: log probability density """ pass def discretize(self, x, t): """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probabiliy flow sampling. Defaults to Euler-Maruyama discretization. Args: x: a torch tensor t: a torch float representing the time step (from 0 to `self.T`) Returns: f, G """ dt = 1 / self.N drift, diffusion = self.sde(x, t) f = drift * dt G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) return f, G def reverse(self, score_fn, probability_flow=False): """Create the reverse-time SDE/ODE. Args: score_fn: A time-dependent score-based model that takes x and t and returns the score. probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. """ N = self.N T = self.T sde_fn = self.sde discretize_fn = self.discretize # Build the class for reverse-time SDE. class RSDE(self.__class__): def __init__(self): self.N = N self.probability_flow = probability_flow @property def T(self): return T def sde(self, x, t): """Create the drift and diffusion functions for the reverse SDE/ODE.""" drift, diffusion = sde_fn(x, t) score = score_fn(x, t) drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) # Set the diffusion function to zero for ODEs. diffusion = 0. if self.probability_flow else diffusion return drift, diffusion def discretize(self, x, t): """Create discretized iteration rules for the reverse diffusion sampler.""" f, G = discretize_fn(x, t) rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.) rev_G = torch.zeros_like(G) if self.probability_flow else G return rev_f, rev_G return RSDE() class VPSDE(SDE): def __init__(self, beta_min=0.1, beta_max=20, N=1000): """Construct a Variance Preserving SDE. Args: beta_min: value of beta(0) beta_max: value of beta(1) N: number of discretization steps """ super().__init__(N) self.beta_0 = beta_min self.beta_1 = beta_max self.N = N self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N) self.alphas = 1. - self.discrete_betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) @property def T(self): return 1 def sde(self, x, t): beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * beta_t[:, None, None, None] * x diffusion = torch.sqrt(beta_t) return drift, diffusion def marginal_prob(self, x, t): log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 mean = torch.exp(log_mean_coeff[:, None, None, None]) * x std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) return mean, std def prior_sampling(self, shape): return torch.randn(*shape) def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. return logps def discretize(self, x, t): """DDPM discretization.""" timestep = (t * (self.N - 1) / self.T).long() beta = self.discrete_betas.to(x.device)[timestep] alpha = self.alphas.to(x.device)[timestep] sqrt_beta = torch.sqrt(beta) f = torch.sqrt(alpha)[:, None, None, None] * x - x G = sqrt_beta return f, G class subVPSDE(SDE): def __init__(self, beta_min=0.1, beta_max=20, N=1000): """Construct the sub-VP SDE that excels at likelihoods. Args: beta_min: value of beta(0) beta_max: value of beta(1) N: number of discretization steps """ super().__init__(N) self.beta_0 = beta_min self.beta_1 = beta_max self.N = N @property def T(self): return 1 def sde(self, x, t): beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * beta_t[:, None, None, None] * x discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2) diffusion = torch.sqrt(beta_t * discount) return drift, diffusion def marginal_prob(self, x, t): log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 mean = torch.exp(log_mean_coeff)[:, None, None, None] * x std = 1 - torch.exp(2. * log_mean_coeff) return mean, std def prior_sampling(self, shape): return torch.randn(*shape) def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2. class VESDE(SDE): def __init__(self, sigma_min=0.01, sigma_max=50, N=1000): """Construct a Variance Exploding SDE. Args: sigma_min: smallest sigma. sigma_max: largest sigma. N: number of discretization steps """ super().__init__(N) self.sigma_min = sigma_min self.sigma_max = sigma_max self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) self.N = N @property def T(self): return 1 def sde(self, x, t): sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t drift = torch.zeros_like(x) diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=t.device)) return drift, diffusion def marginal_prob(self, x, t): std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t mean = x return mean, std def prior_sampling(self, shape): return torch.randn(*shape) * self.sigma_max def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2) def discretize(self, x, t): """SMLD(NCSN) discretization.""" timestep = (t * (self.N - 1) / self.T).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(t.device)) f = torch.zeros_like(x) G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) return f, G ================================================ FILE: test/test_TV.py ================================================ """ python -m pytest """ import sys import pytest import torch import matplotlib.pyplot as plt import skimage import controllable_generation_TV as TV @pytest.mark.parametrize( ["A", "AT"], [ [TV._Dz, TV._DzT], [TV._Dx, TV._DxT], [TV._Dy, TV._DyT], ] ) def test_adjoint(A, AT): x = torch.randn(10, 10, 10, 10) y = torch.randn(10, 10, 10, 10) torch.testing.assert_allclose( torch.dot(A(x).ravel(), y.ravel()), torch.dot(x.ravel(), AT(y).ravel()) ) def test_prox_l21(): prox_val = .75 Dx = torch.randn(1, 1, 1, 1) Dy = torch.randn(1, 1, 1, 1) Dz = torch.randn(1, 1, 1, 1) Dq = torch.cat((Dx, Dy, Dz), dim=1) Dq_norm = torch.linalg.norm(Dq) Dq_prox = TV.prox_l21(Dq, prox_val, dim=1) Dq_prox_norm = torch.linalg.norm(Dq_prox) torch.testing.assert_allclose( max(Dq_norm, 0) - prox_val, Dq_prox_norm, ) torch.testing.assert_allclose( Dq / Dq_norm, Dq_prox / Dq_prox_norm, ) class Identity: @staticmethod def A(x): return x @staticmethod def AT(y): return y def test_ADMM_TV_isotropic(): x_gt = skimage.data.astronaut().mean(axis=2) / 255 x_gt = torch.tensor(x_gt).reshape((1, 1) + x_gt.shape) y = x_gt + 0.5 * torch.randn_like(x_gt) x0 = torch.zeros_like(y) ADMM_TV = TV.get_ADMM_TV_isotropic( radon=Identity(), img_shape=y.shape, lamb_1 = 1e0, rho=1e2) x_recon = ADMM_TV(x0, y) args = dict(vmin=-0.2, vmax=1.2) fig, ax = plt.subplots() im = ax.imshow(x_gt.squeeze(), **args) fig.colorbar(im) fig.savefig('x_gt.png') fig, ax = plt.subplots() im = ax.imshow(y.squeeze(), **args) fig.colorbar(im) fig.savefig('y.png') fig, ax = plt.subplots() im = ax.imshow(x_recon.squeeze(), **args) fig.colorbar(im) fig.savefig('x_recon.png') ================================================ FILE: train_AAPM256.sh ================================================ #!/bin/bash python main.py \ --config=configs/ve/AAPM_256_ncsnpp_continuous.py \ --eval_folder=eval/AAPM256 \ --mode='train' \ --workdir=workdir/AAPM256 ================================================ FILE: utils.py ================================================ from pathlib import Path import torch import os import logging import matplotlib.pyplot as plt import numpy as np from fastmri_utils import fft2c_new, ifft2c_new from statistics import mean, stdev from skimage.metrics import peak_signal_noise_ratio, structural_similarity from sporco.metric import gmsd, mse from scipy.ndimage import gaussian_laplace import functools def clear_color(x): x = x.detach().cpu().squeeze().numpy() return np.transpose(x, (1, 2, 0)) def clear(x, normalize=True): x = x.detach().cpu().squeeze().numpy() if normalize: x = normalize_np(x) return x def restore_checkpoint(ckpt_dir, state, device, skip_sigma=False, skip_optimizer=False): ckpt_dir = Path(ckpt_dir) # import ipdb; ipdb.set_trace() # ckpt = ckpt_dir / "checkpoint.pth" if not ckpt_dir.exists(): logging.warning(f"No checkpoint found at {ckpt_dir}. " f"Returned the same state as input") return state else: loaded_state = torch.load(ckpt_dir, map_location=device) if not skip_optimizer: state['optimizer'].load_state_dict(loaded_state['optimizer']) loaded_model_state = loaded_state['model'] if skip_sigma: loaded_model_state.pop('module.sigmas') state['model'].load_state_dict(loaded_model_state, strict=False) state['ema'].load_state_dict(loaded_state['ema']) state['step'] = loaded_state['step'] print(f'loaded checkpoint dir from {ckpt_dir}') return state def save_checkpoint(ckpt_dir, state, name="checkpoint.pth"): ckpt_dir = Path(ckpt_dir) saved_state = { 'optimizer': state['optimizer'].state_dict(), 'model': state['model'].state_dict(), 'ema': state['ema'].state_dict(), 'step': state['step'] } torch.save(saved_state, ckpt_dir / name) """ Helper functions for new types of inverse problems """ def fft2(x): """ FFT with shifting DC to the center of the image""" return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2]) def ifft2(x): """ IFFT with shifting DC to the corner of the image prior to transform""" return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2])) def fft2_m(x): """ FFT for multi-coil """ return torch.view_as_complex(fft2c_new(torch.view_as_real(x))) def ifft2_m(x): """ IFFT for multi-coil """ return torch.view_as_complex(ifft2c_new(torch.view_as_real(x))) def crop_center(img, cropx, cropy): c, y, x = img.shape startx = x // 2 - (cropx // 2) starty = y // 2 - (cropy // 2) return img[:, starty:starty + cropy, startx:startx + cropx] def normalize(img): """ Normalize img in arbitrary range to [0, 1] """ img -= torch.min(img) img /= torch.max(img) return img def normalize_np(img): """ Normalize img in arbitrary range to [0, 1] """ img -= np.min(img) img /= np.max(img) return img def normalize_np_kwarg(img, maxv=1.0, minv=0.0): """ Normalize img in arbitrary range to [0, 1] """ img -= minv img /= maxv return img def normalize_complex(img): """ normalizes the magnitude of complex-valued image to range [0, 1] """ abs_img = normalize(torch.abs(img)) # ang_img = torch.angle(img) ang_img = normalize(torch.angle(img)) return abs_img * torch.exp(1j * ang_img) def batchfy(tensor, batch_size): n = len(tensor) num_batches = n // batch_size + 1 return tensor.chunk(num_batches, dim=0) def img_wise_min_max(img): img_flatten = img.view(img.shape[0], -1) img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) return (img - img_min) / (img_max - img_min) def patient_wise_min_max(img): std_upper = 3 img_flatten = img.view(img.shape[0], -1) std = torch.std(img) mean = torch.mean(img) img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1) min_max_scaled = (img - img_min) / (img_max - img_min) min_max_scaled_std = (std - img_min) / (img_max - img_min) min_max_scaled_mean = (mean - img_min) / (img_max - img_min) min_max_scaled[min_max_scaled > min_max_scaled_mean + std_upper * min_max_scaled_std] = 1 return min_max_scaled def create_sphere(cx, cy, cz, r, resolution=256): ''' create sphere with center (cx, cy, cz) and radius r ''' phi = np.linspace(0, 2 * np.pi, 2 * resolution) theta = np.linspace(0, np.pi, resolution) theta, phi = np.meshgrid(theta, phi) r_xy = r * np.sin(theta) x = cx + np.cos(phi) * r_xy y = cy + np.sin(phi) * r_xy z = cz + r * np.cos(theta) return np.stack([x, y, z]) class lambda_schedule: def __init__(self, total=2000): self.total = total def get_current_lambda(self, i): pass class lambda_schedule_linear(lambda_schedule): def __init__(self, start_lamb=1.0, end_lamb=0.0): super().__init__() self.start_lamb = start_lamb self.end_lamb = end_lamb def get_current_lambda(self, i): return self.start_lamb + (self.end_lamb - self.start_lamb) * (i / self.total) class lambda_schedule_const(lambda_schedule): def __init__(self, lamb=1.0): super().__init__() self.lamb = lamb def get_current_lambda(self, i): return self.lamb def image_grid(x, sz=32): size = sz channels = 3 img = x.reshape(-1, size, size, channels) w = int(np.sqrt(img.shape[0])) img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels)) return img def show_samples(x, sz=32): x = x.permute(0, 2, 3, 1).detach().cpu().numpy() img = image_grid(x, sz) plt.figure(figsize=(8, 8)) plt.axis('off') plt.imshow(img) plt.show() def image_grid_gray(x, size=32): img = x.reshape(-1, size, size) w = int(np.sqrt(img.shape[0])) img = img.reshape((w, w, size, size)).transpose((0, 2, 1, 3)).reshape((w * size, w * size)) return img def show_samples_gray(x, size=32, save=False, save_fname=None): x = x.detach().cpu().numpy() img = image_grid_gray(x, size=size) plt.figure(figsize=(8, 8)) plt.axis('off') plt.imshow(img, cmap='gray') plt.show() if save: plt.imsave(save_fname, img, cmap='gray') def get_mask(img, size, batch_size, type='gaussian2d', acc_factor=8, center_fraction=0.04, fix=False): mux_in = size ** 2 if type.endswith('2d'): Nsamp = mux_in // acc_factor elif type.endswith('1d'): Nsamp = size // acc_factor if type == 'gaussian2d': mask = torch.zeros_like(img) cov_factor = size * (1.5 / 128) mean = [size // 2, size // 2] cov = [[size * cov_factor, 0], [0, size * cov_factor]] if fix: samples = np.random.multivariate_normal(mean, cov, int(Nsamp)) int_samples = samples.astype(int) int_samples = np.clip(int_samples, 0, size - 1) mask[..., int_samples[:, 0], int_samples[:, 1]] = 1 else: for i in range(batch_size): # sample different masks for batch samples = np.random.multivariate_normal(mean, cov, int(Nsamp)) int_samples = samples.astype(int) int_samples = np.clip(int_samples, 0, size - 1) mask[i, :, int_samples[:, 0], int_samples[:, 1]] = 1 elif type == 'uniformrandom2d': mask = torch.zeros_like(img) if fix: mask_vec = torch.zeros([1, size * size]) samples = np.random.choice(size * size, int(Nsamp)) mask_vec[:, samples] = 1 mask_b = mask_vec.view(size, size) mask[:, ...] = mask_b else: for i in range(batch_size): # sample different masks for batch mask_vec = torch.zeros([1, size * size]) samples = np.random.choice(size * size, int(Nsamp)) mask_vec[:, samples] = 1 mask_b = mask_vec.view(size, size) mask[i, ...] = mask_b elif type == 'gaussian1d': mask = torch.zeros_like(img) mean = size // 2 std = size * (15.0 / 128) Nsamp_center = int(size * center_fraction) if fix: samples = np.random.normal(loc=mean, scale=std, size=int(Nsamp * 1.2)) int_samples = samples.astype(int) int_samples = np.clip(int_samples, 0, size - 1) mask[... , int_samples] = 1 c_from = size // 2 - Nsamp_center // 2 mask[... , c_from:c_from + Nsamp_center] = 1 else: for i in range(batch_size): samples = np.random.normal(loc=mean, scale=std, size=int(Nsamp*1.2)) int_samples = samples.astype(int) int_samples = np.clip(int_samples, 0, size - 1) mask[i, :, :, int_samples] = 1 c_from = size // 2 - Nsamp_center // 2 mask[i, :, :, c_from:c_from + Nsamp_center] = 1 elif type == 'uniform1d': mask = torch.zeros_like(img) if fix: Nsamp_center = int(size * center_fraction) samples = np.random.choice(size, int(Nsamp - Nsamp_center)) mask[..., samples] = 1 # ACS region c_from = size // 2 - Nsamp_center // 2 mask[..., c_from:c_from + Nsamp_center] = 1 else: for i in range(batch_size): Nsamp_center = int(size * center_fraction) samples = np.random.choice(size, int(Nsamp - Nsamp_center)) mask[i, :, :, samples] = 1 # ACS region c_from = size // 2 - Nsamp_center // 2 mask[i, :, :, c_from:c_from+Nsamp_center] = 1 else: NotImplementedError(f'Mask type {type} is currently not supported.') return mask def kspace_to_nchw(tensor): """ Convert torch tensor in (Slice, Coil, Height, Width, Complex) 5D format to (N, C, H, W) 4D format for processing by 2D CNNs. Complex indicates (real, imag) as 2 channels, the complex data format for Pytorch. C is the coils interleaved with real and imaginary values as separate channels. C is therefore always 2 * Coil. Singlecoil data is assumed to be in the 5D format with Coil = 1 Args: tensor (torch.Tensor): Input data in 5D kspace tensor format. Returns: tensor (torch.Tensor): tensor in 4D NCHW format to be fed into a CNN. """ assert isinstance(tensor, torch.Tensor) assert tensor.dim() == 5 s = tensor.shape assert s[-1] == 2 tensor = tensor.permute(dims=(0, 1, 4, 2, 3)).reshape(shape=(s[0], 2 * s[1], s[2], s[3])) return tensor def nchw_to_kspace(tensor): """ Convert a torch tensor in (N, C, H, W) format to the (Slice, Coil, Height, Width, Complex) format. This function assumes that the real and imaginary values of a coil are always adjacent to one another in C. If the coil dimension is not divisible by 2, the function assumes that the input data is 'real' data, and thus pads the imaginary dimension as 0. """ assert isinstance(tensor, torch.Tensor) assert tensor.dim() == 4 s = tensor.shape if tensor.shape[1] == 1: imag_tensor = torch.zeros(s, device=tensor.device) tensor = torch.cat((tensor, imag_tensor), dim=1) s = tensor.shape tensor = tensor.view(size=(s[0], s[1] // 2, 2, s[2], s[3])).permute(dims=(0, 1, 3, 4, 2)) return tensor def root_sum_of_squares(data, dim=0): """ Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. Args: data (torch.Tensor): The input tensor dim (int): The dimensions along which to apply the RSS transform Returns: torch.Tensor: The RSS value """ return torch.sqrt((data ** 2).sum(dim)) def save_data(fname, arr): """ Save data as .npy and .png """ np.save(fname + '.npy', arr) plt.imsave(fname + '.png', arr, cmap='gray') def mean_std(vals: list): return mean(vals), stdev(vals) def cal_metric(comp, label): LoG = functools.partial(gaussian_laplace, sigma=1.5) psnr_val = peak_signal_noise_ratio(comp, label) ssim_val = structural_similarity(comp, label) hfen_val = mse(LoG(comp), LoG(label)) gmsd_val = gmsd(label, comp) return psnr_val, ssim_val, hfen_val, gmsd_val