Full Code of HJ-harry/DiffusionMBIR for AI

main bdfe460582ba cached
97 files
380.5 KB
107.9k tokens
494 symbols
1 requests
Download .txt
Showing preview only (407K chars total). Download the full file or copy to clipboard to get everything.
Repository: HJ-harry/DiffusionMBIR
Branch: main
Commit: bdfe460582ba
Files: 97
Total size: 380.5 KB

Directory structure:
gitextract_en_fo40u/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── default_celeba_configs.py
│   ├── default_cifar10_configs.py
│   ├── default_complex_configs.py
│   ├── default_lsun_configs.py
│   ├── subvp/
│   │   ├── cifar10_ddpm_continuous.py
│   │   ├── cifar10_ddpmpp_continuous.py
│   │   ├── cifar10_ddpmpp_deep_continuous.py
│   │   ├── cifar10_ncsnpp_continuous.py
│   │   └── cifar10_ncsnpp_deep_continuous.py
│   ├── ve/
│   │   ├── AAPM_128_ncsnpp_continuous.py
│   │   ├── AAPM_256_ncsnpp_continuous.py
│   │   ├── Object5_fast.py
│   │   ├── Object5_ncsnpp_continuous.py
│   │   ├── bedroom_ncsnpp_continuous.py
│   │   ├── celeba_ncsnpp.py
│   │   ├── celebahq_256_ncsnpp_continuous.py
│   │   ├── celebahq_ncsnpp_continuous.py
│   │   ├── church_ncsnpp_continuous.py
│   │   ├── cifar10_ddpm.py
│   │   ├── cifar10_ncsnpp.py
│   │   ├── cifar10_ncsnpp_continuous.py
│   │   ├── cifar10_ncsnpp_deep_continuous.py
│   │   ├── fastmri_knee_128_ncsnpp_continuous.py
│   │   ├── fastmri_knee_256_ncsnpp_continuous.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous_complex.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous_complex_magpha.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous_multi.py
│   │   ├── ffhq_256_ncsnpp_continuous.py
│   │   ├── ffhq_ncsnpp_continuous.py
│   │   ├── ncsn/
│   │   │   ├── celeba.py
│   │   │   ├── celeba_124.py
│   │   │   ├── celeba_1245.py
│   │   │   ├── celeba_5.py
│   │   │   ├── cifar10.py
│   │   │   ├── cifar10_124.py
│   │   │   ├── cifar10_1245.py
│   │   │   └── cifar10_5.py
│   │   └── ncsnv2/
│   │       ├── bedroom.py
│   │       ├── celeba.py
│   │       └── cifar10.py
│   └── vp/
│       ├── cifar10_ddpmpp.py
│       ├── cifar10_ddpmpp_continuous.py
│       ├── cifar10_ddpmpp_deep_continuous.py
│       ├── cifar10_ncsnpp.py
│       ├── cifar10_ncsnpp_continuous.py
│       ├── cifar10_ncsnpp_deep_continuous.py
│       └── ddpm/
│           ├── bedroom.py
│           ├── celebahq.py
│           ├── church.py
│           ├── cifar10.py
│           ├── cifar10_continuous.py
│           └── cifar10_unconditional.py
├── controllable_generation_TV.py
├── datasets.py
├── environment.yml
├── evaluation.py
├── fastmri_utils.py
├── inverse_problem_solver_AAPM_3d_total.py
├── inverse_problem_solver_BRATS_MRI_3d_total.py
├── likelihood.py
├── losses.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── ddpm.py
│   ├── ema.py
│   ├── layers.py
│   ├── layerspp.py
│   ├── ncsnpp.py
│   ├── ncsnv2.py
│   ├── normalization.py
│   ├── unet.py
│   ├── up_or_down_sampling.py
│   └── utils.py
├── op/
│   ├── __init__.py
│   ├── fused_act.py
│   ├── fused_bias_act.cpp
│   ├── fused_bias_act_kernel.cu
│   ├── upfirdn2d.cpp
│   ├── upfirdn2d.py
│   └── upfirdn2d_kernel.cu
├── physics/
│   ├── ct.py
│   ├── inpainting.py
│   └── radon/
│       ├── __init__.py
│       ├── filters.py
│       ├── radon.py
│       ├── stackgram.py
│       └── utils.py
├── run_lib.py
├── sampling.py
├── sde_lib.py
├── test/
│   └── test_TV.py
├── train_AAPM256.sh
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Compiled source #
###################
*.o
*.so
*.pyc

# Logs and temporaries #
########################
*.log
*~
.coverage

# Folders #
###########
build/
dist/
*.egg-info/
__pycache__/
.eggs/

data/
exp/
results/
results_AAPM/
results_AAPM_tv/
workdir/

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models (CVPR 2023)

Official PyTorch implementation of **DiffusionMBIR**, the CVPR 2023 paper "[Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models](https://arxiv.org/abs/2211.10655)". Code modified from [score_sde_pytorch](https://github.com/yang-song/score_sde_pytorch).

✅ If you would like to use an updated, faster version of DiffusionMBIR, you might want to use [DDS](https://github.com/hyungjin-chung/DDS)

[![arXiv](https://img.shields.io/badge/arXiv-2211.10655-green)](https://arxiv.org/abs/2211.10655)
[![arXiv](https://img.shields.io/badge/paper-CVPR2023-blue)](https://arxiv.org/abs/2211.10655)
![concept](./figs/forward_model.jpg)
![concept](./figs/cover_result.jpg)

## Getting started

### Download pre-trained model weights
* **CT** experiments: [weights](https://drive.google.com/file/d/1-TaLbg3-4gLwKH2-Qf5VBFCBLG3RjY9j/view)

### Download the data
* **CT** experiments (in-distribution)
```bash
DATA_DIR=./data/CT/ind/256_sorted
mkdir -p "$DATA_DIR"
wget -O "$DATA_DIR"/256_sorted.zip https://www.dropbox.com/sh/ibjpgo5seksjera/AADlhYqCWq5C4K0uWSrCL_JUa?dl=1
unzip -d "$DATA_DIR"/ "$DATA_DIR"/256_sorted.zip
```
* **CT** experiments (out-of-distribution)
```bash
DATA_DIR=./data/CT/ood/256_sorted
mkdir -p "$DATA_DIR"
wget -O "$DATA_DIR"/slice.zip https://www.dropbox.com/s/h3drrlx0pvutyoi/slice.zip?dl=0
unzip -d "$DATA_DIR"/ "$DATA_DIR"/slice.zip
```

* Make a conda environment and install dependencies
```bash
conda env create --file environment.yml
```

## DiffusionMBIR (fast) reconstruction
Once you have the pre-trained weights and the test data set up properly, you may run the following scripts. Modify the parameters in the python scripts directly to change experimental settings.

```bash
conda activate diffusion-mbir
python inverse_problem_solver_AAPM_3d_total.py
python inverse_problem_solver_BRATS_MRI_3d_total.py
```

## Training
You may train the diffusion model with your own data by using e.g.
```bash
bash train_AAPM256.sh
```
You can modify the training config with the ```--config``` flag.

## Citation
If you find our work interesting, please consider citing

```
@InProceedings{chung2023solving,
  title={Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models},
  author={Chung, Hyungjin and Ryu, Dohoon and McCann, Michael T and Klasky, Marc L and Ye, Jong Chul},
  journal={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2023}
}
```


================================================
FILE: configs/default_celeba_configs.py
================================================
import ml_collections
import torch


def get_default_configs():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  # config.training.batch_size = 128
  config.training.batch_size = 64
  training.n_iters = 1300001
  training.snapshot_freq = 50000
  training.log_freq = 50
  training.eval_freq = 100
  ## store additional checkpoints for preemption in cloud computing environments
  training.snapshot_freq_for_preemption = 10000
  ## produce samples at each snapshot.
  training.snapshot_sampling = True
  training.likelihood_weighting = False
  training.continuous = True
  training.reduce_mean = False

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.n_steps_each = 1
  sampling.noise_removal = True
  sampling.probability_flow = False
  sampling.snr = 0.17

  # evaluation
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.begin_ckpt = 1
  evaluate.end_ckpt = 26
  evaluate.batch_size = 1024
  evaluate.enable_sampling = True
  evaluate.num_samples = 50000
  evaluate.enable_loss = True
  evaluate.enable_bpd = False
  evaluate.bpd_dataset = 'test'

  # data
  config.data = data = ml_collections.ConfigDict()
  data.dataset = 'CELEBA'
  data.image_size = 64
  data.random_flip = True
  data.uniform_dequantization = False
  data.centered = False
  data.num_channels = 3

  # model
  config.model = model = ml_collections.ConfigDict()
  model.sigma_max = 90.
  model.sigma_min = 0.01
  model.num_scales = 1000
  model.beta_min = 0.1
  model.beta_max = 20.
  model.dropout = 0.1
  model.embedding_type = 'fourier'

  # optimization
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.

  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config

================================================
FILE: configs/default_cifar10_configs.py
================================================
import ml_collections
import torch


def get_default_configs():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  # config.training.batch_size = 128
  config.training.batch_size = 4
  training.n_iters = 1300001
  training.snapshot_freq = 50000
  training.log_freq = 50
  training.eval_freq = 100
  ## store additional checkpoints for preemption in cloud computing environments
  training.snapshot_freq_for_preemption = 10000
  ## produce samples at each snapshot.
  training.snapshot_sampling = True
  training.likelihood_weighting = False
  training.continuous = True
  training.reduce_mean = False

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.n_steps_each = 1
  sampling.noise_removal = True
  sampling.probability_flow = False
  sampling.snr = 0.16

  # evaluation
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.begin_ckpt = 9
  evaluate.end_ckpt = 26
  evaluate.batch_size = 1024
  evaluate.enable_sampling = False
  evaluate.num_samples = 50000
  evaluate.enable_loss = True
  evaluate.enable_bpd = False
  evaluate.bpd_dataset = 'test'

  # data
  config.data = data = ml_collections.ConfigDict()
  data.dataset = 'CIFAR10'
  data.image_size = 32
  data.random_flip = True
  data.centered = False
  data.uniform_dequantization = False
  data.num_channels = 3
  # data.num_channels = 1

  # model
  config.model = model = ml_collections.ConfigDict()
  model.sigma_min = 0.01
  model.sigma_max = 50
  model.num_scales = 1000
  model.beta_min = 0.1
  model.beta_max = 20.
  model.dropout = 0.1
  model.embedding_type = 'fourier'

  # optimization
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.

  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config

================================================
FILE: configs/default_complex_configs.py
================================================
import ml_collections
import torch


def get_default_configs():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  # config.training.batch_size = 64
  # config.training.batch_size = 2  # seriously?
  config.training.batch_size = 1  # When using single GPU
  # training.n_iters = 2400001
  training.epochs = 100
  training.snapshot_freq = 50000
  # training.log_freq = 50
  training.log_freq = 25
  training.eval_freq = 100
  ## store additional checkpoints for preemption in cloud computing environments
  training.snapshot_freq_for_preemption = 5000
  ## produce samples at each snapshot.
  training.snapshot_sampling = True
  training.likelihood_weighting = False
  training.continuous = True
  training.reduce_mean = False

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.n_steps_each = 1
  sampling.noise_removal = True
  sampling.probability_flow = False
  sampling.snr = 0.075

  # evaluation
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.begin_ckpt = 50
  evaluate.end_ckpt = 96
  # evaluate.batch_size = 512
  evaluate.batch_size = 8
  evaluate.enable_sampling = True
  evaluate.num_samples = 50000
  evaluate.enable_loss = True
  evaluate.enable_bpd = False
  evaluate.bpd_dataset = 'test'

  # data
  config.data = data = ml_collections.ConfigDict()
  # data.dataset = 'LSUN'
  data.image_size = 320
  data.random_flip = True
  data.uniform_dequantization = False
  data.centered = False
  data.num_channels = 2

  # model
  config.model = model = ml_collections.ConfigDict()
  model.sigma_max = 378
  model.sigma_min = 0.01
  model.num_scales = 2000
  model.beta_min = 0.1
  model.beta_max = 20.
  model.dropout = 0.
  model.embedding_type = 'fourier'

  # optimization
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.

  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config

================================================
FILE: configs/default_lsun_configs.py
================================================
import ml_collections
import torch


def get_default_configs():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  # config.training.batch_size = 64
  # config.training.batch_size = 2  # seriously?
  config.training.batch_size = 1  # When using single GPU
  # training.n_iters = 2400001
  training.epochs = 1000
  training.snapshot_freq = 50000
  # training.log_freq = 50
  training.log_freq = 25
  training.eval_freq = 100
  ## store additional checkpoints for preemption in cloud computing environments
  training.snapshot_freq_for_preemption = 5000
  ## produce samples at each snapshot.
  training.snapshot_sampling = True
  training.likelihood_weighting = False
  training.continuous = True
  training.reduce_mean = False

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.n_steps_each = 1
  sampling.noise_removal = True
  sampling.probability_flow = False
  sampling.snr = 0.075

  # evaluation
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.begin_ckpt = 50
  evaluate.end_ckpt = 96
  # evaluate.batch_size = 512
  evaluate.batch_size = 8
  evaluate.enable_sampling = True
  evaluate.num_samples = 50000
  evaluate.enable_loss = True
  evaluate.enable_bpd = False
  evaluate.bpd_dataset = 'test'

  # data
  config.data = data = ml_collections.ConfigDict()
  data.dataset = 'LSUN'
  data.image_size = 256
  data.random_flip = True
  data.uniform_dequantization = False
  data.centered = False
  # data.num_channels = 3
  data.num_channels = 1

  # model
  config.model = model = ml_collections.ConfigDict()
  model.sigma_max = 378
  model.sigma_min = 0.01
  model.num_scales = 2000
  model.beta_min = 0.1
  model.beta_max = 20.
  model.dropout = 0.
  model.embedding_type = 'fourier'

  # optimization
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.

  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config

================================================
FILE: configs/subvp/cifar10_ddpm_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training DDPM with sub-VP SDE."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()

  # training
  training = config.training
  training.sde = 'subvpsde'
  training.continuous = True
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True

  return config


================================================
FILE: configs/subvp/cifar10_ddpmpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSNv3 on CIFAR-10 with continuous sigmas."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'subvpsde'
  training.continuous = True
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = False
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'none'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.embedding_type = 'positional'
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/subvp/cifar10_ddpmpp_deep_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSNv3 on CIFAR-10 with continuous sigmas."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'subvpsde'
  training.continuous = True
  training.reduce_mean = True
  training.n_iters = 950001

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 8
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = False
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'none'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.embedding_type = 'positional'
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/subvp/cifar10_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10 with sub-VP SDE."""
from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'subvpsde'
  training.continuous = True
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.embedding_type = 'positional'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/subvp/cifar10_ncsnpp_deep_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'subvpsde'
  training.continuous = True
  training.n_iters = 950001
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.fourier_scale = 16
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 8
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.embedding_type = 'positional'
  model.init_scale = 0.0
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/AAPM_128_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on fastmri knee with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'aapm'
  data.root = '/media/harry/tomo/AAPM_data/128'
  data.is_complex = False
  data.is_multi = False
  data.image_size = 128

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/AAPM_256_ncsnpp_continuous.py
================================================
from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'AAPM'
  data.root = '/media/harry/tomo/AAPM_data/256'
  data.is_complex = False
  data.is_multi = False
  data.image_size = 256

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config

================================================
FILE: configs/ve/Object5_fast.py
================================================
from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True
  training.epochs = 3

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'Object5Fast'
  data.root = './data/Object5/'
  data.is_complex = False
  data.is_multi = False
  data.image_size = 256

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  model.num_scales = 3  # number of sampling steps

  return config

================================================
FILE: configs/ve/Object5_ncsnpp_continuous.py
================================================
from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'Object5'
  data.root = './data/Object5/'
  data.is_complex = False
  data.is_multi = False
  data.image_size = 256

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config

================================================
FILE: configs/ve/bedroom_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on bedroom with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.category = 'bedroom'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/celeba_ncsnpp.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CelebA with SMLD."""

from configs.default_celeba_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.sigma_begin = 90
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.0
  model.conv_size = 3
  model.embedding_type = 'positional'

  return config


================================================
FILE: configs/ve/celebahq_256_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on Church with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'CelebAHQ'
  data.image_size = 256
  data.tfrecords_path = '/home/yangsong/ncsc/celebahq/r08.tfrecords'


  # model
  model = config.model
  model.name = 'ncsnpp'
  model.sigma_max = 348
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/celebahq_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CelebAHQ with VE SDE."""

import ml_collections
import torch


def get_config():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  training.batch_size = 8
  training.n_iters = 2400001
  training.snapshot_freq = 50000
  training.log_freq = 50
  training.eval_freq = 100
  training.snapshot_freq_for_preemption = 5000
  training.snapshot_sampling = True
  training.sde = 'vesde'
  training.continuous = True
  training.likelihood_weighting = False
  training.reduce_mean = False

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'
  sampling.probability_flow = False
  sampling.snr = 0.15
  sampling.n_steps_each = 1
  sampling.noise_removal = True

  # eval
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.batch_size = 1024
  evaluate.num_samples = 50000
  evaluate.begin_ckpt = 1
  evaluate.end_ckpt = 96

  # data
  config.data = data = ml_collections.ConfigDict()
  data.dataset = 'CelebAHQ'
  data.image_size = 1024
  data.centered = False
  data.random_flip = True
  data.uniform_dequantization = False
  data.num_channels = 3
  data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords'

  # model
  config.model = model = ml_collections.ConfigDict()
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.sigma_max = 1348
  model.num_scales = 2000
  model.ema_rate = 0.9999
  model.sigma_min = 0.01
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 16
  model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)
  model.num_res_blocks = 1
  model.attn_resolutions = (16,)
  model.dropout = 0.
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3
  model.embedding_type = 'fourier'

  # optim
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.

  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config


================================================
FILE: configs/ve/church_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on Church with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.category = 'church_outdoor'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.sigma_max = 380
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/cifar10_ddpm.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Train the original DDPM model with SMLD."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/cifar10_ncsnpp.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10 with SMLD."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.0
  model.embedding_type = 'positional'
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/cifar10_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10 with VE SDE."""
from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/cifar10_ncsnpp_deep_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10 with VE SDE."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True
  training.n_iters = 950001

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.fourier_scale = 16
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 8
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.0
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/fastmri_knee_128_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on fastmri knee with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # training (regression)
  training.mask_type = 'gaussian2d'
  training.acc_factor = [8, 15]

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'fastmri_knee'
  data.root = '/media/harry/tomo/fastmri'
  data.is_complex = False
  data.is_multi = False
  data.image_size = 128

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/fastmri_knee_256_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on fastmri knee with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'fastmri_knee'
  data.root = '/media/harry/tomo/fastmri'
  data.image_size = 256

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on fastmri knee with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'fastmri_knee'
  data.root = '/media/harry/tomo/fastmri'
  data.image_size = 320
  data.is_multi = False
  data.is_complex = False

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_complex.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on fastmri knee with VE SDE."""

from configs.default_complex_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'fastmri_knee'
  data.is_multi = False
  data.is_complex = True
  data.root = '/media/harry/tomo/fastmri'
  data.image_size = 320

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_complex_magpha.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on fastmri knee with VE SDE."""

from configs.default_complex_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'fastmri_knee'
  data.is_multi = False
  data.is_complex = True
  data.magpha = True
  data.root = '/media/harry/tomo/fastmri'
  data.image_size = 320

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_multi.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on fastmri knee with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'fastmri_knee'
  data.is_complex = False
  data.is_multi = True
  data.root = '/media/harry/tomo/fastmri'
  data.image_size = 320

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/ffhq_256_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on FFHQ with VE SDE."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'

  # data
  data = config.data
  data.dataset = 'FFHQ'
  data.image_size = 256
  data.tfrecords_path = '/media/harry/ExtDrive/PycharmProjects/score_sde_pytorch/dataset/FFHQ/ffhq-r08.tfrecords'

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.sigma_max = 348
  model.scale_by_sigma = True
  model.ema_rate = 0.999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/ve/ffhq_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on FFHQ with VE SDEs."""

import ml_collections
import torch

def get_config():
  config = ml_collections.ConfigDict()
  # training
  config.training = training = ml_collections.ConfigDict()
  training.batch_size = 8
  training.n_iters = 2400001
  training.snapshot_freq = 50000
  training.log_freq = 50
  training.eval_freq = 100
  training.snapshot_freq_for_preemption = 5000
  training.snapshot_sampling = True
  training.sde = 'vesde'
  training.continuous = True
  training.likelihood_weighting = False
  training.reduce_mean = True

  # sampling
  config.sampling = sampling = ml_collections.ConfigDict()
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'langevin'
  sampling.probability_flow = False
  sampling.snr = 0.15
  sampling.n_steps_each = 1
  sampling.noise_removal = True

  # eval
  config.eval = evaluate = ml_collections.ConfigDict()
  evaluate.batch_size = 1024
  evaluate.num_samples = 50000
  evaluate.begin_ckpt = 1
  evaluate.end_ckpt = 96

  # data
  config.data = data = ml_collections.ConfigDict()
  data.dataset = 'FFHQ'
  data.image_size = 1024
  data.centered = False
  data.random_flip = True
  data.uniform_dequantization = False
  data.num_channels = 3
  # Plug in your own path to the tfrecords file.
  data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'

  # model
  config.model = model = ml_collections.ConfigDict()
  model.name = 'ncsnpp'
  model.scale_by_sigma = True
  model.sigma_max = 1348
  model.num_scales = 2000
  model.ema_rate = 0.9999
  model.sigma_min = 0.01
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 16
  model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)
  model.num_res_blocks = 1
  model.attn_resolutions = (16,)
  model.dropout = 0.
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'output_skip'
  model.progressive_input = 'input_skip'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3
  model.embedding_type = 'fourier'

  # optim
  config.optim = optim = ml_collections.ConfigDict()
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 2e-4
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 5000
  optim.grad_clip = 1.

  config.seed = 42
  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

  return config


================================================
FILE: configs/ve/ncsn/celeba.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for reproducing NCSNv1 on CelebA."""

from configs.default_celeba_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.loss = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 100
  sampling.snr = 0.316
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.sigma_max = 1
  model.num_scales = 10
  model.ema_rate = 0.
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsn/celeba_124.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSN with technique 1,2,4 only."""

from configs.default_celeba_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 5
  sampling.snr = 0.128
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.num_scales = 500
  model.ema_rate = 0.
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsn/celeba_1245.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSN with technique 1245 only."""

from configs.default_celeba_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 5
  sampling.snr = 0.128
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.num_scales = 500
  model.ema_rate = 0.999
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsn/celeba_5.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSNv1 model with technique 5 only."""

from configs.default_celeba_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 100
  sampling.snr = 0.316
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.sigma_max = 1.
  model.num_scales = 10
  model.ema_rate = 0.999
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsn/cifar10.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for reproducing NCSNv1 on CIFAR-10."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 100
  sampling.snr = 0.316
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.sigma_max = 1
  model.num_scales = 10
  model.ema_rate = 0.
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsn/cifar10_124.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSN with technique 1,2,4 only."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 5
  sampling.snr = 0.176
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.num_scales = 232
  model.ema_rate = 0.
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsn/cifar10_1245.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSN with technique 1,2,4,5 only."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # shared configs for sample generation
  step_size = 0.0000062
  n_steps_each = 5
  ckpt_id = 300000
  final_only = True
  noise_removal = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 5
  sampling.snr = 0.176
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.num_scales = 232
  model.ema_rate = 0.999
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsn/cifar10_5.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSN with technique 5 only."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.snr = 0.316
  sampling.n_steps_each = 100
  # model
  model = config.model
  model.name = 'ncsn'
  model.scale_by_sigma = False
  model.sigma_max = 1
  model.num_scales = 10
  model.ema_rate = 0.999
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-3
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsnv2/bedroom.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSNv2 on bedroom."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.batch_size = 128
  training.sde = 'vesde'
  training.continuouse = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 3
  sampling.snr = 0.095
  # data
  data = config.data
  data.category = 'bedroom'
  data.image_size = 128
  # model
  model = config.model
  model.name = 'ncsnv2_128'
  model.scale_by_sigma = True
  model.sigma_max = 190
  model.num_scales = 1086
  model.ema_rate = 0.9999
  model.sigma_min = 0.01
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-4
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1

  return config


================================================
FILE: configs/ve/ncsnv2/celeba.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSNv2 on CelebA."""

from configs.default_celeba_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # shared configs for sample generation
  step_size = 0.0000033
  n_steps_each = 5
  ckpt_id = 210000
  final_only = True
  noise_removal = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 5
  sampling.snr = 0.128
  # model
  model = config.model
  model.name = 'ncsnv2_64'
  model.scale_by_sigma = True
  model.num_scales = 500
  model.ema_rate = 0.999
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-4
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/ve/ncsnv2/cifar10.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for training NCSNv2 on CIFAR-10."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vesde'
  training.continuous = False
  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'none'
  sampling.corrector = 'ald'
  sampling.n_steps_each = 5
  sampling.snr = 0.176
  # model
  model = config.model
  model.name = 'ncsnv2_64'
  model.scale_by_sigma = True
  model.num_scales = 232
  model.ema_rate = 0.999
  model.normalization = 'InstanceNorm++'
  model.nonlinearity = 'elu'
  model.nf = 128
  model.interpolation = 'bilinear'
  # optim
  optim = config.optim
  optim.weight_decay = 0
  optim.optimizer = 'Adam'
  optim.lr = 1e-4
  optim.beta1 = 0.9
  optim.amsgrad = False
  optim.eps = 1e-8
  optim.warmup = 0
  optim.grad_clip = -1.

  return config


================================================
FILE: configs/vp/cifar10_ddpmpp.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSNv3 on CIFAR-10 with continuous sigmas."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = False
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'ancestral_sampling'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = False
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'none'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.embedding_type = 'positional'
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/vp/cifar10_ddpmpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSNv3 on CIFAR-10 with continuous sigmas."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = True
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = False
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'none'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.embedding_type = 'positional'
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/vp/cifar10_ddpmpp_deep_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSNv3 on CIFAR-10 with continuous sigmas."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = True
  training.reduce_mean = True
  training.n_iters = 950001

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 8
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = False
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'none'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.
  model.embedding_type = 'positional'
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/vp/cifar10_ncsnpp.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10 with DDPM."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = False
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'reverse_diffusion'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.init_scale = 0.0
  model.embedding_type = 'positional'
  model.conv_size = 3

  return config


================================================
FILE: configs/vp/cifar10_ncsnpp_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10 with VP SDE."""
from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = True
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 4
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.embedding_type = 'positional'
  model.init_scale = 0.
  model.fourier_scale = 16
  model.conv_size = 3

  return config


================================================
FILE: configs/vp/cifar10_ncsnpp_deep_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on CIFAR-10."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()
  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = True
  training.n_iters = 950001
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ncsnpp'
  model.fourier_scale = 16
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 8
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True
  model.fir = True
  model.fir_kernel = [1, 3, 3, 1]
  model.skip_rescale = True
  model.resblock_type = 'biggan'
  model.progressive = 'none'
  model.progressive_input = 'residual'
  model.progressive_combine = 'sum'
  model.attention_type = 'ddpm'
  model.embedding_type = 'positional'
  model.init_scale = 0.0
  model.conv_size = 3

  return config


================================================
FILE: configs/vp/ddpm/bedroom.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for reproducing the results of DDPM on bedrooms."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()

  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = False
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'ancestral_sampling'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.category = 'bedroom'
  data.centered = True

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = False
  model.num_scales = 1000
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 4, 4)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True

  # optim
  optim = config.optim
  optim.lr = 2e-5

  return config


================================================
FILE: configs/vp/ddpm/celebahq.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for reproducing the results of DDPM on bedrooms."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()

  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = False
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'ancestral_sampling'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.dataset = 'CelebAHQ'
  data.centered = True
  data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords'
  data.image_size = 256

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = False
  model.num_scales = 1000
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 4, 4)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True

  # optim
  optim = config.optim
  optim.lr = 2e-5

  return config


================================================
FILE: configs/vp/ddpm/church.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for reproducing the results of DDPM on church_outdoor."""

from configs.default_lsun_configs import get_default_configs


def get_config():
  config = get_default_configs()

  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = False
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'ancestral_sampling'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.category = 'church_outdoor'
  data.centered = True

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = False
  model.num_scales = 1000
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 1, 2, 2, 4, 4)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True

  # optim
  optim = config.optim
  optim.lr = 2e-5

  return config


================================================
FILE: configs/vp/ddpm/cifar10.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Config file for reproducing the results of DDPM on cifar-10."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()

  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = False
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'ancestral_sampling'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True

  return config


================================================
FILE: configs/vp/ddpm/cifar10_continuous.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training DDPM with VP SDE."""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()

  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = True
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'euler_maruyama'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = True

  return config


================================================
FILE: configs/vp/ddpm/cifar10_unconditional.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training DDPM on CIFAR-10 without explicitly conditioning on time steps. (NCSNv2 technique 3)"""

from configs.default_cifar10_configs import get_default_configs


def get_config():
  config = get_default_configs()

  # training
  training = config.training
  training.sde = 'vpsde'
  training.continuous = False
  training.reduce_mean = True

  # sampling
  sampling = config.sampling
  sampling.method = 'pc'
  sampling.predictor = 'ancestral_sampling'
  sampling.corrector = 'none'

  # data
  data = config.data
  data.centered = True

  # model
  model = config.model
  model.name = 'ddpm'
  model.scale_by_sigma = False
  model.ema_rate = 0.9999
  model.normalization = 'GroupNorm'
  model.nonlinearity = 'swish'
  model.nf = 128
  model.ch_mult = (1, 2, 2, 2)
  model.num_res_blocks = 2
  model.attn_resolutions = (16,)
  model.resamp_with_conv = True
  model.conditional = False

  return config


================================================
FILE: controllable_generation_TV.py
================================================
import functools
import time

import torch
from numpy.testing._private.utils import measure
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from models import utils as mutils
from sampling import NoneCorrector, NonePredictor, shared_corrector_update_fn, shared_predictor_update_fn
from utils import fft2, ifft2, fft2_m, ifft2_m
from physics.ct import *
from utils import show_samples, show_samples_gray, clear, clear_color, batchfy



class lambda_schedule:
  def __init__(self, total=2000):
    self.total = total

  def get_current_lambda(self, i):
    pass
class lambda_schedule_linear(lambda_schedule):
  def __init__(self, start_lamb=1.0, end_lamb=0.0):
    super().__init__()
    self.start_lamb = start_lamb
    self.end_lamb = end_lamb

  def get_current_lambda(self, i):
    return self.start_lamb + (self.end_lamb - self.start_lamb) * (i / self.total)


class lambda_schedule_const(lambda_schedule):
  def __init__(self, lamb=1.0):
    super().__init__()
    self.lamb = lamb

  def get_current_lambda(self, i):
    return self.lamb


def _Dz(x): # Batch direction
    y = torch.zeros_like(x)
    y[:-1] = x[1:]
    y[-1] = x[0]
    return y - x


def _DzT(x): # Batch direction
    y = torch.zeros_like(x)
    y[:-1] = x[1:]
    y[-1] = x[0]

    tempt = -(y-x)
    difft = tempt[:-1]
    y[1:] = difft
    y[0] = x[-1] - x[0]

    return y

def _Dx(x):  # Batch direction
    y = torch.zeros_like(x)
    y[:, :, :-1, :] = x[:, :, 1:, :]
    y[:, :, -1, :] = x[:, :, 0, :]
    return y - x


def _DxT(x):  # Batch direction
    y = torch.zeros_like(x)
    y[:, :, :-1, :] = x[:, :, 1:, :]
    y[:, :, -1, :] = x[:, :, 0, :]
    tempt = -(y - x)
    difft = tempt[:, :, :-1, :]
    y[:, :, 1:, :] = difft
    y[:, :, 0, :] = x[:, :, -1, :] - x[:, :, 0, :]
    return y


def _Dy(x):  # Batch direction
    y = torch.zeros_like(x)
    y[:, :, :, :-1] = x[:, :, :, 1:]
    y[:, :, :, -1] = x[:, :, :, 0]
    return y - x


def _DyT(x):  # Batch direction
    y = torch.zeros_like(x)
    y[:, :, :, :-1] = x[:, :, :, 1:]
    y[:, :, :, -1] = x[:, :, :, 0]
    tempt = -(y - x)
    difft = tempt[:, :, :, :-1]
    y[:, :, :, 1:] = difft
    y[:, :, :, 0] = x[:, :, :, -1] - x[:, :, :, 0]
    return y


def get_pc_radon_ADMM_TV(sde, predictor, corrector, inverse_scaler, snr,
                         n_steps=1, probability_flow=False, continuous=False,
                         denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None,
                         final_consistency=False, img_cache=None, img_shape=None, lamb_1=5, rho=10):
    """ Sparse application of measurement consistency """
    # Define predictor & corrector
    predictor_update_fn = functools.partial(shared_predictor_update_fn,
                                            sde=sde,
                                            predictor=predictor,
                                            probability_flow=probability_flow,
                                            continuous=continuous)
    corrector_update_fn = functools.partial(shared_corrector_update_fn,
                                            sde=sde,
                                            corrector=corrector,
                                            continuous=continuous,
                                            snr=snr,
                                            n_steps=n_steps)

    if img_cache != None :
        img_shape[0] += 1
    del_z = torch.zeros(img_shape)
    udel_z = torch.zeros(img_shape)
    eps = 1e-10

    def _A(x):
        return radon.A(x)

    def _AT(sinogram):
        return radon.AT(sinogram)

    def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None,
                 norm_const=None):
        x = x + lamb * _AT(measurement - _A(x))/norm_const
        x_mean = x
        return x, x_mean
    
    def A_cg(x):
        return _AT(_A(x)) + rho * _DzT(_Dz(x))

    def CG(A_fn,b_cg,x,n_inner=10):
        r = b_cg - A_fn(x)
        p = r
        rs_old = torch.matmul(r.view(1,-1),r.view(1,-1).T)

        for i in range(n_inner):
            Ap = A_fn(p)
            a = rs_old/torch.matmul(p.view(1,-1),Ap.view(1,-1).T)
    
            x += a * p
            r -= a * Ap

            rs_new = torch.matmul(r.view(1,-1),r.view(1,-1).T)
            if torch.sqrt(rs_new) < eps :
                break
            p = r + (rs_new/rs_old) * p
            rs_old = rs_new
        return x

    def CS_routine(x,ATy, niter=20):
        if img_cache != None :
            x = torch.cat([img_cache,x],dim=0)
            idx = list(range(len(x),0,-1))
            x = x[idx]

        nonlocal del_z, udel_z
        if del_z.device != x.device :
            del_z = del_z.to(x.device)
            udel_z = del_z.to(x.device)
        for i in range(niter):
            b_cg = ATy + rho * (_DzT(del_z)-_DzT(udel_z))
            x = CG(A_cg, b_cg, x, n_inner=1)

            del_z = shrink(_Dz(x) + udel_z, lamb_1/rho)
            udel_z = _Dz(x) - del_z + udel_z
        if img_cache != None :
            x = x[idx]
            x = x[1:]
            del_z[-1] = 0
            udel_z[-1] = 0
        x_mean = x
        return x, x_mean

    def get_update_fn(update_fn):
        def radon_update_fn(model, data, x, t):
            with torch.no_grad():
                vec_t = torch.ones(data.shape[0], device=data.device) * t
                x, x_mean = update_fn(x, vec_t, model=model)
                return x, x_mean
        return radon_update_fn

    def get_corrector_update_fn(update_fn):
        def radon_update_fn(model, data, x, t, measurement=None):
            with torch.no_grad():
                vec_t = torch.ones(data.shape[0], device=data.device) * t
                x, x_mean = update_fn(x, vec_t, model=model)
                ATy = _AT(measurement)
                x, x_mean = CS_routine(x, ATy, niter=1)
                return x, x_mean
        return radon_update_fn

    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)
    corrector_radon_update_fn = get_corrector_update_fn(corrector_update_fn)

    def pc_radon(model, data, measurement=None):
        with torch.no_grad():
            x = sde.prior_sampling(data.shape).to(data.device)

            ones = torch.ones_like(x).to(data.device)
            norm_const = _AT(_A(ones))
            timesteps = torch.linspace(sde.T, eps, sde.N)
            for i in tqdm(range(sde.N)):
                t = timesteps[i]
                x, x_mean = predictor_denoise_update_fn(model, data, x, t)
                x, x_mean = corrector_radon_update_fn(model, data, x, t, measurement=measurement)
                if save_progress:
                    if (i % 50) == 0:
                        print(f'iter: {i}/{sde.N}')
                        plt.imsave(save_root / 'recon' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')
            # Final step which coerces the data fidelity error term to be zero,
            # and thereby satisfying Ax = y
            if final_consistency:
                x, x_mean = kaczmarz(x, x_mean, measurement, lamb=1.0, norm_const=norm_const)

            return inverse_scaler(x_mean if denoise else x)

    return pc_radon


def get_pc_radon_ADMM_TV_vol(sde, predictor, corrector, inverse_scaler, snr,
                             n_steps=1, probability_flow=False, continuous=False,
                             denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None,
                             final_consistency=False, img_shape=None, lamb_1=5, rho=10):
    """ Sparse application of measurement consistency """
    # Define predictor & corrector
    predictor_update_fn = functools.partial(shared_predictor_update_fn,
                                            sde=sde,
                                            predictor=predictor,
                                            probability_flow=probability_flow,
                                            continuous=continuous)
    corrector_update_fn = functools.partial(shared_corrector_update_fn,
                                            sde=sde,
                                            corrector=corrector,
                                            continuous=continuous,
                                            snr=snr,
                                            n_steps=n_steps)

    del_z = torch.zeros(img_shape)
    udel_z = torch.zeros(img_shape)
    eps = 1e-10

    def _A(x):
        return radon.A(x)

    def _AT(sinogram):
        return radon.AT(sinogram)

    def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None,
                 norm_const=None):
        x = x + lamb * _AT(measurement - _A(x)) / norm_const
        x_mean = x
        return x, x_mean

    def A_cg(x):
        return _AT(_A(x)) + rho * _DzT(_Dz(x))

    def CG(A_fn, b_cg, x, n_inner=10):
        r = b_cg - A_fn(x)
        p = r
        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)

        for i in range(n_inner):
            Ap = A_fn(p)
            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)

            x += a * p
            r -= a * Ap

            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)
            if torch.sqrt(rs_new) < eps:
                break
            p = r + (rs_new / rs_old) * p
            rs_old = rs_new
        return x

    def CS_routine(x, ATy, niter=20):
        nonlocal del_z, udel_z
        if del_z.device != x.device:
            del_z = del_z.to(x.device)
            udel_z = del_z.to(x.device)
        for i in range(niter):
            b_cg = ATy + rho * (_DzT(del_z) - _DzT(udel_z))
            x = CG(A_cg, b_cg, x, n_inner=1)

            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)
            udel_z = _Dz(x) - del_z + udel_z
        x_mean = x
        return x, x_mean

    def get_update_fn(update_fn):
        def radon_update_fn(model, data, x, t):
            with torch.no_grad():
                vec_t = torch.ones(x.shape[0], device=x.device) * t
                x, x_mean = update_fn(x, vec_t, model=model)
                return x, x_mean

        return radon_update_fn

    def get_ADMM_TV_fn():
        def ADMM_TV_fn(x, measurement=None):
            with torch.no_grad():
                ATy = _AT(measurement)
                x, x_mean = CS_routine(x, ATy, niter=1)
                return x, x_mean
        return ADMM_TV_fn

    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)
    corrector_denoise_update_fn = get_update_fn(corrector_update_fn)
    mc_update_fn = get_ADMM_TV_fn()

    def pc_radon(model, data, measurement=None):
        with torch.no_grad():
            x = sde.prior_sampling(data.shape).to(data.device)

            ones = torch.ones_like(x).to(data.device)
            norm_const = _AT(_A(ones))
            timesteps = torch.linspace(sde.T, eps, sde.N)
            for i in tqdm(range(sde.N)):
                t = timesteps[i]
                # 1. batchify into sizes that fit into the GPU
                x_batch = batchfy(x, 12)
                # 2. Run PC step for each batch
                x_agg = list()
                for idx, x_batch_sing in enumerate(x_batch):
                    x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t)
                    x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t)
                    x_agg.append(x_batch_sing)
                # 3. Aggregate to run ADMM TV
                x = torch.cat(x_agg, dim=0)
                # 4. Run ADMM TV
                x, x_mean = mc_update_fn(x, measurement=measurement)

                if save_progress:
                    if (i % 50) == 0:
                        print(f'iter: {i}/{sde.N}')
                        plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')
            # Final step which coerces the data fidelity error term to be zero,
            # and thereby satisfying Ax = y
            if final_consistency:
                x, x_mean = kaczmarz(x, x, measurement, lamb=1.0, norm_const=norm_const)

            return inverse_scaler(x_mean if denoise else x)

    return pc_radon


def get_pc_radon_ADMM_TV_all_vol(sde, predictor, corrector, inverse_scaler, snr,
                             n_steps=1, probability_flow=False, continuous=False,
                             denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None,
                             final_consistency=False, img_shape=None, lamb_1=5, rho=10):
    """ Sparse application of measurement consistency """
    # Define predictor & corrector
    predictor_update_fn = functools.partial(shared_predictor_update_fn,
                                            sde=sde,
                                            predictor=predictor,
                                            probability_flow=probability_flow,
                                            continuous=continuous)
    corrector_update_fn = functools.partial(shared_corrector_update_fn,
                                            sde=sde,
                                            corrector=corrector,
                                            continuous=continuous,
                                            snr=snr,
                                            n_steps=n_steps)

    del_x = torch.zeros(img_shape)
    del_y = torch.zeros(img_shape)
    del_z = torch.zeros(img_shape)
    udel_x = torch.zeros(img_shape)
    udel_y = torch.zeros(img_shape)
    udel_z = torch.zeros(img_shape)
    eps = 1e-10

    def _A(x):
        return radon.A(x)

    def _AT(sinogram):
        return radon.AT(sinogram)

    def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None,
                 norm_const=None):
        x = x + lamb * _AT(measurement - _A(x)) / norm_const
        x_mean = x
        return x, x_mean


    def A_cg(x):
        return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x)))

    def CG(A_fn, b_cg, x, n_inner=10):
        r = b_cg - A_fn(x)
        p = r
        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)

        for i in range(n_inner):
            Ap = A_fn(p)
            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)

            x += a * p
            r -= a * Ap

            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)
            if torch.sqrt(rs_new) < eps:
                break
            p = r + (rs_new / rs_old) * p
            rs_old = rs_new
        return x

    def CS_routine(x, ATy, niter=20):
        nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z
        if del_z.device != x.device:
            del_x = del_x.to(x.device)
            del_y = del_y.to(x.device)
            del_z = del_z.to(x.device)
            udel_x = udel_x.to(x.device)
            udel_y = udel_y.to(x.device)
            udel_z = udel_z.to(x.device)
        for i in range(niter):
            b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x))
                                + (_DyT(del_y) - _DyT(udel_y))
                                + (_DzT(del_z) - _DzT(udel_z)))
            x = CG(A_cg, b_cg, x, n_inner=1)

            del_x = shrink(_Dx(x) + udel_x, lamb_1 / rho)
            del_y = shrink(_Dy(x) + udel_y, lamb_1 / rho)
            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)
            udel_x = _Dx(x) - del_x + udel_x
            udel_y = _Dy(x) - del_y + udel_y
            udel_z = _Dz(x) - del_z + udel_z
        x_mean = x
        return x, x_mean

    def get_update_fn(update_fn):
        def radon_update_fn(model, data, x, t):
            with torch.no_grad():
                vec_t = torch.ones(x.shape[0], device=x.device) * t
                x, x_mean = update_fn(x, vec_t, model=model)
                return x, x_mean

        return radon_update_fn

    def get_ADMM_TV_fn():
        def ADMM_TV_fn(x, measurement=None):
            with torch.no_grad():
                ATy = _AT(measurement)
                x, x_mean = CS_routine(x, ATy, niter=1)
                return x, x_mean
        return ADMM_TV_fn

    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)
    corrector_denoise_update_fn = get_update_fn(corrector_update_fn)
    mc_update_fn = get_ADMM_TV_fn()

    def pc_radon(model, data, measurement=None):
        with torch.no_grad():
            x = sde.prior_sampling(data.shape).to(data.device)

            ones = torch.ones_like(x).to(data.device)
            norm_const = _AT(_A(ones))
            timesteps = torch.linspace(sde.T, eps, sde.N)
            for i in tqdm(range(sde.N)):
                t = timesteps[i]
                # 1. batchify into sizes that fit into the GPU
                x_batch = batchfy(x, 12)
                # 2. Run PC step for each batch
                x_agg = list()
                for idx, x_batch_sing in enumerate(x_batch):
                    x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t)
                    x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t)
                    x_agg.append(x_batch_sing)
                # 3. Aggregate to run ADMM TV
                x = torch.cat(x_agg, dim=0)
                # 4. Run ADMM TV
                x, x_mean = mc_update_fn(x, measurement=measurement)

                if save_progress:
                    if (i % 50) == 0:
                        print(f'iter: {i}/{sde.N}')
                        plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')
            # Final step which coerces the data fidelity error term to be zero,
            # and thereby satisfying Ax = y
            if final_consistency:
                x, x_mean = kaczmarz(x, x, measurement, lamb=1.0, norm_const=norm_const)

            return inverse_scaler(x_mean if denoise else x)

    return pc_radon



def get_ADMM_TV(eps=1e-5, radon=None, save_progress=False, save_root=None,
                img_shape=None, lamb_1=5, rho=10, outer_iter=30, inner_iter=20):

    del_x = torch.zeros(img_shape)
    del_y = torch.zeros(img_shape)
    del_z = torch.zeros(img_shape)
    udel_x = torch.zeros(img_shape)
    udel_y = torch.zeros(img_shape)
    udel_z = torch.zeros(img_shape)
    eps = 1e-10

    def _A(x):
        return radon.A(x)

    def _AT(sinogram):
        return radon.AT(sinogram)

    def A_cg(x):
        return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x)))

    def CG(A_fn, b_cg, x, n_inner=20):
        r = b_cg - A_fn(x)
        p = r
        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)

        for i in range(n_inner):
            Ap = A_fn(p)
            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)

            x += a * p
            r -= a * Ap

            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)
            if torch.sqrt(rs_new) < eps:
                break
            p = r + (rs_new / rs_old) * p
            rs_old = rs_new
        return x

    def CS_routine(x, ATy, niter=30):
        nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z
        if del_z.device != x.device:
            del_x = del_x.to(x.device)
            del_y = del_y.to(x.device)
            del_z = del_z.to(x.device)
            udel_x = udel_x.to(x.device)
            udel_y = udel_y.to(x.device)
            udel_z = udel_z.to(x.device)
        for i in tqdm(range(niter)):
            b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x))
                                + (_DyT(del_y) - _DyT(udel_y))
                                + (_DzT(del_z) - _DzT(udel_z)))
            x = CG(A_cg, b_cg, x, n_inner=inner_iter)
            if save_progress:
                plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x[0:1]), cmap='gray')

            del_x = shrink(_Dx(x) + udel_x, lamb_1 / rho)
            del_y = shrink(_Dy(x) + udel_y, lamb_1 / rho)
            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)
            udel_x = _Dx(x) - del_x + udel_x
            udel_y = _Dy(x) - del_y + udel_y
            udel_z = _Dz(x) - del_z + udel_z
        return x

    def get_ADMM_TV_fn():
        def ADMM_TV_fn(x, measurement=None):
            with torch.no_grad():
                ATy = _AT(measurement)
                x, x_mean = CS_routine(x, ATy, niter=outer_iter)
                return x, x_mean
        return ADMM_TV_fn

    mc_update_fn = get_ADMM_TV_fn()

    def ADMM_TV(data, measurement=None):
        with torch.no_grad():
            x = torch.zeros(data.shape).to(data.device)
            x = mc_update_fn(x, measurement=measurement)
            return x

    return ADMM_TV


def get_ADMM_TV_isotropic(eps=1e-5, radon=None, save_progress=False, save_root=None,
                          img_shape=None, lamb_1=5, rho=10, outer_iter=30, inner_iter=20):
    """
    (get_ADMM_TV): implements anisotropic TV-ADMM
    In contrast, this function implements isotropic TV, which regularizes with |TV|_{1,2}
    """
    del_x = torch.zeros(img_shape)
    del_y = torch.zeros(img_shape)
    del_z = torch.zeros(img_shape)
    udel_x = torch.zeros(img_shape)
    udel_y = torch.zeros(img_shape)
    udel_z = torch.zeros(img_shape)
    eps = 1e-10

    def _A(x):
        return radon.A(x)

    def _AT(sinogram):
        return radon.AT(sinogram)

    def A_cg(x):
        return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x)))

    
    def CG(A_fn, b_cg, x, n_inner=20):
        r = b_cg - A_fn(x)
        p = r
        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)

        for i in range(n_inner):
            Ap = A_fn(p)
            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)

            x += a * p
            r -= a * Ap

            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)
            if torch.sqrt(rs_new) < eps:
                break
            p = r + (rs_new / rs_old) * p
            rs_old = rs_new
        return x

    def CS_routine(x, ATy, niter=30):
        nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z
        if del_z.device != x.device:
            del_x = del_x.to(x.device)
            del_y = del_y.to(x.device)
            del_z = del_z.to(x.device)
            udel_x = udel_x.to(x.device)
            udel_y = udel_y.to(x.device)
            udel_z = udel_z.to(x.device)
        for i in tqdm(range(niter)):
            b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x))
                                + (_DyT(del_y) - _DyT(udel_y))
                                + (_DzT(del_z) - _DzT(udel_z)))
            x = CG(A_cg, b_cg, x, n_inner=inner_iter)
            if save_progress:
                plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x[0:1]), cmap='gray')

            # Each of shape [448, 1, 256, 256]
            _Dxx = _Dx(x)
            _Dyx = _Dy(x)
            _Dzx = _Dz(x)
            # shape [448, 3, 256, 256]. dim=1 gradient dimension
            _Dxa = torch.cat((_Dxx, _Dyx, _Dzx), dim=1)
            udel_a = torch.cat((udel_x, udel_y, udel_z), dim=1)

            # prox
            del_a = prox_l21(_Dxa + udel_a, lamb_1 / rho, dim=1)

            # split
            del_x, del_y, del_z = torch.split(del_a, 1, dim=1)

            # del_x = prox_l21(_Dxx + udel_x, lamb_1 / rho, -2)
            # del_y = prox_l21(_Dyx + udel_y, lamb_1 / rho, -1)
            # del_z = prox_l21(_Dzx + udel_z, lamb_1 / rho, 0)

            udel_x = _Dxx - del_x + udel_x
            udel_y = _Dyx - del_y + udel_y
            udel_z = _Dzx - del_z + udel_z
        return x

    def get_ADMM_TV_fn():
        def ADMM_TV_fn(x, measurement=None):
            with torch.no_grad():
                ATy = _AT(measurement)
                x = CS_routine(x, ATy, niter=outer_iter)
                return x
        return ADMM_TV_fn

    mc_update_fn = get_ADMM_TV_fn()

    def ADMM_TV(data, measurement=None):
        with torch.no_grad():
            x = torch.zeros(data.shape).to(data.device)
            x = mc_update_fn(x, measurement=measurement)
            return x

    return ADMM_TV

def prox_l21(src, lamb, dim):
    """
    src.shape = [448(z), 1, 256(x), 256(y)]
    """
    weight_src = torch.linalg.norm(src, dim=dim, keepdim=True)
    weight_src_shrink = shrink(weight_src, lamb)

    weight = weight_src_shrink / weight_src
    return src * weight


def shrink(weight_src, lamb):
    return torch.sign(weight_src) * torch.max(torch.abs(weight_src) - lamb, torch.zeros_like(weight_src))


def get_pc_radon_ADMM_TV_mri(sde, predictor, corrector, inverse_scaler, snr, mask=None,
                             n_steps=1, probability_flow=False, continuous=False,
                             denoise=True, eps=1e-5, save_progress=False, save_root=None,
                             img_shape=None, lamb_1=5, rho=10):
    predictor_update_fn = functools.partial(shared_predictor_update_fn,
                                            sde=sde,
                                            predictor=predictor,
                                            probability_flow=probability_flow,
                                            continuous=continuous)
    corrector_update_fn = functools.partial(shared_corrector_update_fn,
                                            sde=sde,
                                            corrector=corrector,
                                            continuous=continuous,
                                            snr=snr,
                                            n_steps=n_steps)

    del_z = torch.zeros(img_shape)
    udel_z = torch.zeros(img_shape)
    eps = 1e-10

    def _A(x):
        return fft2(x) * mask

    def _AT(kspace):
        return torch.real(ifft2(kspace))

    def _Dz(x):  # Batch direction
        y = torch.zeros_like(x)
        y[:-1] = x[1:]
        y[-1] = x[0]
        return y - x

    def _DzT(x):  # Batch direction
        y = torch.zeros_like(x)
        y[:-1] = x[1:]
        y[-1] = x[0]

        tempt = -(y - x)
        difft = tempt[:-1]
        y[1:] = difft
        y[0] = x[-1] - x[0]

        return y

    def A_cg(x):
        return _AT(_A(x)) + rho * _DzT(_Dz(x))

    def shrink(src, lamb):
        return torch.sign(src) * torch.max(torch.abs(src) - lamb, torch.zeros_like(src))

    def CG(A_fn, b_cg, x, n_inner=10):
        r = b_cg - A_fn(x)
        p = r
        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)

        for i in range(n_inner):
            Ap = A_fn(p)
            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)

            x += a * p
            r -= a * Ap

            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)
            if torch.sqrt(rs_new) < eps:
                break
            p = r + (rs_new / rs_old) * p
            rs_old = rs_new
        return x

    def CS_routine(x, ATy, niter=20):
        nonlocal del_z, udel_z
        if del_z.device != x.device:
            del_z = del_z.to(x.device)
            udel_z = del_z.to(x.device)
        for i in range(niter):
            b_cg = ATy + rho * (_DzT(del_z) - _DzT(udel_z))
            x = CG(A_cg, b_cg, x, n_inner=1)

            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)
            udel_z = _Dz(x) - del_z + udel_z
        x_mean = x
        return x, x_mean

    def get_update_fn(update_fn):
        def radon_update_fn(model, data, x, t):
            with torch.no_grad():
                vec_t = torch.ones(x.shape[0], device=x.device) * t
                x, x_mean = update_fn(x, vec_t, model=model)
                return x, x_mean

        return radon_update_fn

    def get_ADMM_TV_fn():
        def ADMM_TV_fn(x, measurement=None):
            with torch.no_grad():
                ATy = _AT(measurement)
                x, x_mean = CS_routine(x, ATy, niter=1)
                return x, x_mean
        return ADMM_TV_fn

    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)
    corrector_denoise_update_fn = get_update_fn(corrector_update_fn)
    mc_update_fn = get_ADMM_TV_fn()

    def pc_radon(model, data, measurement=None):
        with torch.no_grad():
            x = sde.prior_sampling(data.shape).to(data.device)
            timesteps = torch.linspace(sde.T, eps, sde.N)
            for i in tqdm(range(sde.N)):
                t = timesteps[i]
                # 1. batchify into sizes that fit into the GPU
                x_batch = batchfy(x, 20)
                # 2. Run PC step for each batch
                x_agg = list()
                for idx, x_batch_sing in enumerate(x_batch):
                    x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t)
                    x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t)
                    x_agg.append(x_batch_sing)
                # 3. Aggregate to run ADMM TV
                x = torch.cat(x_agg, dim=0)
                # 4. Run ADMM TV
                x, x_mean = mc_update_fn(x, measurement=measurement)

                if save_progress:
                    if (i % 50) == 0:
                        print(f'iter: {i}/{sde.N}')
                        plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')

            return inverse_scaler(x_mean if denoise else x)

    return pc_radon

================================================
FILE: datasets.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
"""Return training and evaluation/test datasets from config files."""
from torch.utils.data import Dataset, DataLoader
import numpy as np


def get_data_scaler(config):
  """Data normalizer. Assume data are always in [0, 1]."""
  if config.data.centered:
    # Rescale to [-1, 1]
    return lambda x: x * 2. - 1.
  else:
    return lambda x: x


def get_data_inverse_scaler(config):
  """Inverse data normalizer."""
  if config.data.centered:
    # Rescale [-1, 1] to [0, 1]
    return lambda x: (x + 1.) / 2.
  else:
    return lambda x: x


def crop_resize(image, resolution):
  """Crop and resize an image to the given resolution."""
  crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])
  h, w = tf.shape(image)[0], tf.shape(image)[1]
  image = image[(h - crop) // 2:(h + crop) // 2,
          (w - crop) // 2:(w + crop) // 2]
  image = tf.image.resize(
    image,
    size=(resolution, resolution),
    antialias=True,
    method=tf.image.ResizeMethod.BICUBIC)
  return tf.cast(image, tf.uint8)


def resize_small(image, resolution):
  """Shrink an image to the given resolution."""
  h, w = image.shape[0], image.shape[1]
  ratio = resolution / min(h, w)
  h = tf.round(h * ratio, tf.int32)
  w = tf.round(w * ratio, tf.int32)
  return tf.image.resize(image, [h, w], antialias=True)


def central_crop(image, size):
  """Crop the center of an image to the given size."""
  top = (image.shape[0] - size) // 2
  left = (image.shape[1] - size) // 2
  return tf.image.crop_to_bounding_box(image, top, left, size, size)


def get_dataset(config, uniform_dequantization=False, evaluation=False):
  """Create data loaders for training and evaluation.

  Args:
    config: A ml_collection.ConfigDict parsed from config files.
    uniform_dequantization: If `True`, add uniform dequantization to images.
    evaluation: If `True`, fix number of epochs to 1.

  Returns:
    train_ds, eval_ds, dataset_builder.
  """
  # Compute batch size for this worker.
  batch_size = config.training.batch_size if not evaluation else config.eval.batch_size
  if batch_size % jax.device_count() != 0:
    raise ValueError(f'Batch sizes ({batch_size} must be divided by'
                     f'the number of devices ({jax.device_count()})')

  # Reduce this when image resolution is too large and data pointer is stored
  shuffle_buffer_size = 10000
  prefetch_size = tf.data.experimental.AUTOTUNE
  num_epochs = None if not evaluation else 1

  # Create dataset builders for each dataset.
  if config.data.dataset == 'CIFAR10':
    dataset_builder = tfds.builder('cifar10')
    train_split_name = 'train'
    eval_split_name = 'test'

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      # Added to train grayscale models
      # img = tf.image.rgb_to_grayscale(img)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)


  elif config.data.dataset == 'SVHN':
    dataset_builder = tfds.builder('svhn_cropped')
    train_split_name = 'train'
    eval_split_name = 'test'

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)

  elif config.data.dataset == 'CELEBA':
    dataset_builder = tfds.builder('celeb_a')
    train_split_name = 'train'
    eval_split_name = 'validation'

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      img = central_crop(img, 140)
      img = resize_small(img, config.data.image_size)
      return img

  elif config.data.dataset == 'LSUN':
    dataset_builder = tfds.builder(f'lsun/{config.data.category}')
    train_split_name = 'train'
    eval_split_name = 'validation'

    if config.data.image_size == 128:
      def resize_op(img):
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = resize_small(img, config.data.image_size)
        img = central_crop(img, config.data.image_size)
        return img

    else:
      def resize_op(img):
        img = crop_resize(img, config.data.image_size)
        img = tf.image.convert_image_dtype(img, tf.float32)
        return img

  elif config.data.dataset in ['FFHQ', 'CelebAHQ']:
    dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path)
    train_split_name = eval_split_name = 'train'

  else:
    raise NotImplementedError(
      f'Dataset {config.data.dataset} not yet supported.')

  # Customize preprocess functions for each dataset.
  if config.data.dataset in ['FFHQ', 'CelebAHQ']:
    def preprocess_fn(d):
      sample = tf.io.parse_single_example(d, features={
        'shape': tf.io.FixedLenFeature([3], tf.int64),
        'data': tf.io.FixedLenFeature([], tf.string)})
      data = tf.io.decode_raw(sample['data'], tf.uint8)
      data = tf.reshape(data, sample['shape'])
      data = tf.transpose(data, (1, 2, 0))
      img = tf.image.convert_image_dtype(data, tf.float32)
      if config.data.random_flip and not evaluation:
        img = tf.image.random_flip_left_right(img)
      if uniform_dequantization:
        img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.
      return dict(image=img, label=None)

  else:
    def preprocess_fn(d):
      """Basic preprocessing function scales data to [0, 1) and randomly flips."""
      img = resize_op(d['image'])
      if config.data.random_flip and not evaluation:
        img = tf.image.random_flip_left_right(img)
      if uniform_dequantization:
        img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.

      return dict(image=img, label=d.get('label', None))

  def create_dataset(dataset_builder, split):
    dataset_options = tf.data.Options()
    dataset_options.experimental_optimization.map_parallelization = True
    dataset_options.experimental_threading.private_threadpool_size = 48
    dataset_options.experimental_threading.max_intra_op_parallelism = 1
    read_config = tfds.ReadConfig(options=dataset_options)
    if isinstance(dataset_builder, tfds.core.DatasetBuilder):
      dataset_builder.download_and_prepare()
      ds = dataset_builder.as_dataset(
        split=split, shuffle_files=True, read_config=read_config)
    else:
      ds = dataset_builder.with_options(dataset_options)
    ds = ds.repeat(count=num_epochs)
    ds = ds.shuffle(shuffle_buffer_size)
    ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds.prefetch(prefetch_size)

  train_ds = create_dataset(dataset_builder, train_split_name)
  eval_ds = create_dataset(dataset_builder, eval_split_name)
  return train_ds, eval_ds, dataset_builder


from pathlib import Path

class fastmri_knee(Dataset):
  """ Simple pytorch dataset for fastmri knee singlecoil dataset """
  def __init__(self, root, is_complex=False):
    self.root = root
    self.data_list = list(root.glob('*/*.npy'))
    self.is_complex = is_complex

  def __len__(self):
    return len(self.data_list)

  def __getitem__(self, idx):
    fname = self.data_list[idx]
    if not self.is_complex:
      data = np.load(fname)
    else:
      data = np.load(fname).astype(np.complex64)
    data = np.expand_dims(data, axis=0)
    return data


class AAPM(Dataset):
  def __init__(self, root, sort):
    self.root = root
    self.data_list = list(root.glob('full_dose/*.npy'))
    self.sort = sort
    if sort:
      self.data_list = sorted(self.data_list)

  def __len__(self):
    return len(self.data_list)

  def __getitem__(self, idx):
    fname = self.data_list[idx]
    data = np.load(fname)
    data = np.expand_dims(data, axis=0)
    return data


class Object5(Dataset):
  def __init__(self, root, slice, fast=False):
    """
    slice - range of the 2000 _volumes_ that you want,
    but the dataset will return images, so will be 256 times longer

    fast - set to true to get a tiny version of the dataset
    """
    if fast:
      self.NUM_SLICES = 10
    else:
      self.NUM_SLICES = 256


    self.root = root
    self.data_list = list(root.glob('*.npz'))

    if len(self.data_list) == 0:
      raise ValueError(f"No npz files found in {root}")

    self.data_list = sorted(self.data_list)[slice]

  def __len__(self):
    return len(self.data_list) * self.NUM_SLICES

  def __getitem__(self, idx):
    vol_index = idx // self.NUM_SLICES
    slice_index = idx % self.NUM_SLICES
    fname = self.data_list[vol_index]
    data = np.load(fname)['x'][slice_index]
    data = np.expand_dims(data, axis=0)
    return data

class fastmri_knee_infer(Dataset):
  """ Simple pytorch dataset for fastmri knee singlecoil dataset """
  def __init__(self, root, sort=True, is_complex=False):
    self.root = root
    self.data_list = list(root.glob('*/*.npy'))
    self.is_complex = is_complex
    if sort:
      self.data_list = sorted(self.data_list)

  def __len__(self):
    return len(self.data_list)

  def __getitem__(self, idx):
    fname = self.data_list[idx]
    if not self.is_complex:
      data = np.load(fname)
    else:
      data = np.load(fname).astype(np.complex64)
    data = np.expand_dims(data, axis=0)
    return data, str(fname)


class fastmri_knee_magpha(Dataset):
  """ Simple pytorch dataset for fastmri knee singlecoil dataset """
  def __init__(self, root):
    self.root = root
    self.data_list = list(root.glob('*/*.npy'))

  def __len__(self):
    return len(self.data_list)

  def __getitem__(self, idx):
    fname = self.data_list[idx]
    data = np.load(fname).astype(np.float32)
    return data


class fastmri_knee_magpha_infer(Dataset):
  """ Simple pytorch dataset for fastmri knee singlecoil dataset """
  def __init__(self, root, sort=True):
    self.root = root
    self.data_list = list(root.glob('*/*.npy'))
    if sort:
      self.data_list = sorted(self.data_list)

  def __len__(self):
    return len(self.data_list)

  def __getitem__(self, idx):
    fname = self.data_list[idx]
    data = np.load(fname).astype(np.float32)
    return data, str(fname)


def create_dataloader(configs, evaluation=False, sort=True):
  shuffle = True if not evaluation else False
  if configs.data.dataset == 'Object5':
    train_dataset = Object5(Path(configs.data.root), slice(None,1800))  
    val_dataset = Object5(Path(configs.data.root), slice(1800,None)) 
  elif configs.data.dataset == 'Object5Fast':
    train_dataset = Object5(Path(configs.data.root), slice(None,1), fast=True)
    val_dataset = Object5(Path(configs.data.root), slice(1,2), fast=True)
  elif configs.data.dataset == 'AAPM':
    train_dataset = AAPM(Path(configs.data.root) / f'train', sort=False)
    val_dataset = AAPM(Path(configs.data.root) / f'test', sort=True)
  elif configs.data.is_multi:
    train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_multicoil_{configs.data.image_size}_train')
    val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_{configs.data.image_size}_val', sort=sort)
  elif configs.data.is_complex:
    if configs.data.magpha:
      train_dataset = fastmri_knee_magpha(Path(configs.data.root) / f'knee_complex_magpha_{configs.data.image_size}_train')
      val_dataset = fastmri_knee_magpha_infer(Path(configs.data.root) / f'knee_complex_magpha_{configs.data.image_size}_val')
    else:
      train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_complex_{configs.data.image_size}_train', is_complex=True)
      val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_complex_{configs.data.image_size}_val', is_complex=True)
  elif configs.data.dataset == 'fastmri_knee':
    train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_{configs.data.image_size}_train')
    val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_{configs.data.image_size}_val', sort=sort)
  else:
    raise ValueError(f'Dataset {configs.data.dataset} not recognized.')

  train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=configs.training.batch_size,
    shuffle=shuffle,
    drop_last=True
  )
  val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=configs.training.batch_size,
    # shuffle=False,
    shuffle=True,
    drop_last=True
  )
  return train_loader, val_loader



def create_dataloader_regression(configs, evaluation=False):
  shuffle = True if not evaluation else False
  train_dataset = fastmri_knee(Path(configs.root) / f'knee_{configs.image_size}_train')
  val_dataset = fastmri_knee_infer(Path(configs.root) / f'knee_{configs.image_size}_val')

  train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=configs.batch_size,
    shuffle=shuffle,
    drop_last=True
  )
  val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=configs.batch_size,
    shuffle=False,
    drop_last=True
  )
  return train_loader, val_loader


================================================
FILE: environment.yml
================================================
name: diffusion-mbir
channels:
  - conda-forge
  - defaults
dependencies:
  - python=3.8
  - numpy
  - matplotlib
  - scikit-image
  - sporco
  - tqdm
  - ninja
  - pytorch::pytorch
  - pytorch::torchvision
  - tensorboard
  - pip
  - pip:
      - ml_collections
      - ninja


================================================
FILE: evaluation.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for computing FID/Inception scores."""

import numpy as np
import six

INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'
INCEPTION_OUTPUT = 'logits'
INCEPTION_FINAL_POOL = 'pool_3'
_DEFAULT_DTYPES = {
  INCEPTION_OUTPUT: tf.float32,
  INCEPTION_FINAL_POOL: tf.float32
}
INCEPTION_DEFAULT_IMAGE_SIZE = 299


def get_inception_model(inceptionv3=False):
  if inceptionv3:
    return tfhub.load(
      'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4')
  else:
    return tfhub.load(INCEPTION_TFHUB)


def load_dataset_stats(config):
  """Load the pre-computed dataset statistics."""
  if config.data.dataset == 'CIFAR10':
    filename = 'assets/stats/cifar10_stats.npz'
  elif config.data.dataset == 'CELEBA':
    filename = 'assets/stats/celeba_stats.npz'
  elif config.data.dataset == 'LSUN':
    filename = f'assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz'
  else:
    raise ValueError(f'Dataset {config.data.dataset} stats not found.')

  with tf.io.gfile.GFile(filename, 'rb') as fin:
    stats = np.load(fin)
    return stats


def classifier_fn_from_tfhub(output_fields, inception_model,
                             return_tensor=False):
  """Returns a function that can be as a classifier function.

  Copied from tfgan but avoid loading the model each time calling _classifier_fn

  Args:
    output_fields: A string, list, or `None`. If present, assume the module
      outputs a dictionary, and select this field.
    inception_model: A model loaded from TFHub.
    return_tensor: If `True`, return a single tensor instead of a dictionary.

  Returns:
    A one-argument function that takes an image Tensor and returns outputs.
  """
  if isinstance(output_fields, six.string_types):
    output_fields = [output_fields]

  def _classifier_fn(images):
    output = inception_model(images)
    if output_fields is not None:
      output = {x: output[x] for x in output_fields}
    if return_tensor:
      assert len(output) == 1
      output = list(output.values())[0]
    return tf.nest.map_structure(tf.compat.v1.layers.flatten, output)

  return _classifier_fn


@tf.function
def run_inception_jit(inputs,
                      inception_model,
                      num_batches=1,
                      inceptionv3=False):
  """Running the inception network. Assuming input is within [0, 255]."""
  if not inceptionv3:
    inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5
  else:
    inputs = tf.cast(inputs, tf.float32) / 255.

  return tfgan.eval.run_classifier_fn(
    inputs,
    num_batches=num_batches,
    classifier_fn=classifier_fn_from_tfhub(None, inception_model),
    dtypes=_DEFAULT_DTYPES)


@tf.function
def run_inception_distributed(input_tensor,
                              inception_model,
                              num_batches=1,
                              inceptionv3=False):
  """Distribute the inception network computation to all available TPUs.

  Args:
    input_tensor: The input images. Assumed to be within [0, 255].
    inception_model: The inception network model obtained from `tfhub`.
    num_batches: The number of batches used for dividing the input.
    inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1.

  Returns:
    A dictionary with key `pool_3` and `logits`, representing the pool_3 and
      logits of the inception network respectively.
  """
  num_tpus = jax.local_device_count()
  input_tensors = tf.split(input_tensor, num_tpus, axis=0)
  pool3 = []
  logits = [] if not inceptionv3 else None
  device_format = '/TPU:{}' if 'TPU' in str(jax.devices()[0]) else '/GPU:{}'
  for i, tensor in enumerate(input_tensors):
    with tf.device(device_format.format(i)):
      tensor_on_device = tf.identity(tensor)
      res = run_inception_jit(
        tensor_on_device, inception_model, num_batches=num_batches,
        inceptionv3=inceptionv3)

      if not inceptionv3:
        pool3.append(res['pool_3'])
        logits.append(res['logits'])  # pytype: disable=attribute-error
      else:
        pool3.append(res)

  with tf.device('/CPU'):
    return {
      'pool_3': tf.concat(pool3, axis=0),
      'logits': tf.concat(logits, axis=0) if not inceptionv3 else None
    }


================================================
FILE: fastmri_utils.py
================================================
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from typing import List, Optional

import torch
from packaging import version

if version.parse(torch.__version__) >= version.parse("1.7.0"):
    import torch.fft  # type: ignore


def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
    """
    Apply centered 2 dimensional Fast Fourier Transform.
    Args:
        data: Complex valued input data containing at least 3 dimensions:
            dimensions -3 & -2 are spatial dimensions and dimension -1 has size
            2. All other dimensions are assumed to be batch dimensions.
        norm: Whether to include normalization. Must be one of ``"backward"``
            or ``"ortho"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details.
    Returns:
        The FFT of the input.
    """
    if not data.shape[-1] == 2:
        raise ValueError("Tensor does not have separate complex dim.")
    if norm not in ("ortho", "backward"):
        raise ValueError("norm must be 'ortho' or 'backward'.")
    normalized = True if norm == "ortho" else False

    data = ifftshift(data, dim=[-3, -2])
    data = torch.fft(data, 2, normalized=normalized)
    data = fftshift(data, dim=[-3, -2])

    return data


def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.
    Args:
        data: Complex valued input data containing at least 3 dimensions:
            dimensions -3 & -2 are spatial dimensions and dimension -1 has size
            2. All other dimensions are assumed to be batch dimensions.
        norm: Whether to include normalization. Must be one of ``"backward"``
            or ``"ortho"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for
            details.
    Returns:
        The IFFT of the input.
    """
    if not data.shape[-1] == 2:
        raise ValueError("Tensor does not have separate complex dim.")
    if norm not in ("ortho", "backward"):
        raise ValueError("norm must be 'ortho' or 'backward'.")
    normalized = True if norm == "ortho" else False

    data = ifftshift(data, dim=[-3, -2])
    data = torch.ifft(data, 2, normalized=normalized)
    data = fftshift(data, dim=[-3, -2])

    return data


def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
    """
    Apply centered 2 dimensional Fast Fourier Transform.
    Args:
        data: Complex valued input data containing at least 3 dimensions:
            dimensions -3 & -2 are spatial dimensions and dimension -1 has size
            2. All other dimensions are assumed to be batch dimensions.
        norm: Normalization mode. See ``torch.fft.fft``.
    Returns:
        The FFT of the input.
    """
    if not data.shape[-1] == 2:
        raise ValueError("Tensor does not have separate complex dim.")

    data = ifftshift(data, dim=[-3, -2])
    data = torch.view_as_real(
        torch.fft.fftn(  # type: ignore
            torch.view_as_complex(data), dim=(-2, -1), norm=norm
        )
    )
    data = fftshift(data, dim=[-3, -2])

    return data


def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.
    Args:
        data: Complex valued input data containing at least 3 dimensions:
            dimensions -3 & -2 are spatial dimensions and dimension -1 has size
            2. All other dimensions are assumed to be batch dimensions.
        norm: Normalization mode. See ``torch.fft.ifft``.
    Returns:
        The IFFT of the input.
    """
    if not data.shape[-1] == 2:
        raise ValueError("Tensor does not have separate complex dim.")

    data = ifftshift(data, dim=[-3, -2])
    data = torch.view_as_real(
        torch.fft.ifftn(  # type: ignore
            torch.view_as_complex(data), dim=(-2, -1), norm=norm
        )
    )
    data = fftshift(data, dim=[-3, -2])

    return data


# Helper functions


def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
    """
    Similar to roll but for only one dim.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    shift = shift % x.size(dim)
    if shift == 0:
        return x

    left = x.narrow(dim, 0, x.size(dim) - shift)
    right = x.narrow(dim, x.size(dim) - shift, shift)

    return torch.cat((right, left), dim=dim)


def roll(
    x: torch.Tensor,
    shift: List[int],
    dim: List[int],
) -> torch.Tensor:
    """
    Similar to np.roll but applies to PyTorch Tensors.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    if len(shift) != len(dim):
        raise ValueError("len(shift) must match len(dim)")

    for (s, d) in zip(shift, dim):
        x = roll_one_dim(x, s, d)

    return x


def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to fftshift.
    Returns:
        fftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = x.shape[dim_num] // 2

    return roll(x, shift, dim)


def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to ifftshift.
    Returns:
        ifftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = (x.shape[dim_num] + 1) // 2

    return roll(x, shift, dim)

================================================
FILE: inverse_problem_solver_AAPM_3d_total.py
================================================
import torch
from torch._C import device
from losses import get_optimizer
from models.ema import ExponentialMovingAverage

import numpy as np
import controllable_generation_TV

from utils import restore_checkpoint, clear, batchfy, patient_wise_min_max, img_wise_min_max
from pathlib import Path
from models import utils as mutils
from models import ncsnpp
from sde_lib import VESDE
from sampling import (ReverseDiffusionPredictor,
                      LangevinCorrector)
import datasets
import time
# for radon
from physics.ct import CT
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

###############################################
# Configurations
###############################################
problem = 'sparseview_CT_ADMM_TV_total'
config_name = 'AAPM_256_ncsnpp_continuous'
sde = 'VESDE'
num_scales = 2000
ckpt_num = 185
N = num_scales

vol_name = 'L067'
root = Path(f'./data/CT/ind/256_sorted/{vol_name}')

# Parameters for the inverse problem
Nview = 8
det_spacing = 1.0
size = 256
det_count = int((size * (2 * torch.ones(1)).sqrt()).ceil())
lamb = 0.04
rho = 10
freq = 1

if sde.lower() == 'vesde':
    from configs.ve import AAPM_256_ncsnpp_continuous as configs
    ckpt_filename = f"exp/ve/{config_name}/checkpoint_{ckpt_num}.pth"
    config = configs.get_config()
    config.model.num_scales = N
    sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sde.N = N
    sampling_eps = 1e-5
predictor = ReverseDiffusionPredictor
corrector = LangevinCorrector
probability_flow = False
snr = 0.16
n_steps = 1

batch_size = 12
config.training.batch_size = batch_size
config.eval.batch_size = batch_size
random_seed = 0

sigmas = mutils.get_sigmas(config)
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
score_model = mutils.create_model(config)  ## model

optimizer = get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(),
                               decay=config.model.ema_rate)
state = dict(step=0, optimizer=optimizer,
             model=score_model, ema=ema)

state = restore_checkpoint(ckpt_filename, state, config.device, skip_sigma=True, skip_optimizer=True)
ema.copy_to(score_model.parameters())

# Specify save directory for saving generated samples
save_root = Path(f'./results/{config_name}/{problem}/m{Nview}/rho{rho}/lambda{lamb}')
save_root.mkdir(parents=True, exist_ok=True)

irl_types = ['input', 'recon', 'label', 'BP', 'sinogram']
for t in irl_types:
    if t == 'recon':
        save_root_f = save_root / t / 'progress'
        save_root_f.mkdir(exist_ok=True, parents=True)
    else:
        save_root_f = save_root / t
        save_root_f.mkdir(parents=True, exist_ok=True)

# read all data
fname_list = os.listdir(root)
fname_list = sorted(fname_list, key=lambda x: float(x.split(".")[0]))
print(fname_list)
all_img = []

print("Loading all data")
for fname in tqdm(fname_list):
    just_name = fname.split('.')[0]
    img = torch.from_numpy(np.load(os.path.join(root, fname), allow_pickle=True))
    h, w = img.shape
    img = img.view(1, 1, h, w)
    all_img.append(img)
    plt.imsave(os.path.join(save_root, 'label', f'{just_name}.png'), clear(img), cmap='gray')
all_img = torch.cat(all_img, dim=0)
print(f"Data loaded shape : {all_img.shape}")

# full
angles = np.linspace(0, np.pi, 180, endpoint=False)
radon = CT(img_width=h, radon_view=Nview, circle=False, device=config.device)

predicted_sinogram = []
label_sinogram = []
img_cache = None

img = all_img.to(config.device)
pc_radon = controllable_generation_TV.get_pc_radon_ADMM_TV_vol(sde,
                                                               predictor, corrector,
                                                               inverse_scaler,
                                                               snr=snr,
                                                               n_steps=n_steps,
                                                               probability_flow=probability_flow,
                                                               continuous=config.training.continuous,
                                                               denoise=True,
                                                               radon=radon,
                                                               save_progress=True,
                                                               save_root=save_root,
                                                               final_consistency=True,
                                                               img_shape=img.shape,
                                                               lamb_1=lamb,
                                                               rho=rho)
# Sparse by masking
sinogram = radon.A(img)

# A_dagger
bp = radon.AT(sinogram)

# Recon Image
x = pc_radon(score_model, scaler(img), measurement=sinogram)
img_cahce = x[-1].unsqueeze(0)

count = 0
for i, recon_img in enumerate(x):
    plt.imsave(save_root / 'BP' / f'{count}.png', clear(bp[i]), cmap='gray')
    plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray')
    plt.imsave(save_root / 'recon' / f'{count}.png', clear(recon_img), cmap='gray')

    count += 1

# Recon and Save Sinogram
label_sinogram.append(radon.A_all(img))
predicted_sinogram.append(radon.A_all(x))

original_sinogram = torch.cat(label_sinogram, dim=0).detach().cpu().numpy()
recon_sinogram = torch.cat(predicted_sinogram, dim=0).detach().cpu().numpy()

np.save(str(save_root / 'sinogram' / f'original_{count}.npy'), original_sinogram)
np.save(str(save_root / 'sinogram' / f'recon_{count}.npy'), recon_sinogram)

================================================
FILE: inverse_problem_solver_BRATS_MRI_3d_total.py
================================================
from pathlib import Path
from models import utils as mutils
import sampling
from sde_lib import VESDE
from sampling import (ReverseDiffusionPredictor,
                      LangevinCorrector,
                      LangevinCorrectorCS)
from models import ncsnpp
from itertools import islice
from losses import get_optimizer
import datasets
import time
import controllable_generation_TV
from utils import restore_checkpoint, fft2, ifft2, show_samples_gray, get_mask, clear
import torch
import torch.nn as nn
import numpy as np
from models.ema import ExponentialMovingAverage
from scipy.io import savemat, loadmat
from tqdm import tqdm
import matplotlib.pyplot as plt
import importlib


###############################################
# Configurations
###############################################
problem = 'Fourier_CS_3d_admm_tv'
config_name = 'fastmri_knee_320_ncsnpp_continuous'
sde = 'VESDE'
num_scales = 2000
ckpt_num = 95
N = num_scales

root = './data/MRI/BRATS'
vol = 'Brats18_CBICA_AAM_1'

if sde.lower() == 'vesde':
  # from configs.ve import fastmri_knee_320_ncsnpp_continuous as configs
  configs = importlib.import_module(f"configs.ve.{config_name}")
  if config_name == 'fastmri_knee_320_ncsnpp_continuous':
    ckpt_filename = f"./exp/ve/{config_name}/checkpoint_{ckpt_num}.pth"
  elif config_name == 'ffhq_256_ncsnpp_continuous':
    ckpt_filename = f"exp/ve/{config_name}/checkpoint_48.pth"
  config = configs.get_config()
  config.model.num_scales = num_scales
  sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
  sde.N = N
  sampling_eps = 1e-5

img_size = 240
batch_size = 1
config.training.batch_size = batch_size
predictor = ReverseDiffusionPredictor
corrector = LangevinCorrector
probability_flow = False
snr = 0.16
n_steps = 1

# parameters for Fourier CS recon
mask_type = 'uniform1d'
use_measurement_noise = False
acc_factor = 2.0
center_fraction = 0.15

# ADMM TV parameters
lamb_list = [0.005]
rho_list = [0.01]

random_seed = 0

sigmas = mutils.get_sigmas(config)
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)
score_model = mutils.create_model(config)

optimizer = get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(),
                               decay=config.model.ema_rate)
state = dict(step=0, optimizer=optimizer,
             model=score_model, ema=ema)
state = restore_checkpoint(ckpt_filename, state, config.device, skip_sigma=True)
ema.copy_to(score_model.parameters())

fname_list = sorted(list((Path(root) / vol).glob('*.npy')))
all_img = []
for fname in tqdm(fname_list):
    img = np.load(fname)
    img = torch.from_numpy(img)
    h, w = img.shape
    img = img.view(1, 1, h, w)
    all_img.append(img)

all_img = torch.cat(all_img, dim=0)

# normalize the volume to be in proper range
vmax = all_img.max()
all_img /= (vmax + 1e-5)

img = all_img.to(config.device)
b = img.shape[0]

for lamb in lamb_list:
    for rho in rho_list:
        print(f'lambda: {lamb}')
        print(f'rho:    {rho}')
        # Specify save directory for saving generated samples
        save_root = Path(f'./results/{config_name}/{problem}/{mask_type}/acc{acc_factor}/lamb{lamb}/rho{rho}/{vol}')
        save_root.mkdir(parents=True, exist_ok=True)

        irl_types = ['input', 'recon', 'label']
        for t in irl_types:
            save_root_f = save_root / t
            save_root_f.mkdir(parents=True, exist_ok=True)

        ###############################################
        # Inference
        ###############################################

        # forward model
        kspace = fft2(img)

        # generate mask
        mask = get_mask(torch.zeros(1, 1, h, w), img_size, batch_size,
                        type=mask_type, acc_factor=acc_factor, center_fraction=center_fraction)
        mask = mask.to(img.device)
        mask = mask.repeat(b, 1, 1, 1)

        pc_fouriercs = controllable_generation_TV.get_pc_radon_ADMM_TV_mri(sde,
                                                                           predictor, corrector,
                                                                           inverse_scaler,
                                                                           mask=mask,
                                                                           lamb_1=lamb,
                                                                           rho=rho,
                                                                           img_shape=img.shape,
                                                                           snr=snr,
                                                                           n_steps=n_steps,
                                                                           probability_flow=probability_flow,
                                                                           continuous=config.training.continuous)

        # undersampling
        under_kspace = kspace * mask
        under_img = torch.real(ifft2(under_kspace))

        count = 0
        for i, recon_img in enumerate(under_img):
            plt.imsave(save_root / 'input' / f'{count}.png', clear(under_img[i]), cmap='gray')
            plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray')
            count += 1

        x = pc_fouriercs(score_model, scaler(under_img), measurement=under_kspace)

        count = 0
        for i, recon_img in enumerate(x):
            plt.imsave(save_root / 'input' / f'{count}.png', clear(under_img[i]), cmap='gray')
            plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray')
            plt.imsave(save_root / 'recon' / f'{count}.png', clear(recon_img), cmap='gray')
            np.save(str(save_root / 'input' / f'{count}.npy'), clear(under_img[i], normalize=False))
            np.save(str(save_root / 'recon' / f'{count}.npy'), clear(x[i], normalize=False))
            np.save(str(save_root / 'label' / f'{count}.npy'), clear(img[i], normalize=False))
            count += 1



================================================
FILE: likelihood.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""

import torch
import numpy as np
from scipy import integrate
from models import utils as mutils


def get_div_fn(fn):
  """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""

  def div_fn(x, t, eps):
    with torch.enable_grad():
      x.requires_grad_(True)
      fn_eps = torch.sum(fn(x, t) * eps)
      grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
    x.requires_grad_(False)
    return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))

  return div_fn


def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',
                      rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
  """Create a function to compute the unbiased log-likelihood estimate of a given data point.

  Args:
    sde: A `sde_lib.SDE` object that represents the forward SDE.
    inverse_scaler: The inverse data normalizer.
    hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
    rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
    atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
    method: A `str`. The algorithm for the black-box ODE solver.
      See documentation for `scipy.integrate.solve_ivp`.
    eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.

  Returns:
    A function that a batch of data points and returns the log-likelihoods in bits/dim,
      the latent code, and the number of function evaluations cost by computation.
  """

  def drift_fn(model, x, t):
    """The drift function of the reverse-time SDE."""
    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
    # Probability flow ODE is a special case of Reverse SDE
    rsde = sde.reverse(score_fn, probability_flow=True)
    return rsde.sde(x, t)[0]

  def div_fn(model, x, t, noise):
    return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)

  def likelihood_fn(model, data):
    """Compute an unbiased estimate to the log-likelihood in bits/dim.

    Args:
      model: A score model.
      data: A PyTorch tensor.

    Returns:
      bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
      z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
        probability flow ODE.
      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
    """
    with torch.no_grad():
      shape = data.shape
      if hutchinson_type == 'Gaussian':
        epsilon = torch.randn_like(data)
      elif hutchinson_type == 'Rademacher':
        epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
      else:
        raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")

      def ode_func(t, x):
        sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
        vec_t = torch.ones(sample.shape[0], device=sample.device) * t
        drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
        logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
        return np.concatenate([drift, logp_grad], axis=0)

      init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)
      solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
      nfe = solution.nfev
      zp = solution.y[:, -1]
      z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
      delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
      prior_logp = sde.prior_logp(z)
      bpd = -(prior_logp + delta_logp) / np.log(2)
      N = np.prod(shape[1:])
      bpd = bpd / N
      # A hack to convert log-likelihoods to bits/dim
      offset = 7. - inverse_scaler(-1.)
      bpd = bpd + offset
      return bpd, z, nfe

  return likelihood_fn


================================================
FILE: losses.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""All functions related to loss computation and optimization.
"""

import torch
import torch.optim as optim
import numpy as np
from models import utils as mutils
from sde_lib import VESDE, VPSDE
from utils import fft2, ifft2, get_mask
import numpy as np


def get_optimizer(config, params):
  """Returns a flax optimizer object based on `config`."""
  if config.optim.optimizer == 'Adam':
    optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                           weight_decay=config.optim.weight_decay)
  else:
    raise NotImplementedError(
      f'Optimizer {config.optim.optimizer} not supported yet!')

  return optimizer


def optimization_manager(config):
  """Returns an optimize_fn based on `config`."""

  def optimize_fn(optimizer, params, step, lr=config.optim.lr,
                  warmup=config.optim.warmup,
                  grad_clip=config.optim.grad_clip):
    """Optimizes with warmup and gradient clipping (disabled if negative)."""
    if warmup > 0:
      for g in optimizer.param_groups:
        g['lr'] = lr * np.minimum(step / warmup, 1.0)
    if grad_clip >= 0:
      torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
    optimizer.step()

  return optimize_fn


def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):
  """Create a loss function for training with arbirary SDEs.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    train: `True` for training loss and `False` for evaluation loss.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires
      ad-hoc interpolation to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses
      according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper.
    eps: A `float` number. The smallest time step to sample from.

  Returns:
    A loss function.
  """
  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, batch):
    """Compute the loss function.
    Args:
      model: A score model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
    """
    score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
    t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
    z = torch.randn_like(batch)
    mean, std = sde.marginal_prob(batch, t)
    perturbed_data = mean + std[:, None, None, None] * z
    score = score_fn(perturbed_data, t)

    if not likelihood_weighting:
      losses = torch.square(score * std[:, None, None, None] + z)
      losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
    else:
      g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2
      losses = torch.square(score + z / std[:, None, None, None])
      losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2

    loss = torch.mean(losses)
    return loss

  return loss_fn


def get_smld_loss_fn(vesde, train, reduce_mean=False):
  """Legacy code to reproduce previous results on SMLD(NCSN). Not recommended for new work."""
  assert isinstance(vesde, VESDE), "SMLD training only works for VESDEs."

  # Previous SMLD models assume descending sigmas
  smld_sigma_array = torch.flip(vesde.discrete_sigmas, dims=(0,))
  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, batch):
    model_fn = mutils.get_model_fn(model, train=train)
    labels = torch.randint(0, vesde.N, (batch.shape[0],), device=batch.device)
    sigmas = smld_sigma_array.to(batch.device)[labels]
    noise = torch.randn_like(batch) * sigmas[:, None, None, None]
    perturbed_data = noise + batch
    score = model_fn(perturbed_data, labels)
    target = -noise / (sigmas ** 2)[:, None, None, None]
    losses = torch.square(score - target)
    losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * sigmas ** 2
    loss = torch.mean(losses)
    return loss

  return loss_fn


def get_ddpm_loss_fn(vpsde, train, reduce_mean=True):
  """Legacy code to reproduce previous results on DDPM. Not recommended for new work."""
  assert isinstance(vpsde, VPSDE), "DDPM training only works for VPSDEs."

  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, batch):
    model_fn = mutils.get_model_fn(model, train=train)
    labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
    sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
    sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
    noise = torch.randn_like(batch)
    perturbed_data = sqrt_alphas_cumprod[labels, None, None, None] * batch + \
                     sqrt_1m_alphas_cumprod[labels, None, None, None] * noise
    score = model_fn(perturbed_data, labels)
    losses = torch.square(score - noise)
    losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)
    loss = torch.mean(losses)
    return loss

  return loss_fn


def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False):
  """Create a one-step training/evaluation function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    optimize_fn: An optimization function.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses according to
      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.

  Returns:
    A one-step function for training or evaluation.
  """
  if continuous:
    loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,
                              continuous=True, likelihood_weighting=likelihood_weighting)
  else:
    assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
    if isinstance(sde, VESDE):
      loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean)
    elif isinstance(sde, VPSDE):
      loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean)
    else:
      raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")

  def step_fn(state, batch):
    """Running one step of training or evaluation.

    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
    for faster execution.

    Args:
      state: A dictionary of training information, containing the score model, optimizer,
       EMA status, and number of optimization steps.
      batch: A mini-batch of training/evaluation data.

    Returns:
      loss: The average loss value of this state.
    """
    model = state['model']
    if train:
      optimizer = state['optimizer']
      optimizer.zero_grad()
      loss = loss_fn(model, batch)
      loss.backward()
      optimize_fn(optimizer, model.parameters(), step=state['step'])
      state['step'] += 1
      state['ema'].update(model.parameters())
    else:
      with torch.no_grad():
        ema = state['ema']
        ema.store(model.parameters())
        ema.copy_to(model.parameters())
        loss = loss_fn(model, batch)
        ema.restore(model.parameters())

    return loss

  return step_fn



def get_step_fn_regression(train, config, mask=None, loss_fn=None, optimize_fn=None):

  def step_fn(state, batch):
    model = state['model']
    if train:
      optimizer = state['optimizer']
      optimizer.zero_grad()

      # fft
      kspace = fft2(batch)

      # sample mask
      acc_factor = np.random.choice(config.training.acc_factor)
      mask = get_mask(batch, config.data.image_size, config.training.batch_size,
                      type=config.training.mask_type,
                      acc_factor=acc_factor,
                      fix=True)

      # undersampling
      under_kspace = kspace * mask
      under_img = torch.abs(ifft2(under_kspace))

      est_img = model(under_img)
      loss = loss_fn(est_img, batch)
      loss.backward()
      optimize_fn(optimizer, model.parameters(), step=state['step'])
      state['step'] += 1
      state['ema'].update(model.parameters())
      return loss
    else:
      with torch.no_grad():
        ema = state['ema']
        ema.store(model.parameters())
        ema.copy_to(model.parameters())
        # fft
        kspace = fft2(batch)

        # sample mask
        mask = get_mask(batch, config.data.image_size, config.traiing.batch_size,
                        type=config.training.mask_type,
                        acc_factor=config.training.acc_factor)

        # undersampling
        under_kspace = kspace * mask
        under_img = torch.real(ifft2(under_kspace))

        est_img = model(under_img)
        ema.restore(model.parameters())
        return est_img
  return step_fn


================================================
FILE: main.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES O
Download .txt
gitextract_en_fo40u/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── default_celeba_configs.py
│   ├── default_cifar10_configs.py
│   ├── default_complex_configs.py
│   ├── default_lsun_configs.py
│   ├── subvp/
│   │   ├── cifar10_ddpm_continuous.py
│   │   ├── cifar10_ddpmpp_continuous.py
│   │   ├── cifar10_ddpmpp_deep_continuous.py
│   │   ├── cifar10_ncsnpp_continuous.py
│   │   └── cifar10_ncsnpp_deep_continuous.py
│   ├── ve/
│   │   ├── AAPM_128_ncsnpp_continuous.py
│   │   ├── AAPM_256_ncsnpp_continuous.py
│   │   ├── Object5_fast.py
│   │   ├── Object5_ncsnpp_continuous.py
│   │   ├── bedroom_ncsnpp_continuous.py
│   │   ├── celeba_ncsnpp.py
│   │   ├── celebahq_256_ncsnpp_continuous.py
│   │   ├── celebahq_ncsnpp_continuous.py
│   │   ├── church_ncsnpp_continuous.py
│   │   ├── cifar10_ddpm.py
│   │   ├── cifar10_ncsnpp.py
│   │   ├── cifar10_ncsnpp_continuous.py
│   │   ├── cifar10_ncsnpp_deep_continuous.py
│   │   ├── fastmri_knee_128_ncsnpp_continuous.py
│   │   ├── fastmri_knee_256_ncsnpp_continuous.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous_complex.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous_complex_magpha.py
│   │   ├── fastmri_knee_320_ncsnpp_continuous_multi.py
│   │   ├── ffhq_256_ncsnpp_continuous.py
│   │   ├── ffhq_ncsnpp_continuous.py
│   │   ├── ncsn/
│   │   │   ├── celeba.py
│   │   │   ├── celeba_124.py
│   │   │   ├── celeba_1245.py
│   │   │   ├── celeba_5.py
│   │   │   ├── cifar10.py
│   │   │   ├── cifar10_124.py
│   │   │   ├── cifar10_1245.py
│   │   │   └── cifar10_5.py
│   │   └── ncsnv2/
│   │       ├── bedroom.py
│   │       ├── celeba.py
│   │       └── cifar10.py
│   └── vp/
│       ├── cifar10_ddpmpp.py
│       ├── cifar10_ddpmpp_continuous.py
│       ├── cifar10_ddpmpp_deep_continuous.py
│       ├── cifar10_ncsnpp.py
│       ├── cifar10_ncsnpp_continuous.py
│       ├── cifar10_ncsnpp_deep_continuous.py
│       └── ddpm/
│           ├── bedroom.py
│           ├── celebahq.py
│           ├── church.py
│           ├── cifar10.py
│           ├── cifar10_continuous.py
│           └── cifar10_unconditional.py
├── controllable_generation_TV.py
├── datasets.py
├── environment.yml
├── evaluation.py
├── fastmri_utils.py
├── inverse_problem_solver_AAPM_3d_total.py
├── inverse_problem_solver_BRATS_MRI_3d_total.py
├── likelihood.py
├── losses.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── ddpm.py
│   ├── ema.py
│   ├── layers.py
│   ├── layerspp.py
│   ├── ncsnpp.py
│   ├── ncsnv2.py
│   ├── normalization.py
│   ├── unet.py
│   ├── up_or_down_sampling.py
│   └── utils.py
├── op/
│   ├── __init__.py
│   ├── fused_act.py
│   ├── fused_bias_act.cpp
│   ├── fused_bias_act_kernel.cu
│   ├── upfirdn2d.cpp
│   ├── upfirdn2d.py
│   └── upfirdn2d_kernel.cu
├── physics/
│   ├── ct.py
│   ├── inpainting.py
│   └── radon/
│       ├── __init__.py
│       ├── filters.py
│       ├── radon.py
│       ├── stackgram.py
│       └── utils.py
├── run_lib.py
├── sampling.py
├── sde_lib.py
├── test/
│   └── test_TV.py
├── train_AAPM256.sh
└── utils.py
Download .txt
SYMBOL INDEX (494 symbols across 85 files)

FILE: configs/default_celeba_configs.py
  function get_default_configs (line 5) | def get_default_configs():

FILE: configs/default_cifar10_configs.py
  function get_default_configs (line 5) | def get_default_configs():

FILE: configs/default_complex_configs.py
  function get_default_configs (line 5) | def get_default_configs():

FILE: configs/default_lsun_configs.py
  function get_default_configs (line 5) | def get_default_configs():

FILE: configs/subvp/cifar10_ddpm_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/subvp/cifar10_ddpmpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/subvp/cifar10_ddpmpp_deep_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/subvp/cifar10_ncsnpp_continuous.py
  function get_config (line 21) | def get_config():

FILE: configs/subvp/cifar10_ncsnpp_deep_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/AAPM_128_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/AAPM_256_ncsnpp_continuous.py
  function get_config (line 4) | def get_config():

FILE: configs/ve/Object5_fast.py
  function get_config (line 4) | def get_config():

FILE: configs/ve/Object5_ncsnpp_continuous.py
  function get_config (line 4) | def get_config():

FILE: configs/ve/bedroom_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/celeba_ncsnpp.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/celebahq_256_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/celebahq_ncsnpp_continuous.py
  function get_config (line 23) | def get_config():

FILE: configs/ve/church_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/cifar10_ddpm.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/cifar10_ncsnpp.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/cifar10_ncsnpp_continuous.py
  function get_config (line 21) | def get_config():

FILE: configs/ve/cifar10_ncsnpp_deep_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/fastmri_knee_128_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/fastmri_knee_256_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_complex.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_complex_magpha.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/fastmri_knee_320_ncsnpp_continuous_multi.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ffhq_256_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ffhq_ncsnpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/celeba.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/celeba_124.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/celeba_1245.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/celeba_5.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/cifar10.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/cifar10_124.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/cifar10_1245.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsn/cifar10_5.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsnv2/bedroom.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsnv2/celeba.py
  function get_config (line 22) | def get_config():

FILE: configs/ve/ncsnv2/cifar10.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/cifar10_ddpmpp.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/cifar10_ddpmpp_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/cifar10_ddpmpp_deep_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/cifar10_ncsnpp.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/cifar10_ncsnpp_continuous.py
  function get_config (line 21) | def get_config():

FILE: configs/vp/cifar10_ncsnpp_deep_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/ddpm/bedroom.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/ddpm/celebahq.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/ddpm/church.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/ddpm/cifar10.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/ddpm/cifar10_continuous.py
  function get_config (line 22) | def get_config():

FILE: configs/vp/ddpm/cifar10_unconditional.py
  function get_config (line 22) | def get_config():

FILE: controllable_generation_TV.py
  class lambda_schedule (line 18) | class lambda_schedule:
    method __init__ (line 19) | def __init__(self, total=2000):
    method get_current_lambda (line 22) | def get_current_lambda(self, i):
  class lambda_schedule_linear (line 24) | class lambda_schedule_linear(lambda_schedule):
    method __init__ (line 25) | def __init__(self, start_lamb=1.0, end_lamb=0.0):
    method get_current_lambda (line 30) | def get_current_lambda(self, i):
  class lambda_schedule_const (line 34) | class lambda_schedule_const(lambda_schedule):
    method __init__ (line 35) | def __init__(self, lamb=1.0):
    method get_current_lambda (line 39) | def get_current_lambda(self, i):
  function _Dz (line 43) | def _Dz(x): # Batch direction
  function _DzT (line 50) | def _DzT(x): # Batch direction
  function _Dx (line 62) | def _Dx(x):  # Batch direction
  function _DxT (line 69) | def _DxT(x):  # Batch direction
  function _Dy (line 80) | def _Dy(x):  # Batch direction
  function _DyT (line 87) | def _DyT(x):  # Batch direction
  function get_pc_radon_ADMM_TV (line 98) | def get_pc_radon_ADMM_TV(sde, predictor, corrector, inverse_scaler, snr,
  function get_pc_radon_ADMM_TV_vol (line 226) | def get_pc_radon_ADMM_TV_vol(sde, predictor, corrector, inverse_scaler, ...
  function get_pc_radon_ADMM_TV_all_vol (line 353) | def get_pc_radon_ADMM_TV_all_vol(sde, predictor, corrector, inverse_scal...
  function get_ADMM_TV (line 496) | def get_ADMM_TV(eps=1e-5, radon=None, save_progress=False, save_root=None,
  function get_ADMM_TV_isotropic (line 579) | def get_ADMM_TV_isotropic(eps=1e-5, radon=None, save_progress=False, sav...
  function prox_l21 (line 680) | def prox_l21(src, lamb, dim):
  function shrink (line 691) | def shrink(weight_src, lamb):
  function get_pc_radon_ADMM_TV_mri (line 695) | def get_pc_radon_ADMM_TV_mri(sde, predictor, corrector, inverse_scaler, ...

FILE: datasets.py
  function get_data_scaler (line 22) | def get_data_scaler(config):
  function get_data_inverse_scaler (line 31) | def get_data_inverse_scaler(config):
  function crop_resize (line 40) | def crop_resize(image, resolution):
  function resize_small (line 54) | def resize_small(image, resolution):
  function central_crop (line 63) | def central_crop(image, size):
  function get_dataset (line 70) | def get_dataset(config, uniform_dequantization=False, evaluation=False):
  class fastmri_knee (line 203) | class fastmri_knee(Dataset):
    method __init__ (line 205) | def __init__(self, root, is_complex=False):
    method __len__ (line 210) | def __len__(self):
    method __getitem__ (line 213) | def __getitem__(self, idx):
  class AAPM (line 223) | class AAPM(Dataset):
    method __init__ (line 224) | def __init__(self, root, sort):
    method __len__ (line 231) | def __len__(self):
    method __getitem__ (line 234) | def __getitem__(self, idx):
  class Object5 (line 241) | class Object5(Dataset):
    method __init__ (line 242) | def __init__(self, root, slice, fast=False):
    method __len__ (line 263) | def __len__(self):
    method __getitem__ (line 266) | def __getitem__(self, idx):
  class fastmri_knee_infer (line 274) | class fastmri_knee_infer(Dataset):
    method __init__ (line 276) | def __init__(self, root, sort=True, is_complex=False):
    method __len__ (line 283) | def __len__(self):
    method __getitem__ (line 286) | def __getitem__(self, idx):
  class fastmri_knee_magpha (line 296) | class fastmri_knee_magpha(Dataset):
    method __init__ (line 298) | def __init__(self, root):
    method __len__ (line 302) | def __len__(self):
    method __getitem__ (line 305) | def __getitem__(self, idx):
  class fastmri_knee_magpha_infer (line 311) | class fastmri_knee_magpha_infer(Dataset):
    method __init__ (line 313) | def __init__(self, root, sort=True):
    method __len__ (line 319) | def __len__(self):
    method __getitem__ (line 322) | def __getitem__(self, idx):
  function create_dataloader (line 328) | def create_dataloader(configs, evaluation=False, sort=True):
  function create_dataloader_regression (line 372) | def create_dataloader_regression(configs, evaluation=False):

FILE: evaluation.py
  function get_inception_model (line 31) | def get_inception_model(inceptionv3=False):
  function load_dataset_stats (line 39) | def load_dataset_stats(config):
  function classifier_fn_from_tfhub (line 55) | def classifier_fn_from_tfhub(output_fields, inception_model,
  function run_inception_jit (line 86) | def run_inception_jit(inputs,
  function run_inception_distributed (line 104) | def run_inception_distributed(input_tensor,

FILE: fastmri_utils.py
  function fft2c_old (line 16) | def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
  function ifft2c_old (line 41) | def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
  function fft2c_new (line 67) | def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
  function ifft2c_new (line 92) | def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
  function roll_one_dim (line 120) | def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
  function roll (line 140) | def roll(
  function fftshift (line 163) | def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch....
  function ifftshift (line 186) | def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch...

FILE: likelihood.py
  function get_div_fn (line 26) | def get_div_fn(fn):
  function get_likelihood_fn (line 40) | def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',

FILE: losses.py
  function get_optimizer (line 28) | def get_optimizer(config, params):
  function optimization_manager (line 40) | def optimization_manager(config):
  function get_sde_loss_fn (line 57) | def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likel...
  function get_smld_loss_fn (line 105) | def get_smld_loss_fn(vesde, train, reduce_mean=False):
  function get_ddpm_loss_fn (line 129) | def get_ddpm_loss_fn(vpsde, train, reduce_mean=True):
  function get_step_fn (line 152) | def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continu...
  function get_step_fn_regression (line 215) | def get_step_fn_regression(train, config, mask=None, loss_fn=None, optim...

FILE: main.py
  function main (line 38) | def main(argv):

FILE: models/ddpm.py
  class DDPM (line 40) | class DDPM(nn.Module):
    method __init__ (line 41) | def __init__(self, config):
    method forward (line 110) | def forward(self, x, labels):

FILE: models/ema.py
  class ExponentialMovingAverage (line 10) | class ExponentialMovingAverage:
    method __init__ (line 15) | def __init__(self, parameters, decay, use_num_updates=True):
    method update (line 32) | def update(self, parameters):
    method copy_to (line 53) | def copy_to(self, parameters):
    method store (line 66) | def store(self, parameters):
    method restore (line 76) | def restore(self, parameters):
    method state_dict (line 91) | def state_dict(self):
    method load_state_dict (line 95) | def load_state_dict(self, state_dict):

FILE: models/layers.py
  class SiLU (line 29) | class SiLU(nn.Module):
    method forward (line 30) | def forward(self, x):
  function get_act (line 33) | def get_act(config):
  function ncsn_conv1x1 (line 48) | def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1,...
  function variance_scaling (line 58) | def variance_scaling(scale, mode, distribution,
  function default_init (line 92) | def default_init(scale=1.):
  class Dense (line 98) | class Dense(nn.Module):
    method __init__ (line 100) | def __init__(self):
  function ddpm_conv1x1 (line 104) | def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=...
  function ncsn_conv3x3 (line 112) | def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1,...
  function ddpm_conv3x3 (line 122) | def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1,...
  class CRPBlock (line 137) | class CRPBlock(nn.Module):
    method __init__ (line 138) | def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
    method forward (line 151) | def forward(self, x):
  class CondCRPBlock (line 161) | class CondCRPBlock(nn.Module):
    method __init__ (line 162) | def __init__(self, features, n_stages, num_classes, normalizer, act=nn...
    method forward (line 175) | def forward(self, x, y):
  class RCUBlock (line 187) | class RCUBlock(nn.Module):
    method __init__ (line 188) | def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
    method forward (line 200) | def forward(self, x):
  class CondRCUBlock (line 211) | class CondRCUBlock(nn.Module):
    method __init__ (line 212) | def __init__(self, features, n_blocks, n_stages, num_classes, normaliz...
    method forward (line 226) | def forward(self, x, y):
  class MSFBlock (line 238) | class MSFBlock(nn.Module):
    method __init__ (line 239) | def __init__(self, in_planes, features):
    method forward (line 248) | def forward(self, xs, shape):
  class CondMSFBlock (line 257) | class CondMSFBlock(nn.Module):
    method __init__ (line 258) | def __init__(self, in_planes, features, num_classes, normalizer):
    method forward (line 271) | def forward(self, xs, y, shape):
  class RefineBlock (line 281) | class RefineBlock(nn.Module):
    method __init__ (line 282) | def __init__(self, in_planes, features, act=nn.ReLU(), start=False, en...
    method forward (line 299) | def forward(self, xs, output_shape):
  class CondRefineBlock (line 317) | class CondRefineBlock(nn.Module):
    method __init__ (line 318) | def __init__(self, in_planes, features, num_classes, normalizer, act=n...
    method forward (line 337) | def forward(self, xs, y, output_shape):
  class ConvMeanPool (line 355) | class ConvMeanPool(nn.Module):
    method __init__ (line 356) | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, ...
    method forward (line 369) | def forward(self, inputs):
  class MeanPoolConv (line 376) | class MeanPoolConv(nn.Module):
    method __init__ (line 377) | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
    method forward (line 381) | def forward(self, inputs):
  class UpsampleConv (line 388) | class UpsampleConv(nn.Module):
    method __init__ (line 389) | def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
    method forward (line 394) | def forward(self, inputs):
  class ConditionalResidualBlock (line 401) | class ConditionalResidualBlock(nn.Module):
    method __init__ (line 402) | def __init__(self, input_dim, output_dim, num_classes, resample=1, act...
    method forward (line 441) | def forward(self, x, y):
  class ResidualBlock (line 457) | class ResidualBlock(nn.Module):
    method __init__ (line 458) | def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
    method forward (line 498) | def forward(self, x):
  function get_timestep_embedding (line 519) | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
  function _einsum (line 533) | def _einsum(a, b, c, x, y):
  function contract_inner (line 538) | def contract_inner(x, y):
  class NIN (line 547) | class NIN(nn.Module):
    method __init__ (line 548) | def __init__(self, in_dim, num_units, init_scale=0.1):
    method forward (line 553) | def forward(self, x):
  class AttnBlock (line 559) | class AttnBlock(nn.Module):
    method __init__ (line 561) | def __init__(self, channels):
    method forward (line 569) | def forward(self, x):
  class Upsample (line 585) | class Upsample(nn.Module):
    method __init__ (line 586) | def __init__(self, channels, with_conv=False):
    method forward (line 592) | def forward(self, x):
  class Downsample (line 600) | class Downsample(nn.Module):
    method __init__ (line 601) | def __init__(self, channels, with_conv=False):
    method forward (line 607) | def forward(self, x):
  class ResnetBlockDDPM (line 620) | class ResnetBlockDDPM(nn.Module):
    method __init__ (line 622) | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortc...
    method forward (line 646) | def forward(self, x, temb=None):

FILE: models/layerspp.py
  class GaussianFourierProjection (line 32) | class GaussianFourierProjection(nn.Module):
    method __init__ (line 35) | def __init__(self, embedding_size=256, scale=1.0):
    method forward (line 39) | def forward(self, x):
  class Combine (line 44) | class Combine(nn.Module):
    method __init__ (line 47) | def __init__(self, dim1, dim2, method='cat'):
    method forward (line 52) | def forward(self, x, y):
  class AttnBlockpp (line 62) | class AttnBlockpp(nn.Module):
    method __init__ (line 65) | def __init__(self, channels, skip_rescale=False, init_scale=0.):
    method forward (line 75) | def forward(self, x):
  class Upsample (line 94) | class Upsample(nn.Module):
    method __init__ (line 95) | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
    method forward (line 114) | def forward(self, x):
  class Downsample (line 129) | class Downsample(nn.Module):
    method __init__ (line 130) | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
    method forward (line 149) | def forward(self, x):
  class ResnetBlockDDPMpp (line 166) | class ResnetBlockDDPMpp(nn.Module):
    method __init__ (line 169) | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortc...
    method forward (line 193) | def forward(self, x, temb=None):
  class ResnetBlockBigGANpp (line 212) | class ResnetBlockBigGANpp(nn.Module):
    method __init__ (line 213) | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, d...
    method forward (line 242) | def forward(self, x, temb=None):

FILE: models/ncsnpp.py
  class NCSNpp (line 35) | class NCSNpp(nn.Module):
    method __init__ (line 38) | def __init__(self, config):
    method forward (line 232) | def forward(self, x, time_cond):

FILE: models/ncsnv2.py
  function get_network (line 31) | def get_network(config):
  class NCSNv2 (line 44) | class NCSNv2(nn.Module):
    method __init__ (line 45) | def __init__(self, config):
    method _compute_cond_module (line 101) | def _compute_cond_module(self, module, x):
    method forward (line 106) | def forward(self, x, y):
  class NCSN (line 136) | class NCSN(nn.Module):
    method __init__ (line 137) | def __init__(self, config):
    method _compute_cond_module (line 191) | def _compute_cond_module(self, module, x, y):
    method forward (line 196) | def forward(self, x, y):
  class NCSNv2_128 (line 222) | class NCSNv2_128(nn.Module):
    method __init__ (line 224) | def __init__(self, config):
    method _compute_cond_module (line 279) | def _compute_cond_module(self, module, x):
    method forward (line 284) | def forward(self, x, y):
  class NCSNv2_256 (line 316) | class NCSNv2_256(nn.Module):
    method __init__ (line 318) | def __init__(self, config):
    method _compute_cond_module (line 381) | def _compute_cond_module(self, module, x):
    method forward (line 386) | def forward(self, x, y):

FILE: models/normalization.py
  function get_normalization (line 22) | def get_normalization(config, conditional=False):
  class ConditionalBatchNorm2d (line 43) | class ConditionalBatchNorm2d(nn.Module):
    method __init__ (line 44) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 57) | def forward(self, x, y):
  class ConditionalInstanceNorm2d (line 68) | class ConditionalInstanceNorm2d(nn.Module):
    method __init__ (line 69) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 82) | def forward(self, x, y):
  class ConditionalVarianceNorm2d (line 93) | class ConditionalVarianceNorm2d(nn.Module):
    method __init__ (line 94) | def __init__(self, num_features, num_classes, bias=False):
    method forward (line 101) | def forward(self, x, y):
  class VarianceNorm2d (line 110) | class VarianceNorm2d(nn.Module):
    method __init__ (line 111) | def __init__(self, num_features, bias=False):
    method forward (line 118) | def forward(self, x):
  class ConditionalNoneNorm2d (line 126) | class ConditionalNoneNorm2d(nn.Module):
    method __init__ (line 127) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 139) | def forward(self, x, y):
  class NoneNorm2d (line 149) | class NoneNorm2d(nn.Module):
    method __init__ (line 150) | def __init__(self, num_features, bias=True):
    method forward (line 153) | def forward(self, x):
  class InstanceNorm2dPlus (line 157) | class InstanceNorm2dPlus(nn.Module):
    method __init__ (line 158) | def __init__(self, num_features, bias=True):
    method forward (line 170) | def forward(self, x):
  class ConditionalInstanceNorm2dPlus (line 186) | class ConditionalInstanceNorm2dPlus(nn.Module):
    method __init__ (line 187) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 200) | def forward(self, x, y):

FILE: models/unet.py
  class ConvBlock (line 7) | class ConvBlock(nn.Module):
    method __init__ (line 13) | def __init__(self, in_chans, out_chans, stride=2):
    method forward (line 35) | def forward(self, tensor):
    method __repr__ (line 44) | def __repr__(self):
  class Unet (line 49) | class Unet(nn.Module):
    method __init__ (line 50) | def __init__(self, in_chans=1, out_chans=1, chans=64, num_pool_layers=...
    method forward (line 81) | def forward(self, tensor):

FILE: models/up_or_down_sampling.py
  function get_weight (line 14) | def get_weight(module,
  class Conv2d (line 23) | class Conv2d(nn.Module):
    method __init__ (line 26) | def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
    method forward (line 45) | def forward(self, x):
  function naive_upsample_2d (line 59) | def naive_upsample_2d(x, factor=2):
  function naive_downsample_2d (line 66) | def naive_downsample_2d(x, factor=2):
  function upsample_conv_2d (line 72) | def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
  function conv_downsample_2d (line 144) | def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
  function _setup_kernel (line 181) | def _setup_kernel(k):
  function _shape (line 191) | def _shape(x, dim):
  function upsample_2d (line 195) | def upsample_2d(x, k=None, factor=2, gain=1):
  function downsample_2d (line 227) | def downsample_2d(x, k=None, factor=2, gain=1):

FILE: models/utils.py
  function register_model (line 27) | def register_model(cls=None, *, name=None):
  function get_model (line 46) | def get_model(name):
  function get_sigmas (line 50) | def get_sigmas(config):
  function get_ddpm_params (line 63) | def get_ddpm_params(config):
  function create_model (line 88) | def create_model(config):
  function get_model_fn (line 97) | def get_model_fn(model, train=False):
  function get_score_fn (line 129) | def get_score_fn(sde, model, train=False, continuous=False):
  function to_flattened_numpy (line 181) | def to_flattened_numpy(x):
  function from_flattened_numpy (line 186) | def from_flattened_numpy(x, shape):

FILE: op/fused_act.py
  class FusedLeakyReLUFunctionBackward (line 20) | class FusedLeakyReLUFunctionBackward(Function):
    method forward (line 22) | def forward(ctx, grad_output, out, negative_slope, scale):
    method backward (line 43) | def backward(ctx, gradgrad_input, gradgrad_bias):
  class FusedLeakyReLUFunction (line 52) | class FusedLeakyReLUFunction(Function):
    method forward (line 54) | def forward(ctx, input, bias, negative_slope, scale):
    method backward (line 64) | def backward(ctx, grad_output):
  class FusedLeakyReLU (line 74) | class FusedLeakyReLU(nn.Module):
    method __init__ (line 75) | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
    method forward (line 82) | def forward(self, input):
  function fused_leaky_relu (line 86) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):

FILE: op/fused_bias_act.cpp
  function fused_bias_act (line 11) | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Te...
  function PYBIND11_MODULE (line 19) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: op/upfirdn2d.cpp
  function upfirdn2d (line 12) | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor&...
  function PYBIND11_MODULE (line 21) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: op/upfirdn2d.py
  class UpFirDn2dBackward (line 19) | class UpFirDn2dBackward(Function):
    method forward (line 21) | def forward(
    method backward (line 63) | def backward(ctx, gradgrad_input):
  class UpFirDn2d (line 88) | class UpFirDn2d(Function):
    method forward (line 90) | def forward(ctx, input, kernel, up, down, pad):
    method backward (line 127) | def backward(ctx, grad_output):
  function upfirdn2d (line 145) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
  function upfirdn2d_native (line 159) | def upfirdn2d_native(

FILE: physics/ct.py
  class CT (line 5) | class CT():
    method __init__ (line 6) | def __init__(self, img_width, radon_view, uniform=True, circle=False, ...
    method A (line 20) | def A(self, x):
    method A_all (line 23) | def A_all(self, x):
    method A_all_dagger (line 26) | def A_all_dagger(self, x):
    method A_dagger (line 29) | def A_dagger(self, y):
    method AT (line 32) | def AT(self, y):
  class CT_LA (line 36) | class CT_LA():
    method __init__ (line 40) | def __init__(self, img_width, radon_view, uniform=True, circle=False, ...
    method A (line 49) | def A(self, x):
    method A_dagger (line 52) | def A_dagger(self, y):
    method AT (line 55) | def AT(self, y):

FILE: physics/inpainting.py
  class Inpainting (line 4) | class Inpainting():
    method __init__ (line 5) | def __init__(self, img_heigth=512, img_width=512, mode='random', mask_...
    method A (line 14) | def A(self, x):
    method A_dagger (line 17) | def A_dagger(self, x):

FILE: physics/radon/filters.py
  class AbstractFilter (line 9) | class AbstractFilter(nn.Module):
    method __init__ (line 10) | def __init__(self):
    method forward (line 13) | def forward(self, x):
    method _get_fourier_filter (line 24) | def _get_fourier_filter(self, size):
    method create_filter (line 38) | def create_filter(self, f):
  class RampFilter (line 41) | class RampFilter(AbstractFilter):
    method __init__ (line 42) | def __init__(self):
    method create_filter (line 45) | def create_filter(self, f):
  class HannFilter (line 48) | class HannFilter(AbstractFilter):
    method __init__ (line 49) | def __init__(self):
    method create_filter (line 52) | def create_filter(self, f):
  class LearnableFilter (line 57) | class LearnableFilter(AbstractFilter):
    method __init__ (line 58) | def __init__(self, filter_size):
    method forward (line 62) | def forward(self, x):

FILE: physics/radon/radon.py
  class Radon (line 11) | class Radon(nn.Module):
    method __init__ (line 12) | def __init__(self, in_size=None, theta=None, circle=True, dtype=torch....
    method forward (line 23) | def forward(self, x):
    method _create_grids (line 48) | def _create_grids(self, angles, grid_size, circle):
  class IRadon (line 62) | class IRadon(nn.Module):
    method __init__ (line 63) | def __init__(self, in_size=None, theta=None, circle=True,
    method forward (line 77) | def forward(self, x):
    method _create_yxgrid (line 118) | def _create_yxgrid(self, in_size, circle):
    method _XYtoT (line 124) | def _XYtoT(self, theta):
    method _create_grids (line 128) | def _create_grids(self, angles, grid_size, circle):

FILE: physics/radon/stackgram.py
  class Stackgram (line 9) | class Stackgram(nn.Module):
    method __init__ (line 10) | def __init__(self, out_size, theta=None, circle=True, mode='nearest', ...
    method forward (line 22) | def forward(self, x):
    method _create_grids (line 33) | def _create_grids(self, angles, grid_size):
  class IStackgram (line 41) | class IStackgram(nn.Module):
    method __init__ (line 42) | def __init__(self, out_size, theta=None, circle=True, mode='bilinear',...
    method forward (line 54) | def forward(self, x):
    method _create_grids (line 65) | def _create_grids(self, angles, grid_size):

FILE: physics/radon/utils.py
  function fftfreq (line 17) | def fftfreq(n):
  function deg2rad (line 27) | def deg2rad(x):

FILE: run_lib.py
  function train (line 47) | def train(config, workdir):
  function evaluate (line 168) | def evaluate(config,

FILE: sampling.py
  function register_predictor (line 35) | def register_predictor(cls=None, *, name=None):
  function register_corrector (line 54) | def register_corrector(cls=None, *, name=None):
  function get_predictor (line 73) | def get_predictor(name):
  function get_corrector (line 77) | def get_corrector(name):
  function get_sampling_fn (line 81) | def get_sampling_fn(config, sde, shape, inverse_scaler, eps):
  class Predictor (line 127) | class Predictor(abc.ABC):
    method __init__ (line 130) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 138) | def update_fn(self, x, t):
  class Corrector (line 152) | class Corrector(abc.ABC):
    method __init__ (line 155) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 163) | def update_fn(self, x, t):
  class EulerMaruyamaPredictor (line 178) | class EulerMaruyamaPredictor(Predictor):
    method __init__ (line 179) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 182) | def update_fn(self, x, t):
  class ReverseDiffusionPredictor (line 192) | class ReverseDiffusionPredictor(Predictor):
    method __init__ (line 193) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 196) | def update_fn(self, x, t):
  class AncestralSamplingPredictor (line 205) | class AncestralSamplingPredictor(Predictor):
    method __init__ (line 208) | def __init__(self, sde, score_fn, probability_flow=False):
    method vesde_update_fn (line 214) | def vesde_update_fn(self, x, t):
    method vpsde_update_fn (line 226) | def vpsde_update_fn(self, x, t):
    method update_fn (line 236) | def update_fn(self, x, t):
  class NonePredictor (line 244) | class NonePredictor(Predictor):
    method __init__ (line 247) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 250) | def update_fn(self, x, t):
  class LangevinCorrector (line 255) | class LangevinCorrector(Corrector):
    method __init__ (line 256) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 263) | def update_fn(self, x, t):
  class LangevinCorrectorCS (line 286) | class LangevinCorrectorCS(Corrector):
    method __init__ (line 288) | def __init__(self, sde, score_fn, snr, n_steps, sigma_min, sigma_max, N):
    method update_fn (line 295) | def update_fn(self, x, t, y, discrete_sigmas):
  class AnnealedLangevinDynamics (line 329) | class AnnealedLangevinDynamics(Corrector):
    method __init__ (line 335) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 342) | def update_fn(self, x, t):
  class NoneCorrector (line 366) | class NoneCorrector(Corrector):
    method __init__ (line 369) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 372) | def update_fn(self, x, t):
  function shared_predictor_update_fn (line 376) | def shared_predictor_update_fn(x, t, sde, model, predictor, probability_...
  function shared_corrector_update_fn (line 387) | def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, ...
  function get_pc_sampler (line 406) | def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
  function get_ode_sampler (line 473) | def get_ode_sampler(sde, shape, inverse_scaler,

FILE: sde_lib.py
  class SDE (line 7) | class SDE(abc.ABC):
    method __init__ (line 10) | def __init__(self, N):
    method T (line 21) | def T(self):
    method sde (line 26) | def sde(self, x, t):
    method marginal_prob (line 30) | def marginal_prob(self, x, t):
    method prior_sampling (line 35) | def prior_sampling(self, shape):
    method prior_logp (line 40) | def prior_logp(self, z):
    method discretize (line 52) | def discretize(self, x, t):
    method reverse (line 71) | def reverse(self, score_fn, probability_flow=False):
  class VPSDE (line 112) | class VPSDE(SDE):
    method __init__ (line 113) | def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    method T (line 132) | def T(self):
    method sde (line 135) | def sde(self, x, t):
    method marginal_prob (line 141) | def marginal_prob(self, x, t):
    method prior_sampling (line 147) | def prior_sampling(self, shape):
    method prior_logp (line 150) | def prior_logp(self, z):
    method discretize (line 156) | def discretize(self, x, t):
  class subVPSDE (line 167) | class subVPSDE(SDE):
    method __init__ (line 168) | def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    method T (line 182) | def T(self):
    method sde (line 185) | def sde(self, x, t):
    method marginal_prob (line 192) | def marginal_prob(self, x, t):
    method prior_sampling (line 198) | def prior_sampling(self, shape):
    method prior_logp (line 201) | def prior_logp(self, z):
  class VESDE (line 207) | class VESDE(SDE):
    method __init__ (line 208) | def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):
    method T (line 223) | def T(self):
    method sde (line 226) | def sde(self, x, t):
    method marginal_prob (line 233) | def marginal_prob(self, x, t):
    method prior_sampling (line 238) | def prior_sampling(self, shape):
    method prior_logp (line 241) | def prior_logp(self, z):
    method discretize (line 246) | def discretize(self, x, t):

FILE: test/test_TV.py
  function test_adjoint (line 22) | def test_adjoint(A, AT):
  function test_prox_l21 (line 31) | def test_prox_l21():
  class Identity (line 54) | class Identity:
    method A (line 56) | def A(x):
    method AT (line 60) | def AT(y):
  function test_ADMM_TV_isotropic (line 63) | def test_ADMM_TV_isotropic():

FILE: utils.py
  function clear_color (line 16) | def clear_color(x):
  function clear (line 20) | def clear(x, normalize=True):
  function restore_checkpoint (line 27) | def restore_checkpoint(ckpt_dir, state, device, skip_sigma=False, skip_o...
  function save_checkpoint (line 50) | def save_checkpoint(ckpt_dir, state, name="checkpoint.pth"):
  function fft2 (line 65) | def fft2(x):
  function ifft2 (line 70) | def ifft2(x):
  function fft2_m (line 75) | def fft2_m(x):
  function ifft2_m (line 80) | def ifft2_m(x):
  function crop_center (line 85) | def crop_center(img, cropx, cropy):
  function normalize (line 92) | def normalize(img):
  function normalize_np (line 98) | def normalize_np(img):
  function normalize_np_kwarg (line 105) | def normalize_np_kwarg(img, maxv=1.0, minv=0.0):
  function normalize_complex (line 112) | def normalize_complex(img):
  function batchfy (line 120) | def batchfy(tensor, batch_size):
  function img_wise_min_max (line 126) | def img_wise_min_max(img):
  function patient_wise_min_max (line 134) | def patient_wise_min_max(img):
  function create_sphere (line 153) | def create_sphere(cx, cy, cz, r, resolution=256):
  class lambda_schedule (line 170) | class lambda_schedule:
    method __init__ (line 171) | def __init__(self, total=2000):
    method get_current_lambda (line 174) | def get_current_lambda(self, i):
  class lambda_schedule_linear (line 178) | class lambda_schedule_linear(lambda_schedule):
    method __init__ (line 179) | def __init__(self, start_lamb=1.0, end_lamb=0.0):
    method get_current_lambda (line 184) | def get_current_lambda(self, i):
  class lambda_schedule_const (line 188) | class lambda_schedule_const(lambda_schedule):
    method __init__ (line 189) | def __init__(self, lamb=1.0):
    method get_current_lambda (line 193) | def get_current_lambda(self, i):
  function image_grid (line 199) | def image_grid(x, sz=32):
  function show_samples (line 208) | def show_samples(x, sz=32):
  function image_grid_gray (line 217) | def image_grid_gray(x, size=32):
  function show_samples_gray (line 224) | def show_samples_gray(x, size=32, save=False, save_fname=None):
  function get_mask (line 235) | def get_mask(img, size, batch_size, type='gaussian2d', acc_factor=8, cen...
  function kspace_to_nchw (line 317) | def kspace_to_nchw(tensor):
  function nchw_to_kspace (line 342) | def nchw_to_kspace(tensor):
  function root_sum_of_squares (line 361) | def root_sum_of_squares(data, dim=0):
  function save_data (line 373) | def save_data(fname, arr):
  function mean_std (line 378) | def mean_std(vals: list):
  function cal_metric (line 381) | def cal_metric(comp, label):
Condensed preview — 97 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (410K chars).
[
  {
    "path": ".gitignore",
    "chars": 255,
    "preview": "# Compiled source #\n###################\n*.o\n*.so\n*.pyc\n\n# Logs and temporaries #\n########################\n*.log\n*~\n.cove"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 2491,
    "preview": "# Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models (CVPR 2023)\n\nOfficial PyTorch implementation of **Di"
  },
  {
    "path": "configs/default_celeba_configs.py",
    "chars": 1983,
    "preview": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  co"
  },
  {
    "path": "configs/default_cifar10_configs.py",
    "chars": 2009,
    "preview": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  co"
  },
  {
    "path": "configs/default_complex_configs.py",
    "chars": 2136,
    "preview": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  co"
  },
  {
    "path": "configs/default_lsun_configs.py",
    "chars": 2161,
    "preview": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  co"
  },
  {
    "path": "configs/subvp/cifar10_ddpm_continuous.py",
    "chars": 1468,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/subvp/cifar10_ddpmpp_continuous.py",
    "chars": 1848,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/subvp/cifar10_ddpmpp_deep_continuous.py",
    "chars": 1876,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/subvp/cifar10_ncsnpp_continuous.py",
    "chars": 1843,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/subvp/cifar10_ncsnpp_deep_continuous.py",
    "chars": 1857,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/AAPM_128_ncsnpp_continuous.py",
    "chars": 1898,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/AAPM_256_ncsnpp_continuous.py",
    "chars": 1218,
    "preview": "from configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # tr"
  },
  {
    "path": "configs/ve/Object5_fast.py",
    "chars": 1283,
    "preview": "from configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # tr"
  },
  {
    "path": "configs/ve/Object5_ncsnpp_continuous.py",
    "chars": 1205,
    "preview": "from configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # tr"
  },
  {
    "path": "configs/ve/bedroom_ncsnpp_continuous.py",
    "chars": 1793,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/celeba_ncsnpp.py",
    "chars": 1753,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/celebahq_256_ncsnpp_continuous.py",
    "chars": 1910,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/celebahq_ncsnpp_continuous.py",
    "chars": 3184,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/church_ncsnpp_continuous.py",
    "chars": 1823,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/cifar10_ddpm.py",
    "chars": 1418,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/cifar10_ncsnpp.py",
    "chars": 1731,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/cifar10_ncsnpp_continuous.py",
    "chars": 1719,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/cifar10_ncsnpp_deep_continuous.py",
    "chars": 1749,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/fastmri_knee_128_ncsnpp_continuous.py",
    "chars": 1995,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/fastmri_knee_256_ncsnpp_continuous.py",
    "chars": 1850,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous.py",
    "chars": 1900,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous_complex.py",
    "chars": 1902,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous_complex_magpha.py",
    "chars": 1923,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous_multi.py",
    "chars": 1899,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ffhq_256_ncsnpp_continuous.py",
    "chars": 1948,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ffhq_ncsnpp_continuous.py",
    "chars": 3229,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/celeba.py",
    "chars": 1575,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/celeba_124.py",
    "chars": 1562,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/celeba_1245.py",
    "chars": 1564,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/celeba_5.py",
    "chars": 1593,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/cifar10.py",
    "chars": 1577,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/cifar10_124.py",
    "chars": 1563,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/cifar10_1245.py",
    "chars": 1715,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsn/cifar10_5.py",
    "chars": 1585,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsnv2/bedroom.py",
    "chars": 1716,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsnv2/celeba.py",
    "chars": 1702,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/ve/ncsnv2/cifar10.py",
    "chars": 1558,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/cifar10_ddpmpp.py",
    "chars": 1850,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/cifar10_ddpmpp_continuous.py",
    "chars": 1845,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/cifar10_ddpmpp_deep_continuous.py",
    "chars": 1873,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/cifar10_ncsnpp.py",
    "chars": 1813,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/cifar10_ncsnpp_continuous.py",
    "chars": 1836,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/cifar10_ncsnpp_deep_continuous.py",
    "chars": 1854,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/ddpm/bedroom.py",
    "chars": 1609,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/ddpm/celebahq.py",
    "chars": 1702,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/ddpm/church.py",
    "chars": 1622,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/ddpm/cifar10.py",
    "chars": 1500,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/ddpm/cifar10_continuous.py",
    "chars": 1461,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "configs/vp/ddpm/cifar10_unconditional.py",
    "chars": 1534,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "controllable_generation_TV.py",
    "chars": 30110,
    "preview": "import functools\r\nimport time\r\n\r\nimport torch\r\nfrom numpy.testing._private.utils import measure\r\nimport numpy as np\r\nimp"
  },
  {
    "path": "datasets.py",
    "chars": 13411,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "environment.yml",
    "chars": 277,
    "preview": "name: diffusion-mbir\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - python=3.8\n  - numpy\n  - matplotlib\n  - sc"
  },
  {
    "path": "evaluation.py",
    "chars": 4855,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "fastmri_utils.py",
    "chars": 6367,
    "preview": "\"\"\"\nCopyright (c) Facebook, Inc. and its affiliates.\nThis source code is licensed under the MIT license found in the\nLIC"
  },
  {
    "path": "inverse_problem_solver_AAPM_3d_total.py",
    "chars": 5726,
    "preview": "import torch\nfrom torch._C import device\nfrom losses import get_optimizer\nfrom models.ema import ExponentialMovingAverag"
  },
  {
    "path": "inverse_problem_solver_BRATS_MRI_3d_total.py",
    "chars": 6110,
    "preview": "from pathlib import Path\nfrom models import utils as mutils\nimport sampling\nfrom sde_lib import VESDE\nfrom sampling impo"
  },
  {
    "path": "likelihood.py",
    "chars": 4713,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "losses.py",
    "chars": 9950,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "main.py",
    "chars": 2291,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/__init__.py",
    "chars": 608,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/ddpm.py",
    "chars": 6082,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/ema.py",
    "chars": 3414,
    "preview": "# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py\n\nfrom __future__ import divi"
  },
  {
    "path": "models/layers.py",
    "chars": 22577,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/layerspp.py",
    "chars": 9001,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/ncsnpp.py",
    "chars": 14116,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/ncsnv2.py",
    "chars": 16043,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/normalization.py",
    "chars": 7657,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "models/unet.py",
    "chars": 3834,
    "preview": "from . import utils\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass ConvBlock(nn.Module):"
  },
  {
    "path": "models/up_or_down_sampling.py",
    "chars": 8900,
    "preview": "\"\"\"Layers used for up-sampling or down-sampling images.\n\nMany functions are ported from https://github.com/NVlabs/styleg"
  },
  {
    "path": "models/utils.py",
    "chars": 5684,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "op/__init__.py",
    "chars": 89,
    "preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
  },
  {
    "path": "op/fused_act.py",
    "chars": 2787,
    "preview": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Functi"
  },
  {
    "path": "op/fused_bias_act.cpp",
    "chars": 846,
    "preview": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias,"
  },
  {
    "path": "op/fused_bias_act_kernel.cu",
    "chars": 2875,
    "preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
  },
  {
    "path": "op/upfirdn2d.cpp",
    "chars": 988,
    "preview": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,\r\n "
  },
  {
    "path": "op/upfirdn2d.py",
    "chars": 5872,
    "preview": "import os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.c"
  },
  {
    "path": "op/upfirdn2d_kernel.cu",
    "chars": 12079,
    "preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
  },
  {
    "path": "physics/ct.py",
    "chars": 1749,
    "preview": "import torch\nimport numpy as np\nfrom .radon import Radon, IRadon\n\nclass CT():\n    def __init__(self, img_width, radon_vi"
  },
  {
    "path": "physics/inpainting.py",
    "chars": 691,
    "preview": "import os\nimport torch\n\nclass Inpainting():\n    def __init__(self, img_heigth=512, img_width=512, mode='random', mask_ra"
  },
  {
    "path": "physics/radon/__init__.py",
    "chars": 77,
    "preview": "from .radon import Radon, IRadon\nfrom .stackgram import Stackgram, IStackgram"
  },
  {
    "path": "physics/radon/filters.py",
    "chars": 2352,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .utils import PI, fftfreq\n\n'''source: https://gi"
  },
  {
    "path": "physics/radon/radon.py",
    "chars": 5770,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom physics.radon.filters import RampFilter\nfrom phy"
  },
  {
    "path": "physics/radon/stackgram.py",
    "chars": 2982,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .utils import SQRT2, deg2rad, affine_grid, grid_"
  },
  {
    "path": "physics/radon/utils.py",
    "chars": 727,
    "preview": "import torch\nimport torch.nn.functional as F\n\n'''source: https://github.com/matteo-ronchetti/torch-radon'''\n\nif torch.__"
  },
  {
    "path": "run_lib.py",
    "chars": 17769,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "sampling.py",
    "chars": 20073,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "sde_lib.py",
    "chars": 7554,
    "preview": "\"\"\"Abstract SDE classes, Reverse SDE, and VE/VP SDEs.\"\"\"\nimport abc\nimport torch\nimport numpy as np\n\n\nclass SDE(abc.ABC)"
  },
  {
    "path": "test/test_TV.py",
    "chars": 1926,
    "preview": "\"\"\"\npython -m pytest\n\"\"\"\nimport sys\n\nimport pytest\nimport torch\nimport matplotlib.pyplot as plt\nimport skimage\n\n\nimport "
  },
  {
    "path": "train_AAPM256.sh",
    "chars": 161,
    "preview": "#!/bin/bash\n\npython main.py \\\n  --config=configs/ve/AAPM_256_ncsnpp_continuous.py \\\n  --eval_folder=eval/AAPM256 \\\n  --m"
  },
  {
    "path": "utils.py",
    "chars": 11732,
    "preview": "from pathlib import Path\n\nimport torch\nimport os\nimport logging\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom "
  }
]

About this extraction

This page contains the full source code of the HJ-harry/DiffusionMBIR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 97 files (380.5 KB), approximately 107.9k tokens, and a symbol index with 494 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!