[
  {
    "path": ".gitignore",
    "content": "# Compiled source #\n###################\n*.o\n*.so\n*.pyc\n\n# Logs and temporaries #\n########################\n*.log\n*~\n.coverage\n\n# Folders #\n###########\nbuild/\ndist/\n*.egg-info/\n__pycache__/\n.eggs/\n\ndata/\nexp/\nresults/\nresults_AAPM/\nresults_AAPM_tv/\nworkdir/"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models (CVPR 2023)\n\nOfficial PyTorch implementation of **DiffusionMBIR**, the CVPR 2023 paper \"[Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models](https://arxiv.org/abs/2211.10655)\". Code modified from [score_sde_pytorch](https://github.com/yang-song/score_sde_pytorch).\n\n✅ If you would like to use an updated, faster version of DiffusionMBIR, you might want to use [DDS](https://github.com/hyungjin-chung/DDS)\n\n[![arXiv](https://img.shields.io/badge/arXiv-2211.10655-green)](https://arxiv.org/abs/2211.10655)\n[![arXiv](https://img.shields.io/badge/paper-CVPR2023-blue)](https://arxiv.org/abs/2211.10655)\n![concept](./figs/forward_model.jpg)\n![concept](./figs/cover_result.jpg)\n\n## Getting started\n\n### Download pre-trained model weights\n* **CT** experiments: [weights](https://drive.google.com/file/d/1-TaLbg3-4gLwKH2-Qf5VBFCBLG3RjY9j/view)\n\n### Download the data\n* **CT** experiments (in-distribution)\n```bash\nDATA_DIR=./data/CT/ind/256_sorted\nmkdir -p \"$DATA_DIR\"\nwget -O \"$DATA_DIR\"/256_sorted.zip https://www.dropbox.com/sh/ibjpgo5seksjera/AADlhYqCWq5C4K0uWSrCL_JUa?dl=1\nunzip -d \"$DATA_DIR\"/ \"$DATA_DIR\"/256_sorted.zip\n```\n* **CT** experiments (out-of-distribution)\n```bash\nDATA_DIR=./data/CT/ood/256_sorted\nmkdir -p \"$DATA_DIR\"\nwget -O \"$DATA_DIR\"/slice.zip https://www.dropbox.com/s/h3drrlx0pvutyoi/slice.zip?dl=0\nunzip -d \"$DATA_DIR\"/ \"$DATA_DIR\"/slice.zip\n```\n\n* Make a conda environment and install dependencies\n```bash\nconda env create --file environment.yml\n```\n\n## DiffusionMBIR (fast) reconstruction\nOnce you have the pre-trained weights and the test data set up properly, you may run the following scripts. Modify the parameters in the python scripts directly to change experimental settings.\n\n```bash\nconda activate diffusion-mbir\npython inverse_problem_solver_AAPM_3d_total.py\npython inverse_problem_solver_BRATS_MRI_3d_total.py\n```\n\n## Training\nYou may train the diffusion model with your own data by using e.g.\n```bash\nbash train_AAPM256.sh\n```\nYou can modify the training config with the ```--config``` flag.\n\n## Citation\nIf you find our work interesting, please consider citing\n\n```\n@InProceedings{chung2023solving,\n  title={Solving 3D Inverse Problems using Pre-trained 2D Diffusion Models},\n  author={Chung, Hyungjin and Ryu, Dohoon and McCann, Michael T and Klasky, Marc L and Ye, Jong Chul},\n  journal={IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  year={2023}\n}\n```\n"
  },
  {
    "path": "configs/default_celeba_configs.py",
    "content": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  config.training = training = ml_collections.ConfigDict()\n  # config.training.batch_size = 128\n  config.training.batch_size = 64\n  training.n_iters = 1300001\n  training.snapshot_freq = 50000\n  training.log_freq = 50\n  training.eval_freq = 100\n  ## store additional checkpoints for preemption in cloud computing environments\n  training.snapshot_freq_for_preemption = 10000\n  ## produce samples at each snapshot.\n  training.snapshot_sampling = True\n  training.likelihood_weighting = False\n  training.continuous = True\n  training.reduce_mean = False\n\n  # sampling\n  config.sampling = sampling = ml_collections.ConfigDict()\n  sampling.n_steps_each = 1\n  sampling.noise_removal = True\n  sampling.probability_flow = False\n  sampling.snr = 0.17\n\n  # evaluation\n  config.eval = evaluate = ml_collections.ConfigDict()\n  evaluate.begin_ckpt = 1\n  evaluate.end_ckpt = 26\n  evaluate.batch_size = 1024\n  evaluate.enable_sampling = True\n  evaluate.num_samples = 50000\n  evaluate.enable_loss = True\n  evaluate.enable_bpd = False\n  evaluate.bpd_dataset = 'test'\n\n  # data\n  config.data = data = ml_collections.ConfigDict()\n  data.dataset = 'CELEBA'\n  data.image_size = 64\n  data.random_flip = True\n  data.uniform_dequantization = False\n  data.centered = False\n  data.num_channels = 3\n\n  # model\n  config.model = model = ml_collections.ConfigDict()\n  model.sigma_max = 90.\n  model.sigma_min = 0.01\n  model.num_scales = 1000\n  model.beta_min = 0.1\n  model.beta_max = 20.\n  model.dropout = 0.1\n  model.embedding_type = 'fourier'\n\n  # optimization\n  config.optim = optim = ml_collections.ConfigDict()\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 2e-4\n  optim.beta1 = 0.9\n  optim.eps = 1e-8\n  optim.warmup = 5000\n  optim.grad_clip = 1.\n\n  config.seed = 42\n  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n  return config"
  },
  {
    "path": "configs/default_cifar10_configs.py",
    "content": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  config.training = training = ml_collections.ConfigDict()\n  # config.training.batch_size = 128\n  config.training.batch_size = 4\n  training.n_iters = 1300001\n  training.snapshot_freq = 50000\n  training.log_freq = 50\n  training.eval_freq = 100\n  ## store additional checkpoints for preemption in cloud computing environments\n  training.snapshot_freq_for_preemption = 10000\n  ## produce samples at each snapshot.\n  training.snapshot_sampling = True\n  training.likelihood_weighting = False\n  training.continuous = True\n  training.reduce_mean = False\n\n  # sampling\n  config.sampling = sampling = ml_collections.ConfigDict()\n  sampling.n_steps_each = 1\n  sampling.noise_removal = True\n  sampling.probability_flow = False\n  sampling.snr = 0.16\n\n  # evaluation\n  config.eval = evaluate = ml_collections.ConfigDict()\n  evaluate.begin_ckpt = 9\n  evaluate.end_ckpt = 26\n  evaluate.batch_size = 1024\n  evaluate.enable_sampling = False\n  evaluate.num_samples = 50000\n  evaluate.enable_loss = True\n  evaluate.enable_bpd = False\n  evaluate.bpd_dataset = 'test'\n\n  # data\n  config.data = data = ml_collections.ConfigDict()\n  data.dataset = 'CIFAR10'\n  data.image_size = 32\n  data.random_flip = True\n  data.centered = False\n  data.uniform_dequantization = False\n  data.num_channels = 3\n  # data.num_channels = 1\n\n  # model\n  config.model = model = ml_collections.ConfigDict()\n  model.sigma_min = 0.01\n  model.sigma_max = 50\n  model.num_scales = 1000\n  model.beta_min = 0.1\n  model.beta_max = 20.\n  model.dropout = 0.1\n  model.embedding_type = 'fourier'\n\n  # optimization\n  config.optim = optim = ml_collections.ConfigDict()\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 2e-4\n  optim.beta1 = 0.9\n  optim.eps = 1e-8\n  optim.warmup = 5000\n  optim.grad_clip = 1.\n\n  config.seed = 42\n  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n  return config"
  },
  {
    "path": "configs/default_complex_configs.py",
    "content": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  config.training = training = ml_collections.ConfigDict()\n  # config.training.batch_size = 64\n  # config.training.batch_size = 2  # seriously?\n  config.training.batch_size = 1  # When using single GPU\n  # training.n_iters = 2400001\n  training.epochs = 100\n  training.snapshot_freq = 50000\n  # training.log_freq = 50\n  training.log_freq = 25\n  training.eval_freq = 100\n  ## store additional checkpoints for preemption in cloud computing environments\n  training.snapshot_freq_for_preemption = 5000\n  ## produce samples at each snapshot.\n  training.snapshot_sampling = True\n  training.likelihood_weighting = False\n  training.continuous = True\n  training.reduce_mean = False\n\n  # sampling\n  config.sampling = sampling = ml_collections.ConfigDict()\n  sampling.n_steps_each = 1\n  sampling.noise_removal = True\n  sampling.probability_flow = False\n  sampling.snr = 0.075\n\n  # evaluation\n  config.eval = evaluate = ml_collections.ConfigDict()\n  evaluate.begin_ckpt = 50\n  evaluate.end_ckpt = 96\n  # evaluate.batch_size = 512\n  evaluate.batch_size = 8\n  evaluate.enable_sampling = True\n  evaluate.num_samples = 50000\n  evaluate.enable_loss = True\n  evaluate.enable_bpd = False\n  evaluate.bpd_dataset = 'test'\n\n  # data\n  config.data = data = ml_collections.ConfigDict()\n  # data.dataset = 'LSUN'\n  data.image_size = 320\n  data.random_flip = True\n  data.uniform_dequantization = False\n  data.centered = False\n  data.num_channels = 2\n\n  # model\n  config.model = model = ml_collections.ConfigDict()\n  model.sigma_max = 378\n  model.sigma_min = 0.01\n  model.num_scales = 2000\n  model.beta_min = 0.1\n  model.beta_max = 20.\n  model.dropout = 0.\n  model.embedding_type = 'fourier'\n\n  # optimization\n  config.optim = optim = ml_collections.ConfigDict()\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 2e-4\n  optim.beta1 = 0.9\n  optim.eps = 1e-8\n  optim.warmup = 5000\n  optim.grad_clip = 1.\n\n  config.seed = 42\n  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n  return config"
  },
  {
    "path": "configs/default_lsun_configs.py",
    "content": "import ml_collections\nimport torch\n\n\ndef get_default_configs():\n  config = ml_collections.ConfigDict()\n  # training\n  config.training = training = ml_collections.ConfigDict()\n  # config.training.batch_size = 64\n  # config.training.batch_size = 2  # seriously?\n  config.training.batch_size = 1  # When using single GPU\n  # training.n_iters = 2400001\n  training.epochs = 1000\n  training.snapshot_freq = 50000\n  # training.log_freq = 50\n  training.log_freq = 25\n  training.eval_freq = 100\n  ## store additional checkpoints for preemption in cloud computing environments\n  training.snapshot_freq_for_preemption = 5000\n  ## produce samples at each snapshot.\n  training.snapshot_sampling = True\n  training.likelihood_weighting = False\n  training.continuous = True\n  training.reduce_mean = False\n\n  # sampling\n  config.sampling = sampling = ml_collections.ConfigDict()\n  sampling.n_steps_each = 1\n  sampling.noise_removal = True\n  sampling.probability_flow = False\n  sampling.snr = 0.075\n\n  # evaluation\n  config.eval = evaluate = ml_collections.ConfigDict()\n  evaluate.begin_ckpt = 50\n  evaluate.end_ckpt = 96\n  # evaluate.batch_size = 512\n  evaluate.batch_size = 8\n  evaluate.enable_sampling = True\n  evaluate.num_samples = 50000\n  evaluate.enable_loss = True\n  evaluate.enable_bpd = False\n  evaluate.bpd_dataset = 'test'\n\n  # data\n  config.data = data = ml_collections.ConfigDict()\n  data.dataset = 'LSUN'\n  data.image_size = 256\n  data.random_flip = True\n  data.uniform_dequantization = False\n  data.centered = False\n  # data.num_channels = 3\n  data.num_channels = 1\n\n  # model\n  config.model = model = ml_collections.ConfigDict()\n  model.sigma_max = 378\n  model.sigma_min = 0.01\n  model.num_scales = 2000\n  model.beta_min = 0.1\n  model.beta_max = 20.\n  model.dropout = 0.\n  model.embedding_type = 'fourier'\n\n  # optimization\n  config.optim = optim = ml_collections.ConfigDict()\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 2e-4\n  optim.beta1 = 0.9\n  optim.eps = 1e-8\n  optim.warmup = 5000\n  optim.grad_clip = 1.\n\n  config.seed = 42\n  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n  return config"
  },
  {
    "path": "configs/subvp/cifar10_ddpm_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training DDPM with sub-VP SDE.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n\n  # training\n  training = config.training\n  training.sde = 'subvpsde'\n  training.continuous = True\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n\n  return config\n"
  },
  {
    "path": "configs/subvp/cifar10_ddpmpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSNv3 on CIFAR-10 with continuous sigmas.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'subvpsde'\n  training.continuous = True\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = False\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'none'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.embedding_type = 'positional'\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/subvp/cifar10_ddpmpp_deep_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSNv3 on CIFAR-10 with continuous sigmas.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'subvpsde'\n  training.continuous = True\n  training.reduce_mean = True\n  training.n_iters = 950001\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 8\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = False\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'none'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.embedding_type = 'positional'\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/subvp/cifar10_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10 with sub-VP SDE.\"\"\"\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'subvpsde'\n  training.continuous = True\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.embedding_type = 'positional'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/subvp/cifar10_ncsnpp_deep_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'subvpsde'\n  training.continuous = True\n  training.n_iters = 950001\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.fourier_scale = 16\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 8\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.embedding_type = 'positional'\n  model.init_scale = 0.0\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/AAPM_128_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on fastmri knee with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'aapm'\n  data.root = '/media/harry/tomo/AAPM_data/128'\n  data.is_complex = False\n  data.is_multi = False\n  data.image_size = 128\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/AAPM_256_ncsnpp_continuous.py",
    "content": "from configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'AAPM'\n  data.root = '/media/harry/tomo/AAPM_data/256'\n  data.is_complex = False\n  data.is_multi = False\n  data.image_size = 256\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config"
  },
  {
    "path": "configs/ve/Object5_fast.py",
    "content": "from configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n  training.epochs = 3\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'Object5Fast'\n  data.root = './data/Object5/'\n  data.is_complex = False\n  data.is_multi = False\n  data.image_size = 256\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  model.num_scales = 3  # number of sampling steps\n\n  return config"
  },
  {
    "path": "configs/ve/Object5_ncsnpp_continuous.py",
    "content": "from configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'Object5'\n  data.root = './data/Object5/'\n  data.is_complex = False\n  data.is_multi = False\n  data.image_size = 256\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config"
  },
  {
    "path": "configs/ve/bedroom_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on bedroom with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.category = 'bedroom'\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'output_skip'\n  model.progressive_input = 'input_skip'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/celeba_ncsnpp.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CelebA with SMLD.\"\"\"\n\nfrom configs.default_celeba_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.sigma_begin = 90\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.0\n  model.conv_size = 3\n  model.embedding_type = 'positional'\n\n  return config\n"
  },
  {
    "path": "configs/ve/celebahq_256_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on Church with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'CelebAHQ'\n  data.image_size = 256\n  data.tfrecords_path = '/home/yangsong/ncsc/celebahq/r08.tfrecords'\n\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.sigma_max = 348\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'output_skip'\n  model.progressive_input = 'input_skip'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/celebahq_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CelebAHQ with VE SDE.\"\"\"\n\nimport ml_collections\nimport torch\n\n\ndef get_config():\n  config = ml_collections.ConfigDict()\n  # training\n  config.training = training = ml_collections.ConfigDict()\n  training.batch_size = 8\n  training.n_iters = 2400001\n  training.snapshot_freq = 50000\n  training.log_freq = 50\n  training.eval_freq = 100\n  training.snapshot_freq_for_preemption = 5000\n  training.snapshot_sampling = True\n  training.sde = 'vesde'\n  training.continuous = True\n  training.likelihood_weighting = False\n  training.reduce_mean = False\n\n  # sampling\n  config.sampling = sampling = ml_collections.ConfigDict()\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n  sampling.probability_flow = False\n  sampling.snr = 0.15\n  sampling.n_steps_each = 1\n  sampling.noise_removal = True\n\n  # eval\n  config.eval = evaluate = ml_collections.ConfigDict()\n  evaluate.batch_size = 1024\n  evaluate.num_samples = 50000\n  evaluate.begin_ckpt = 1\n  evaluate.end_ckpt = 96\n\n  # data\n  config.data = data = ml_collections.ConfigDict()\n  data.dataset = 'CelebAHQ'\n  data.image_size = 1024\n  data.centered = False\n  data.random_flip = True\n  data.uniform_dequantization = False\n  data.num_channels = 3\n  data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords'\n\n  # model\n  config.model = model = ml_collections.ConfigDict()\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.sigma_max = 1348\n  model.num_scales = 2000\n  model.ema_rate = 0.9999\n  model.sigma_min = 0.01\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 16\n  model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)\n  model.num_res_blocks = 1\n  model.attn_resolutions = (16,)\n  model.dropout = 0.\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'output_skip'\n  model.progressive_input = 'input_skip'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n  model.embedding_type = 'fourier'\n\n  # optim\n  config.optim = optim = ml_collections.ConfigDict()\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 2e-4\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 5000\n  optim.grad_clip = 1.\n\n  config.seed = 42\n  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n  return config\n"
  },
  {
    "path": "configs/ve/church_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on Church with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.category = 'church_outdoor'\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.sigma_max = 380\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'output_skip'\n  model.progressive_input = 'input_skip'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/cifar10_ddpm.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Train the original DDPM model with SMLD.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/cifar10_ncsnpp.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10 with SMLD.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.0\n  model.embedding_type = 'positional'\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/cifar10_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10 with VE SDE.\"\"\"\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/cifar10_ncsnpp_deep_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10 with VE SDE.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n  training.n_iters = 950001\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.fourier_scale = 16\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 8\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.0\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/fastmri_knee_128_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on fastmri knee with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # training (regression)\n  training.mask_type = 'gaussian2d'\n  training.acc_factor = [8, 15]\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'fastmri_knee'\n  data.root = '/media/harry/tomo/fastmri'\n  data.is_complex = False\n  data.is_multi = False\n  data.image_size = 128\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/fastmri_knee_256_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on fastmri knee with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'fastmri_knee'\n  data.root = '/media/harry/tomo/fastmri'\n  data.image_size = 256\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on fastmri knee with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'fastmri_knee'\n  data.root = '/media/harry/tomo/fastmri'\n  data.image_size = 320\n  data.is_multi = False\n  data.is_complex = False\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous_complex.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on fastmri knee with VE SDE.\"\"\"\n\nfrom configs.default_complex_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'fastmri_knee'\n  data.is_multi = False\n  data.is_complex = True\n  data.root = '/media/harry/tomo/fastmri'\n  data.image_size = 320\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous_complex_magpha.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on fastmri knee with VE SDE.\"\"\"\n\nfrom configs.default_complex_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'fastmri_knee'\n  data.is_multi = False\n  data.is_complex = True\n  data.magpha = True\n  data.root = '/media/harry/tomo/fastmri'\n  data.image_size = 320\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/fastmri_knee_320_ncsnpp_continuous_multi.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on fastmri knee with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'fastmri_knee'\n  data.is_complex = False\n  data.is_multi = True\n  data.root = '/media/harry/tomo/fastmri'\n  data.image_size = 320\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/ffhq_256_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on FFHQ with VE SDE.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n\n  # data\n  data = config.data\n  data.dataset = 'FFHQ'\n  data.image_size = 256\n  data.tfrecords_path = '/media/harry/ExtDrive/PycharmProjects/score_sde_pytorch/dataset/FFHQ/ffhq-r08.tfrecords'\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.sigma_max = 348\n  model.scale_by_sigma = True\n  model.ema_rate = 0.999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 1, 2, 2, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'output_skip'\n  model.progressive_input = 'input_skip'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/ve/ffhq_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on FFHQ with VE SDEs.\"\"\"\n\nimport ml_collections\nimport torch\n\ndef get_config():\n  config = ml_collections.ConfigDict()\n  # training\n  config.training = training = ml_collections.ConfigDict()\n  training.batch_size = 8\n  training.n_iters = 2400001\n  training.snapshot_freq = 50000\n  training.log_freq = 50\n  training.eval_freq = 100\n  training.snapshot_freq_for_preemption = 5000\n  training.snapshot_sampling = True\n  training.sde = 'vesde'\n  training.continuous = True\n  training.likelihood_weighting = False\n  training.reduce_mean = True\n\n  # sampling\n  config.sampling = sampling = ml_collections.ConfigDict()\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'langevin'\n  sampling.probability_flow = False\n  sampling.snr = 0.15\n  sampling.n_steps_each = 1\n  sampling.noise_removal = True\n\n  # eval\n  config.eval = evaluate = ml_collections.ConfigDict()\n  evaluate.batch_size = 1024\n  evaluate.num_samples = 50000\n  evaluate.begin_ckpt = 1\n  evaluate.end_ckpt = 96\n\n  # data\n  config.data = data = ml_collections.ConfigDict()\n  data.dataset = 'FFHQ'\n  data.image_size = 1024\n  data.centered = False\n  data.random_flip = True\n  data.uniform_dequantization = False\n  data.num_channels = 3\n  # Plug in your own path to the tfrecords file.\n  data.tfrecords_path = '/raid/song/ffhq-dataset/ffhq/ffhq-r10.tfrecords'\n\n  # model\n  config.model = model = ml_collections.ConfigDict()\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = True\n  model.sigma_max = 1348\n  model.num_scales = 2000\n  model.ema_rate = 0.9999\n  model.sigma_min = 0.01\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 16\n  model.ch_mult = (1, 2, 4, 8, 16, 32, 32, 32)\n  model.num_res_blocks = 1\n  model.attn_resolutions = (16,)\n  model.dropout = 0.\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'output_skip'\n  model.progressive_input = 'input_skip'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n  model.embedding_type = 'fourier'\n\n  # optim\n  config.optim = optim = ml_collections.ConfigDict()\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 2e-4\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 5000\n  optim.grad_clip = 1.\n\n  config.seed = 42\n  config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/celeba.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for reproducing NCSNv1 on CelebA.\"\"\"\n\nfrom configs.default_celeba_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.loss = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 100\n  sampling.snr = 0.316\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.sigma_max = 1\n  model.num_scales = 10\n  model.ema_rate = 0.\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/celeba_124.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSN with technique 1,2,4 only.\"\"\"\n\nfrom configs.default_celeba_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 5\n  sampling.snr = 0.128\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.num_scales = 500\n  model.ema_rate = 0.\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/celeba_1245.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSN with technique 1245 only.\"\"\"\n\nfrom configs.default_celeba_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 5\n  sampling.snr = 0.128\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.num_scales = 500\n  model.ema_rate = 0.999\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/celeba_5.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSNv1 model with technique 5 only.\"\"\"\n\nfrom configs.default_celeba_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 100\n  sampling.snr = 0.316\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.sigma_max = 1.\n  model.num_scales = 10\n  model.ema_rate = 0.999\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/cifar10.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for reproducing NCSNv1 on CIFAR-10.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 100\n  sampling.snr = 0.316\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.sigma_max = 1\n  model.num_scales = 10\n  model.ema_rate = 0.\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/cifar10_124.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSN with technique 1,2,4 only.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 5\n  sampling.snr = 0.176\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.num_scales = 232\n  model.ema_rate = 0.\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/cifar10_1245.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSN with technique 1,2,4,5 only.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # shared configs for sample generation\n  step_size = 0.0000062\n  n_steps_each = 5\n  ckpt_id = 300000\n  final_only = True\n  noise_removal = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 5\n  sampling.snr = 0.176\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.num_scales = 232\n  model.ema_rate = 0.999\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsn/cifar10_5.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSN with technique 5 only.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.snr = 0.316\n  sampling.n_steps_each = 100\n  # model\n  model = config.model\n  model.name = 'ncsn'\n  model.scale_by_sigma = False\n  model.sigma_max = 1\n  model.num_scales = 10\n  model.ema_rate = 0.999\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-3\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsnv2/bedroom.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSNv2 on bedroom.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.batch_size = 128\n  training.sde = 'vesde'\n  training.continuouse = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 3\n  sampling.snr = 0.095\n  # data\n  data = config.data\n  data.category = 'bedroom'\n  data.image_size = 128\n  # model\n  model = config.model\n  model.name = 'ncsnv2_128'\n  model.scale_by_sigma = True\n  model.sigma_max = 190\n  model.num_scales = 1086\n  model.ema_rate = 0.9999\n  model.sigma_min = 0.01\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-4\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsnv2/celeba.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSNv2 on CelebA.\"\"\"\n\nfrom configs.default_celeba_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # shared configs for sample generation\n  step_size = 0.0000033\n  n_steps_each = 5\n  ckpt_id = 210000\n  final_only = True\n  noise_removal = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 5\n  sampling.snr = 0.128\n  # model\n  model = config.model\n  model.name = 'ncsnv2_64'\n  model.scale_by_sigma = True\n  model.num_scales = 500\n  model.ema_rate = 0.999\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-4\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/ve/ncsnv2/cifar10.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for training NCSNv2 on CIFAR-10.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vesde'\n  training.continuous = False\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'none'\n  sampling.corrector = 'ald'\n  sampling.n_steps_each = 5\n  sampling.snr = 0.176\n  # model\n  model = config.model\n  model.name = 'ncsnv2_64'\n  model.scale_by_sigma = True\n  model.num_scales = 232\n  model.ema_rate = 0.999\n  model.normalization = 'InstanceNorm++'\n  model.nonlinearity = 'elu'\n  model.nf = 128\n  model.interpolation = 'bilinear'\n  # optim\n  optim = config.optim\n  optim.weight_decay = 0\n  optim.optimizer = 'Adam'\n  optim.lr = 1e-4\n  optim.beta1 = 0.9\n  optim.amsgrad = False\n  optim.eps = 1e-8\n  optim.warmup = 0\n  optim.grad_clip = -1.\n\n  return config\n"
  },
  {
    "path": "configs/vp/cifar10_ddpmpp.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSNv3 on CIFAR-10 with continuous sigmas.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = False\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'ancestral_sampling'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = False\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'none'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.embedding_type = 'positional'\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/vp/cifar10_ddpmpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSNv3 on CIFAR-10 with continuous sigmas.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = True\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = False\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'none'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.embedding_type = 'positional'\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/vp/cifar10_ddpmpp_deep_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSNv3 on CIFAR-10 with continuous sigmas.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = True\n  training.reduce_mean = True\n  training.n_iters = 950001\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 8\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = False\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'none'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.\n  model.embedding_type = 'positional'\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/vp/cifar10_ncsnpp.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10 with DDPM.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = False\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'reverse_diffusion'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.init_scale = 0.0\n  model.embedding_type = 'positional'\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/vp/cifar10_ncsnpp_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10 with VP SDE.\"\"\"\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = True\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 4\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.embedding_type = 'positional'\n  model.init_scale = 0.\n  model.fourier_scale = 16\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/vp/cifar10_ncsnpp_deep_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training NCSN++ on CIFAR-10.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = True\n  training.n_iters = 950001\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ncsnpp'\n  model.fourier_scale = 16\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 8\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n  model.fir = True\n  model.fir_kernel = [1, 3, 3, 1]\n  model.skip_rescale = True\n  model.resblock_type = 'biggan'\n  model.progressive = 'none'\n  model.progressive_input = 'residual'\n  model.progressive_combine = 'sum'\n  model.attention_type = 'ddpm'\n  model.embedding_type = 'positional'\n  model.init_scale = 0.0\n  model.conv_size = 3\n\n  return config\n"
  },
  {
    "path": "configs/vp/ddpm/bedroom.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for reproducing the results of DDPM on bedrooms.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = False\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'ancestral_sampling'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.category = 'bedroom'\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = False\n  model.num_scales = 1000\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 1, 2, 2, 4, 4)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n\n  # optim\n  optim = config.optim\n  optim.lr = 2e-5\n\n  return config\n"
  },
  {
    "path": "configs/vp/ddpm/celebahq.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for reproducing the results of DDPM on bedrooms.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = False\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'ancestral_sampling'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.dataset = 'CelebAHQ'\n  data.centered = True\n  data.tfrecords_path = '/atlas/u/yangsong/celeba_hq/-r10.tfrecords'\n  data.image_size = 256\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = False\n  model.num_scales = 1000\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 1, 2, 2, 4, 4)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n\n  # optim\n  optim = config.optim\n  optim.lr = 2e-5\n\n  return config\n"
  },
  {
    "path": "configs/vp/ddpm/church.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for reproducing the results of DDPM on church_outdoor.\"\"\"\n\nfrom configs.default_lsun_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = False\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'ancestral_sampling'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.category = 'church_outdoor'\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = False\n  model.num_scales = 1000\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 1, 2, 2, 4, 4)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n\n  # optim\n  optim = config.optim\n  optim.lr = 2e-5\n\n  return config\n"
  },
  {
    "path": "configs/vp/ddpm/cifar10.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Config file for reproducing the results of DDPM on cifar-10.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = False\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'ancestral_sampling'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n\n  return config\n"
  },
  {
    "path": "configs/vp/ddpm/cifar10_continuous.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training DDPM with VP SDE.\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = True\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'euler_maruyama'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = True\n\n  return config\n"
  },
  {
    "path": "configs/vp/ddpm/cifar10_unconditional.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Lint as: python3\n\"\"\"Training DDPM on CIFAR-10 without explicitly conditioning on time steps. (NCSNv2 technique 3)\"\"\"\n\nfrom configs.default_cifar10_configs import get_default_configs\n\n\ndef get_config():\n  config = get_default_configs()\n\n  # training\n  training = config.training\n  training.sde = 'vpsde'\n  training.continuous = False\n  training.reduce_mean = True\n\n  # sampling\n  sampling = config.sampling\n  sampling.method = 'pc'\n  sampling.predictor = 'ancestral_sampling'\n  sampling.corrector = 'none'\n\n  # data\n  data = config.data\n  data.centered = True\n\n  # model\n  model = config.model\n  model.name = 'ddpm'\n  model.scale_by_sigma = False\n  model.ema_rate = 0.9999\n  model.normalization = 'GroupNorm'\n  model.nonlinearity = 'swish'\n  model.nf = 128\n  model.ch_mult = (1, 2, 2, 2)\n  model.num_res_blocks = 2\n  model.attn_resolutions = (16,)\n  model.resamp_with_conv = True\n  model.conditional = False\n\n  return config\n"
  },
  {
    "path": "controllable_generation_TV.py",
    "content": "import functools\r\nimport time\r\n\r\nimport torch\r\nfrom numpy.testing._private.utils import measure\r\nimport numpy as np\r\nimport matplotlib.pyplot as plt\r\nfrom tqdm import tqdm\r\n\r\nfrom models import utils as mutils\r\nfrom sampling import NoneCorrector, NonePredictor, shared_corrector_update_fn, shared_predictor_update_fn\r\nfrom utils import fft2, ifft2, fft2_m, ifft2_m\r\nfrom physics.ct import *\r\nfrom utils import show_samples, show_samples_gray, clear, clear_color, batchfy\r\n\r\n\r\n\r\nclass lambda_schedule:\r\n  def __init__(self, total=2000):\r\n    self.total = total\r\n\r\n  def get_current_lambda(self, i):\r\n    pass\r\nclass lambda_schedule_linear(lambda_schedule):\r\n  def __init__(self, start_lamb=1.0, end_lamb=0.0):\r\n    super().__init__()\r\n    self.start_lamb = start_lamb\r\n    self.end_lamb = end_lamb\r\n\r\n  def get_current_lambda(self, i):\r\n    return self.start_lamb + (self.end_lamb - self.start_lamb) * (i / self.total)\r\n\r\n\r\nclass lambda_schedule_const(lambda_schedule):\r\n  def __init__(self, lamb=1.0):\r\n    super().__init__()\r\n    self.lamb = lamb\r\n\r\n  def get_current_lambda(self, i):\r\n    return self.lamb\r\n\r\n\r\ndef _Dz(x): # Batch direction\r\n    y = torch.zeros_like(x)\r\n    y[:-1] = x[1:]\r\n    y[-1] = x[0]\r\n    return y - x\r\n\r\n\r\ndef _DzT(x): # Batch direction\r\n    y = torch.zeros_like(x)\r\n    y[:-1] = x[1:]\r\n    y[-1] = x[0]\r\n\r\n    tempt = -(y-x)\r\n    difft = tempt[:-1]\r\n    y[1:] = difft\r\n    y[0] = x[-1] - x[0]\r\n\r\n    return y\r\n\r\ndef _Dx(x):  # Batch direction\r\n    y = torch.zeros_like(x)\r\n    y[:, :, :-1, :] = x[:, :, 1:, :]\r\n    y[:, :, -1, :] = x[:, :, 0, :]\r\n    return y - x\r\n\r\n\r\ndef _DxT(x):  # Batch direction\r\n    y = torch.zeros_like(x)\r\n    y[:, :, :-1, :] = x[:, :, 1:, :]\r\n    y[:, :, -1, :] = x[:, :, 0, :]\r\n    tempt = -(y - x)\r\n    difft = tempt[:, :, :-1, :]\r\n    y[:, :, 1:, :] = difft\r\n    y[:, :, 0, :] = x[:, :, -1, :] - x[:, :, 0, :]\r\n    return y\r\n\r\n\r\ndef _Dy(x):  # Batch direction\r\n    y = torch.zeros_like(x)\r\n    y[:, :, :, :-1] = x[:, :, :, 1:]\r\n    y[:, :, :, -1] = x[:, :, :, 0]\r\n    return y - x\r\n\r\n\r\ndef _DyT(x):  # Batch direction\r\n    y = torch.zeros_like(x)\r\n    y[:, :, :, :-1] = x[:, :, :, 1:]\r\n    y[:, :, :, -1] = x[:, :, :, 0]\r\n    tempt = -(y - x)\r\n    difft = tempt[:, :, :, :-1]\r\n    y[:, :, :, 1:] = difft\r\n    y[:, :, :, 0] = x[:, :, :, -1] - x[:, :, :, 0]\r\n    return y\r\n\r\n\r\ndef get_pc_radon_ADMM_TV(sde, predictor, corrector, inverse_scaler, snr,\r\n                         n_steps=1, probability_flow=False, continuous=False,\r\n                         denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None,\r\n                         final_consistency=False, img_cache=None, img_shape=None, lamb_1=5, rho=10):\r\n    \"\"\" Sparse application of measurement consistency \"\"\"\r\n    # Define predictor & corrector\r\n    predictor_update_fn = functools.partial(shared_predictor_update_fn,\r\n                                            sde=sde,\r\n                                            predictor=predictor,\r\n                                            probability_flow=probability_flow,\r\n                                            continuous=continuous)\r\n    corrector_update_fn = functools.partial(shared_corrector_update_fn,\r\n                                            sde=sde,\r\n                                            corrector=corrector,\r\n                                            continuous=continuous,\r\n                                            snr=snr,\r\n                                            n_steps=n_steps)\r\n\r\n    if img_cache != None :\r\n        img_shape[0] += 1\r\n    del_z = torch.zeros(img_shape)\r\n    udel_z = torch.zeros(img_shape)\r\n    eps = 1e-10\r\n\r\n    def _A(x):\r\n        return radon.A(x)\r\n\r\n    def _AT(sinogram):\r\n        return radon.AT(sinogram)\r\n\r\n    def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None,\r\n                 norm_const=None):\r\n        x = x + lamb * _AT(measurement - _A(x))/norm_const\r\n        x_mean = x\r\n        return x, x_mean\r\n    \r\n    def A_cg(x):\r\n        return _AT(_A(x)) + rho * _DzT(_Dz(x))\r\n\r\n    def CG(A_fn,b_cg,x,n_inner=10):\r\n        r = b_cg - A_fn(x)\r\n        p = r\r\n        rs_old = torch.matmul(r.view(1,-1),r.view(1,-1).T)\r\n\r\n        for i in range(n_inner):\r\n            Ap = A_fn(p)\r\n            a = rs_old/torch.matmul(p.view(1,-1),Ap.view(1,-1).T)\r\n    \r\n            x += a * p\r\n            r -= a * Ap\r\n\r\n            rs_new = torch.matmul(r.view(1,-1),r.view(1,-1).T)\r\n            if torch.sqrt(rs_new) < eps :\r\n                break\r\n            p = r + (rs_new/rs_old) * p\r\n            rs_old = rs_new\r\n        return x\r\n\r\n    def CS_routine(x,ATy, niter=20):\r\n        if img_cache != None :\r\n            x = torch.cat([img_cache,x],dim=0)\r\n            idx = list(range(len(x),0,-1))\r\n            x = x[idx]\r\n\r\n        nonlocal del_z, udel_z\r\n        if del_z.device != x.device :\r\n            del_z = del_z.to(x.device)\r\n            udel_z = del_z.to(x.device)\r\n        for i in range(niter):\r\n            b_cg = ATy + rho * (_DzT(del_z)-_DzT(udel_z))\r\n            x = CG(A_cg, b_cg, x, n_inner=1)\r\n\r\n            del_z = shrink(_Dz(x) + udel_z, lamb_1/rho)\r\n            udel_z = _Dz(x) - del_z + udel_z\r\n        if img_cache != None :\r\n            x = x[idx]\r\n            x = x[1:]\r\n            del_z[-1] = 0\r\n            udel_z[-1] = 0\r\n        x_mean = x\r\n        return x, x_mean\r\n\r\n    def get_update_fn(update_fn):\r\n        def radon_update_fn(model, data, x, t):\r\n            with torch.no_grad():\r\n                vec_t = torch.ones(data.shape[0], device=data.device) * t\r\n                x, x_mean = update_fn(x, vec_t, model=model)\r\n                return x, x_mean\r\n        return radon_update_fn\r\n\r\n    def get_corrector_update_fn(update_fn):\r\n        def radon_update_fn(model, data, x, t, measurement=None):\r\n            with torch.no_grad():\r\n                vec_t = torch.ones(data.shape[0], device=data.device) * t\r\n                x, x_mean = update_fn(x, vec_t, model=model)\r\n                ATy = _AT(measurement)\r\n                x, x_mean = CS_routine(x, ATy, niter=1)\r\n                return x, x_mean\r\n        return radon_update_fn\r\n\r\n    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)\r\n    corrector_radon_update_fn = get_corrector_update_fn(corrector_update_fn)\r\n\r\n    def pc_radon(model, data, measurement=None):\r\n        with torch.no_grad():\r\n            x = sde.prior_sampling(data.shape).to(data.device)\r\n\r\n            ones = torch.ones_like(x).to(data.device)\r\n            norm_const = _AT(_A(ones))\r\n            timesteps = torch.linspace(sde.T, eps, sde.N)\r\n            for i in tqdm(range(sde.N)):\r\n                t = timesteps[i]\r\n                x, x_mean = predictor_denoise_update_fn(model, data, x, t)\r\n                x, x_mean = corrector_radon_update_fn(model, data, x, t, measurement=measurement)\r\n                if save_progress:\r\n                    if (i % 50) == 0:\r\n                        print(f'iter: {i}/{sde.N}')\r\n                        plt.imsave(save_root / 'recon' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')\r\n            # Final step which coerces the data fidelity error term to be zero,\r\n            # and thereby satisfying Ax = y\r\n            if final_consistency:\r\n                x, x_mean = kaczmarz(x, x_mean, measurement, lamb=1.0, norm_const=norm_const)\r\n\r\n            return inverse_scaler(x_mean if denoise else x)\r\n\r\n    return pc_radon\r\n\r\n\r\ndef get_pc_radon_ADMM_TV_vol(sde, predictor, corrector, inverse_scaler, snr,\r\n                             n_steps=1, probability_flow=False, continuous=False,\r\n                             denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None,\r\n                             final_consistency=False, img_shape=None, lamb_1=5, rho=10):\r\n    \"\"\" Sparse application of measurement consistency \"\"\"\r\n    # Define predictor & corrector\r\n    predictor_update_fn = functools.partial(shared_predictor_update_fn,\r\n                                            sde=sde,\r\n                                            predictor=predictor,\r\n                                            probability_flow=probability_flow,\r\n                                            continuous=continuous)\r\n    corrector_update_fn = functools.partial(shared_corrector_update_fn,\r\n                                            sde=sde,\r\n                                            corrector=corrector,\r\n                                            continuous=continuous,\r\n                                            snr=snr,\r\n                                            n_steps=n_steps)\r\n\r\n    del_z = torch.zeros(img_shape)\r\n    udel_z = torch.zeros(img_shape)\r\n    eps = 1e-10\r\n\r\n    def _A(x):\r\n        return radon.A(x)\r\n\r\n    def _AT(sinogram):\r\n        return radon.AT(sinogram)\r\n\r\n    def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None,\r\n                 norm_const=None):\r\n        x = x + lamb * _AT(measurement - _A(x)) / norm_const\r\n        x_mean = x\r\n        return x, x_mean\r\n\r\n    def A_cg(x):\r\n        return _AT(_A(x)) + rho * _DzT(_Dz(x))\r\n\r\n    def CG(A_fn, b_cg, x, n_inner=10):\r\n        r = b_cg - A_fn(x)\r\n        p = r\r\n        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n\r\n        for i in range(n_inner):\r\n            Ap = A_fn(p)\r\n            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)\r\n\r\n            x += a * p\r\n            r -= a * Ap\r\n\r\n            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n            if torch.sqrt(rs_new) < eps:\r\n                break\r\n            p = r + (rs_new / rs_old) * p\r\n            rs_old = rs_new\r\n        return x\r\n\r\n    def CS_routine(x, ATy, niter=20):\r\n        nonlocal del_z, udel_z\r\n        if del_z.device != x.device:\r\n            del_z = del_z.to(x.device)\r\n            udel_z = del_z.to(x.device)\r\n        for i in range(niter):\r\n            b_cg = ATy + rho * (_DzT(del_z) - _DzT(udel_z))\r\n            x = CG(A_cg, b_cg, x, n_inner=1)\r\n\r\n            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)\r\n            udel_z = _Dz(x) - del_z + udel_z\r\n        x_mean = x\r\n        return x, x_mean\r\n\r\n    def get_update_fn(update_fn):\r\n        def radon_update_fn(model, data, x, t):\r\n            with torch.no_grad():\r\n                vec_t = torch.ones(x.shape[0], device=x.device) * t\r\n                x, x_mean = update_fn(x, vec_t, model=model)\r\n                return x, x_mean\r\n\r\n        return radon_update_fn\r\n\r\n    def get_ADMM_TV_fn():\r\n        def ADMM_TV_fn(x, measurement=None):\r\n            with torch.no_grad():\r\n                ATy = _AT(measurement)\r\n                x, x_mean = CS_routine(x, ATy, niter=1)\r\n                return x, x_mean\r\n        return ADMM_TV_fn\r\n\r\n    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)\r\n    corrector_denoise_update_fn = get_update_fn(corrector_update_fn)\r\n    mc_update_fn = get_ADMM_TV_fn()\r\n\r\n    def pc_radon(model, data, measurement=None):\r\n        with torch.no_grad():\r\n            x = sde.prior_sampling(data.shape).to(data.device)\r\n\r\n            ones = torch.ones_like(x).to(data.device)\r\n            norm_const = _AT(_A(ones))\r\n            timesteps = torch.linspace(sde.T, eps, sde.N)\r\n            for i in tqdm(range(sde.N)):\r\n                t = timesteps[i]\r\n                # 1. batchify into sizes that fit into the GPU\r\n                x_batch = batchfy(x, 12)\r\n                # 2. Run PC step for each batch\r\n                x_agg = list()\r\n                for idx, x_batch_sing in enumerate(x_batch):\r\n                    x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t)\r\n                    x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t)\r\n                    x_agg.append(x_batch_sing)\r\n                # 3. Aggregate to run ADMM TV\r\n                x = torch.cat(x_agg, dim=0)\r\n                # 4. Run ADMM TV\r\n                x, x_mean = mc_update_fn(x, measurement=measurement)\r\n\r\n                if save_progress:\r\n                    if (i % 50) == 0:\r\n                        print(f'iter: {i}/{sde.N}')\r\n                        plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')\r\n            # Final step which coerces the data fidelity error term to be zero,\r\n            # and thereby satisfying Ax = y\r\n            if final_consistency:\r\n                x, x_mean = kaczmarz(x, x, measurement, lamb=1.0, norm_const=norm_const)\r\n\r\n            return inverse_scaler(x_mean if denoise else x)\r\n\r\n    return pc_radon\r\n\r\n\r\ndef get_pc_radon_ADMM_TV_all_vol(sde, predictor, corrector, inverse_scaler, snr,\r\n                             n_steps=1, probability_flow=False, continuous=False,\r\n                             denoise=True, eps=1e-5, radon=None, save_progress=False, save_root=None,\r\n                             final_consistency=False, img_shape=None, lamb_1=5, rho=10):\r\n    \"\"\" Sparse application of measurement consistency \"\"\"\r\n    # Define predictor & corrector\r\n    predictor_update_fn = functools.partial(shared_predictor_update_fn,\r\n                                            sde=sde,\r\n                                            predictor=predictor,\r\n                                            probability_flow=probability_flow,\r\n                                            continuous=continuous)\r\n    corrector_update_fn = functools.partial(shared_corrector_update_fn,\r\n                                            sde=sde,\r\n                                            corrector=corrector,\r\n                                            continuous=continuous,\r\n                                            snr=snr,\r\n                                            n_steps=n_steps)\r\n\r\n    del_x = torch.zeros(img_shape)\r\n    del_y = torch.zeros(img_shape)\r\n    del_z = torch.zeros(img_shape)\r\n    udel_x = torch.zeros(img_shape)\r\n    udel_y = torch.zeros(img_shape)\r\n    udel_z = torch.zeros(img_shape)\r\n    eps = 1e-10\r\n\r\n    def _A(x):\r\n        return radon.A(x)\r\n\r\n    def _AT(sinogram):\r\n        return radon.AT(sinogram)\r\n\r\n    def kaczmarz(x, x_mean, measurement=None, lamb=1.0, i=None,\r\n                 norm_const=None):\r\n        x = x + lamb * _AT(measurement - _A(x)) / norm_const\r\n        x_mean = x\r\n        return x, x_mean\r\n\r\n\r\n    def A_cg(x):\r\n        return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x)))\r\n\r\n    def CG(A_fn, b_cg, x, n_inner=10):\r\n        r = b_cg - A_fn(x)\r\n        p = r\r\n        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n\r\n        for i in range(n_inner):\r\n            Ap = A_fn(p)\r\n            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)\r\n\r\n            x += a * p\r\n            r -= a * Ap\r\n\r\n            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n            if torch.sqrt(rs_new) < eps:\r\n                break\r\n            p = r + (rs_new / rs_old) * p\r\n            rs_old = rs_new\r\n        return x\r\n\r\n    def CS_routine(x, ATy, niter=20):\r\n        nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z\r\n        if del_z.device != x.device:\r\n            del_x = del_x.to(x.device)\r\n            del_y = del_y.to(x.device)\r\n            del_z = del_z.to(x.device)\r\n            udel_x = udel_x.to(x.device)\r\n            udel_y = udel_y.to(x.device)\r\n            udel_z = udel_z.to(x.device)\r\n        for i in range(niter):\r\n            b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x))\r\n                                + (_DyT(del_y) - _DyT(udel_y))\r\n                                + (_DzT(del_z) - _DzT(udel_z)))\r\n            x = CG(A_cg, b_cg, x, n_inner=1)\r\n\r\n            del_x = shrink(_Dx(x) + udel_x, lamb_1 / rho)\r\n            del_y = shrink(_Dy(x) + udel_y, lamb_1 / rho)\r\n            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)\r\n            udel_x = _Dx(x) - del_x + udel_x\r\n            udel_y = _Dy(x) - del_y + udel_y\r\n            udel_z = _Dz(x) - del_z + udel_z\r\n        x_mean = x\r\n        return x, x_mean\r\n\r\n    def get_update_fn(update_fn):\r\n        def radon_update_fn(model, data, x, t):\r\n            with torch.no_grad():\r\n                vec_t = torch.ones(x.shape[0], device=x.device) * t\r\n                x, x_mean = update_fn(x, vec_t, model=model)\r\n                return x, x_mean\r\n\r\n        return radon_update_fn\r\n\r\n    def get_ADMM_TV_fn():\r\n        def ADMM_TV_fn(x, measurement=None):\r\n            with torch.no_grad():\r\n                ATy = _AT(measurement)\r\n                x, x_mean = CS_routine(x, ATy, niter=1)\r\n                return x, x_mean\r\n        return ADMM_TV_fn\r\n\r\n    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)\r\n    corrector_denoise_update_fn = get_update_fn(corrector_update_fn)\r\n    mc_update_fn = get_ADMM_TV_fn()\r\n\r\n    def pc_radon(model, data, measurement=None):\r\n        with torch.no_grad():\r\n            x = sde.prior_sampling(data.shape).to(data.device)\r\n\r\n            ones = torch.ones_like(x).to(data.device)\r\n            norm_const = _AT(_A(ones))\r\n            timesteps = torch.linspace(sde.T, eps, sde.N)\r\n            for i in tqdm(range(sde.N)):\r\n                t = timesteps[i]\r\n                # 1. batchify into sizes that fit into the GPU\r\n                x_batch = batchfy(x, 12)\r\n                # 2. Run PC step for each batch\r\n                x_agg = list()\r\n                for idx, x_batch_sing in enumerate(x_batch):\r\n                    x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t)\r\n                    x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t)\r\n                    x_agg.append(x_batch_sing)\r\n                # 3. Aggregate to run ADMM TV\r\n                x = torch.cat(x_agg, dim=0)\r\n                # 4. Run ADMM TV\r\n                x, x_mean = mc_update_fn(x, measurement=measurement)\r\n\r\n                if save_progress:\r\n                    if (i % 50) == 0:\r\n                        print(f'iter: {i}/{sde.N}')\r\n                        plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')\r\n            # Final step which coerces the data fidelity error term to be zero,\r\n            # and thereby satisfying Ax = y\r\n            if final_consistency:\r\n                x, x_mean = kaczmarz(x, x, measurement, lamb=1.0, norm_const=norm_const)\r\n\r\n            return inverse_scaler(x_mean if denoise else x)\r\n\r\n    return pc_radon\r\n\r\n\r\n\r\ndef get_ADMM_TV(eps=1e-5, radon=None, save_progress=False, save_root=None,\r\n                img_shape=None, lamb_1=5, rho=10, outer_iter=30, inner_iter=20):\r\n\r\n    del_x = torch.zeros(img_shape)\r\n    del_y = torch.zeros(img_shape)\r\n    del_z = torch.zeros(img_shape)\r\n    udel_x = torch.zeros(img_shape)\r\n    udel_y = torch.zeros(img_shape)\r\n    udel_z = torch.zeros(img_shape)\r\n    eps = 1e-10\r\n\r\n    def _A(x):\r\n        return radon.A(x)\r\n\r\n    def _AT(sinogram):\r\n        return radon.AT(sinogram)\r\n\r\n    def A_cg(x):\r\n        return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x)))\r\n\r\n    def CG(A_fn, b_cg, x, n_inner=20):\r\n        r = b_cg - A_fn(x)\r\n        p = r\r\n        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n\r\n        for i in range(n_inner):\r\n            Ap = A_fn(p)\r\n            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)\r\n\r\n            x += a * p\r\n            r -= a * Ap\r\n\r\n            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n            if torch.sqrt(rs_new) < eps:\r\n                break\r\n            p = r + (rs_new / rs_old) * p\r\n            rs_old = rs_new\r\n        return x\r\n\r\n    def CS_routine(x, ATy, niter=30):\r\n        nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z\r\n        if del_z.device != x.device:\r\n            del_x = del_x.to(x.device)\r\n            del_y = del_y.to(x.device)\r\n            del_z = del_z.to(x.device)\r\n            udel_x = udel_x.to(x.device)\r\n            udel_y = udel_y.to(x.device)\r\n            udel_z = udel_z.to(x.device)\r\n        for i in tqdm(range(niter)):\r\n            b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x))\r\n                                + (_DyT(del_y) - _DyT(udel_y))\r\n                                + (_DzT(del_z) - _DzT(udel_z)))\r\n            x = CG(A_cg, b_cg, x, n_inner=inner_iter)\r\n            if save_progress:\r\n                plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x[0:1]), cmap='gray')\r\n\r\n            del_x = shrink(_Dx(x) + udel_x, lamb_1 / rho)\r\n            del_y = shrink(_Dy(x) + udel_y, lamb_1 / rho)\r\n            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)\r\n            udel_x = _Dx(x) - del_x + udel_x\r\n            udel_y = _Dy(x) - del_y + udel_y\r\n            udel_z = _Dz(x) - del_z + udel_z\r\n        return x\r\n\r\n    def get_ADMM_TV_fn():\r\n        def ADMM_TV_fn(x, measurement=None):\r\n            with torch.no_grad():\r\n                ATy = _AT(measurement)\r\n                x, x_mean = CS_routine(x, ATy, niter=outer_iter)\r\n                return x, x_mean\r\n        return ADMM_TV_fn\r\n\r\n    mc_update_fn = get_ADMM_TV_fn()\r\n\r\n    def ADMM_TV(data, measurement=None):\r\n        with torch.no_grad():\r\n            x = torch.zeros(data.shape).to(data.device)\r\n            x = mc_update_fn(x, measurement=measurement)\r\n            return x\r\n\r\n    return ADMM_TV\r\n\r\n\r\ndef get_ADMM_TV_isotropic(eps=1e-5, radon=None, save_progress=False, save_root=None,\r\n                          img_shape=None, lamb_1=5, rho=10, outer_iter=30, inner_iter=20):\r\n    \"\"\"\r\n    (get_ADMM_TV): implements anisotropic TV-ADMM\r\n    In contrast, this function implements isotropic TV, which regularizes with |TV|_{1,2}\r\n    \"\"\"\r\n    del_x = torch.zeros(img_shape)\r\n    del_y = torch.zeros(img_shape)\r\n    del_z = torch.zeros(img_shape)\r\n    udel_x = torch.zeros(img_shape)\r\n    udel_y = torch.zeros(img_shape)\r\n    udel_z = torch.zeros(img_shape)\r\n    eps = 1e-10\r\n\r\n    def _A(x):\r\n        return radon.A(x)\r\n\r\n    def _AT(sinogram):\r\n        return radon.AT(sinogram)\r\n\r\n    def A_cg(x):\r\n        return _AT(_A(x)) + rho * (_DxT(_Dx(x)) + _DyT(_Dy(x)) + _DzT(_Dz(x)))\r\n\r\n    \r\n    def CG(A_fn, b_cg, x, n_inner=20):\r\n        r = b_cg - A_fn(x)\r\n        p = r\r\n        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n\r\n        for i in range(n_inner):\r\n            Ap = A_fn(p)\r\n            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)\r\n\r\n            x += a * p\r\n            r -= a * Ap\r\n\r\n            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n            if torch.sqrt(rs_new) < eps:\r\n                break\r\n            p = r + (rs_new / rs_old) * p\r\n            rs_old = rs_new\r\n        return x\r\n\r\n    def CS_routine(x, ATy, niter=30):\r\n        nonlocal del_x, del_y, del_z, udel_x, udel_y, udel_z\r\n        if del_z.device != x.device:\r\n            del_x = del_x.to(x.device)\r\n            del_y = del_y.to(x.device)\r\n            del_z = del_z.to(x.device)\r\n            udel_x = udel_x.to(x.device)\r\n            udel_y = udel_y.to(x.device)\r\n            udel_z = udel_z.to(x.device)\r\n        for i in tqdm(range(niter)):\r\n            b_cg = ATy + rho * ((_DxT(del_x) - _DxT(udel_x))\r\n                                + (_DyT(del_y) - _DyT(udel_y))\r\n                                + (_DzT(del_z) - _DzT(udel_z)))\r\n            x = CG(A_cg, b_cg, x, n_inner=inner_iter)\r\n            if save_progress:\r\n                plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x[0:1]), cmap='gray')\r\n\r\n            # Each of shape [448, 1, 256, 256]\r\n            _Dxx = _Dx(x)\r\n            _Dyx = _Dy(x)\r\n            _Dzx = _Dz(x)\r\n            # shape [448, 3, 256, 256]. dim=1 gradient dimension\r\n            _Dxa = torch.cat((_Dxx, _Dyx, _Dzx), dim=1)\r\n            udel_a = torch.cat((udel_x, udel_y, udel_z), dim=1)\r\n\r\n            # prox\r\n            del_a = prox_l21(_Dxa + udel_a, lamb_1 / rho, dim=1)\r\n\r\n            # split\r\n            del_x, del_y, del_z = torch.split(del_a, 1, dim=1)\r\n\r\n            # del_x = prox_l21(_Dxx + udel_x, lamb_1 / rho, -2)\r\n            # del_y = prox_l21(_Dyx + udel_y, lamb_1 / rho, -1)\r\n            # del_z = prox_l21(_Dzx + udel_z, lamb_1 / rho, 0)\r\n\r\n            udel_x = _Dxx - del_x + udel_x\r\n            udel_y = _Dyx - del_y + udel_y\r\n            udel_z = _Dzx - del_z + udel_z\r\n        return x\r\n\r\n    def get_ADMM_TV_fn():\r\n        def ADMM_TV_fn(x, measurement=None):\r\n            with torch.no_grad():\r\n                ATy = _AT(measurement)\r\n                x = CS_routine(x, ATy, niter=outer_iter)\r\n                return x\r\n        return ADMM_TV_fn\r\n\r\n    mc_update_fn = get_ADMM_TV_fn()\r\n\r\n    def ADMM_TV(data, measurement=None):\r\n        with torch.no_grad():\r\n            x = torch.zeros(data.shape).to(data.device)\r\n            x = mc_update_fn(x, measurement=measurement)\r\n            return x\r\n\r\n    return ADMM_TV\r\n\r\ndef prox_l21(src, lamb, dim):\r\n    \"\"\"\r\n    src.shape = [448(z), 1, 256(x), 256(y)]\r\n    \"\"\"\r\n    weight_src = torch.linalg.norm(src, dim=dim, keepdim=True)\r\n    weight_src_shrink = shrink(weight_src, lamb)\r\n\r\n    weight = weight_src_shrink / weight_src\r\n    return src * weight\r\n\r\n\r\ndef shrink(weight_src, lamb):\r\n    return torch.sign(weight_src) * torch.max(torch.abs(weight_src) - lamb, torch.zeros_like(weight_src))\r\n\r\n\r\ndef get_pc_radon_ADMM_TV_mri(sde, predictor, corrector, inverse_scaler, snr, mask=None,\r\n                             n_steps=1, probability_flow=False, continuous=False,\r\n                             denoise=True, eps=1e-5, save_progress=False, save_root=None,\r\n                             img_shape=None, lamb_1=5, rho=10):\r\n    predictor_update_fn = functools.partial(shared_predictor_update_fn,\r\n                                            sde=sde,\r\n                                            predictor=predictor,\r\n                                            probability_flow=probability_flow,\r\n                                            continuous=continuous)\r\n    corrector_update_fn = functools.partial(shared_corrector_update_fn,\r\n                                            sde=sde,\r\n                                            corrector=corrector,\r\n                                            continuous=continuous,\r\n                                            snr=snr,\r\n                                            n_steps=n_steps)\r\n\r\n    del_z = torch.zeros(img_shape)\r\n    udel_z = torch.zeros(img_shape)\r\n    eps = 1e-10\r\n\r\n    def _A(x):\r\n        return fft2(x) * mask\r\n\r\n    def _AT(kspace):\r\n        return torch.real(ifft2(kspace))\r\n\r\n    def _Dz(x):  # Batch direction\r\n        y = torch.zeros_like(x)\r\n        y[:-1] = x[1:]\r\n        y[-1] = x[0]\r\n        return y - x\r\n\r\n    def _DzT(x):  # Batch direction\r\n        y = torch.zeros_like(x)\r\n        y[:-1] = x[1:]\r\n        y[-1] = x[0]\r\n\r\n        tempt = -(y - x)\r\n        difft = tempt[:-1]\r\n        y[1:] = difft\r\n        y[0] = x[-1] - x[0]\r\n\r\n        return y\r\n\r\n    def A_cg(x):\r\n        return _AT(_A(x)) + rho * _DzT(_Dz(x))\r\n\r\n    def shrink(src, lamb):\r\n        return torch.sign(src) * torch.max(torch.abs(src) - lamb, torch.zeros_like(src))\r\n\r\n    def CG(A_fn, b_cg, x, n_inner=10):\r\n        r = b_cg - A_fn(x)\r\n        p = r\r\n        rs_old = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n\r\n        for i in range(n_inner):\r\n            Ap = A_fn(p)\r\n            a = rs_old / torch.matmul(p.view(1, -1), Ap.view(1, -1).T)\r\n\r\n            x += a * p\r\n            r -= a * Ap\r\n\r\n            rs_new = torch.matmul(r.view(1, -1), r.view(1, -1).T)\r\n            if torch.sqrt(rs_new) < eps:\r\n                break\r\n            p = r + (rs_new / rs_old) * p\r\n            rs_old = rs_new\r\n        return x\r\n\r\n    def CS_routine(x, ATy, niter=20):\r\n        nonlocal del_z, udel_z\r\n        if del_z.device != x.device:\r\n            del_z = del_z.to(x.device)\r\n            udel_z = del_z.to(x.device)\r\n        for i in range(niter):\r\n            b_cg = ATy + rho * (_DzT(del_z) - _DzT(udel_z))\r\n            x = CG(A_cg, b_cg, x, n_inner=1)\r\n\r\n            del_z = shrink(_Dz(x) + udel_z, lamb_1 / rho)\r\n            udel_z = _Dz(x) - del_z + udel_z\r\n        x_mean = x\r\n        return x, x_mean\r\n\r\n    def get_update_fn(update_fn):\r\n        def radon_update_fn(model, data, x, t):\r\n            with torch.no_grad():\r\n                vec_t = torch.ones(x.shape[0], device=x.device) * t\r\n                x, x_mean = update_fn(x, vec_t, model=model)\r\n                return x, x_mean\r\n\r\n        return radon_update_fn\r\n\r\n    def get_ADMM_TV_fn():\r\n        def ADMM_TV_fn(x, measurement=None):\r\n            with torch.no_grad():\r\n                ATy = _AT(measurement)\r\n                x, x_mean = CS_routine(x, ATy, niter=1)\r\n                return x, x_mean\r\n        return ADMM_TV_fn\r\n\r\n    predictor_denoise_update_fn = get_update_fn(predictor_update_fn)\r\n    corrector_denoise_update_fn = get_update_fn(corrector_update_fn)\r\n    mc_update_fn = get_ADMM_TV_fn()\r\n\r\n    def pc_radon(model, data, measurement=None):\r\n        with torch.no_grad():\r\n            x = sde.prior_sampling(data.shape).to(data.device)\r\n            timesteps = torch.linspace(sde.T, eps, sde.N)\r\n            for i in tqdm(range(sde.N)):\r\n                t = timesteps[i]\r\n                # 1. batchify into sizes that fit into the GPU\r\n                x_batch = batchfy(x, 20)\r\n                # 2. Run PC step for each batch\r\n                x_agg = list()\r\n                for idx, x_batch_sing in enumerate(x_batch):\r\n                    x_batch_sing, _ = predictor_denoise_update_fn(model, data, x_batch_sing, t)\r\n                    x_batch_sing, _ = corrector_denoise_update_fn(model, data, x_batch_sing, t)\r\n                    x_agg.append(x_batch_sing)\r\n                # 3. Aggregate to run ADMM TV\r\n                x = torch.cat(x_agg, dim=0)\r\n                # 4. Run ADMM TV\r\n                x, x_mean = mc_update_fn(x, measurement=measurement)\r\n\r\n                if save_progress:\r\n                    if (i % 50) == 0:\r\n                        print(f'iter: {i}/{sde.N}')\r\n                        plt.imsave(save_root / 'recon' / 'progress' / f'progress{i}.png', clear(x_mean[0:1]), cmap='gray')\r\n\r\n            return inverse_scaler(x_mean if denoise else x)\r\n\r\n    return pc_radon"
  },
  {
    "path": "datasets.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"Return training and evaluation/test datasets from config files.\"\"\"\nfrom torch.utils.data import Dataset, DataLoader\nimport numpy as np\n\n\ndef get_data_scaler(config):\n  \"\"\"Data normalizer. Assume data are always in [0, 1].\"\"\"\n  if config.data.centered:\n    # Rescale to [-1, 1]\n    return lambda x: x * 2. - 1.\n  else:\n    return lambda x: x\n\n\ndef get_data_inverse_scaler(config):\n  \"\"\"Inverse data normalizer.\"\"\"\n  if config.data.centered:\n    # Rescale [-1, 1] to [0, 1]\n    return lambda x: (x + 1.) / 2.\n  else:\n    return lambda x: x\n\n\ndef crop_resize(image, resolution):\n  \"\"\"Crop and resize an image to the given resolution.\"\"\"\n  crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1])\n  h, w = tf.shape(image)[0], tf.shape(image)[1]\n  image = image[(h - crop) // 2:(h + crop) // 2,\n          (w - crop) // 2:(w + crop) // 2]\n  image = tf.image.resize(\n    image,\n    size=(resolution, resolution),\n    antialias=True,\n    method=tf.image.ResizeMethod.BICUBIC)\n  return tf.cast(image, tf.uint8)\n\n\ndef resize_small(image, resolution):\n  \"\"\"Shrink an image to the given resolution.\"\"\"\n  h, w = image.shape[0], image.shape[1]\n  ratio = resolution / min(h, w)\n  h = tf.round(h * ratio, tf.int32)\n  w = tf.round(w * ratio, tf.int32)\n  return tf.image.resize(image, [h, w], antialias=True)\n\n\ndef central_crop(image, size):\n  \"\"\"Crop the center of an image to the given size.\"\"\"\n  top = (image.shape[0] - size) // 2\n  left = (image.shape[1] - size) // 2\n  return tf.image.crop_to_bounding_box(image, top, left, size, size)\n\n\ndef get_dataset(config, uniform_dequantization=False, evaluation=False):\n  \"\"\"Create data loaders for training and evaluation.\n\n  Args:\n    config: A ml_collection.ConfigDict parsed from config files.\n    uniform_dequantization: If `True`, add uniform dequantization to images.\n    evaluation: If `True`, fix number of epochs to 1.\n\n  Returns:\n    train_ds, eval_ds, dataset_builder.\n  \"\"\"\n  # Compute batch size for this worker.\n  batch_size = config.training.batch_size if not evaluation else config.eval.batch_size\n  if batch_size % jax.device_count() != 0:\n    raise ValueError(f'Batch sizes ({batch_size} must be divided by'\n                     f'the number of devices ({jax.device_count()})')\n\n  # Reduce this when image resolution is too large and data pointer is stored\n  shuffle_buffer_size = 10000\n  prefetch_size = tf.data.experimental.AUTOTUNE\n  num_epochs = None if not evaluation else 1\n\n  # Create dataset builders for each dataset.\n  if config.data.dataset == 'CIFAR10':\n    dataset_builder = tfds.builder('cifar10')\n    train_split_name = 'train'\n    eval_split_name = 'test'\n\n    def resize_op(img):\n      img = tf.image.convert_image_dtype(img, tf.float32)\n      # Added to train grayscale models\n      # img = tf.image.rgb_to_grayscale(img)\n      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)\n\n\n  elif config.data.dataset == 'SVHN':\n    dataset_builder = tfds.builder('svhn_cropped')\n    train_split_name = 'train'\n    eval_split_name = 'test'\n\n    def resize_op(img):\n      img = tf.image.convert_image_dtype(img, tf.float32)\n      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)\n\n  elif config.data.dataset == 'CELEBA':\n    dataset_builder = tfds.builder('celeb_a')\n    train_split_name = 'train'\n    eval_split_name = 'validation'\n\n    def resize_op(img):\n      img = tf.image.convert_image_dtype(img, tf.float32)\n      img = central_crop(img, 140)\n      img = resize_small(img, config.data.image_size)\n      return img\n\n  elif config.data.dataset == 'LSUN':\n    dataset_builder = tfds.builder(f'lsun/{config.data.category}')\n    train_split_name = 'train'\n    eval_split_name = 'validation'\n\n    if config.data.image_size == 128:\n      def resize_op(img):\n        img = tf.image.convert_image_dtype(img, tf.float32)\n        img = resize_small(img, config.data.image_size)\n        img = central_crop(img, config.data.image_size)\n        return img\n\n    else:\n      def resize_op(img):\n        img = crop_resize(img, config.data.image_size)\n        img = tf.image.convert_image_dtype(img, tf.float32)\n        return img\n\n  elif config.data.dataset in ['FFHQ', 'CelebAHQ']:\n    dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path)\n    train_split_name = eval_split_name = 'train'\n\n  else:\n    raise NotImplementedError(\n      f'Dataset {config.data.dataset} not yet supported.')\n\n  # Customize preprocess functions for each dataset.\n  if config.data.dataset in ['FFHQ', 'CelebAHQ']:\n    def preprocess_fn(d):\n      sample = tf.io.parse_single_example(d, features={\n        'shape': tf.io.FixedLenFeature([3], tf.int64),\n        'data': tf.io.FixedLenFeature([], tf.string)})\n      data = tf.io.decode_raw(sample['data'], tf.uint8)\n      data = tf.reshape(data, sample['shape'])\n      data = tf.transpose(data, (1, 2, 0))\n      img = tf.image.convert_image_dtype(data, tf.float32)\n      if config.data.random_flip and not evaluation:\n        img = tf.image.random_flip_left_right(img)\n      if uniform_dequantization:\n        img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.\n      return dict(image=img, label=None)\n\n  else:\n    def preprocess_fn(d):\n      \"\"\"Basic preprocessing function scales data to [0, 1) and randomly flips.\"\"\"\n      img = resize_op(d['image'])\n      if config.data.random_flip and not evaluation:\n        img = tf.image.random_flip_left_right(img)\n      if uniform_dequantization:\n        img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.\n\n      return dict(image=img, label=d.get('label', None))\n\n  def create_dataset(dataset_builder, split):\n    dataset_options = tf.data.Options()\n    dataset_options.experimental_optimization.map_parallelization = True\n    dataset_options.experimental_threading.private_threadpool_size = 48\n    dataset_options.experimental_threading.max_intra_op_parallelism = 1\n    read_config = tfds.ReadConfig(options=dataset_options)\n    if isinstance(dataset_builder, tfds.core.DatasetBuilder):\n      dataset_builder.download_and_prepare()\n      ds = dataset_builder.as_dataset(\n        split=split, shuffle_files=True, read_config=read_config)\n    else:\n      ds = dataset_builder.with_options(dataset_options)\n    ds = ds.repeat(count=num_epochs)\n    ds = ds.shuffle(shuffle_buffer_size)\n    ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n    ds = ds.batch(batch_size, drop_remainder=True)\n    return ds.prefetch(prefetch_size)\n\n  train_ds = create_dataset(dataset_builder, train_split_name)\n  eval_ds = create_dataset(dataset_builder, eval_split_name)\n  return train_ds, eval_ds, dataset_builder\n\n\nfrom pathlib import Path\n\nclass fastmri_knee(Dataset):\n  \"\"\" Simple pytorch dataset for fastmri knee singlecoil dataset \"\"\"\n  def __init__(self, root, is_complex=False):\n    self.root = root\n    self.data_list = list(root.glob('*/*.npy'))\n    self.is_complex = is_complex\n\n  def __len__(self):\n    return len(self.data_list)\n\n  def __getitem__(self, idx):\n    fname = self.data_list[idx]\n    if not self.is_complex:\n      data = np.load(fname)\n    else:\n      data = np.load(fname).astype(np.complex64)\n    data = np.expand_dims(data, axis=0)\n    return data\n\n\nclass AAPM(Dataset):\n  def __init__(self, root, sort):\n    self.root = root\n    self.data_list = list(root.glob('full_dose/*.npy'))\n    self.sort = sort\n    if sort:\n      self.data_list = sorted(self.data_list)\n\n  def __len__(self):\n    return len(self.data_list)\n\n  def __getitem__(self, idx):\n    fname = self.data_list[idx]\n    data = np.load(fname)\n    data = np.expand_dims(data, axis=0)\n    return data\n\n\nclass Object5(Dataset):\n  def __init__(self, root, slice, fast=False):\n    \"\"\"\n    slice - range of the 2000 _volumes_ that you want,\n    but the dataset will return images, so will be 256 times longer\n\n    fast - set to true to get a tiny version of the dataset\n    \"\"\"\n    if fast:\n      self.NUM_SLICES = 10\n    else:\n      self.NUM_SLICES = 256\n\n\n    self.root = root\n    self.data_list = list(root.glob('*.npz'))\n\n    if len(self.data_list) == 0:\n      raise ValueError(f\"No npz files found in {root}\")\n\n    self.data_list = sorted(self.data_list)[slice]\n\n  def __len__(self):\n    return len(self.data_list) * self.NUM_SLICES\n\n  def __getitem__(self, idx):\n    vol_index = idx // self.NUM_SLICES\n    slice_index = idx % self.NUM_SLICES\n    fname = self.data_list[vol_index]\n    data = np.load(fname)['x'][slice_index]\n    data = np.expand_dims(data, axis=0)\n    return data\n\nclass fastmri_knee_infer(Dataset):\n  \"\"\" Simple pytorch dataset for fastmri knee singlecoil dataset \"\"\"\n  def __init__(self, root, sort=True, is_complex=False):\n    self.root = root\n    self.data_list = list(root.glob('*/*.npy'))\n    self.is_complex = is_complex\n    if sort:\n      self.data_list = sorted(self.data_list)\n\n  def __len__(self):\n    return len(self.data_list)\n\n  def __getitem__(self, idx):\n    fname = self.data_list[idx]\n    if not self.is_complex:\n      data = np.load(fname)\n    else:\n      data = np.load(fname).astype(np.complex64)\n    data = np.expand_dims(data, axis=0)\n    return data, str(fname)\n\n\nclass fastmri_knee_magpha(Dataset):\n  \"\"\" Simple pytorch dataset for fastmri knee singlecoil dataset \"\"\"\n  def __init__(self, root):\n    self.root = root\n    self.data_list = list(root.glob('*/*.npy'))\n\n  def __len__(self):\n    return len(self.data_list)\n\n  def __getitem__(self, idx):\n    fname = self.data_list[idx]\n    data = np.load(fname).astype(np.float32)\n    return data\n\n\nclass fastmri_knee_magpha_infer(Dataset):\n  \"\"\" Simple pytorch dataset for fastmri knee singlecoil dataset \"\"\"\n  def __init__(self, root, sort=True):\n    self.root = root\n    self.data_list = list(root.glob('*/*.npy'))\n    if sort:\n      self.data_list = sorted(self.data_list)\n\n  def __len__(self):\n    return len(self.data_list)\n\n  def __getitem__(self, idx):\n    fname = self.data_list[idx]\n    data = np.load(fname).astype(np.float32)\n    return data, str(fname)\n\n\ndef create_dataloader(configs, evaluation=False, sort=True):\n  shuffle = True if not evaluation else False\n  if configs.data.dataset == 'Object5':\n    train_dataset = Object5(Path(configs.data.root), slice(None,1800))  \n    val_dataset = Object5(Path(configs.data.root), slice(1800,None)) \n  elif configs.data.dataset == 'Object5Fast':\n    train_dataset = Object5(Path(configs.data.root), slice(None,1), fast=True)\n    val_dataset = Object5(Path(configs.data.root), slice(1,2), fast=True)\n  elif configs.data.dataset == 'AAPM':\n    train_dataset = AAPM(Path(configs.data.root) / f'train', sort=False)\n    val_dataset = AAPM(Path(configs.data.root) / f'test', sort=True)\n  elif configs.data.is_multi:\n    train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_multicoil_{configs.data.image_size}_train')\n    val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_{configs.data.image_size}_val', sort=sort)\n  elif configs.data.is_complex:\n    if configs.data.magpha:\n      train_dataset = fastmri_knee_magpha(Path(configs.data.root) / f'knee_complex_magpha_{configs.data.image_size}_train')\n      val_dataset = fastmri_knee_magpha_infer(Path(configs.data.root) / f'knee_complex_magpha_{configs.data.image_size}_val')\n    else:\n      train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_complex_{configs.data.image_size}_train', is_complex=True)\n      val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_complex_{configs.data.image_size}_val', is_complex=True)\n  elif configs.data.dataset == 'fastmri_knee':\n    train_dataset = fastmri_knee(Path(configs.data.root) / f'knee_{configs.data.image_size}_train')\n    val_dataset = fastmri_knee_infer(Path(configs.data.root) / f'knee_{configs.data.image_size}_val', sort=sort)\n  else:\n    raise ValueError(f'Dataset {configs.data.dataset} not recognized.')\n\n  train_loader = DataLoader(\n    dataset=train_dataset,\n    batch_size=configs.training.batch_size,\n    shuffle=shuffle,\n    drop_last=True\n  )\n  val_loader = DataLoader(\n    dataset=val_dataset,\n    batch_size=configs.training.batch_size,\n    # shuffle=False,\n    shuffle=True,\n    drop_last=True\n  )\n  return train_loader, val_loader\n\n\n\ndef create_dataloader_regression(configs, evaluation=False):\n  shuffle = True if not evaluation else False\n  train_dataset = fastmri_knee(Path(configs.root) / f'knee_{configs.image_size}_train')\n  val_dataset = fastmri_knee_infer(Path(configs.root) / f'knee_{configs.image_size}_val')\n\n  train_loader = DataLoader(\n    dataset=train_dataset,\n    batch_size=configs.batch_size,\n    shuffle=shuffle,\n    drop_last=True\n  )\n  val_loader = DataLoader(\n    dataset=val_dataset,\n    batch_size=configs.batch_size,\n    shuffle=False,\n    drop_last=True\n  )\n  return train_loader, val_loader\n"
  },
  {
    "path": "environment.yml",
    "content": "name: diffusion-mbir\nchannels:\n  - conda-forge\n  - defaults\ndependencies:\n  - python=3.8\n  - numpy\n  - matplotlib\n  - scikit-image\n  - sporco\n  - tqdm\n  - ninja\n  - pytorch::pytorch\n  - pytorch::torchvision\n  - tensorboard\n  - pip\n  - pip:\n      - ml_collections\n      - ninja\n"
  },
  {
    "path": "evaluation.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Utility functions for computing FID/Inception scores.\"\"\"\n\nimport numpy as np\nimport six\n\nINCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'\nINCEPTION_OUTPUT = 'logits'\nINCEPTION_FINAL_POOL = 'pool_3'\n_DEFAULT_DTYPES = {\n  INCEPTION_OUTPUT: tf.float32,\n  INCEPTION_FINAL_POOL: tf.float32\n}\nINCEPTION_DEFAULT_IMAGE_SIZE = 299\n\n\ndef get_inception_model(inceptionv3=False):\n  if inceptionv3:\n    return tfhub.load(\n      'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4')\n  else:\n    return tfhub.load(INCEPTION_TFHUB)\n\n\ndef load_dataset_stats(config):\n  \"\"\"Load the pre-computed dataset statistics.\"\"\"\n  if config.data.dataset == 'CIFAR10':\n    filename = 'assets/stats/cifar10_stats.npz'\n  elif config.data.dataset == 'CELEBA':\n    filename = 'assets/stats/celeba_stats.npz'\n  elif config.data.dataset == 'LSUN':\n    filename = f'assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz'\n  else:\n    raise ValueError(f'Dataset {config.data.dataset} stats not found.')\n\n  with tf.io.gfile.GFile(filename, 'rb') as fin:\n    stats = np.load(fin)\n    return stats\n\n\ndef classifier_fn_from_tfhub(output_fields, inception_model,\n                             return_tensor=False):\n  \"\"\"Returns a function that can be as a classifier function.\n\n  Copied from tfgan but avoid loading the model each time calling _classifier_fn\n\n  Args:\n    output_fields: A string, list, or `None`. If present, assume the module\n      outputs a dictionary, and select this field.\n    inception_model: A model loaded from TFHub.\n    return_tensor: If `True`, return a single tensor instead of a dictionary.\n\n  Returns:\n    A one-argument function that takes an image Tensor and returns outputs.\n  \"\"\"\n  if isinstance(output_fields, six.string_types):\n    output_fields = [output_fields]\n\n  def _classifier_fn(images):\n    output = inception_model(images)\n    if output_fields is not None:\n      output = {x: output[x] for x in output_fields}\n    if return_tensor:\n      assert len(output) == 1\n      output = list(output.values())[0]\n    return tf.nest.map_structure(tf.compat.v1.layers.flatten, output)\n\n  return _classifier_fn\n\n\n@tf.function\ndef run_inception_jit(inputs,\n                      inception_model,\n                      num_batches=1,\n                      inceptionv3=False):\n  \"\"\"Running the inception network. Assuming input is within [0, 255].\"\"\"\n  if not inceptionv3:\n    inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5\n  else:\n    inputs = tf.cast(inputs, tf.float32) / 255.\n\n  return tfgan.eval.run_classifier_fn(\n    inputs,\n    num_batches=num_batches,\n    classifier_fn=classifier_fn_from_tfhub(None, inception_model),\n    dtypes=_DEFAULT_DTYPES)\n\n\n@tf.function\ndef run_inception_distributed(input_tensor,\n                              inception_model,\n                              num_batches=1,\n                              inceptionv3=False):\n  \"\"\"Distribute the inception network computation to all available TPUs.\n\n  Args:\n    input_tensor: The input images. Assumed to be within [0, 255].\n    inception_model: The inception network model obtained from `tfhub`.\n    num_batches: The number of batches used for dividing the input.\n    inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1.\n\n  Returns:\n    A dictionary with key `pool_3` and `logits`, representing the pool_3 and\n      logits of the inception network respectively.\n  \"\"\"\n  num_tpus = jax.local_device_count()\n  input_tensors = tf.split(input_tensor, num_tpus, axis=0)\n  pool3 = []\n  logits = [] if not inceptionv3 else None\n  device_format = '/TPU:{}' if 'TPU' in str(jax.devices()[0]) else '/GPU:{}'\n  for i, tensor in enumerate(input_tensors):\n    with tf.device(device_format.format(i)):\n      tensor_on_device = tf.identity(tensor)\n      res = run_inception_jit(\n        tensor_on_device, inception_model, num_batches=num_batches,\n        inceptionv3=inceptionv3)\n\n      if not inceptionv3:\n        pool3.append(res['pool_3'])\n        logits.append(res['logits'])  # pytype: disable=attribute-error\n      else:\n        pool3.append(res)\n\n  with tf.device('/CPU'):\n    return {\n      'pool_3': tf.concat(pool3, axis=0),\n      'logits': tf.concat(logits, axis=0) if not inceptionv3 else None\n    }\n"
  },
  {
    "path": "fastmri_utils.py",
    "content": "\"\"\"\nCopyright (c) Facebook, Inc. and its affiliates.\nThis source code is licensed under the MIT license found in the\nLICENSE file in the root directory of this source tree.\n\"\"\"\n\nfrom typing import List, Optional\n\nimport torch\nfrom packaging import version\n\nif version.parse(torch.__version__) >= version.parse(\"1.7.0\"):\n    import torch.fft  # type: ignore\n\n\ndef fft2c_old(data: torch.Tensor, norm: str = \"ortho\") -> torch.Tensor:\n    \"\"\"\n    Apply centered 2 dimensional Fast Fourier Transform.\n    Args:\n        data: Complex valued input data containing at least 3 dimensions:\n            dimensions -3 & -2 are spatial dimensions and dimension -1 has size\n            2. All other dimensions are assumed to be batch dimensions.\n        norm: Whether to include normalization. Must be one of ``\"backward\"``\n            or ``\"ortho\"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details.\n    Returns:\n        The FFT of the input.\n    \"\"\"\n    if not data.shape[-1] == 2:\n        raise ValueError(\"Tensor does not have separate complex dim.\")\n    if norm not in (\"ortho\", \"backward\"):\n        raise ValueError(\"norm must be 'ortho' or 'backward'.\")\n    normalized = True if norm == \"ortho\" else False\n\n    data = ifftshift(data, dim=[-3, -2])\n    data = torch.fft(data, 2, normalized=normalized)\n    data = fftshift(data, dim=[-3, -2])\n\n    return data\n\n\ndef ifft2c_old(data: torch.Tensor, norm: str = \"ortho\") -> torch.Tensor:\n    \"\"\"\n    Apply centered 2-dimensional Inverse Fast Fourier Transform.\n    Args:\n        data: Complex valued input data containing at least 3 dimensions:\n            dimensions -3 & -2 are spatial dimensions and dimension -1 has size\n            2. All other dimensions are assumed to be batch dimensions.\n        norm: Whether to include normalization. Must be one of ``\"backward\"``\n            or ``\"ortho\"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for\n            details.\n    Returns:\n        The IFFT of the input.\n    \"\"\"\n    if not data.shape[-1] == 2:\n        raise ValueError(\"Tensor does not have separate complex dim.\")\n    if norm not in (\"ortho\", \"backward\"):\n        raise ValueError(\"norm must be 'ortho' or 'backward'.\")\n    normalized = True if norm == \"ortho\" else False\n\n    data = ifftshift(data, dim=[-3, -2])\n    data = torch.ifft(data, 2, normalized=normalized)\n    data = fftshift(data, dim=[-3, -2])\n\n    return data\n\n\ndef fft2c_new(data: torch.Tensor, norm: str = \"ortho\") -> torch.Tensor:\n    \"\"\"\n    Apply centered 2 dimensional Fast Fourier Transform.\n    Args:\n        data: Complex valued input data containing at least 3 dimensions:\n            dimensions -3 & -2 are spatial dimensions and dimension -1 has size\n            2. All other dimensions are assumed to be batch dimensions.\n        norm: Normalization mode. See ``torch.fft.fft``.\n    Returns:\n        The FFT of the input.\n    \"\"\"\n    if not data.shape[-1] == 2:\n        raise ValueError(\"Tensor does not have separate complex dim.\")\n\n    data = ifftshift(data, dim=[-3, -2])\n    data = torch.view_as_real(\n        torch.fft.fftn(  # type: ignore\n            torch.view_as_complex(data), dim=(-2, -1), norm=norm\n        )\n    )\n    data = fftshift(data, dim=[-3, -2])\n\n    return data\n\n\ndef ifft2c_new(data: torch.Tensor, norm: str = \"ortho\") -> torch.Tensor:\n    \"\"\"\n    Apply centered 2-dimensional Inverse Fast Fourier Transform.\n    Args:\n        data: Complex valued input data containing at least 3 dimensions:\n            dimensions -3 & -2 are spatial dimensions and dimension -1 has size\n            2. All other dimensions are assumed to be batch dimensions.\n        norm: Normalization mode. See ``torch.fft.ifft``.\n    Returns:\n        The IFFT of the input.\n    \"\"\"\n    if not data.shape[-1] == 2:\n        raise ValueError(\"Tensor does not have separate complex dim.\")\n\n    data = ifftshift(data, dim=[-3, -2])\n    data = torch.view_as_real(\n        torch.fft.ifftn(  # type: ignore\n            torch.view_as_complex(data), dim=(-2, -1), norm=norm\n        )\n    )\n    data = fftshift(data, dim=[-3, -2])\n\n    return data\n\n\n# Helper functions\n\n\ndef roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:\n    \"\"\"\n    Similar to roll but for only one dim.\n    Args:\n        x: A PyTorch tensor.\n        shift: Amount to roll.\n        dim: Which dimension to roll.\n    Returns:\n        Rolled version of x.\n    \"\"\"\n    shift = shift % x.size(dim)\n    if shift == 0:\n        return x\n\n    left = x.narrow(dim, 0, x.size(dim) - shift)\n    right = x.narrow(dim, x.size(dim) - shift, shift)\n\n    return torch.cat((right, left), dim=dim)\n\n\ndef roll(\n    x: torch.Tensor,\n    shift: List[int],\n    dim: List[int],\n) -> torch.Tensor:\n    \"\"\"\n    Similar to np.roll but applies to PyTorch Tensors.\n    Args:\n        x: A PyTorch tensor.\n        shift: Amount to roll.\n        dim: Which dimension to roll.\n    Returns:\n        Rolled version of x.\n    \"\"\"\n    if len(shift) != len(dim):\n        raise ValueError(\"len(shift) must match len(dim)\")\n\n    for (s, d) in zip(shift, dim):\n        x = roll_one_dim(x, s, d)\n\n    return x\n\n\ndef fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:\n    \"\"\"\n    Similar to np.fft.fftshift but applies to PyTorch Tensors\n    Args:\n        x: A PyTorch tensor.\n        dim: Which dimension to fftshift.\n    Returns:\n        fftshifted version of x.\n    \"\"\"\n    if dim is None:\n        # this weird code is necessary for toch.jit.script typing\n        dim = [0] * (x.dim())\n        for i in range(1, x.dim()):\n            dim[i] = i\n\n    # also necessary for torch.jit.script\n    shift = [0] * len(dim)\n    for i, dim_num in enumerate(dim):\n        shift[i] = x.shape[dim_num] // 2\n\n    return roll(x, shift, dim)\n\n\ndef ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:\n    \"\"\"\n    Similar to np.fft.ifftshift but applies to PyTorch Tensors\n    Args:\n        x: A PyTorch tensor.\n        dim: Which dimension to ifftshift.\n    Returns:\n        ifftshifted version of x.\n    \"\"\"\n    if dim is None:\n        # this weird code is necessary for toch.jit.script typing\n        dim = [0] * (x.dim())\n        for i in range(1, x.dim()):\n            dim[i] = i\n\n    # also necessary for torch.jit.script\n    shift = [0] * len(dim)\n    for i, dim_num in enumerate(dim):\n        shift[i] = (x.shape[dim_num] + 1) // 2\n\n    return roll(x, shift, dim)"
  },
  {
    "path": "inverse_problem_solver_AAPM_3d_total.py",
    "content": "import torch\nfrom torch._C import device\nfrom losses import get_optimizer\nfrom models.ema import ExponentialMovingAverage\n\nimport numpy as np\nimport controllable_generation_TV\n\nfrom utils import restore_checkpoint, clear, batchfy, patient_wise_min_max, img_wise_min_max\nfrom pathlib import Path\nfrom models import utils as mutils\nfrom models import ncsnpp\nfrom sde_lib import VESDE\nfrom sampling import (ReverseDiffusionPredictor,\n                      LangevinCorrector)\nimport datasets\nimport time\n# for radon\nfrom physics.ct import CT\nimport matplotlib.pyplot as plt\nimport os\nfrom tqdm import tqdm\n\n###############################################\n# Configurations\n###############################################\nproblem = 'sparseview_CT_ADMM_TV_total'\nconfig_name = 'AAPM_256_ncsnpp_continuous'\nsde = 'VESDE'\nnum_scales = 2000\nckpt_num = 185\nN = num_scales\n\nvol_name = 'L067'\nroot = Path(f'./data/CT/ind/256_sorted/{vol_name}')\n\n# Parameters for the inverse problem\nNview = 8\ndet_spacing = 1.0\nsize = 256\ndet_count = int((size * (2 * torch.ones(1)).sqrt()).ceil())\nlamb = 0.04\nrho = 10\nfreq = 1\n\nif sde.lower() == 'vesde':\n    from configs.ve import AAPM_256_ncsnpp_continuous as configs\n    ckpt_filename = f\"exp/ve/{config_name}/checkpoint_{ckpt_num}.pth\"\n    config = configs.get_config()\n    config.model.num_scales = N\n    sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)\n    sde.N = N\n    sampling_eps = 1e-5\npredictor = ReverseDiffusionPredictor\ncorrector = LangevinCorrector\nprobability_flow = False\nsnr = 0.16\nn_steps = 1\n\nbatch_size = 12\nconfig.training.batch_size = batch_size\nconfig.eval.batch_size = batch_size\nrandom_seed = 0\n\nsigmas = mutils.get_sigmas(config)\nscaler = datasets.get_data_scaler(config)\ninverse_scaler = datasets.get_data_inverse_scaler(config)\nscore_model = mutils.create_model(config)  ## model\n\noptimizer = get_optimizer(config, score_model.parameters())\nema = ExponentialMovingAverage(score_model.parameters(),\n                               decay=config.model.ema_rate)\nstate = dict(step=0, optimizer=optimizer,\n             model=score_model, ema=ema)\n\nstate = restore_checkpoint(ckpt_filename, state, config.device, skip_sigma=True, skip_optimizer=True)\nema.copy_to(score_model.parameters())\n\n# Specify save directory for saving generated samples\nsave_root = Path(f'./results/{config_name}/{problem}/m{Nview}/rho{rho}/lambda{lamb}')\nsave_root.mkdir(parents=True, exist_ok=True)\n\nirl_types = ['input', 'recon', 'label', 'BP', 'sinogram']\nfor t in irl_types:\n    if t == 'recon':\n        save_root_f = save_root / t / 'progress'\n        save_root_f.mkdir(exist_ok=True, parents=True)\n    else:\n        save_root_f = save_root / t\n        save_root_f.mkdir(parents=True, exist_ok=True)\n\n# read all data\nfname_list = os.listdir(root)\nfname_list = sorted(fname_list, key=lambda x: float(x.split(\".\")[0]))\nprint(fname_list)\nall_img = []\n\nprint(\"Loading all data\")\nfor fname in tqdm(fname_list):\n    just_name = fname.split('.')[0]\n    img = torch.from_numpy(np.load(os.path.join(root, fname), allow_pickle=True))\n    h, w = img.shape\n    img = img.view(1, 1, h, w)\n    all_img.append(img)\n    plt.imsave(os.path.join(save_root, 'label', f'{just_name}.png'), clear(img), cmap='gray')\nall_img = torch.cat(all_img, dim=0)\nprint(f\"Data loaded shape : {all_img.shape}\")\n\n# full\nangles = np.linspace(0, np.pi, 180, endpoint=False)\nradon = CT(img_width=h, radon_view=Nview, circle=False, device=config.device)\n\npredicted_sinogram = []\nlabel_sinogram = []\nimg_cache = None\n\nimg = all_img.to(config.device)\npc_radon = controllable_generation_TV.get_pc_radon_ADMM_TV_vol(sde,\n                                                               predictor, corrector,\n                                                               inverse_scaler,\n                                                               snr=snr,\n                                                               n_steps=n_steps,\n                                                               probability_flow=probability_flow,\n                                                               continuous=config.training.continuous,\n                                                               denoise=True,\n                                                               radon=radon,\n                                                               save_progress=True,\n                                                               save_root=save_root,\n                                                               final_consistency=True,\n                                                               img_shape=img.shape,\n                                                               lamb_1=lamb,\n                                                               rho=rho)\n# Sparse by masking\nsinogram = radon.A(img)\n\n# A_dagger\nbp = radon.AT(sinogram)\n\n# Recon Image\nx = pc_radon(score_model, scaler(img), measurement=sinogram)\nimg_cahce = x[-1].unsqueeze(0)\n\ncount = 0\nfor i, recon_img in enumerate(x):\n    plt.imsave(save_root / 'BP' / f'{count}.png', clear(bp[i]), cmap='gray')\n    plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray')\n    plt.imsave(save_root / 'recon' / f'{count}.png', clear(recon_img), cmap='gray')\n\n    count += 1\n\n# Recon and Save Sinogram\nlabel_sinogram.append(radon.A_all(img))\npredicted_sinogram.append(radon.A_all(x))\n\noriginal_sinogram = torch.cat(label_sinogram, dim=0).detach().cpu().numpy()\nrecon_sinogram = torch.cat(predicted_sinogram, dim=0).detach().cpu().numpy()\n\nnp.save(str(save_root / 'sinogram' / f'original_{count}.npy'), original_sinogram)\nnp.save(str(save_root / 'sinogram' / f'recon_{count}.npy'), recon_sinogram)"
  },
  {
    "path": "inverse_problem_solver_BRATS_MRI_3d_total.py",
    "content": "from pathlib import Path\nfrom models import utils as mutils\nimport sampling\nfrom sde_lib import VESDE\nfrom sampling import (ReverseDiffusionPredictor,\n                      LangevinCorrector,\n                      LangevinCorrectorCS)\nfrom models import ncsnpp\nfrom itertools import islice\nfrom losses import get_optimizer\nimport datasets\nimport time\nimport controllable_generation_TV\nfrom utils import restore_checkpoint, fft2, ifft2, show_samples_gray, get_mask, clear\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom models.ema import ExponentialMovingAverage\nfrom scipy.io import savemat, loadmat\nfrom tqdm import tqdm\nimport matplotlib.pyplot as plt\nimport importlib\n\n\n###############################################\n# Configurations\n###############################################\nproblem = 'Fourier_CS_3d_admm_tv'\nconfig_name = 'fastmri_knee_320_ncsnpp_continuous'\nsde = 'VESDE'\nnum_scales = 2000\nckpt_num = 95\nN = num_scales\n\nroot = './data/MRI/BRATS'\nvol = 'Brats18_CBICA_AAM_1'\n\nif sde.lower() == 'vesde':\n  # from configs.ve import fastmri_knee_320_ncsnpp_continuous as configs\n  configs = importlib.import_module(f\"configs.ve.{config_name}\")\n  if config_name == 'fastmri_knee_320_ncsnpp_continuous':\n    ckpt_filename = f\"./exp/ve/{config_name}/checkpoint_{ckpt_num}.pth\"\n  elif config_name == 'ffhq_256_ncsnpp_continuous':\n    ckpt_filename = f\"exp/ve/{config_name}/checkpoint_48.pth\"\n  config = configs.get_config()\n  config.model.num_scales = num_scales\n  sde = VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)\n  sde.N = N\n  sampling_eps = 1e-5\n\nimg_size = 240\nbatch_size = 1\nconfig.training.batch_size = batch_size\npredictor = ReverseDiffusionPredictor\ncorrector = LangevinCorrector\nprobability_flow = False\nsnr = 0.16\nn_steps = 1\n\n# parameters for Fourier CS recon\nmask_type = 'uniform1d'\nuse_measurement_noise = False\nacc_factor = 2.0\ncenter_fraction = 0.15\n\n# ADMM TV parameters\nlamb_list = [0.005]\nrho_list = [0.01]\n\nrandom_seed = 0\n\nsigmas = mutils.get_sigmas(config)\nscaler = datasets.get_data_scaler(config)\ninverse_scaler = datasets.get_data_inverse_scaler(config)\nscore_model = mutils.create_model(config)\n\noptimizer = get_optimizer(config, score_model.parameters())\nema = ExponentialMovingAverage(score_model.parameters(),\n                               decay=config.model.ema_rate)\nstate = dict(step=0, optimizer=optimizer,\n             model=score_model, ema=ema)\nstate = restore_checkpoint(ckpt_filename, state, config.device, skip_sigma=True)\nema.copy_to(score_model.parameters())\n\nfname_list = sorted(list((Path(root) / vol).glob('*.npy')))\nall_img = []\nfor fname in tqdm(fname_list):\n    img = np.load(fname)\n    img = torch.from_numpy(img)\n    h, w = img.shape\n    img = img.view(1, 1, h, w)\n    all_img.append(img)\n\nall_img = torch.cat(all_img, dim=0)\n\n# normalize the volume to be in proper range\nvmax = all_img.max()\nall_img /= (vmax + 1e-5)\n\nimg = all_img.to(config.device)\nb = img.shape[0]\n\nfor lamb in lamb_list:\n    for rho in rho_list:\n        print(f'lambda: {lamb}')\n        print(f'rho:    {rho}')\n        # Specify save directory for saving generated samples\n        save_root = Path(f'./results/{config_name}/{problem}/{mask_type}/acc{acc_factor}/lamb{lamb}/rho{rho}/{vol}')\n        save_root.mkdir(parents=True, exist_ok=True)\n\n        irl_types = ['input', 'recon', 'label']\n        for t in irl_types:\n            save_root_f = save_root / t\n            save_root_f.mkdir(parents=True, exist_ok=True)\n\n        ###############################################\n        # Inference\n        ###############################################\n\n        # forward model\n        kspace = fft2(img)\n\n        # generate mask\n        mask = get_mask(torch.zeros(1, 1, h, w), img_size, batch_size,\n                        type=mask_type, acc_factor=acc_factor, center_fraction=center_fraction)\n        mask = mask.to(img.device)\n        mask = mask.repeat(b, 1, 1, 1)\n\n        pc_fouriercs = controllable_generation_TV.get_pc_radon_ADMM_TV_mri(sde,\n                                                                           predictor, corrector,\n                                                                           inverse_scaler,\n                                                                           mask=mask,\n                                                                           lamb_1=lamb,\n                                                                           rho=rho,\n                                                                           img_shape=img.shape,\n                                                                           snr=snr,\n                                                                           n_steps=n_steps,\n                                                                           probability_flow=probability_flow,\n                                                                           continuous=config.training.continuous)\n\n        # undersampling\n        under_kspace = kspace * mask\n        under_img = torch.real(ifft2(under_kspace))\n\n        count = 0\n        for i, recon_img in enumerate(under_img):\n            plt.imsave(save_root / 'input' / f'{count}.png', clear(under_img[i]), cmap='gray')\n            plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray')\n            count += 1\n\n        x = pc_fouriercs(score_model, scaler(under_img), measurement=under_kspace)\n\n        count = 0\n        for i, recon_img in enumerate(x):\n            plt.imsave(save_root / 'input' / f'{count}.png', clear(under_img[i]), cmap='gray')\n            plt.imsave(save_root / 'label' / f'{count}.png', clear(img[i]), cmap='gray')\n            plt.imsave(save_root / 'recon' / f'{count}.png', clear(recon_img), cmap='gray')\n            np.save(str(save_root / 'input' / f'{count}.npy'), clear(under_img[i], normalize=False))\n            np.save(str(save_root / 'recon' / f'{count}.npy'), clear(x[i], normalize=False))\n            np.save(str(save_root / 'label' / f'{count}.npy'), clear(img[i], normalize=False))\n            count += 1\n\n"
  },
  {
    "path": "likelihood.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n# pytype: skip-file\n\"\"\"Various sampling methods.\"\"\"\n\nimport torch\nimport numpy as np\nfrom scipy import integrate\nfrom models import utils as mutils\n\n\ndef get_div_fn(fn):\n  \"\"\"Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.\"\"\"\n\n  def div_fn(x, t, eps):\n    with torch.enable_grad():\n      x.requires_grad_(True)\n      fn_eps = torch.sum(fn(x, t) * eps)\n      grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]\n    x.requires_grad_(False)\n    return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))\n\n  return div_fn\n\n\ndef get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',\n                      rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):\n  \"\"\"Create a function to compute the unbiased log-likelihood estimate of a given data point.\n\n  Args:\n    sde: A `sde_lib.SDE` object that represents the forward SDE.\n    inverse_scaler: The inverse data normalizer.\n    hutchinson_type: \"Rademacher\" or \"Gaussian\". The type of noise for Hutchinson-Skilling trace estimator.\n    rtol: A `float` number. The relative tolerance level of the black-box ODE solver.\n    atol: A `float` number. The absolute tolerance level of the black-box ODE solver.\n    method: A `str`. The algorithm for the black-box ODE solver.\n      See documentation for `scipy.integrate.solve_ivp`.\n    eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.\n\n  Returns:\n    A function that a batch of data points and returns the log-likelihoods in bits/dim,\n      the latent code, and the number of function evaluations cost by computation.\n  \"\"\"\n\n  def drift_fn(model, x, t):\n    \"\"\"The drift function of the reverse-time SDE.\"\"\"\n    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)\n    # Probability flow ODE is a special case of Reverse SDE\n    rsde = sde.reverse(score_fn, probability_flow=True)\n    return rsde.sde(x, t)[0]\n\n  def div_fn(model, x, t, noise):\n    return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)\n\n  def likelihood_fn(model, data):\n    \"\"\"Compute an unbiased estimate to the log-likelihood in bits/dim.\n\n    Args:\n      model: A score model.\n      data: A PyTorch tensor.\n\n    Returns:\n      bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.\n      z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the\n        probability flow ODE.\n      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.\n    \"\"\"\n    with torch.no_grad():\n      shape = data.shape\n      if hutchinson_type == 'Gaussian':\n        epsilon = torch.randn_like(data)\n      elif hutchinson_type == 'Rademacher':\n        epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.\n      else:\n        raise NotImplementedError(f\"Hutchinson type {hutchinson_type} unknown.\")\n\n      def ode_func(t, x):\n        sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)\n        vec_t = torch.ones(sample.shape[0], device=sample.device) * t\n        drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))\n        logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))\n        return np.concatenate([drift, logp_grad], axis=0)\n\n      init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)\n      solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)\n      nfe = solution.nfev\n      zp = solution.y[:, -1]\n      z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)\n      delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)\n      prior_logp = sde.prior_logp(z)\n      bpd = -(prior_logp + delta_logp) / np.log(2)\n      N = np.prod(shape[1:])\n      bpd = bpd / N\n      # A hack to convert log-likelihoods to bits/dim\n      offset = 7. - inverse_scaler(-1.)\n      bpd = bpd + offset\n      return bpd, z, nfe\n\n  return likelihood_fn\n"
  },
  {
    "path": "losses.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"All functions related to loss computation and optimization.\n\"\"\"\n\nimport torch\nimport torch.optim as optim\nimport numpy as np\nfrom models import utils as mutils\nfrom sde_lib import VESDE, VPSDE\nfrom utils import fft2, ifft2, get_mask\nimport numpy as np\n\n\ndef get_optimizer(config, params):\n  \"\"\"Returns a flax optimizer object based on `config`.\"\"\"\n  if config.optim.optimizer == 'Adam':\n    optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,\n                           weight_decay=config.optim.weight_decay)\n  else:\n    raise NotImplementedError(\n      f'Optimizer {config.optim.optimizer} not supported yet!')\n\n  return optimizer\n\n\ndef optimization_manager(config):\n  \"\"\"Returns an optimize_fn based on `config`.\"\"\"\n\n  def optimize_fn(optimizer, params, step, lr=config.optim.lr,\n                  warmup=config.optim.warmup,\n                  grad_clip=config.optim.grad_clip):\n    \"\"\"Optimizes with warmup and gradient clipping (disabled if negative).\"\"\"\n    if warmup > 0:\n      for g in optimizer.param_groups:\n        g['lr'] = lr * np.minimum(step / warmup, 1.0)\n    if grad_clip >= 0:\n      torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)\n    optimizer.step()\n\n  return optimize_fn\n\n\ndef get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5):\n  \"\"\"Create a loss function for training with arbirary SDEs.\n\n  Args:\n    sde: An `sde_lib.SDE` object that represents the forward SDE.\n    train: `True` for training loss and `False` for evaluation loss.\n    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.\n    continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires\n      ad-hoc interpolation to take continuous time steps.\n    likelihood_weighting: If `True`, weight the mixture of score matching losses\n      according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper.\n    eps: A `float` number. The smallest time step to sample from.\n\n  Returns:\n    A loss function.\n  \"\"\"\n  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)\n\n  def loss_fn(model, batch):\n    \"\"\"Compute the loss function.\n    Args:\n      model: A score model.\n      batch: A mini-batch of training data.\n\n    Returns:\n      loss: A scalar that represents the average loss value across the mini-batch.\n    \"\"\"\n    score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)\n    t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps\n    z = torch.randn_like(batch)\n    mean, std = sde.marginal_prob(batch, t)\n    perturbed_data = mean + std[:, None, None, None] * z\n    score = score_fn(perturbed_data, t)\n\n    if not likelihood_weighting:\n      losses = torch.square(score * std[:, None, None, None] + z)\n      losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)\n    else:\n      g2 = sde.sde(torch.zeros_like(batch), t)[1] ** 2\n      losses = torch.square(score + z / std[:, None, None, None])\n      losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * g2\n\n    loss = torch.mean(losses)\n    return loss\n\n  return loss_fn\n\n\ndef get_smld_loss_fn(vesde, train, reduce_mean=False):\n  \"\"\"Legacy code to reproduce previous results on SMLD(NCSN). Not recommended for new work.\"\"\"\n  assert isinstance(vesde, VESDE), \"SMLD training only works for VESDEs.\"\n\n  # Previous SMLD models assume descending sigmas\n  smld_sigma_array = torch.flip(vesde.discrete_sigmas, dims=(0,))\n  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)\n\n  def loss_fn(model, batch):\n    model_fn = mutils.get_model_fn(model, train=train)\n    labels = torch.randint(0, vesde.N, (batch.shape[0],), device=batch.device)\n    sigmas = smld_sigma_array.to(batch.device)[labels]\n    noise = torch.randn_like(batch) * sigmas[:, None, None, None]\n    perturbed_data = noise + batch\n    score = model_fn(perturbed_data, labels)\n    target = -noise / (sigmas ** 2)[:, None, None, None]\n    losses = torch.square(score - target)\n    losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1) * sigmas ** 2\n    loss = torch.mean(losses)\n    return loss\n\n  return loss_fn\n\n\ndef get_ddpm_loss_fn(vpsde, train, reduce_mean=True):\n  \"\"\"Legacy code to reproduce previous results on DDPM. Not recommended for new work.\"\"\"\n  assert isinstance(vpsde, VPSDE), \"DDPM training only works for VPSDEs.\"\n\n  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)\n\n  def loss_fn(model, batch):\n    model_fn = mutils.get_model_fn(model, train=train)\n    labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)\n    sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)\n    sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)\n    noise = torch.randn_like(batch)\n    perturbed_data = sqrt_alphas_cumprod[labels, None, None, None] * batch + \\\n                     sqrt_1m_alphas_cumprod[labels, None, None, None] * noise\n    score = model_fn(perturbed_data, labels)\n    losses = torch.square(score - noise)\n    losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)\n    loss = torch.mean(losses)\n    return loss\n\n  return loss_fn\n\n\ndef get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False):\n  \"\"\"Create a one-step training/evaluation function.\n\n  Args:\n    sde: An `sde_lib.SDE` object that represents the forward SDE.\n    optimize_fn: An optimization function.\n    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.\n    continuous: `True` indicates that the model is defined to take continuous time steps.\n    likelihood_weighting: If `True`, weight the mixture of score matching losses according to\n      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.\n\n  Returns:\n    A one-step function for training or evaluation.\n  \"\"\"\n  if continuous:\n    loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,\n                              continuous=True, likelihood_weighting=likelihood_weighting)\n  else:\n    assert not likelihood_weighting, \"Likelihood weighting is not supported for original SMLD/DDPM training.\"\n    if isinstance(sde, VESDE):\n      loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean)\n    elif isinstance(sde, VPSDE):\n      loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean)\n    else:\n      raise ValueError(f\"Discrete training for {sde.__class__.__name__} is not recommended.\")\n\n  def step_fn(state, batch):\n    \"\"\"Running one step of training or evaluation.\n\n    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together\n    for faster execution.\n\n    Args:\n      state: A dictionary of training information, containing the score model, optimizer,\n       EMA status, and number of optimization steps.\n      batch: A mini-batch of training/evaluation data.\n\n    Returns:\n      loss: The average loss value of this state.\n    \"\"\"\n    model = state['model']\n    if train:\n      optimizer = state['optimizer']\n      optimizer.zero_grad()\n      loss = loss_fn(model, batch)\n      loss.backward()\n      optimize_fn(optimizer, model.parameters(), step=state['step'])\n      state['step'] += 1\n      state['ema'].update(model.parameters())\n    else:\n      with torch.no_grad():\n        ema = state['ema']\n        ema.store(model.parameters())\n        ema.copy_to(model.parameters())\n        loss = loss_fn(model, batch)\n        ema.restore(model.parameters())\n\n    return loss\n\n  return step_fn\n\n\n\ndef get_step_fn_regression(train, config, mask=None, loss_fn=None, optimize_fn=None):\n\n  def step_fn(state, batch):\n    model = state['model']\n    if train:\n      optimizer = state['optimizer']\n      optimizer.zero_grad()\n\n      # fft\n      kspace = fft2(batch)\n\n      # sample mask\n      acc_factor = np.random.choice(config.training.acc_factor)\n      mask = get_mask(batch, config.data.image_size, config.training.batch_size,\n                      type=config.training.mask_type,\n                      acc_factor=acc_factor,\n                      fix=True)\n\n      # undersampling\n      under_kspace = kspace * mask\n      under_img = torch.abs(ifft2(under_kspace))\n\n      est_img = model(under_img)\n      loss = loss_fn(est_img, batch)\n      loss.backward()\n      optimize_fn(optimizer, model.parameters(), step=state['step'])\n      state['step'] += 1\n      state['ema'].update(model.parameters())\n      return loss\n    else:\n      with torch.no_grad():\n        ema = state['ema']\n        ema.store(model.parameters())\n        ema.copy_to(model.parameters())\n        # fft\n        kspace = fft2(batch)\n\n        # sample mask\n        mask = get_mask(batch, config.data.image_size, config.traiing.batch_size,\n                        type=config.training.mask_type,\n                        acc_factor=config.training.acc_factor)\n\n        # undersampling\n        under_kspace = kspace * mask\n        under_img = torch.real(ifft2(under_kspace))\n\n        est_img = model(under_img)\n        ema.restore(model.parameters())\n        return est_img\n  return step_fn\n"
  },
  {
    "path": "main.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Training and evaluation\"\"\"\nimport os\nfrom pathlib import Path\n\nfrom absl import app\nfrom absl import flags\nfrom ml_collections.config_flags import config_flags\nimport logging\n\nimport run_lib\n\nFLAGS = flags.FLAGS\n\nconfig_flags.DEFINE_config_file(\n  \"config\", None, \"Training configuration.\", lock_config=True)\nflags.DEFINE_string(\"workdir\", None, \"Work directory.\")\nflags.DEFINE_enum(\"mode\", None, [\"train\", \"train_regression\", \"eval\"], \"Running mode: train, train_regression, or eval\")\nflags.DEFINE_string(\"eval_folder\", \"eval\",\n                    \"The folder name for storing evaluation results\")\nflags.mark_flags_as_required([\"workdir\", \"config\", \"mode\"])\n\n\ndef main(argv):\n  print(FLAGS.config)\n  if FLAGS.mode == \"train\" or FLAGS.mode == \"train_regression\":\n    # Create the working directory\n    Path(FLAGS.workdir).mkdir(parents=True, exist_ok=True)\n    # Set logger so that it outputs to both console and file\n    # Make logging work for both disk and Google Cloud Storage\n    gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'w')\n    handler = logging.StreamHandler(gfile_stream)\n    formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')\n    handler.setFormatter(formatter)\n    logger = logging.getLogger()\n    logger.addHandler(handler)\n    logger.setLevel('INFO')\n    # Run the training pipeline\n    if FLAGS.mode == \"train\":\n      run_lib.train(FLAGS.config, FLAGS.workdir)\n  elif FLAGS.mode == \"eval\":\n    # Run the evaluation pipeline\n    run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder)\n  else:\n    raise ValueError(f\"Mode {FLAGS.mode} not recognized.\")\n\n\nif __name__ == \"__main__\":\n  app.run(main)\n"
  },
  {
    "path": "models/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "models/ddpm.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"DDPM model.\n\nThis code is the pytorch equivalent of:\nhttps://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport functools\n\nfrom . import utils, layers, normalization\n\nRefineBlock = layers.RefineBlock\nResidualBlock = layers.ResidualBlock\nResnetBlockDDPM = layers.ResnetBlockDDPM\nUpsample = layers.Upsample\nDownsample = layers.Downsample\nconv3x3 = layers.ddpm_conv3x3\nget_act = layers.get_act\nget_normalization = normalization.get_normalization\ndefault_initializer = layers.default_init\n\n\n@utils.register_model(name='ddpm')\nclass DDPM(nn.Module):\n  def __init__(self, config):\n    super().__init__()\n    self.act = act = get_act(config)\n    self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))\n\n    self.nf = nf = config.model.nf\n    ch_mult = config.model.ch_mult\n    self.num_res_blocks = num_res_blocks = config.model.num_res_blocks\n    self.attn_resolutions = attn_resolutions = config.model.attn_resolutions\n    dropout = config.model.dropout\n    resamp_with_conv = config.model.resamp_with_conv\n    self.num_resolutions = num_resolutions = len(ch_mult)\n    self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)]\n\n    AttnBlock = functools.partial(layers.AttnBlock)\n    self.conditional = conditional = config.model.conditional\n    ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout)\n    if conditional:\n      # Condition on noise levels.\n      modules = [nn.Linear(nf, nf * 4)]\n      modules[0].weight.data = default_initializer()(modules[0].weight.data.shape)\n      nn.init.zeros_(modules[0].bias)\n      modules.append(nn.Linear(nf * 4, nf * 4))\n      modules[1].weight.data = default_initializer()(modules[1].weight.data.shape)\n      nn.init.zeros_(modules[1].bias)\n\n    self.centered = config.data.centered\n    channels = config.data.num_channels\n\n    # Downsampling block\n    modules.append(conv3x3(channels, nf))\n    hs_c = [nf]\n    in_ch = nf\n    for i_level in range(num_resolutions):\n      # Residual blocks for this resolution\n      for i_block in range(num_res_blocks):\n        out_ch = nf * ch_mult[i_level]\n        modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))\n        in_ch = out_ch\n        if all_resolutions[i_level] in attn_resolutions:\n          modules.append(AttnBlock(channels=in_ch))\n        hs_c.append(in_ch)\n      if i_level != num_resolutions - 1:\n        modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))\n        hs_c.append(in_ch)\n\n    in_ch = hs_c[-1]\n    modules.append(ResnetBlock(in_ch=in_ch))\n    modules.append(AttnBlock(channels=in_ch))\n    modules.append(ResnetBlock(in_ch=in_ch))\n\n    # Upsampling block\n    for i_level in reversed(range(num_resolutions)):\n      for i_block in range(num_res_blocks + 1):\n        out_ch = nf * ch_mult[i_level]\n        modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))\n        in_ch = out_ch\n      if all_resolutions[i_level] in attn_resolutions:\n        modules.append(AttnBlock(channels=in_ch))\n      if i_level != 0:\n        modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))\n\n    assert not hs_c\n    modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6))\n    modules.append(conv3x3(in_ch, channels, init_scale=0.))\n    self.all_modules = nn.ModuleList(modules)\n\n    self.scale_by_sigma = config.model.scale_by_sigma\n\n  def forward(self, x, labels):\n    modules = self.all_modules\n    m_idx = 0\n    if self.conditional:\n      # timestep/scale embedding\n      timesteps = labels\n      temb = layers.get_timestep_embedding(timesteps, self.nf)\n      temb = modules[m_idx](temb)\n      m_idx += 1\n      temb = modules[m_idx](self.act(temb))\n      m_idx += 1\n    else:\n      temb = None\n\n    if self.centered:\n      # Input is in [-1, 1]\n      h = x\n    else:\n      # Input is in [0, 1]\n      h = 2 * x - 1.\n\n    # Downsampling block\n    hs = [modules[m_idx](h)]\n    m_idx += 1\n    for i_level in range(self.num_resolutions):\n      # Residual blocks for this resolution\n      for i_block in range(self.num_res_blocks):\n        h = modules[m_idx](hs[-1], temb)\n        m_idx += 1\n        if h.shape[-1] in self.attn_resolutions:\n          h = modules[m_idx](h)\n          m_idx += 1\n        hs.append(h)\n      if i_level != self.num_resolutions - 1:\n        hs.append(modules[m_idx](hs[-1]))\n        m_idx += 1\n\n    h = hs[-1]\n    h = modules[m_idx](h, temb)\n    m_idx += 1\n    h = modules[m_idx](h)\n    m_idx += 1\n    h = modules[m_idx](h, temb)\n    m_idx += 1\n\n    # Upsampling block\n    for i_level in reversed(range(self.num_resolutions)):\n      for i_block in range(self.num_res_blocks + 1):\n        h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)\n        m_idx += 1\n      if h.shape[-1] in self.attn_resolutions:\n        h = modules[m_idx](h)\n        m_idx += 1\n      if i_level != 0:\n        h = modules[m_idx](h)\n        m_idx += 1\n\n    assert not hs\n    h = self.act(modules[m_idx](h))\n    m_idx += 1\n    h = modules[m_idx](h)\n    m_idx += 1\n    assert m_idx == len(modules)\n\n    if self.scale_by_sigma:\n      # Divide the output by sigmas. Useful for training with the NCSN loss.\n      # The DDPM loss scales the network output by sigma in the loss function,\n      # so no need of doing it here.\n      used_sigmas = self.sigmas[labels, None, None, None]\n      h = h / used_sigmas\n\n    return h\n"
  },
  {
    "path": "models/ema.py",
    "content": "# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py\n\nfrom __future__ import division\nfrom __future__ import unicode_literals\n\nimport torch\n\n\n# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py\nclass ExponentialMovingAverage:\n  \"\"\"\n  Maintains (exponential) moving average of a set of parameters.\n  \"\"\"\n\n  def __init__(self, parameters, decay, use_num_updates=True):\n    \"\"\"\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; usually the result of\n        `model.parameters()`.\n      decay: The exponential decay.\n      use_num_updates: Whether to use number of updates when computing\n        averages.\n    \"\"\"\n    if decay < 0.0 or decay > 1.0:\n      raise ValueError('Decay must be between 0 and 1')\n    self.decay = decay\n    self.num_updates = 0 if use_num_updates else None\n    self.shadow_params = [p.clone().detach()\n                          for p in parameters if p.requires_grad]\n    self.collected_params = []\n\n  def update(self, parameters):\n    \"\"\"\n    Update currently maintained parameters.\n\n    Call this every time the parameters are updated, such as the result of\n    the `optimizer.step()` call.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; usually the same set of\n        parameters used to initialize this object.\n    \"\"\"\n    decay = self.decay\n    if self.num_updates is not None:\n      self.num_updates += 1\n      decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))\n    one_minus_decay = 1.0 - decay\n    with torch.no_grad():\n      parameters = [p for p in parameters if p.requires_grad]\n      for s_param, param in zip(self.shadow_params, parameters):\n        s_param.sub_(one_minus_decay * (s_param - param))\n\n  def copy_to(self, parameters):\n    \"\"\"\n    Copy current parameters into given collection of parameters.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n        updated with the stored moving averages.\n    \"\"\"\n    parameters = [p for p in parameters if p.requires_grad]\n    for s_param, param in zip(self.shadow_params, parameters):\n      if param.requires_grad:\n        param.data.copy_(s_param.data)\n\n  def store(self, parameters):\n    \"\"\"\n    Save the current parameters for restoring later.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n        temporarily stored.\n    \"\"\"\n    self.collected_params = [param.clone() for param in parameters]\n\n  def restore(self, parameters):\n    \"\"\"\n    Restore the parameters stored with the `store` method.\n    Useful to validate the model with EMA parameters without affecting the\n    original optimization process. Store the parameters before the\n    `copy_to` method. After validation (or model saving), use this to\n    restore the former parameters.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n        updated with the stored parameters.\n    \"\"\"\n    for c_param, param in zip(self.collected_params, parameters):\n      param.data.copy_(c_param.data)\n\n  def state_dict(self):\n    return dict(decay=self.decay, num_updates=self.num_updates,\n                shadow_params=self.shadow_params)\n\n  def load_state_dict(self, state_dict):\n    self.decay = state_dict['decay']\n    self.num_updates = state_dict['num_updates']\n    self.shadow_params = state_dict['shadow_params']"
  },
  {
    "path": "models/layers.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"Common layers for defining score networks.\n\"\"\"\nimport math\nimport string\nfrom functools import partial\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom .normalization import ConditionalInstanceNorm2dPlus\n\n\nclass SiLU(nn.Module):\n  def forward(self, x):\n    return x * torch.sigmoid(x)\n\ndef get_act(config):\n  \"\"\"Get activation functions from the config file.\"\"\"\n\n  if config.model.nonlinearity.lower() == 'elu':\n    return nn.ELU()\n  elif config.model.nonlinearity.lower() == 'relu':\n    return nn.ReLU()\n  elif config.model.nonlinearity.lower() == 'lrelu':\n    return nn.LeakyReLU(negative_slope=0.2)\n  elif config.model.nonlinearity.lower() == 'swish':\n    return nn.SiLU()\n  else:\n    raise NotImplementedError('activation function does not exist!')\n\n\ndef ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):\n  \"\"\"1x1 convolution. Same as NCSNv1/v2.\"\"\"\n  conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,\n                   padding=padding)\n  init_scale = 1e-10 if init_scale == 0 else init_scale\n  conv.weight.data *= init_scale\n  conv.bias.data *= init_scale\n  return conv\n\n\ndef variance_scaling(scale, mode, distribution,\n                     in_axis=1, out_axis=0,\n                     dtype=torch.float32,\n                     device='cpu'):\n  \"\"\"Ported from JAX. \"\"\"\n\n  def _compute_fans(shape, in_axis=1, out_axis=0):\n    receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]\n    fan_in = shape[in_axis] * receptive_field_size\n    fan_out = shape[out_axis] * receptive_field_size\n    return fan_in, fan_out\n\n  def init(shape, dtype=dtype, device=device):\n    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)\n    if mode == \"fan_in\":\n      denominator = fan_in\n    elif mode == \"fan_out\":\n      denominator = fan_out\n    elif mode == \"fan_avg\":\n      denominator = (fan_in + fan_out) / 2\n    else:\n      raise ValueError(\n        \"invalid mode for variance scaling initializer: {}\".format(mode))\n    variance = scale / denominator\n    if distribution == \"normal\":\n      return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)\n    elif distribution == \"uniform\":\n      return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)\n    else:\n      raise ValueError(\"invalid distribution for variance scaling initializer\")\n\n  return init\n\n\ndef default_init(scale=1.):\n  \"\"\"The same initialization used in DDPM.\"\"\"\n  scale = 1e-10 if scale == 0 else scale\n  return variance_scaling(scale, 'fan_avg', 'uniform')\n\n\nclass Dense(nn.Module):\n  \"\"\"Linear layer with `default_init`.\"\"\"\n  def __init__(self):\n    super().__init__()\n\n\ndef ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):\n  \"\"\"1x1 convolution with DDPM initialization.\"\"\"\n  conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)\n  conv.weight.data = default_init(init_scale)(conv.weight.data.shape)\n  nn.init.zeros_(conv.bias)\n  return conv\n\n\ndef ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):\n  \"\"\"3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.\"\"\"\n  init_scale = 1e-10 if init_scale == 0 else init_scale\n  conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,\n                   dilation=dilation, padding=padding, kernel_size=3)\n  conv.weight.data *= init_scale\n  conv.bias.data *= init_scale\n  return conv\n\n\ndef ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):\n  \"\"\"3x3 convolution with DDPM initialization.\"\"\"\n  conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,\n                   dilation=dilation, bias=bias)\n  conv.weight.data = default_init(init_scale)(conv.weight.data.shape)\n  nn.init.zeros_(conv.bias)\n  return conv\n\n  ###########################################################################\n  # Functions below are ported over from the NCSNv1/NCSNv2 codebase:\n  # https://github.com/ermongroup/ncsn\n  # https://github.com/ermongroup/ncsnv2\n  ###########################################################################\n\n\nclass CRPBlock(nn.Module):\n  def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):\n    super().__init__()\n    self.convs = nn.ModuleList()\n    for i in range(n_stages):\n      self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))\n    self.n_stages = n_stages\n    if maxpool:\n      self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)\n    else:\n      self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)\n\n    self.act = act\n\n  def forward(self, x):\n    x = self.act(x)\n    path = x\n    for i in range(self.n_stages):\n      path = self.pool(path)\n      path = self.convs[i](path)\n      x = path + x\n    return x\n\n\nclass CondCRPBlock(nn.Module):\n  def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):\n    super().__init__()\n    self.convs = nn.ModuleList()\n    self.norms = nn.ModuleList()\n    self.normalizer = normalizer\n    for i in range(n_stages):\n      self.norms.append(normalizer(features, num_classes, bias=True))\n      self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))\n\n    self.n_stages = n_stages\n    self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)\n    self.act = act\n\n  def forward(self, x, y):\n    x = self.act(x)\n    path = x\n    for i in range(self.n_stages):\n      path = self.norms[i](path, y)\n      path = self.pool(path)\n      path = self.convs[i](path)\n\n      x = path + x\n    return x\n\n\nclass RCUBlock(nn.Module):\n  def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):\n    super().__init__()\n\n    for i in range(n_blocks):\n      for j in range(n_stages):\n        setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))\n\n    self.stride = 1\n    self.n_blocks = n_blocks\n    self.n_stages = n_stages\n    self.act = act\n\n  def forward(self, x):\n    for i in range(self.n_blocks):\n      residual = x\n      for j in range(self.n_stages):\n        x = self.act(x)\n        x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)\n\n      x += residual\n    return x\n\n\nclass CondRCUBlock(nn.Module):\n  def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):\n    super().__init__()\n\n    for i in range(n_blocks):\n      for j in range(n_stages):\n        setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))\n        setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))\n\n    self.stride = 1\n    self.n_blocks = n_blocks\n    self.n_stages = n_stages\n    self.act = act\n    self.normalizer = normalizer\n\n  def forward(self, x, y):\n    for i in range(self.n_blocks):\n      residual = x\n      for j in range(self.n_stages):\n        x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)\n        x = self.act(x)\n        x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)\n\n      x += residual\n    return x\n\n\nclass MSFBlock(nn.Module):\n  def __init__(self, in_planes, features):\n    super().__init__()\n    assert isinstance(in_planes, list) or isinstance(in_planes, tuple)\n    self.convs = nn.ModuleList()\n    self.features = features\n\n    for i in range(len(in_planes)):\n      self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))\n\n  def forward(self, xs, shape):\n    sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)\n    for i in range(len(self.convs)):\n      h = self.convs[i](xs[i])\n      h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)\n      sums += h\n    return sums\n\n\nclass CondMSFBlock(nn.Module):\n  def __init__(self, in_planes, features, num_classes, normalizer):\n    super().__init__()\n    assert isinstance(in_planes, list) or isinstance(in_planes, tuple)\n\n    self.convs = nn.ModuleList()\n    self.norms = nn.ModuleList()\n    self.features = features\n    self.normalizer = normalizer\n\n    for i in range(len(in_planes)):\n      self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))\n      self.norms.append(normalizer(in_planes[i], num_classes, bias=True))\n\n  def forward(self, xs, y, shape):\n    sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)\n    for i in range(len(self.convs)):\n      h = self.norms[i](xs[i], y)\n      h = self.convs[i](h)\n      h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)\n      sums += h\n    return sums\n\n\nclass RefineBlock(nn.Module):\n  def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):\n    super().__init__()\n\n    assert isinstance(in_planes, tuple) or isinstance(in_planes, list)\n    self.n_blocks = n_blocks = len(in_planes)\n\n    self.adapt_convs = nn.ModuleList()\n    for i in range(n_blocks):\n      self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))\n\n    self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)\n\n    if not start:\n      self.msf = MSFBlock(in_planes, features)\n\n    self.crp = CRPBlock(features, 2, act, maxpool=maxpool)\n\n  def forward(self, xs, output_shape):\n    assert isinstance(xs, tuple) or isinstance(xs, list)\n    hs = []\n    for i in range(len(xs)):\n      h = self.adapt_convs[i](xs[i])\n      hs.append(h)\n\n    if self.n_blocks > 1:\n      h = self.msf(hs, output_shape)\n    else:\n      h = hs[0]\n\n    h = self.crp(h)\n    h = self.output_convs(h)\n\n    return h\n\n\nclass CondRefineBlock(nn.Module):\n  def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):\n    super().__init__()\n\n    assert isinstance(in_planes, tuple) or isinstance(in_planes, list)\n    self.n_blocks = n_blocks = len(in_planes)\n\n    self.adapt_convs = nn.ModuleList()\n    for i in range(n_blocks):\n      self.adapt_convs.append(\n        CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)\n      )\n\n    self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)\n\n    if not start:\n      self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)\n\n    self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)\n\n  def forward(self, xs, y, output_shape):\n    assert isinstance(xs, tuple) or isinstance(xs, list)\n    hs = []\n    for i in range(len(xs)):\n      h = self.adapt_convs[i](xs[i], y)\n      hs.append(h)\n\n    if self.n_blocks > 1:\n      h = self.msf(hs, y, output_shape)\n    else:\n      h = hs[0]\n\n    h = self.crp(h, y)\n    h = self.output_convs(h, y)\n\n    return h\n\n\nclass ConvMeanPool(nn.Module):\n  def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):\n    super().__init__()\n    if not adjust_padding:\n      conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)\n      self.conv = conv\n    else:\n      conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)\n\n      self.conv = nn.Sequential(\n        nn.ZeroPad2d((1, 0, 1, 0)),\n        conv\n      )\n\n  def forward(self, inputs):\n    output = self.conv(inputs)\n    output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],\n                  output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.\n    return output\n\n\nclass MeanPoolConv(nn.Module):\n  def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):\n    super().__init__()\n    self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)\n\n  def forward(self, inputs):\n    output = inputs\n    output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],\n                  output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.\n    return self.conv(output)\n\n\nclass UpsampleConv(nn.Module):\n  def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):\n    super().__init__()\n    self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)\n    self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)\n\n  def forward(self, inputs):\n    output = inputs\n    output = torch.cat([output, output, output, output], dim=1)\n    output = self.pixelshuffle(output)\n    return self.conv(output)\n\n\nclass ConditionalResidualBlock(nn.Module):\n  def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),\n               normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):\n    super().__init__()\n    self.non_linearity = act\n    self.input_dim = input_dim\n    self.output_dim = output_dim\n    self.resample = resample\n    self.normalization = normalization\n    if resample == 'down':\n      if dilation > 1:\n        self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)\n        self.normalize2 = normalization(input_dim, num_classes)\n        self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)\n        conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)\n      else:\n        self.conv1 = ncsn_conv3x3(input_dim, input_dim)\n        self.normalize2 = normalization(input_dim, num_classes)\n        self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)\n        conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)\n\n    elif resample is None:\n      if dilation > 1:\n        conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)\n        self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)\n        self.normalize2 = normalization(output_dim, num_classes)\n        self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)\n      else:\n        conv_shortcut = nn.Conv2d\n        self.conv1 = ncsn_conv3x3(input_dim, output_dim)\n        self.normalize2 = normalization(output_dim, num_classes)\n        self.conv2 = ncsn_conv3x3(output_dim, output_dim)\n    else:\n      raise Exception('invalid resample value')\n\n    if output_dim != input_dim or resample is not None:\n      self.shortcut = conv_shortcut(input_dim, output_dim)\n\n    self.normalize1 = normalization(input_dim, num_classes)\n\n  def forward(self, x, y):\n    output = self.normalize1(x, y)\n    output = self.non_linearity(output)\n    output = self.conv1(output)\n    output = self.normalize2(output, y)\n    output = self.non_linearity(output)\n    output = self.conv2(output)\n\n    if self.output_dim == self.input_dim and self.resample is None:\n      shortcut = x\n    else:\n      shortcut = self.shortcut(x)\n\n    return shortcut + output\n\n\nclass ResidualBlock(nn.Module):\n  def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),\n               normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):\n    super().__init__()\n    self.non_linearity = act\n    self.input_dim = input_dim\n    self.output_dim = output_dim\n    self.resample = resample\n    self.normalization = normalization\n    if resample == 'down':\n      if dilation > 1:\n        self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)\n        self.normalize2 = normalization(input_dim)\n        self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)\n        conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)\n      else:\n        self.conv1 = ncsn_conv3x3(input_dim, input_dim)\n        self.normalize2 = normalization(input_dim)\n        self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)\n        conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)\n\n    elif resample is None:\n      if dilation > 1:\n        conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)\n        self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)\n        self.normalize2 = normalization(output_dim)\n        self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)\n      else:\n        # conv_shortcut = nn.Conv2d ### Something wierd here.\n        conv_shortcut = partial(ncsn_conv1x1)\n        self.conv1 = ncsn_conv3x3(input_dim, output_dim)\n        self.normalize2 = normalization(output_dim)\n        self.conv2 = ncsn_conv3x3(output_dim, output_dim)\n    else:\n      raise Exception('invalid resample value')\n\n    if output_dim != input_dim or resample is not None:\n      self.shortcut = conv_shortcut(input_dim, output_dim)\n\n    self.normalize1 = normalization(input_dim)\n\n  def forward(self, x):\n    output = self.normalize1(x)\n    output = self.non_linearity(output)\n    output = self.conv1(output)\n    output = self.normalize2(output)\n    output = self.non_linearity(output)\n    output = self.conv2(output)\n\n    if self.output_dim == self.input_dim and self.resample is None:\n      shortcut = x\n    else:\n      shortcut = self.shortcut(x)\n\n    return shortcut + output\n\n\n###########################################################################\n# Functions below are ported over from the DDPM codebase:\n#  https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py\n###########################################################################\n\ndef get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):\n  assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32\n  half_dim = embedding_dim // 2\n  # magic number 10000 is from transformers\n  emb = math.log(max_positions) / (half_dim - 1)\n  emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)\n  emb = timesteps.float()[:, None] * emb[None, :]\n  emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n  if embedding_dim % 2 == 1:  # zero pad\n    emb = F.pad(emb, (0, 1), mode='constant')\n  assert emb.shape == (timesteps.shape[0], embedding_dim)\n  return emb\n\n\ndef _einsum(a, b, c, x, y):\n  einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))\n  return torch.einsum(einsum_str, x, y)\n\n\ndef contract_inner(x, y):\n  \"\"\"tensordot(x, y, 1).\"\"\"\n  x_chars = list(string.ascii_lowercase[:len(x.shape)])\n  y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])\n  y_chars[0] = x_chars[-1]  # first axis of y and last of x get summed\n  out_chars = x_chars[:-1] + y_chars[1:]\n  return _einsum(x_chars, y_chars, out_chars, x, y)\n\n\nclass NIN(nn.Module):\n  def __init__(self, in_dim, num_units, init_scale=0.1):\n    super().__init__()\n    self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)\n    self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)\n\n  def forward(self, x):\n    x = x.permute(0, 2, 3, 1)\n    y = contract_inner(x, self.W) + self.b\n    return y.permute(0, 3, 1, 2)\n\n\nclass AttnBlock(nn.Module):\n  \"\"\"Channel-wise self-attention block.\"\"\"\n  def __init__(self, channels):\n    super().__init__()\n    self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)\n    self.NIN_0 = NIN(channels, channels)\n    self.NIN_1 = NIN(channels, channels)\n    self.NIN_2 = NIN(channels, channels)\n    self.NIN_3 = NIN(channels, channels, init_scale=0.)\n\n  def forward(self, x):\n    B, C, H, W = x.shape\n    h = self.GroupNorm_0(x)\n    q = self.NIN_0(h)\n    k = self.NIN_1(h)\n    v = self.NIN_2(h)\n\n    w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))\n    w = torch.reshape(w, (B, H, W, H * W))\n    w = F.softmax(w, dim=-1)\n    w = torch.reshape(w, (B, H, W, H, W))\n    h = torch.einsum('bhwij,bcij->bchw', w, v)\n    h = self.NIN_3(h)\n    return x + h\n\n\nclass Upsample(nn.Module):\n  def __init__(self, channels, with_conv=False):\n    super().__init__()\n    if with_conv:\n      self.Conv_0 = ddpm_conv3x3(channels, channels)\n    self.with_conv = with_conv\n\n  def forward(self, x):\n    B, C, H, W = x.shape\n    h = F.interpolate(x, (H * 2, W * 2), mode='nearest')\n    if self.with_conv:\n      h = self.Conv_0(h)\n    return h\n\n\nclass Downsample(nn.Module):\n  def __init__(self, channels, with_conv=False):\n    super().__init__()\n    if with_conv:\n      self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)\n    self.with_conv = with_conv\n\n  def forward(self, x):\n    B, C, H, W = x.shape\n    # Emulate 'SAME' padding\n    if self.with_conv:\n      x = F.pad(x, (0, 1, 0, 1))\n      x = self.Conv_0(x)\n    else:\n      x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)\n\n    assert x.shape == (B, C, H // 2, W // 2)\n    return x\n\n\nclass ResnetBlockDDPM(nn.Module):\n  \"\"\"The ResNet Blocks used in DDPM.\"\"\"\n  def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):\n    super().__init__()\n    if out_ch is None:\n      out_ch = in_ch\n    self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)\n    self.act = act\n    self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)\n    if temb_dim is not None:\n      self.Dense_0 = nn.Linear(temb_dim, out_ch)\n      self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)\n      nn.init.zeros_(self.Dense_0.bias)\n\n    self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)\n    self.Dropout_0 = nn.Dropout(dropout)\n    self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)\n    if in_ch != out_ch:\n      if conv_shortcut:\n        self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)\n      else:\n        self.NIN_0 = NIN(in_ch, out_ch)\n    self.out_ch = out_ch\n    self.in_ch = in_ch\n    self.conv_shortcut = conv_shortcut\n\n  def forward(self, x, temb=None):\n    B, C, H, W = x.shape\n    assert C == self.in_ch\n    out_ch = self.out_ch if self.out_ch else self.in_ch\n    h = self.act(self.GroupNorm_0(x))\n    h = self.Conv_0(h)\n    # Add bias to each feature map conditioned on the time embedding\n    if temb is not None:\n      h += self.Dense_0(self.act(temb))[:, :, None, None]\n    h = self.act(self.GroupNorm_1(h))\n    h = self.Dropout_0(h)\n    h = self.Conv_1(h)\n    if C != out_ch:\n      if self.conv_shortcut:\n        x = self.Conv_2(x)\n      else:\n        x = self.NIN_0(x)\n    return x + h"
  },
  {
    "path": "models/layerspp.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"Layers for defining NCSN++.\n\"\"\"\nfrom . import layers\nfrom . import up_or_down_sampling\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nconv1x1 = layers.ddpm_conv1x1\nconv3x3 = layers.ddpm_conv3x3\nNIN = layers.NIN\ndefault_init = layers.default_init\n\n\nclass GaussianFourierProjection(nn.Module):\n  \"\"\"Gaussian Fourier embeddings for noise levels.\"\"\"\n\n  def __init__(self, embedding_size=256, scale=1.0):\n    super().__init__()\n    self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)\n\n  def forward(self, x):\n    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi\n    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)\n\n\nclass Combine(nn.Module):\n  \"\"\"Combine information from skip connections.\"\"\"\n\n  def __init__(self, dim1, dim2, method='cat'):\n    super().__init__()\n    self.Conv_0 = conv1x1(dim1, dim2)\n    self.method = method\n\n  def forward(self, x, y):\n    h = self.Conv_0(x)\n    if self.method == 'cat':\n      return torch.cat([h, y], dim=1)\n    elif self.method == 'sum':\n      return h + y\n    else:\n      raise ValueError(f'Method {self.method} not recognized.')\n\n\nclass AttnBlockpp(nn.Module):\n  \"\"\"Channel-wise self-attention block. Modified from DDPM.\"\"\"\n\n  def __init__(self, channels, skip_rescale=False, init_scale=0.):\n    super().__init__()\n    self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,\n                                  eps=1e-6)\n    self.NIN_0 = NIN(channels, channels)\n    self.NIN_1 = NIN(channels, channels)\n    self.NIN_2 = NIN(channels, channels)\n    self.NIN_3 = NIN(channels, channels, init_scale=init_scale)\n    self.skip_rescale = skip_rescale\n\n  def forward(self, x):\n    B, C, H, W = x.shape\n    h = self.GroupNorm_0(x)\n    q = self.NIN_0(h)\n    k = self.NIN_1(h)\n    v = self.NIN_2(h)\n\n    w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))\n    w = torch.reshape(w, (B, H, W, H * W))\n    w = F.softmax(w, dim=-1)\n    w = torch.reshape(w, (B, H, W, H, W))\n    h = torch.einsum('bhwij,bcij->bchw', w, v)\n    h = self.NIN_3(h)\n    if not self.skip_rescale:\n      return x + h\n    else:\n      return (x + h) / np.sqrt(2.)\n\n\nclass Upsample(nn.Module):\n  def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,\n               fir_kernel=(1, 3, 3, 1)):\n    super().__init__()\n    out_ch = out_ch if out_ch else in_ch\n    if not fir:\n      if with_conv:\n        self.Conv_0 = conv3x3(in_ch, out_ch)\n    else:\n      if with_conv:\n        self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,\n                                                 kernel=3, up=True,\n                                                 resample_kernel=fir_kernel,\n                                                 use_bias=True,\n                                                 kernel_init=default_init())\n    self.fir = fir\n    self.with_conv = with_conv\n    self.fir_kernel = fir_kernel\n    self.out_ch = out_ch\n\n  def forward(self, x):\n    B, C, H, W = x.shape\n    if not self.fir:\n      h = F.interpolate(x, (H * 2, W * 2), 'nearest')\n      if self.with_conv:\n        h = self.Conv_0(h)\n    else:\n      if not self.with_conv:\n        h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)\n      else:\n        h = self.Conv2d_0(x)\n\n    return h\n\n\nclass Downsample(nn.Module):\n  def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,\n               fir_kernel=(1, 3, 3, 1)):\n    super().__init__()\n    out_ch = out_ch if out_ch else in_ch\n    if not fir:\n      if with_conv:\n        self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)\n    else:\n      if with_conv:\n        self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,\n                                                 kernel=3, down=True,\n                                                 resample_kernel=fir_kernel,\n                                                 use_bias=True,\n                                                 kernel_init=default_init())\n    self.fir = fir\n    self.fir_kernel = fir_kernel\n    self.with_conv = with_conv\n    self.out_ch = out_ch\n\n  def forward(self, x):\n    B, C, H, W = x.shape\n    if not self.fir:\n      if self.with_conv:\n        x = F.pad(x, (0, 1, 0, 1))\n        x = self.Conv_0(x)\n      else:\n        x = F.avg_pool2d(x, 2, stride=2)\n    else:\n      if not self.with_conv:\n        x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)\n      else:\n        x = self.Conv2d_0(x)\n\n    return x\n\n\nclass ResnetBlockDDPMpp(nn.Module):\n  \"\"\"ResBlock adapted from DDPM.\"\"\"\n\n  def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,\n               dropout=0.1, skip_rescale=False, init_scale=0.):\n    super().__init__()\n    out_ch = out_ch if out_ch else in_ch\n    self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)\n    self.Conv_0 = conv3x3(in_ch, out_ch)\n    if temb_dim is not None:\n      self.Dense_0 = nn.Linear(temb_dim, out_ch)\n      self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)\n      nn.init.zeros_(self.Dense_0.bias)\n    self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)\n    self.Dropout_0 = nn.Dropout(dropout)\n    self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)\n    if in_ch != out_ch:\n      if conv_shortcut:\n        self.Conv_2 = conv3x3(in_ch, out_ch)\n      else:\n        self.NIN_0 = NIN(in_ch, out_ch)\n\n    self.skip_rescale = skip_rescale\n    self.act = act\n    self.out_ch = out_ch\n    self.conv_shortcut = conv_shortcut\n\n  def forward(self, x, temb=None):\n    h = self.act(self.GroupNorm_0(x))\n    h = self.Conv_0(h)\n    if temb is not None:\n      h += self.Dense_0(self.act(temb))[:, :, None, None]\n    h = self.act(self.GroupNorm_1(h))\n    h = self.Dropout_0(h)\n    h = self.Conv_1(h)\n    if x.shape[1] != self.out_ch:\n      if self.conv_shortcut:\n        x = self.Conv_2(x)\n      else:\n        x = self.NIN_0(x)\n    if not self.skip_rescale:\n      return x + h\n    else:\n      return (x + h) / np.sqrt(2.)\n\n\nclass ResnetBlockBigGANpp(nn.Module):\n  def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,\n               dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),\n               skip_rescale=True, init_scale=0.):\n    super().__init__()\n\n    out_ch = out_ch if out_ch else in_ch\n    self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)\n    self.up = up\n    self.down = down\n    self.fir = fir\n    self.fir_kernel = fir_kernel\n\n    self.Conv_0 = conv3x3(in_ch, out_ch)\n    if temb_dim is not None:\n      self.Dense_0 = nn.Linear(temb_dim, out_ch)\n      self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)\n      nn.init.zeros_(self.Dense_0.bias)\n\n    self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)\n    self.Dropout_0 = nn.Dropout(dropout)\n    self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)\n    if in_ch != out_ch or up or down:\n      self.Conv_2 = conv1x1(in_ch, out_ch)\n\n    self.skip_rescale = skip_rescale\n    self.act = act\n    self.in_ch = in_ch\n    self.out_ch = out_ch\n\n  def forward(self, x, temb=None):\n    h = self.act(self.GroupNorm_0(x))\n\n    if self.up:\n      if self.fir:\n        h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)\n        x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)\n      else:\n        h = up_or_down_sampling.naive_upsample_2d(h, factor=2)\n        x = up_or_down_sampling.naive_upsample_2d(x, factor=2)\n    elif self.down:\n      if self.fir:\n        h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)\n        x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)\n      else:\n        h = up_or_down_sampling.naive_downsample_2d(h, factor=2)\n        x = up_or_down_sampling.naive_downsample_2d(x, factor=2)\n\n    h = self.Conv_0(h)\n    # Add bias to each feature map conditioned on the time embedding\n    if temb is not None:\n      h += self.Dense_0(self.act(temb))[:, :, None, None]\n    h = self.act(self.GroupNorm_1(h))\n    h = self.Dropout_0(h)\n    h = self.Conv_1(h)\n\n    if self.in_ch != self.out_ch or self.up or self.down:\n      x = self.Conv_2(x)\n\n    if not self.skip_rescale:\n      return x + h\n    else:\n      return (x + h) / np.sqrt(2.)\n"
  },
  {
    "path": "models/ncsnpp.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\nfrom . import utils, layers, layerspp, normalization\nimport torch.nn as nn\nimport functools\nimport torch\nimport numpy as np\n\nResnetBlockDDPM = layerspp.ResnetBlockDDPMpp\nResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp\nCombine = layerspp.Combine\nconv3x3 = layerspp.conv3x3\nconv1x1 = layerspp.conv1x1\nget_act = layers.get_act\nget_normalization = normalization.get_normalization\ndefault_initializer = layers.default_init\n\n\n@utils.register_model(name='ncsnpp')\nclass NCSNpp(nn.Module):\n  \"\"\"NCSN++ model\"\"\"\n\n  def __init__(self, config):\n    super().__init__()\n    self.config = config\n    self.act = act = get_act(config)\n    self.register_buffer('sigmas', torch.tensor(utils.get_sigmas(config)))\n\n    self.nf = nf = config.model.nf\n    ch_mult = config.model.ch_mult\n    self.num_res_blocks = num_res_blocks = config.model.num_res_blocks\n    self.attn_resolutions = attn_resolutions = config.model.attn_resolutions\n    dropout = config.model.dropout\n    resamp_with_conv = config.model.resamp_with_conv\n    self.num_resolutions = num_resolutions = len(ch_mult)\n    self.all_resolutions = all_resolutions = [config.data.image_size // (2 ** i) for i in range(num_resolutions)]\n\n    self.conditional = conditional = config.model.conditional  # noise-conditional\n    fir = config.model.fir\n    fir_kernel = config.model.fir_kernel\n    self.skip_rescale = skip_rescale = config.model.skip_rescale\n    self.resblock_type = resblock_type = config.model.resblock_type.lower()\n    self.progressive = progressive = config.model.progressive.lower()\n    self.progressive_input = progressive_input = config.model.progressive_input.lower()\n    self.embedding_type = embedding_type = config.model.embedding_type.lower()\n    init_scale = config.model.init_scale\n    assert progressive in ['none', 'output_skip', 'residual']\n    assert progressive_input in ['none', 'input_skip', 'residual']\n    assert embedding_type in ['fourier', 'positional']\n    combine_method = config.model.progressive_combine.lower()\n    combiner = functools.partial(Combine, method=combine_method)\n\n    modules = []\n    # timestep/noise_level embedding; only for continuous training\n    if embedding_type == 'fourier':\n      # Gaussian Fourier features embeddings.\n      assert config.training.continuous, \"Fourier features are only used for continuous training.\"\n\n      modules.append(layerspp.GaussianFourierProjection(\n        embedding_size=nf, scale=config.model.fourier_scale\n      ))\n      embed_dim = 2 * nf\n\n    elif embedding_type == 'positional':\n      embed_dim = nf\n\n    else:\n      raise ValueError(f'embedding type {embedding_type} unknown.')\n\n    if conditional:\n      modules.append(nn.Linear(embed_dim, nf * 4))\n      modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)\n      nn.init.zeros_(modules[-1].bias)\n      modules.append(nn.Linear(nf * 4, nf * 4))\n      modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)\n      nn.init.zeros_(modules[-1].bias)\n\n    AttnBlock = functools.partial(layerspp.AttnBlockpp,\n                                  init_scale=init_scale,\n                                  skip_rescale=skip_rescale)\n\n    Upsample = functools.partial(layerspp.Upsample,\n                                 with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)\n\n    if progressive == 'output_skip':\n      self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)\n    elif progressive == 'residual':\n      pyramid_upsample = functools.partial(layerspp.Upsample,\n                                           fir=fir, fir_kernel=fir_kernel, with_conv=True)\n\n    Downsample = functools.partial(layerspp.Downsample,\n                                   with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)\n\n    if progressive_input == 'input_skip':\n      self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)\n    elif progressive_input == 'residual':\n      pyramid_downsample = functools.partial(layerspp.Downsample,\n                                             fir=fir, fir_kernel=fir_kernel, with_conv=True)\n\n    if resblock_type == 'ddpm':\n      ResnetBlock = functools.partial(ResnetBlockDDPM,\n                                      act=act,\n                                      dropout=dropout,\n                                      init_scale=init_scale,\n                                      skip_rescale=skip_rescale,\n                                      temb_dim=nf * 4)\n\n    elif resblock_type == 'biggan':\n      ResnetBlock = functools.partial(ResnetBlockBigGAN,\n                                      act=act,\n                                      dropout=dropout,\n                                      fir=fir,\n                                      fir_kernel=fir_kernel,\n                                      init_scale=init_scale,\n                                      skip_rescale=skip_rescale,\n                                      temb_dim=nf * 4)\n\n    else:\n      raise ValueError(f'resblock type {resblock_type} unrecognized.')\n\n    # Downsampling block\n\n    channels = config.data.num_channels\n    if progressive_input != 'none':\n      input_pyramid_ch = channels\n\n    modules.append(conv3x3(channels, nf))\n    hs_c = [nf]\n\n    in_ch = nf\n    for i_level in range(num_resolutions):\n      # Residual blocks for this resolution\n      for i_block in range(num_res_blocks):\n        out_ch = nf * ch_mult[i_level]\n        modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))\n        in_ch = out_ch\n\n        if all_resolutions[i_level] in attn_resolutions:\n          modules.append(AttnBlock(channels=in_ch))\n        hs_c.append(in_ch)\n\n      if i_level != num_resolutions - 1:\n        if resblock_type == 'ddpm':\n          modules.append(Downsample(in_ch=in_ch))\n        else:\n          modules.append(ResnetBlock(down=True, in_ch=in_ch))\n\n        if progressive_input == 'input_skip':\n          modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))\n          if combine_method == 'cat':\n            in_ch *= 2\n\n        elif progressive_input == 'residual':\n          modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))\n          input_pyramid_ch = in_ch\n\n        hs_c.append(in_ch)\n\n    in_ch = hs_c[-1]\n    modules.append(ResnetBlock(in_ch=in_ch))\n    modules.append(AttnBlock(channels=in_ch))\n    modules.append(ResnetBlock(in_ch=in_ch))\n\n    pyramid_ch = 0\n    # Upsampling block\n    for i_level in reversed(range(num_resolutions)):\n      for i_block in range(num_res_blocks + 1):\n        out_ch = nf * ch_mult[i_level]\n        modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(),\n                                   out_ch=out_ch))\n        in_ch = out_ch\n\n      if all_resolutions[i_level] in attn_resolutions:\n        modules.append(AttnBlock(channels=in_ch))\n\n      if progressive != 'none':\n        if i_level == num_resolutions - 1:\n          if progressive == 'output_skip':\n            modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),\n                                        num_channels=in_ch, eps=1e-6))\n            modules.append(conv3x3(in_ch, channels, init_scale=init_scale))\n            pyramid_ch = channels\n          elif progressive == 'residual':\n            modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),\n                                        num_channels=in_ch, eps=1e-6))\n            modules.append(conv3x3(in_ch, in_ch, bias=True))\n            pyramid_ch = in_ch\n          else:\n            raise ValueError(f'{progressive} is not a valid name.')\n        else:\n          if progressive == 'output_skip':\n            modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),\n                                        num_channels=in_ch, eps=1e-6))\n            modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))\n            pyramid_ch = channels\n          elif progressive == 'residual':\n            modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))\n            pyramid_ch = in_ch\n          else:\n            raise ValueError(f'{progressive} is not a valid name')\n\n      if i_level != 0:\n        if resblock_type == 'ddpm':\n          modules.append(Upsample(in_ch=in_ch))\n        else:\n          modules.append(ResnetBlock(in_ch=in_ch, up=True))\n\n    assert not hs_c\n\n    if progressive != 'output_skip':\n      modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),\n                                  num_channels=in_ch, eps=1e-6))\n      modules.append(conv3x3(in_ch, channels, init_scale=init_scale))\n\n    self.all_modules = nn.ModuleList(modules)\n\n  def forward(self, x, time_cond):\n    # timestep/noise_level embedding; only for continuous training\n    modules = self.all_modules\n    m_idx = 0\n    if self.embedding_type == 'fourier':\n      # Gaussian Fourier features embeddings.\n      used_sigmas = time_cond\n      temb = modules[m_idx](torch.log(used_sigmas))\n      m_idx += 1\n\n    elif self.embedding_type == 'positional':\n      # Sinusoidal positional embeddings.\n      timesteps = time_cond\n      used_sigmas = self.sigmas[time_cond.long()]\n      temb = layers.get_timestep_embedding(timesteps, self.nf)\n\n    else:\n      raise ValueError(f'embedding type {self.embedding_type} unknown.')\n\n    if self.conditional:\n      temb = modules[m_idx](temb)\n      m_idx += 1\n      temb = modules[m_idx](self.act(temb))\n      m_idx += 1\n    else:\n      temb = None\n\n    if not self.config.data.centered:\n      # If input data is in [0, 1]\n      x = 2 * x - 1.\n\n    # Downsampling block\n    input_pyramid = None\n    if self.progressive_input != 'none':\n      input_pyramid = x\n\n    hs = [modules[m_idx](x)]\n    m_idx += 1\n    for i_level in range(self.num_resolutions):\n      # Residual blocks for this resolution\n      for i_block in range(self.num_res_blocks):\n        h = modules[m_idx](hs[-1], temb)\n        m_idx += 1\n        if h.shape[-1] in self.attn_resolutions:\n          h = modules[m_idx](h)\n          m_idx += 1\n\n        hs.append(h)\n        # debug\n        # print(f'lv/block : {i_level}/{i_block}    shape: {h.shape}')\n\n      if i_level != self.num_resolutions - 1:\n        if self.resblock_type == 'ddpm':\n          h = modules[m_idx](hs[-1])\n          m_idx += 1\n        else:\n          h = modules[m_idx](hs[-1], temb)\n          m_idx += 1\n\n        # debug\n        if self.progressive_input == 'input_skip':\n          input_pyramid = self.pyramid_downsample(input_pyramid)\n          h = modules[m_idx](input_pyramid, h)\n          m_idx += 1\n\n        elif self.progressive_input == 'residual':\n          input_pyramid = modules[m_idx](input_pyramid)\n          m_idx += 1\n          if self.skip_rescale:\n            input_pyramid = (input_pyramid + h) / np.sqrt(2.)\n          else:\n            input_pyramid = input_pyramid + h\n          h = input_pyramid\n\n        hs.append(h)\n\n    h = hs[-1]\n    h = modules[m_idx](h, temb)\n    m_idx += 1\n    h = modules[m_idx](h)\n    m_idx += 1\n    h = modules[m_idx](h, temb)\n    m_idx += 1\n\n    pyramid = None\n\n    # Upsampling block\n    for i_level in reversed(range(self.num_resolutions)):\n      for i_block in range(self.num_res_blocks + 1):\n        tmp = hs.pop()\n        h = modules[m_idx](torch.cat([h, tmp], dim=1), temb)\n        m_idx += 1\n\n        # debug\n        # print(f'lv/block : {i_level}/{i_block}    shape: {h.shape}')\n\n      if h.shape[-1] in self.attn_resolutions:\n        h = modules[m_idx](h)\n        m_idx += 1\n\n        # debug\n        # print(f'(ATTN) lv/block : {i_level}/{i_block}    shape: {h.shape}')\n\n      if self.progressive != 'none':\n        if i_level == self.num_resolutions - 1:\n          if self.progressive == 'output_skip':\n            pyramid = self.act(modules[m_idx](h))\n            m_idx += 1\n            pyramid = modules[m_idx](pyramid)\n            m_idx += 1\n          elif self.progressive == 'residual':\n            pyramid = self.act(modules[m_idx](h))\n            m_idx += 1\n            pyramid = modules[m_idx](pyramid)\n            m_idx += 1\n          else:\n            raise ValueError(f'{self.progressive} is not a valid name.')\n        else:\n          if self.progressive == 'output_skip':\n            pyramid = self.pyramid_upsample(pyramid)\n            pyramid_h = self.act(modules[m_idx](h))\n            m_idx += 1\n            pyramid_h = modules[m_idx](pyramid_h)\n            m_idx += 1\n            pyramid = pyramid + pyramid_h\n          elif self.progressive == 'residual':\n            pyramid = modules[m_idx](pyramid)\n            m_idx += 1\n            if self.skip_rescale:\n              pyramid = (pyramid + h) / np.sqrt(2.)\n            else:\n              pyramid = pyramid + h\n            h = pyramid\n          else:\n            raise ValueError(f'{self.progressive} is not a valid name')\n\n      if i_level != 0:\n        if self.resblock_type == 'ddpm':\n          h = modules[m_idx](h)\n          m_idx += 1\n        else:\n          h = modules[m_idx](h, temb)\n          m_idx += 1\n\n        # debug\n\n    assert not hs\n\n    if self.progressive == 'output_skip':\n      h = pyramid\n    else:\n      h = self.act(modules[m_idx](h))\n      m_idx += 1\n      h = modules[m_idx](h)\n      m_idx += 1\n\n    # debug\n    # print(f'module : {modules[m_idx-1]}    shape: {h.shape}')\n\n    assert m_idx == len(modules)\n    if self.config.model.scale_by_sigma:\n      used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))\n      # debug\n      # print(f'used_sigmas: {used_sigmas.shape}')\n      h = h / used_sigmas\n\n    return h\n"
  },
  {
    "path": "models/ncsnv2.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"The NCSNv2 model.\"\"\"\nimport torch\nimport torch.nn as nn\nimport functools\n\nfrom .utils import get_sigmas, register_model\nfrom .layers import (CondRefineBlock, RefineBlock, ResidualBlock, ncsn_conv3x3,\n                     ConditionalResidualBlock, get_act)\nfrom .normalization import get_normalization\n\nCondResidualBlock = ConditionalResidualBlock\nconv3x3 = ncsn_conv3x3\n\n\ndef get_network(config):\n  if config.data.image_size < 96:\n    return functools.partial(NCSNv2, config=config)\n  elif 96 <= config.data.image_size <= 128:\n    return functools.partial(NCSNv2_128, config=config)\n  elif 128 < config.data.image_size <= 256:\n    return functools.partial(NCSNv2_256, config=config)\n  else:\n    raise NotImplementedError(\n      f'No network suitable for {config.data.image_size}px implemented yet.')\n\n\n@register_model(name='ncsnv2_64')\nclass NCSNv2(nn.Module):\n  def __init__(self, config):\n    super().__init__()\n    self.centered = config.data.centered\n    self.norm = get_normalization(config)\n    self.nf = nf = config.model.nf\n\n    self.act = act = get_act(config)\n    self.register_buffer('sigmas', torch.tensor(get_sigmas(config)))\n    self.config = config\n\n    self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1)\n\n    self.normalizer = self.norm(nf, config.model.num_scales)\n    self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1)\n\n    self.res1 = nn.ModuleList([\n      ResidualBlock(self.nf, self.nf, resample=None, act=act,\n                    normalization=self.norm),\n      ResidualBlock(self.nf, self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res2 = nn.ModuleList([\n      ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act,\n                    normalization=self.norm),\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res3 = nn.ModuleList([\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act,\n                    normalization=self.norm, dilation=2),\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                    normalization=self.norm, dilation=2)]\n    )\n\n    if config.data.image_size == 28:\n      self.res4 = nn.ModuleList([\n        ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act,\n                      normalization=self.norm, adjust_padding=True, dilation=4),\n        ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                      normalization=self.norm, dilation=4)]\n      )\n    else:\n      self.res4 = nn.ModuleList([\n        ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act,\n                      normalization=self.norm, adjust_padding=False, dilation=4),\n        ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                      normalization=self.norm, dilation=4)]\n      )\n\n    self.refine1 = RefineBlock([2 * self.nf], 2 * self.nf, act=act, start=True)\n    self.refine2 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act)\n    self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act)\n    self.refine4 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True)\n\n  def _compute_cond_module(self, module, x):\n    for m in module:\n      x = m(x)\n    return x\n\n  def forward(self, x, y):\n    if not self.centered:\n      h = 2 * x - 1.\n    else:\n      h = x\n\n    output = self.begin_conv(h)\n\n    layer1 = self._compute_cond_module(self.res1, output)\n    layer2 = self._compute_cond_module(self.res2, layer1)\n    layer3 = self._compute_cond_module(self.res3, layer2)\n    layer4 = self._compute_cond_module(self.res4, layer3)\n\n    ref1 = self.refine1([layer4], layer4.shape[2:])\n    ref2 = self.refine2([layer3, ref1], layer3.shape[2:])\n    ref3 = self.refine3([layer2, ref2], layer2.shape[2:])\n    output = self.refine4([layer1, ref3], layer1.shape[2:])\n\n    output = self.normalizer(output)\n    output = self.act(output)\n    output = self.end_conv(output)\n\n    used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:])))\n\n    output = output / used_sigmas\n\n    return output\n\n\n@register_model(name='ncsn')\nclass NCSN(nn.Module):\n  def __init__(self, config):\n    super().__init__()\n    self.centered = config.data.centered\n    self.norm = get_normalization(config)\n    self.nf = nf = config.model.nf\n    self.act = act = get_act(config)\n    self.config = config\n\n    self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1)\n\n    self.normalizer = self.norm(nf, config.model.num_scales)\n    self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1)\n\n    self.res1 = nn.ModuleList([\n      ConditionalResidualBlock(self.nf, self.nf, config.model.num_scales, resample=None, act=act,\n                               normalization=self.norm),\n      ConditionalResidualBlock(self.nf, self.nf, config.model.num_scales, resample=None, act=act,\n                               normalization=self.norm)]\n    )\n\n    self.res2 = nn.ModuleList([\n      ConditionalResidualBlock(self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act,\n                               normalization=self.norm),\n      ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act,\n                               normalization=self.norm)]\n    )\n\n    self.res3 = nn.ModuleList([\n      ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act,\n                               normalization=self.norm, dilation=2),\n      ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act,\n                               normalization=self.norm, dilation=2)]\n    )\n\n    if config.data.image_size == 28:\n      self.res4 = nn.ModuleList([\n        ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act,\n                                 normalization=self.norm, adjust_padding=True, dilation=4),\n        ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act,\n                                 normalization=self.norm, dilation=4)]\n      )\n    else:\n      self.res4 = nn.ModuleList([\n        ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample='down', act=act,\n                                 normalization=self.norm, adjust_padding=False, dilation=4),\n        ConditionalResidualBlock(2 * self.nf, 2 * self.nf, config.model.num_scales, resample=None, act=act,\n                                 normalization=self.norm, dilation=4)]\n      )\n\n    self.refine1 = CondRefineBlock([2 * self.nf], 2 * self.nf, config.model.num_scales, self.norm, act=act, start=True)\n    self.refine2 = CondRefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, config.model.num_scales, self.norm, act=act)\n    self.refine3 = CondRefineBlock([2 * self.nf, 2 * self.nf], self.nf, config.model.num_scales, self.norm, act=act)\n    self.refine4 = CondRefineBlock([self.nf, self.nf], self.nf, config.model.num_scales, self.norm, act=act, end=True)\n\n  def _compute_cond_module(self, module, x, y):\n    for m in module:\n      x = m(x, y)\n    return x\n\n  def forward(self, x, y):\n    if not self.centered:\n      h = 2 * x - 1.\n    else:\n      h = x\n\n    output = self.begin_conv(h)\n\n    layer1 = self._compute_cond_module(self.res1, output, y)\n    layer2 = self._compute_cond_module(self.res2, layer1, y)\n    layer3 = self._compute_cond_module(self.res3, layer2, y)\n    layer4 = self._compute_cond_module(self.res4, layer3, y)\n\n    ref1 = self.refine1([layer4], y, layer4.shape[2:])\n    ref2 = self.refine2([layer3, ref1], y, layer3.shape[2:])\n    ref3 = self.refine3([layer2, ref2], y, layer2.shape[2:])\n    output = self.refine4([layer1, ref3], y, layer1.shape[2:])\n\n    output = self.normalizer(output, y)\n    output = self.act(output)\n    output = self.end_conv(output)\n\n    return output\n\n\n@register_model(name='ncsnv2_128')\nclass NCSNv2_128(nn.Module):\n  \"\"\"NCSNv2 model architecture for 128px images.\"\"\"\n  def __init__(self, config):\n    super().__init__()\n    self.centered = config.data.centered\n    self.norm = get_normalization(config)\n    self.nf = nf = config.model.nf\n    self.act = act = get_act(config)\n    self.register_buffer('sigmas', torch.tensor(get_sigmas(config)))\n    self.config = config\n\n    self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1)\n    self.normalizer = self.norm(nf, config.model.num_scales)\n\n    self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1)\n\n    self.res1 = nn.ModuleList([\n      ResidualBlock(self.nf, self.nf, resample=None, act=act,\n                    normalization=self.norm),\n      ResidualBlock(self.nf, self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res2 = nn.ModuleList([\n      ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act,\n                    normalization=self.norm),\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res3 = nn.ModuleList([\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act,\n                    normalization=self.norm),\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res4 = nn.ModuleList([\n      ResidualBlock(2 * self.nf, 4 * self.nf, resample='down', act=act,\n                    normalization=self.norm, dilation=2),\n      ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act,\n                    normalization=self.norm, dilation=2)]\n    )\n\n    self.res5 = nn.ModuleList([\n      ResidualBlock(4 * self.nf, 4 * self.nf, resample='down', act=act,\n                    normalization=self.norm, dilation=4),\n      ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act,\n                    normalization=self.norm, dilation=4)]\n    )\n\n    self.refine1 = RefineBlock([4 * self.nf], 4 * self.nf, act=act, start=True)\n    self.refine2 = RefineBlock([4 * self.nf, 4 * self.nf], 2 * self.nf, act=act)\n    self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act)\n    self.refine4 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act)\n    self.refine5 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True)\n\n  def _compute_cond_module(self, module, x):\n    for m in module:\n      x = m(x)\n    return x\n\n  def forward(self, x, y):\n    if not self.centered:\n      h = 2 * x - 1.\n    else:\n      h = x\n\n    output = self.begin_conv(h)\n\n    layer1 = self._compute_cond_module(self.res1, output)\n    layer2 = self._compute_cond_module(self.res2, layer1)\n    layer3 = self._compute_cond_module(self.res3, layer2)\n    layer4 = self._compute_cond_module(self.res4, layer3)\n    layer5 = self._compute_cond_module(self.res5, layer4)\n\n    ref1 = self.refine1([layer5], layer5.shape[2:])\n    ref2 = self.refine2([layer4, ref1], layer4.shape[2:])\n    ref3 = self.refine3([layer3, ref2], layer3.shape[2:])\n    ref4 = self.refine4([layer2, ref3], layer2.shape[2:])\n    output = self.refine5([layer1, ref4], layer1.shape[2:])\n\n    output = self.normalizer(output)\n    output = self.act(output)\n    output = self.end_conv(output)\n\n    used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:])))\n\n    output = output / used_sigmas\n\n    return output\n\n\n@register_model(name='ncsnv2_256')\nclass NCSNv2_256(nn.Module):\n  \"\"\"NCSNv2 model architecture for 256px images.\"\"\"\n  def __init__(self, config):\n    super().__init__()\n    self.centered = config.data.centered\n    self.norm = get_normalization(config)\n    self.nf = nf = config.model.nf\n    self.act = act = get_act(config)\n    self.register_buffer('sigmas', torch.tensor(get_sigmas(config)))\n    self.config = config\n\n    self.begin_conv = nn.Conv2d(config.data.channels, nf, 3, stride=1, padding=1)\n    self.normalizer = self.norm(nf, config.model.num_scales)\n\n    self.end_conv = nn.Conv2d(nf, config.data.channels, 3, stride=1, padding=1)\n\n    self.res1 = nn.ModuleList([\n      ResidualBlock(self.nf, self.nf, resample=None, act=act,\n                    normalization=self.norm),\n      ResidualBlock(self.nf, self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res2 = nn.ModuleList([\n      ResidualBlock(self.nf, 2 * self.nf, resample='down', act=act,\n                    normalization=self.norm),\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res3 = nn.ModuleList([\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act,\n                    normalization=self.norm),\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res31 = nn.ModuleList([\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample='down', act=act,\n                    normalization=self.norm),\n      ResidualBlock(2 * self.nf, 2 * self.nf, resample=None, act=act,\n                    normalization=self.norm)]\n    )\n\n    self.res4 = nn.ModuleList([\n      ResidualBlock(2 * self.nf, 4 * self.nf, resample='down', act=act,\n                    normalization=self.norm, dilation=2),\n      ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act,\n                    normalization=self.norm, dilation=2)]\n    )\n\n    self.res5 = nn.ModuleList([\n      ResidualBlock(4 * self.nf, 4 * self.nf, resample='down', act=act,\n                    normalization=self.norm, dilation=4),\n      ResidualBlock(4 * self.nf, 4 * self.nf, resample=None, act=act,\n                    normalization=self.norm, dilation=4)]\n    )\n\n    self.refine1 = RefineBlock([4 * self.nf], 4 * self.nf, act=act, start=True)\n    self.refine2 = RefineBlock([4 * self.nf, 4 * self.nf], 2 * self.nf, act=act)\n    self.refine3 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act)\n    self.refine31 = RefineBlock([2 * self.nf, 2 * self.nf], 2 * self.nf, act=act)\n    self.refine4 = RefineBlock([2 * self.nf, 2 * self.nf], self.nf, act=act)\n    self.refine5 = RefineBlock([self.nf, self.nf], self.nf, act=act, end=True)\n\n  def _compute_cond_module(self, module, x):\n    for m in module:\n      x = m(x)\n    return x\n\n  def forward(self, x, y):\n    if not self.centered:\n      h = 2 * x - 1.\n    else:\n      h = x\n\n    output = self.begin_conv(h)\n\n    layer1 = self._compute_cond_module(self.res1, output)\n    layer2 = self._compute_cond_module(self.res2, layer1)\n    layer3 = self._compute_cond_module(self.res3, layer2)\n    layer31 = self._compute_cond_module(self.res31, layer3)\n    layer4 = self._compute_cond_module(self.res4, layer31)\n    layer5 = self._compute_cond_module(self.res5, layer4)\n\n    ref1 = self.refine1([layer5], layer5.shape[2:])\n    ref2 = self.refine2([layer4, ref1], layer4.shape[2:])\n    ref31 = self.refine31([layer31, ref2], layer31.shape[2:])\n    ref3 = self.refine3([layer3, ref31], layer3.shape[2:])\n    ref4 = self.refine4([layer2, ref3], layer2.shape[2:])\n    output = self.refine5([layer1, ref4], layer1.shape[2:])\n\n    output = self.normalizer(output)\n    output = self.act(output)\n    output = self.end_conv(output)\n\n    used_sigmas = self.sigmas[y].view(x.shape[0], *([1] * len(x.shape[1:])))\n\n    output = output / used_sigmas\n\n    return output"
  },
  {
    "path": "models/normalization.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Normalization layers.\"\"\"\nimport torch.nn as nn\nimport torch\nimport functools\n\n\ndef get_normalization(config, conditional=False):\n  \"\"\"Obtain normalization modules from the config file.\"\"\"\n  norm = config.model.normalization\n  if conditional:\n    if norm == 'InstanceNorm++':\n      return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)\n    else:\n      raise NotImplementedError(f'{norm} not implemented yet.')\n  else:\n    if norm == 'InstanceNorm':\n      return nn.InstanceNorm2d\n    elif norm == 'InstanceNorm++':\n      return InstanceNorm2dPlus\n    elif norm == 'VarianceNorm':\n      return VarianceNorm2d\n    elif norm == 'GroupNorm':\n      return nn.GroupNorm\n    else:\n      raise ValueError('Unknown normalization: %s' % norm)\n\n\nclass ConditionalBatchNorm2d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.bn = nn.BatchNorm2d(num_features, affine=False)\n    if self.bias:\n      self.embed = nn.Embedding(num_classes, num_features * 2)\n      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, num_features)\n      self.embed.weight.data.uniform_()\n\n  def forward(self, x, y):\n    out = self.bn(x)\n    if self.bias:\n      gamma, beta = self.embed(y).chunk(2, dim=1)\n      out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)\n    else:\n      gamma = self.embed(y)\n      out = gamma.view(-1, self.num_features, 1, 1) * out\n    return out\n\n\nclass ConditionalInstanceNorm2d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)\n    if bias:\n      self.embed = nn.Embedding(num_classes, num_features * 2)\n      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, num_features)\n      self.embed.weight.data.uniform_()\n\n  def forward(self, x, y):\n    h = self.instance_norm(x)\n    if self.bias:\n      gamma, beta = self.embed(y).chunk(2, dim=-1)\n      out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)\n    else:\n      gamma = self.embed(y)\n      out = gamma.view(-1, self.num_features, 1, 1) * h\n    return out\n\n\nclass ConditionalVarianceNorm2d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=False):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.embed = nn.Embedding(num_classes, num_features)\n    self.embed.weight.data.normal_(1, 0.02)\n\n  def forward(self, x, y):\n    vars = torch.var(x, dim=(2, 3), keepdim=True)\n    h = x / torch.sqrt(vars + 1e-5)\n\n    gamma = self.embed(y)\n    out = gamma.view(-1, self.num_features, 1, 1) * h\n    return out\n\n\nclass VarianceNorm2d(nn.Module):\n  def __init__(self, num_features, bias=False):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.alpha = nn.Parameter(torch.zeros(num_features))\n    self.alpha.data.normal_(1, 0.02)\n\n  def forward(self, x):\n    vars = torch.var(x, dim=(2, 3), keepdim=True)\n    h = x / torch.sqrt(vars + 1e-5)\n\n    out = self.alpha.view(-1, self.num_features, 1, 1) * h\n    return out\n\n\nclass ConditionalNoneNorm2d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    if bias:\n      self.embed = nn.Embedding(num_classes, num_features * 2)\n      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, num_features)\n      self.embed.weight.data.uniform_()\n\n  def forward(self, x, y):\n    if self.bias:\n      gamma, beta = self.embed(y).chunk(2, dim=-1)\n      out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)\n    else:\n      gamma = self.embed(y)\n      out = gamma.view(-1, self.num_features, 1, 1) * x\n    return out\n\n\nclass NoneNorm2d(nn.Module):\n  def __init__(self, num_features, bias=True):\n    super().__init__()\n\n  def forward(self, x):\n    return x\n\n\nclass InstanceNorm2dPlus(nn.Module):\n  def __init__(self, num_features, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)\n    self.alpha = nn.Parameter(torch.zeros(num_features))\n    self.gamma = nn.Parameter(torch.zeros(num_features))\n    self.alpha.data.normal_(1, 0.02)\n    self.gamma.data.normal_(1, 0.02)\n    if bias:\n      self.beta = nn.Parameter(torch.zeros(num_features))\n\n  def forward(self, x):\n    means = torch.mean(x, dim=(2, 3))\n    m = torch.mean(means, dim=-1, keepdim=True)\n    v = torch.var(means, dim=-1, keepdim=True)\n    means = (means - m) / (torch.sqrt(v + 1e-5))\n    h = self.instance_norm(x)\n\n    if self.bias:\n      h = h + means[..., None, None] * self.alpha[..., None, None]\n      out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)\n    else:\n      h = h + means[..., None, None] * self.alpha[..., None, None]\n      out = self.gamma.view(-1, self.num_features, 1, 1) * h\n    return out\n\n\nclass ConditionalInstanceNorm2dPlus(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)\n    if bias:\n      self.embed = nn.Embedding(num_classes, num_features * 3)\n      self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, 2 * num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, 2 * num_features)\n      self.embed.weight.data.normal_(1, 0.02)\n\n  def forward(self, x, y):\n    means = torch.mean(x, dim=(2, 3))\n    m = torch.mean(means, dim=-1, keepdim=True)\n    v = torch.var(means, dim=-1, keepdim=True)\n    means = (means - m) / (torch.sqrt(v + 1e-5))\n    h = self.instance_norm(x)\n\n    if self.bias:\n      gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)\n      h = h + means[..., None, None] * alpha[..., None, None]\n      out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)\n    else:\n      gamma, alpha = self.embed(y).chunk(2, dim=-1)\n      h = h + means[..., None, None] * alpha[..., None, None]\n      out = gamma.view(-1, self.num_features, 1, 1) * h\n    return out\n"
  },
  {
    "path": "models/unet.py",
    "content": "from . import utils\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass ConvBlock(nn.Module):\n    \"\"\"\n    A Convolutional Block that consists of two convolution layers each followed by\n    instance normalization, relu activation and dropout.\n    \"\"\"\n\n    def __init__(self, in_chans, out_chans, stride=2):\n        \"\"\"\n        Args:\n            in_chans (int): Number of channels in the input.\n            out_chans (int): Number of channels in the output.\n            drop_prob (float): Dropout probability.\n        \"\"\"\n        super().__init__()\n\n        self.in_chans = in_chans\n        self.out_chans = out_chans\n\n        self.layers = nn.Sequential(\n            nn.Conv2d(in_chans, out_chans, kernel_size=3, stride=stride, padding=1),\n            nn.GroupNorm(num_groups=8, num_channels=out_chans),\n            nn.LeakyReLU(),\n\n            nn.Conv2d(out_chans, out_chans, kernel_size=3, stride=1, padding=1),\n            nn.GroupNorm(num_groups=8, num_channels=out_chans),\n            nn.LeakyReLU(),\n        )\n\n    def forward(self, tensor):\n        \"\"\"\n        Args:\n            tensor (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]\n        Returns:\n            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]\n        \"\"\"\n        return self.layers(tensor)\n\n    def __repr__(self):\n        return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans})'\n\n\n@utils.register_model(name='unet')\nclass Unet(nn.Module):\n    def __init__(self, in_chans=1, out_chans=1, chans=64, num_pool_layers=4, use_residual=True):\n        super().__init__()\n        # self.config = config\n        # self.in_chans = config.model.in_chans\n        # self.out_chans = config.model.out_chans\n        # self.chans = config.model.chans\n        # self.num_pool_layers = config.model.num_pool_layers\n        # self.use_residual = config.model.use_residual\n\n        self.in_chans = in_chans\n        self.out_chans = out_chans\n        self.chans = chans\n        self.num_pool_layers = num_pool_layers\n        self.use_residual = use_residual\n\n        ch = self.chans\n        self.down_sample_layers = nn.ModuleList([ConvBlock(self.in_chans, self.chans, stride=1)])\n        for i in range(self.num_pool_layers - 1):\n            self.down_sample_layers += [ConvBlock(ch, ch * 2, stride=2)]\n            ch *= 2\n\n        # Size reduction happens at the beginning of a block, hence the need for stride here\n        self.conv = ConvBlock(ch, ch, stride=2)\n\n        self.up_sample_layers = nn.ModuleList()\n        for i in range(self.num_pool_layers - 1):\n            self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, stride=1)]\n            ch //= 2\n        self.up_sample_layers += [ConvBlock(ch * 2, ch, stride=1)]\n        self.conv2 = nn.Conv2d(ch, self.out_chans, kernel_size=1)\n\n    def forward(self, tensor):\n        \"\"\"\n        Args:\n            tensor (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]\n        Returns:\n            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]\n        \"\"\"\n        stack = list()\n        output = tensor\n        # Apply down-sampling layers\n        for layer in self.down_sample_layers:\n            output = layer(output)\n            stack.append(output)\n            # output = F.avg_pool2d(output, kernel_size=2)\n\n        output = self.conv(output)\n\n        # Apply up-sampling layers\n        for layer in self.up_sample_layers:\n            output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False)\n            output = torch.cat((output, stack.pop()), dim=1)\n            output = layer(output)\n\n        output = self.conv2(output)\n        if self.use_residual:\n            output = output + tensor\n\n        return output"
  },
  {
    "path": "models/up_or_down_sampling.py",
    "content": "\"\"\"Layers used for up-sampling or down-sampling images.\n\nMany functions are ported from https://github.com/NVlabs/stylegan2.\n\"\"\"\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom op import upfirdn2d\n\n\n# Function ported from StyleGAN2\ndef get_weight(module,\n               shape,\n               weight_var='weight',\n               kernel_init=None):\n  \"\"\"Get/create weight tensor for a convolution or fully-connected layer.\"\"\"\n\n  return module.param(weight_var, kernel_init, shape)\n\n\nclass Conv2d(nn.Module):\n  \"\"\"Conv2d layer with optimal upsampling and downsampling (StyleGAN2).\"\"\"\n\n  def __init__(self, in_ch, out_ch, kernel, up=False, down=False,\n               resample_kernel=(1, 3, 3, 1),\n               use_bias=True,\n               kernel_init=None):\n    super().__init__()\n    assert not (up and down)\n    assert kernel >= 1 and kernel % 2 == 1\n    self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))\n    if kernel_init is not None:\n      self.weight.data = kernel_init(self.weight.data.shape)\n    if use_bias:\n      self.bias = nn.Parameter(torch.zeros(out_ch))\n\n    self.up = up\n    self.down = down\n    self.resample_kernel = resample_kernel\n    self.kernel = kernel\n    self.use_bias = use_bias\n\n  def forward(self, x):\n    if self.up:\n      x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)\n    elif self.down:\n      x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)\n    else:\n      x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)\n\n    if self.use_bias:\n      x = x + self.bias.reshape(1, -1, 1, 1)\n\n    return x\n\n\ndef naive_upsample_2d(x, factor=2):\n  _N, C, H, W = x.shape\n  x = torch.reshape(x, (-1, C, H, 1, W, 1))\n  x = x.repeat(1, 1, 1, factor, 1, factor)\n  return torch.reshape(x, (-1, C, H * factor, W * factor))\n\n\ndef naive_downsample_2d(x, factor=2):\n  _N, C, H, W = x.shape\n  x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))\n  return torch.mean(x, dim=(3, 5))\n\n\ndef upsample_conv_2d(x, w, k=None, factor=2, gain=1):\n  \"\"\"Fused `upsample_2d()` followed by `tf.nn.conv2d()`.\n\n     Padding is performed only once at the beginning, not between the\n     operations.\n     The fused op is considerably more efficient than performing the same\n     calculation\n     using standard TensorFlow ops. It supports gradients of arbitrary order.\n     Args:\n       x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,\n         C]`.\n       w:            Weight tensor of the shape `[filterH, filterW, inChannels,\n         outChannels]`. Grouped convolution can be performed by `inChannels =\n         x.shape[0] // numGroups`.\n       k:            FIR filter of the shape `[firH, firW]` or `[firN]`\n         (separable). The default is `[1] * factor`, which corresponds to\n         nearest-neighbor upsampling.\n       factor:       Integer upsampling factor (default: 2).\n       gain:         Scaling factor for signal magnitude (default: 1.0).\n\n     Returns:\n       Tensor of the shape `[N, C, H * factor, W * factor]` or\n       `[N, H * factor, W * factor, C]`, and same datatype as `x`.\n  \"\"\"\n\n  assert isinstance(factor, int) and factor >= 1\n\n  # Check weight shape.\n  assert len(w.shape) == 4\n  convH = w.shape[2]\n  convW = w.shape[3]\n  inC = w.shape[1]\n  outC = w.shape[0]\n\n  assert convW == convH\n\n  # Setup filter kernel.\n  if k is None:\n    k = [1] * factor\n  k = _setup_kernel(k) * (gain * (factor ** 2))\n  p = (k.shape[0] - factor) - (convW - 1)\n\n  stride = (factor, factor)\n\n  # Determine data dimensions.\n  stride = [1, 1, factor, factor]\n  output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)\n  output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,\n                    output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)\n  assert output_padding[0] >= 0 and output_padding[1] >= 0\n  num_groups = _shape(x, 1) // inC\n\n  # Transpose weights.\n  w = torch.reshape(w, (num_groups, -1, inC, convH, convW))\n  w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)\n  w = torch.reshape(w, (num_groups * inC, -1, convH, convW))\n\n  x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)\n  ## Original TF code.\n  # x = tf.nn.conv2d_transpose(\n  #     x,\n  #     w,\n  #     output_shape=output_shape,\n  #     strides=stride,\n  #     padding='VALID',\n  #     data_format=data_format)\n  ## JAX equivalent\n\n  return upfirdn2d(x, torch.tensor(k, device=x.device),\n                   pad=((p + 1) // 2 + factor - 1, p // 2 + 1))\n\n\ndef conv_downsample_2d(x, w, k=None, factor=2, gain=1):\n  \"\"\"Fused `tf.nn.conv2d()` followed by `downsample_2d()`.\n\n    Padding is performed only once at the beginning, not between the operations.\n    The fused op is considerably more efficient than performing the same\n    calculation\n    using standard TensorFlow ops. It supports gradients of arbitrary order.\n    Args:\n        x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,\n          C]`.\n        w:            Weight tensor of the shape `[filterH, filterW, inChannels,\n          outChannels]`. Grouped convolution can be performed by `inChannels =\n          x.shape[0] // numGroups`.\n        k:            FIR filter of the shape `[firH, firW]` or `[firN]`\n          (separable). The default is `[1] * factor`, which corresponds to\n          average pooling.\n        factor:       Integer downsampling factor (default: 2).\n        gain:         Scaling factor for signal magnitude (default: 1.0).\n\n    Returns:\n        Tensor of the shape `[N, C, H // factor, W // factor]` or\n        `[N, H // factor, W // factor, C]`, and same datatype as `x`.\n  \"\"\"\n\n  assert isinstance(factor, int) and factor >= 1\n  _outC, _inC, convH, convW = w.shape\n  assert convW == convH\n  if k is None:\n    k = [1] * factor\n  k = _setup_kernel(k) * gain\n  p = (k.shape[0] - factor) + (convW - 1)\n  s = [factor, factor]\n  x = upfirdn2d(x, torch.tensor(k, device=x.device),\n                pad=((p + 1) // 2, p // 2))\n  return F.conv2d(x, w, stride=s, padding=0)\n\n\ndef _setup_kernel(k):\n  k = np.asarray(k, dtype=np.float32)\n  if k.ndim == 1:\n    k = np.outer(k, k)\n  k /= np.sum(k)\n  assert k.ndim == 2\n  assert k.shape[0] == k.shape[1]\n  return k\n\n\ndef _shape(x, dim):\n  return x.shape[dim]\n\n\ndef upsample_2d(x, k=None, factor=2, gain=1):\n  r\"\"\"Upsample a batch of 2D images with the given filter.\n\n    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`\n    and upsamples each image with the given filter. The filter is normalized so\n    that\n    if the input pixels are constant, they will be scaled by the specified\n    `gain`.\n    Pixels outside the image are assumed to be zero, and the filter is padded\n    with\n    zeros so that its shape is a multiple of the upsampling factor.\n    Args:\n        x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,\n          C]`.\n        k:            FIR filter of the shape `[firH, firW]` or `[firN]`\n          (separable). The default is `[1] * factor`, which corresponds to\n          nearest-neighbor upsampling.\n        factor:       Integer upsampling factor (default: 2).\n        gain:         Scaling factor for signal magnitude (default: 1.0).\n\n    Returns:\n        Tensor of the shape `[N, C, H * factor, W * factor]`\n  \"\"\"\n  assert isinstance(factor, int) and factor >= 1\n  if k is None:\n    k = [1] * factor\n  k = _setup_kernel(k) * (gain * (factor ** 2))\n  p = k.shape[0] - factor\n  return upfirdn2d(x, torch.tensor(k, device=x.device),\n                   up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))\n\n\ndef downsample_2d(x, k=None, factor=2, gain=1):\n  r\"\"\"Downsample a batch of 2D images with the given filter.\n\n    Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`\n    and downsamples each image with the given filter. The filter is normalized\n    so that\n    if the input pixels are constant, they will be scaled by the specified\n    `gain`.\n    Pixels outside the image are assumed to be zero, and the filter is padded\n    with\n    zeros so that its shape is a multiple of the downsampling factor.\n    Args:\n        x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,\n          C]`.\n        k:            FIR filter of the shape `[firH, firW]` or `[firN]`\n          (separable). The default is `[1] * factor`, which corresponds to\n          average pooling.\n        factor:       Integer downsampling factor (default: 2).\n        gain:         Scaling factor for signal magnitude (default: 1.0).\n\n    Returns:\n        Tensor of the shape `[N, C, H // factor, W // factor]`\n  \"\"\"\n\n  assert isinstance(factor, int) and factor >= 1\n  if k is None:\n    k = [1] * factor\n  k = _setup_kernel(k) * gain\n  p = k.shape[0] - factor\n  return upfirdn2d(x, torch.tensor(k, device=x.device),\n                   down=factor, pad=((p + 1) // 2, p // 2))\n"
  },
  {
    "path": "models/utils.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"All functions and modules related to model definition.\n\"\"\"\n\nimport torch\nimport sde_lib\nimport numpy as np\n\n\n_MODELS = {}\n\n\ndef register_model(cls=None, *, name=None):\n  \"\"\"A decorator for registering model classes.\"\"\"\n\n  def _register(cls):\n    if name is None:\n      local_name = cls.__name__\n    else:\n      local_name = name\n    if local_name in _MODELS:\n      raise ValueError(f'Already registered model with name: {local_name}')\n    _MODELS[local_name] = cls\n    return cls\n\n  if cls is None:\n    return _register\n  else:\n    return _register(cls)\n\n\ndef get_model(name):\n  return _MODELS[name]\n\n\ndef get_sigmas(config):\n  \"\"\"Get sigmas --- the set of noise levels for SMLD from config files.\n  Args:\n    config: A ConfigDict object parsed from the config file\n  Returns:\n    sigmas: a jax numpy arrary of noise levels\n  \"\"\"\n  sigmas = np.exp(\n    np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))\n\n  return sigmas\n\n\ndef get_ddpm_params(config):\n  \"\"\"Get betas and alphas --- parameters used in the original DDPM paper.\"\"\"\n  num_diffusion_timesteps = 1000\n  # parameters need to be adapted if number of time steps differs from 1000\n  beta_start = config.model.beta_min / config.model.num_scales\n  beta_end = config.model.beta_max / config.model.num_scales\n  betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)\n\n  alphas = 1. - betas\n  alphas_cumprod = np.cumprod(alphas, axis=0)\n  sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)\n  sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)\n\n  return {\n    'betas': betas,\n    'alphas': alphas,\n    'alphas_cumprod': alphas_cumprod,\n    'sqrt_alphas_cumprod': sqrt_alphas_cumprod,\n    'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,\n    'beta_min': beta_start * (num_diffusion_timesteps - 1),\n    'beta_max': beta_end * (num_diffusion_timesteps - 1),\n    'num_diffusion_timesteps': num_diffusion_timesteps\n  }\n\n\ndef create_model(config):\n  \"\"\"Create the score model.\"\"\"\n  model_name = config.model.name\n  score_model = get_model(model_name)(config)\n  score_model = score_model.to(config.device)\n  score_model = torch.nn.DataParallel(score_model)\n  return score_model\n\n\ndef get_model_fn(model, train=False):\n  \"\"\"Create a function to give the output of the score-based model.\n\n  Args:\n    model: The score model.\n    train: `True` for training and `False` for evaluation.\n\n  Returns:\n    A model function.\n  \"\"\"\n\n  def model_fn(x, labels):\n    \"\"\"Compute the output of the score-based model.\n\n    Args:\n      x: A mini-batch of input data.\n      labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently\n        for different models.\n\n    Returns:\n      A tuple of (model output, new mutable states)\n    \"\"\"\n    if not train:\n      model.eval()\n      return model(x, labels)\n    else:\n      model.train()\n      return model(x, labels)\n\n  return model_fn\n\n\ndef get_score_fn(sde, model, train=False, continuous=False):\n  \"\"\"Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.\n\n  Args:\n    sde: An `sde_lib.SDE` object that represents the forward SDE.\n    model: A score model.\n    train: `True` for training and `False` for evaluation.\n    continuous: If `True`, the score-based model is expected to directly take continuous time steps.\n\n  Returns:\n    A score function.\n  \"\"\"\n  model_fn = get_model_fn(model, train=train)\n\n  if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):\n    def score_fn(x, t):\n      # Scale neural network output by standard deviation and flip sign\n      if continuous or isinstance(sde, sde_lib.subVPSDE):\n        # For VP-trained models, t=0 corresponds to the lowest noise level\n        # The maximum value of time embedding is assumed to 999 for\n        # continuously-trained models.\n        labels = t * 999\n        score = model_fn(x, labels)\n        std = sde.marginal_prob(torch.zeros_like(x), t)[1]\n      else:\n        # For VP-trained models, t=0 corresponds to the lowest noise level\n        labels = t * (sde.N - 1)\n        score = model_fn(x, labels)\n        std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]\n\n      score = -score / std[:, None, None, None]\n      return score\n\n  elif isinstance(sde, sde_lib.VESDE):\n    def score_fn(x, t):\n      if continuous:\n        labels = sde.marginal_prob(torch.zeros_like(x), t)[1]\n      else:\n        # For VE-trained models, t=0 corresponds to the highest noise level\n        labels = sde.T - t\n        labels *= sde.N - 1\n        labels = torch.round(labels).long()\n\n      score = model_fn(x, labels)\n      return score\n\n  else:\n    raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n\n  return score_fn\n\n\ndef to_flattened_numpy(x):\n  \"\"\"Flatten a torch tensor `x` and convert it to numpy.\"\"\"\n  return x.detach().cpu().numpy().reshape((-1,))\n\n\ndef from_flattened_numpy(x, shape):\n  \"\"\"Form a torch tensor with the given `shape` from a flattened numpy array `x`.\"\"\"\n  return torch.from_numpy(x.reshape(shape))"
  },
  {
    "path": "op/__init__.py",
    "content": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
  },
  {
    "path": "op/fused_act.py",
    "content": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension import load\r\n\r\n\r\nmodule_path = os.path.dirname(__file__)\r\nfused = load(\r\n    \"fused\",\r\n    sources=[\r\n        os.path.join(module_path, \"fused_bias_act.cpp\"),\r\n        os.path.join(module_path, \"fused_bias_act_kernel.cu\"),\r\n    ],\r\n)\r\n\r\n\r\nclass FusedLeakyReLUFunctionBackward(Function):\r\n    @staticmethod\r\n    def forward(ctx, grad_output, out, negative_slope, scale):\r\n        ctx.save_for_backward(out)\r\n        ctx.negative_slope = negative_slope\r\n        ctx.scale = scale\r\n\r\n        empty = grad_output.new_empty(0)\r\n\r\n        grad_input = fused.fused_bias_act(\r\n            grad_output, empty, out, 3, 1, negative_slope, scale\r\n        )\r\n\r\n        dim = [0]\r\n\r\n        if grad_input.ndim > 2:\r\n            dim += list(range(2, grad_input.ndim))\r\n\r\n        grad_bias = grad_input.sum(dim).detach()\r\n\r\n        return grad_input, grad_bias\r\n\r\n    @staticmethod\r\n    def backward(ctx, gradgrad_input, gradgrad_bias):\r\n        out, = ctx.saved_tensors\r\n        gradgrad_out = fused.fused_bias_act(\r\n            gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale\r\n        )\r\n\r\n        return gradgrad_out, None, None, None\r\n\r\n\r\nclass FusedLeakyReLUFunction(Function):\r\n    @staticmethod\r\n    def forward(ctx, input, bias, negative_slope, scale):\r\n        empty = input.new_empty(0)\r\n        out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)\r\n        ctx.save_for_backward(out)\r\n        ctx.negative_slope = negative_slope\r\n        ctx.scale = scale\r\n\r\n        return out\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        out, = ctx.saved_tensors\r\n\r\n        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(\r\n            grad_output, out, ctx.negative_slope, ctx.scale\r\n        )\r\n\r\n        return grad_input, grad_bias, None, None\r\n\r\n\r\nclass FusedLeakyReLU(nn.Module):\r\n    def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):\r\n        super().__init__()\r\n\r\n        self.bias = nn.Parameter(torch.zeros(channel))\r\n        self.negative_slope = negative_slope\r\n        self.scale = scale\r\n\r\n    def forward(self, input):\r\n        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)\r\n\r\n\r\ndef fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):\r\n    if input.device.type == \"cpu\":\r\n        rest_dim = [1] * (input.ndim - bias.ndim - 1)\r\n        return (\r\n            F.leaky_relu(\r\n                input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2\r\n            )\r\n            * scale\r\n        )\r\n\r\n    else:\r\n        return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)\r\n"
  },
  {
    "path": "op/fused_bias_act.cpp",
    "content": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale);\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\r\n\r\ntorch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale) {\r\n    CHECK_CUDA(input);\r\n    CHECK_CUDA(bias);\r\n\r\n    return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n    m.def(\"fused_bias_act\", &fused_bias_act, \"fused bias act (CUDA)\");\r\n}"
  },
  {
    "path": "op/fused_bias_act_kernel.cu",
    "content": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Source Code License-NC.\r\n// To view a copy of this license, visit\r\n// https://nvlabs.github.io/stylegan2/license.html\r\n\r\n#include <torch/types.h>\r\n\r\n#include <ATen/ATen.h>\r\n#include <ATen/AccumulateType.h>\r\n#include <ATen/cuda/CUDAContext.h>\r\n#include <ATen/cuda/CUDAApplyUtils.cuh>\r\n\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n\r\n\r\ntemplate <typename scalar_t>\r\nstatic __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,\r\n    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {\r\n    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;\r\n\r\n    scalar_t zero = 0.0;\r\n\r\n    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {\r\n        scalar_t x = p_x[xi];\r\n\r\n        if (use_bias) {\r\n            x += p_b[(xi / step_b) % size_b];\r\n        }\r\n\r\n        scalar_t ref = use_ref ? p_ref[xi] : zero;\r\n\r\n        scalar_t y;\r\n\r\n        switch (act * 10 + grad) {\r\n            default:\r\n            case 10: y = x; break;\r\n            case 11: y = x; break;\r\n            case 12: y = 0.0; break;\r\n\r\n            case 30: y = (x > 0.0) ? x : x * alpha; break;\r\n            case 31: y = (ref > 0.0) ? x : x * alpha; break;\r\n            case 32: y = 0.0; break;\r\n        }\r\n\r\n        out[xi] = y * scale;\r\n    }\r\n}\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale) {\r\n    int curDevice = -1;\r\n    cudaGetDevice(&curDevice);\r\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);\r\n\r\n    auto x = input.contiguous();\r\n    auto b = bias.contiguous();\r\n    auto ref = refer.contiguous();\r\n\r\n    int use_bias = b.numel() ? 1 : 0;\r\n    int use_ref = ref.numel() ? 1 : 0;\r\n\r\n    int size_x = x.numel();\r\n    int size_b = b.numel();\r\n    int step_b = 1;\r\n\r\n    for (int i = 1 + 1; i < x.dim(); i++) {\r\n        step_b *= x.size(i);\r\n    }\r\n\r\n    int loop_x = 4;\r\n    int block_size = 4 * 32;\r\n    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;\r\n\r\n    auto y = torch::empty_like(x);\r\n\r\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"fused_bias_act_kernel\", [&] {\r\n        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(\r\n            y.data_ptr<scalar_t>(),\r\n            x.data_ptr<scalar_t>(),\r\n            b.data_ptr<scalar_t>(),\r\n            ref.data_ptr<scalar_t>(),\r\n            act,\r\n            grad,\r\n            alpha,\r\n            scale,\r\n            loop_x,\r\n            size_x,\r\n            step_b,\r\n            size_b,\r\n            use_bias,\r\n            use_ref\r\n        );\r\n    });\r\n\r\n    return y;\r\n}"
  },
  {
    "path": "op/upfirdn2d.cpp",
    "content": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,\r\n                            int up_x, int up_y, int down_x, int down_y,\r\n                            int pad_x0, int pad_x1, int pad_y0, int pad_y1);\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\r\n\r\ntorch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,\r\n                        int up_x, int up_y, int down_x, int down_y,\r\n                        int pad_x0, int pad_x1, int pad_y0, int pad_y1) {\r\n    CHECK_CUDA(input);\r\n    CHECK_CUDA(kernel);\r\n\r\n    return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n    m.def(\"upfirdn2d\", &upfirdn2d, \"upfirdn2d (CUDA)\");\r\n}"
  },
  {
    "path": "op/upfirdn2d.py",
    "content": "import os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension import load\r\n\r\n\r\nmodule_path = os.path.dirname(__file__)\r\nupfirdn2d_op = load(\r\n    \"upfirdn2d\",\r\n    sources=[\r\n        os.path.join(module_path, \"upfirdn2d.cpp\"),\r\n        os.path.join(module_path, \"upfirdn2d_kernel.cu\"),\r\n    ],\r\n)\r\n\r\n\r\nclass UpFirDn2dBackward(Function):\r\n    @staticmethod\r\n    def forward(\r\n        ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size\r\n    ):\r\n\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad\r\n\r\n        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)\r\n\r\n        grad_input = upfirdn2d_op.upfirdn2d(\r\n            grad_output,\r\n            grad_kernel,\r\n            down_x,\r\n            down_y,\r\n            up_x,\r\n            up_y,\r\n            g_pad_x0,\r\n            g_pad_x1,\r\n            g_pad_y0,\r\n            g_pad_y1,\r\n        )\r\n        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])\r\n\r\n        ctx.save_for_backward(kernel)\r\n\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        ctx.up_x = up_x\r\n        ctx.up_y = up_y\r\n        ctx.down_x = down_x\r\n        ctx.down_y = down_y\r\n        ctx.pad_x0 = pad_x0\r\n        ctx.pad_x1 = pad_x1\r\n        ctx.pad_y0 = pad_y0\r\n        ctx.pad_y1 = pad_y1\r\n        ctx.in_size = in_size\r\n        ctx.out_size = out_size\r\n\r\n        return grad_input\r\n\r\n    @staticmethod\r\n    def backward(ctx, gradgrad_input):\r\n        kernel, = ctx.saved_tensors\r\n\r\n        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)\r\n\r\n        gradgrad_out = upfirdn2d_op.upfirdn2d(\r\n            gradgrad_input,\r\n            kernel,\r\n            ctx.up_x,\r\n            ctx.up_y,\r\n            ctx.down_x,\r\n            ctx.down_y,\r\n            ctx.pad_x0,\r\n            ctx.pad_x1,\r\n            ctx.pad_y0,\r\n            ctx.pad_y1,\r\n        )\r\n        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])\r\n        gradgrad_out = gradgrad_out.view(\r\n            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]\r\n        )\r\n\r\n        return gradgrad_out, None, None, None, None, None, None, None, None\r\n\r\n\r\nclass UpFirDn2d(Function):\r\n    @staticmethod\r\n    def forward(ctx, input, kernel, up, down, pad):\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        kernel_h, kernel_w = kernel.shape\r\n        batch, channel, in_h, in_w = input.shape\r\n        ctx.in_size = input.shape\r\n\r\n        input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))\r\n\r\n        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1\r\n        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1\r\n        ctx.out_size = (out_h, out_w)\r\n\r\n        ctx.up = (up_x, up_y)\r\n        ctx.down = (down_x, down_y)\r\n        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)\r\n\r\n        g_pad_x0 = kernel_w - pad_x0 - 1\r\n        g_pad_y0 = kernel_h - pad_y0 - 1\r\n        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1\r\n        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1\r\n\r\n        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)\r\n\r\n        out = upfirdn2d_op.upfirdn2d(\r\n            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n        )\r\n        # out = out.view(major, out_h, out_w, minor)\r\n        out = out.view(-1, channel, out_h, out_w)\r\n\r\n        return out\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        kernel, grad_kernel = ctx.saved_tensors\r\n\r\n        grad_input = UpFirDn2dBackward.apply(\r\n            grad_output,\r\n            kernel,\r\n            grad_kernel,\r\n            ctx.up,\r\n            ctx.down,\r\n            ctx.pad,\r\n            ctx.g_pad,\r\n            ctx.in_size,\r\n            ctx.out_size,\r\n        )\r\n\r\n        return grad_input, None, None, None, None\r\n\r\n\r\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\r\n    if input.device.type == \"cpu\":\r\n        out = upfirdn2d_native(\r\n            input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]\r\n        )\r\n\r\n    else:\r\n        out = UpFirDn2d.apply(\r\n            input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])\r\n        )\r\n\r\n    return out\r\n\r\n\r\ndef upfirdn2d_native(\r\n    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n):\r\n    _, channel, in_h, in_w = input.shape\r\n    input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n    _, in_h, in_w, minor = input.shape\r\n    kernel_h, kernel_w = kernel.shape\r\n\r\n    out = input.view(-1, in_h, 1, in_w, 1, minor)\r\n    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])\r\n    out = out.view(-1, in_h * up_y, in_w * up_x, minor)\r\n\r\n    out = F.pad(\r\n        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]\r\n    )\r\n    out = out[\r\n        :,\r\n        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),\r\n        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),\r\n        :,\r\n    ]\r\n\r\n    out = out.permute(0, 3, 1, 2)\r\n    out = out.reshape(\r\n        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]\r\n    )\r\n    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\r\n    out = F.conv2d(out, w)\r\n    out = out.reshape(\r\n        -1,\r\n        minor,\r\n        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\r\n        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,\r\n    )\r\n    out = out.permute(0, 2, 3, 1)\r\n    out = out[:, ::down_y, ::down_x, :]\r\n\r\n    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1\r\n    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1\r\n\r\n    return out.view(-1, channel, out_h, out_w)\r\n"
  },
  {
    "path": "op/upfirdn2d_kernel.cu",
    "content": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Source Code License-NC.\r\n// To view a copy of this license, visit\r\n// https://nvlabs.github.io/stylegan2/license.html\r\n\r\n#include <torch/types.h>\r\n\r\n#include <ATen/ATen.h>\r\n#include <ATen/AccumulateType.h>\r\n#include <ATen/cuda/CUDAApplyUtils.cuh>\r\n#include <ATen/cuda/CUDAContext.h>\r\n\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n\r\nstatic __host__ __device__ __forceinline__ int floor_div(int a, int b) {\r\n  int c = a / b;\r\n\r\n  if (c * b > a) {\r\n    c--;\r\n  }\r\n\r\n  return c;\r\n}\r\n\r\nstruct UpFirDn2DKernelParams {\r\n  int up_x;\r\n  int up_y;\r\n  int down_x;\r\n  int down_y;\r\n  int pad_x0;\r\n  int pad_x1;\r\n  int pad_y0;\r\n  int pad_y1;\r\n\r\n  int major_dim;\r\n  int in_h;\r\n  int in_w;\r\n  int minor_dim;\r\n  int kernel_h;\r\n  int kernel_w;\r\n  int out_h;\r\n  int out_w;\r\n  int loop_major;\r\n  int loop_x;\r\n};\r\n\r\ntemplate <typename scalar_t>\r\n__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,\r\n                                       const scalar_t *kernel,\r\n                                       const UpFirDn2DKernelParams p) {\r\n  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;\r\n  int out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= out_y * p.minor_dim;\r\n  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (out_x_base >= p.out_w || out_y >= p.out_h ||\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;\r\n  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);\r\n  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;\r\n  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major && major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, out_x = out_x_base;\r\n         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {\r\n      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;\r\n      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);\r\n      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;\r\n      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;\r\n\r\n      const scalar_t *x_p =\r\n          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +\r\n                 minor_idx];\r\n      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];\r\n      int x_px = p.minor_dim;\r\n      int k_px = -p.up_x;\r\n      int x_py = p.in_w * p.minor_dim;\r\n      int k_py = -p.up_y * p.kernel_w;\r\n\r\n      scalar_t v = 0.0f;\r\n\r\n      for (int y = 0; y < h; y++) {\r\n        for (int x = 0; x < w; x++) {\r\n          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);\r\n          x_p += x_px;\r\n          k_p += k_px;\r\n        }\r\n\r\n        x_p += x_py - w * x_px;\r\n        k_p += k_py - w * k_px;\r\n      }\r\n\r\n      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n          minor_idx] = v;\r\n    }\r\n  }\r\n}\r\n\r\ntemplate <typename scalar_t, int up_x, int up_y, int down_x, int down_y,\r\n          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>\r\n__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,\r\n                                 const scalar_t *kernel,\r\n                                 const UpFirDn2DKernelParams p) {\r\n  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;\r\n  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;\r\n\r\n  __shared__ volatile float sk[kernel_h][kernel_w];\r\n  __shared__ volatile float sx[tile_in_h][tile_in_w];\r\n\r\n  int minor_idx = blockIdx.x;\r\n  int tile_out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= tile_out_y * p.minor_dim;\r\n  tile_out_y *= tile_out_h;\r\n  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;\r\n       tap_idx += blockDim.x) {\r\n    int ky = tap_idx / kernel_w;\r\n    int kx = tap_idx - ky * kernel_w;\r\n    scalar_t v = 0.0;\r\n\r\n    if (kx < p.kernel_w & ky < p.kernel_h) {\r\n      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];\r\n    }\r\n\r\n    sk[ky][kx] = v;\r\n  }\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major & major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, tile_out_x = tile_out_x_base;\r\n         loop_x < p.loop_x & tile_out_x < p.out_w;\r\n         loop_x++, tile_out_x += tile_out_w) {\r\n      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;\r\n      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;\r\n      int tile_in_x = floor_div(tile_mid_x, up_x);\r\n      int tile_in_y = floor_div(tile_mid_y, up_y);\r\n\r\n      __syncthreads();\r\n\r\n      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;\r\n           in_idx += blockDim.x) {\r\n        int rel_in_y = in_idx / tile_in_w;\r\n        int rel_in_x = in_idx - rel_in_y * tile_in_w;\r\n        int in_x = rel_in_x + tile_in_x;\r\n        int in_y = rel_in_y + tile_in_y;\r\n\r\n        scalar_t v = 0.0;\r\n\r\n        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {\r\n          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *\r\n                        p.minor_dim +\r\n                    minor_idx];\r\n        }\r\n\r\n        sx[rel_in_y][rel_in_x] = v;\r\n      }\r\n\r\n      __syncthreads();\r\n      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;\r\n           out_idx += blockDim.x) {\r\n        int rel_out_y = out_idx / tile_out_w;\r\n        int rel_out_x = out_idx - rel_out_y * tile_out_w;\r\n        int out_x = rel_out_x + tile_out_x;\r\n        int out_y = rel_out_y + tile_out_y;\r\n\r\n        int mid_x = tile_mid_x + rel_out_x * down_x;\r\n        int mid_y = tile_mid_y + rel_out_y * down_y;\r\n        int in_x = floor_div(mid_x, up_x);\r\n        int in_y = floor_div(mid_y, up_y);\r\n        int rel_in_x = in_x - tile_in_x;\r\n        int rel_in_y = in_y - tile_in_y;\r\n        int kernel_x = (in_x + 1) * up_x - mid_x - 1;\r\n        int kernel_y = (in_y + 1) * up_y - mid_y - 1;\r\n\r\n        scalar_t v = 0.0;\r\n\r\n#pragma unroll\r\n        for (int y = 0; y < kernel_h / up_y; y++)\r\n#pragma unroll\r\n          for (int x = 0; x < kernel_w / up_x; x++)\r\n            v += sx[rel_in_y + y][rel_in_x + x] *\r\n                 sk[kernel_y + y * up_y][kernel_x + x * up_x];\r\n\r\n        if (out_x < p.out_w & out_y < p.out_h) {\r\n          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n              minor_idx] = v;\r\n        }\r\n      }\r\n    }\r\n  }\r\n}\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor &input,\r\n                           const torch::Tensor &kernel, int up_x, int up_y,\r\n                           int down_x, int down_y, int pad_x0, int pad_x1,\r\n                           int pad_y0, int pad_y1) {\r\n  int curDevice = -1;\r\n  cudaGetDevice(&curDevice);\r\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);\r\n\r\n  UpFirDn2DKernelParams p;\r\n\r\n  auto x = input.contiguous();\r\n  auto k = kernel.contiguous();\r\n\r\n  p.major_dim = x.size(0);\r\n  p.in_h = x.size(1);\r\n  p.in_w = x.size(2);\r\n  p.minor_dim = x.size(3);\r\n  p.kernel_h = k.size(0);\r\n  p.kernel_w = k.size(1);\r\n  p.up_x = up_x;\r\n  p.up_y = up_y;\r\n  p.down_x = down_x;\r\n  p.down_y = down_y;\r\n  p.pad_x0 = pad_x0;\r\n  p.pad_x1 = pad_x1;\r\n  p.pad_y0 = pad_y0;\r\n  p.pad_y1 = pad_y1;\r\n\r\n  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /\r\n            p.down_y;\r\n  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /\r\n            p.down_x;\r\n\r\n  auto out =\r\n      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());\r\n\r\n  int mode = -1;\r\n\r\n  int tile_out_h = -1;\r\n  int tile_out_w = -1;\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 1;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 3 && p.kernel_w <= 3) {\r\n    mode = 2;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 3;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n    mode = 4;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 5;\r\n    tile_out_h = 8;\r\n    tile_out_w = 32;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n      p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n    mode = 6;\r\n    tile_out_h = 8;\r\n    tile_out_w = 32;\r\n  }\r\n\r\n  dim3 block_size;\r\n  dim3 grid_size;\r\n\r\n  if (tile_out_h > 0 && tile_out_w > 0) {\r\n    p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n    p.loop_x = 1;\r\n    block_size = dim3(32 * 8, 1, 1);\r\n    grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,\r\n                     (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,\r\n                     (p.major_dim - 1) / p.loop_major + 1);\r\n  } else {\r\n    p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n    p.loop_x = 4;\r\n    block_size = dim3(4, 32, 1);\r\n    grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,\r\n                     (p.out_w - 1) / (p.loop_x * block_size.y) + 1,\r\n                     (p.major_dim - 1) / p.loop_major + 1);\r\n  }\r\n\r\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"upfirdn2d_cuda\", [&] {\r\n    switch (mode) {\r\n    case 1:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 2:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 3:\r\n      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 4:\r\n      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 5:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 6:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    default:\r\n      upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(\r\n          out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),\r\n          k.data_ptr<scalar_t>(), p);\r\n    }\r\n  });\r\n\r\n  return out;\r\n}"
  },
  {
    "path": "physics/ct.py",
    "content": "import torch\nimport numpy as np\nfrom .radon import Radon, IRadon\n\nclass CT():\n    def __init__(self, img_width, radon_view, uniform=True, circle=False, device='cuda:0'):\n        if uniform:\n            theta = np.linspace(0, 180, radon_view, endpoint=False)\n            theta_all = np.linspace(0, 180, 180, endpoint=False)\n        else:\n            theta = torch.arange(radon_view)\n            theta_all = torch.arange(radon_view)\n\n        self.radon = Radon(img_width, theta, circle).to(device)\n        self.radon_all = Radon(img_width, theta_all, circle).to(device)\n        self.iradon_all = IRadon(img_width, theta_all, circle).to(device)\n        self.iradon = IRadon(img_width, theta, circle).to(device)\n        self.radont = IRadon(img_width, theta, circle, use_filter=None).to(device)\n\n    def A(self, x):\n        return self.radon(x)\n\n    def A_all(self, x):\n        return self.radon_all(x)\n\n    def A_all_dagger(self, x):\n        return self.iradon_all(x)\n\n    def A_dagger(self, y):\n        return self.iradon(y)\n\n    def AT(self, y):\n        return self.radont(y)\n\n\nclass CT_LA():\n    \"\"\"\n    Limited Angle tomography\n    \"\"\"\n    def __init__(self, img_width, radon_view, uniform=True, circle=False, device='cuda:0'):\n        if uniform:\n            theta = np.linspace(0, 180, radon_view, endpoint=False)\n        else:\n            theta = torch.arange(radon_view)\n        self.radon = Radon(img_width, theta, circle).to(device)\n        self.iradon = IRadon(img_width, theta, circle).to(device)\n        self.radont = IRadon(img_width, theta, circle, use_filter=None).to(device)\n\n    def A(self, x):\n        return self.radon(x)\n\n    def A_dagger(self, y):\n        return self.iradon(y)\n\n    def AT(self, y):\n        return self.radont(y)\n"
  },
  {
    "path": "physics/inpainting.py",
    "content": "import os\nimport torch\n\nclass Inpainting():\n    def __init__(self, img_heigth=512, img_width=512, mode='random', mask_rate=0.3, resize=False, device='cuda:0'):\n        mask_path = './physics/mask_random{}.pt'.format(mask_rate)\n        if os.path.exists(mask_path):\n            self.mask = torch.load(mask_path).to(device)\n        else:\n            self.mask = torch.ones(img_heigth, img_width, device=device)\n            self.mask[torch.rand_like(self.mask) > 1 - mask_rate] = 0\n            torch.save(self.mask, mask_path)\n\n    def A(self, x):\n        return torch.einsum('kl,ijkl->ijkl', self.mask, x)\n\n    def A_dagger(self, x):\n        return torch.einsum('kl,ijkl->ijkl', self.mask, x)\n"
  },
  {
    "path": "physics/radon/__init__.py",
    "content": "from .radon import Radon, IRadon\nfrom .stackgram import Stackgram, IStackgram"
  },
  {
    "path": "physics/radon/filters.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .utils import PI, fftfreq\n\n'''source: https://github.com/matteo-ronchetti/torch-radon'''\n\nclass AbstractFilter(nn.Module):\n    def __init__(self):\n        super(AbstractFilter, self).__init__()\n\n    def forward(self, x):\n        input_size = x.shape[2]\n        projection_size_padded = \\\n            max(64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil()))\n        pad_width = projection_size_padded - input_size\n        padded_tensor = F.pad(x, (0,0,0,pad_width))\n        f = self._get_fourier_filter(padded_tensor.shape[2]).to(x.device)\n        fourier_filter = self.create_filter(f)[..., None]\n        projection = torch.fft.fft(padded_tensor, dim=2) * fourier_filter\n        return torch.real(torch.fft.ifft(projection, dim=2)[:,:,:input_size,:])\n\n    def _get_fourier_filter(self, size):\n        n = torch.cat([\n            torch.arange(1, size / 2 + 1, 2),\n            torch.arange(size / 2 - 1, 0, -2)\n        ])\n\n        f = torch.zeros(size)\n        f[0] = 0.25\n        f[1::2] = -1 / (PI * n) ** 2\n\n        fourier_filter = torch.fft.fft(f)\n\n        return 2*fourier_filter\n\n    def create_filter(self, f):\n        raise NotImplementedError\n\nclass RampFilter(AbstractFilter):\n    def __init__(self):\n        super(RampFilter, self).__init__()\n\n    def create_filter(self, f):\n        return f\n\nclass HannFilter(AbstractFilter):\n    def __init__(self):\n        super(HannFilter, self).__init__()\n\n    def create_filter(self, f):\n        n = torch.arange(0, f.shape[0])\n        hann = 0.5 - 0.5*(2.0*PI*n/(f.shape[0]-1)).cos()\n        return f*hann.roll(hann.shape[0]//2,0).unsqueeze(-1)\n\nclass LearnableFilter(AbstractFilter):\n    def __init__(self, filter_size):\n        super(LearnableFilter, self).__init__()\n        self.filter = nn.Parameter(2*fftfreq(filter_size).abs().view(-1, 1))\n\n    def forward(self, x):\n        fourier_filter = self.filter.unsqueeze(-1).repeat(1,1,2).to(x.device)\n        projection = torch.rfft(x.transpose(2,3), 1, onesided=False).transpose(2,3) * fourier_filter\n        return torch.irfft(projection.transpose(2,3), 1, onesided=False).transpose(2,3)\n\n        # projection = torch.fft.rfft(x.transpose(2, 3), 1).transpose(2, 3) * fourier_filter\n        # return torch.fft.irfft(projection.transpose(2, 3), 1).transpose(2, 3)"
  },
  {
    "path": "physics/radon/radon.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom physics.radon.filters import RampFilter\nfrom physics.radon.utils import PI, SQRT2, deg2rad, affine_grid, grid_sample\n\n'''source: https://github.com/matteo-ronchetti/torch-radon'''\n\n\nclass Radon(nn.Module):\n    def __init__(self, in_size=None, theta=None, circle=True, dtype=torch.float):\n        super(Radon, self).__init__()\n        self.circle = circle\n        self.theta = theta\n        if theta is None:\n            self.theta = torch.arange(180)\n        self.dtype = dtype\n        self.all_grids = None\n        if in_size is not None:\n            self.all_grids = self._create_grids(self.theta, in_size, circle)\n\n    def forward(self, x):\n        N, C, W, H = x.shape\n        assert (W == H)\n\n        if self.all_grids is None:\n            self.all_grids = self._create_grids(self.theta, W, self.circle)\n\n        if not self.circle:\n            diagonal = SQRT2 * W\n            pad = int((diagonal - W).ceil())\n            new_center = (W + pad) // 2\n            old_center = W // 2\n            pad_before = new_center - old_center\n            pad_width = (pad_before, pad - pad_before)\n            x = F.pad(x, (pad_width[0], pad_width[1], pad_width[0], pad_width[1]))\n\n        N, C, W, _ = x.shape\n        out = torch.zeros(N, C, W, len(self.theta), device=x.device, dtype=self.dtype)\n\n        for i in range(len(self.theta)):\n            rotated = grid_sample(x, self.all_grids[i].repeat(N, 1, 1, 1).to(x.device))\n            out[..., i] = rotated.sum(2)\n\n        return out\n\n    def _create_grids(self, angles, grid_size, circle):\n        if not circle:\n            grid_size = int((SQRT2 * grid_size).ceil())\n        all_grids = []\n        for theta in angles:\n            theta = deg2rad(theta)\n            R = torch.tensor([[\n                [theta.cos(), theta.sin(), 0],\n                [-theta.sin(), theta.cos(), 0],\n            ]], dtype=self.dtype)\n            all_grids.append(affine_grid(R, torch.Size([1, 1, grid_size, grid_size])))\n        return all_grids\n\n\nclass IRadon(nn.Module):\n    def __init__(self, in_size=None, theta=None, circle=True,\n                 use_filter=RampFilter(), out_size=None, dtype=torch.float):\n        super(IRadon, self).__init__()\n        self.circle = circle\n        self.theta = theta if theta is not None else torch.arange(180)\n        self.out_size = out_size\n        self.in_size = in_size\n        self.dtype = dtype\n        self.ygrid, self.xgrid, self.all_grids = None, None, None\n        if in_size is not None:\n            self.ygrid, self.xgrid = self._create_yxgrid(in_size, circle)\n            self.all_grids = self._create_grids(self.theta, in_size, circle)\n        self.filter = use_filter if use_filter is not None else lambda x: x\n\n    def forward(self, x):\n        it_size = x.shape[2]\n        ch_size = x.shape[1]\n\n        if self.in_size is None:\n            self.in_size = int((it_size / SQRT2).floor()) if not self.circle else it_size\n        # if None in [self.ygrid, self.xgrid, self.all_grids]:\n        if self.ygrid is None or self.xgrid is None or self.all_grids is None :\n            self.ygrid, self.xgrid = self._create_yxgrid(self.in_size, self.circle)\n            self.all_grids = self._create_grids(self.theta, self.in_size, self.circle)\n\n        # sinogram\n        x = self.filter(x)\n\n        reco = torch.zeros(x.shape[0], ch_size, it_size, it_size, device=x.device, dtype=self.dtype)\n        for i_theta in range(len(self.theta)):\n            reco += grid_sample(x, self.all_grids[i_theta].repeat(reco.shape[0], 1, 1, 1).to(x.device))\n\n        if not self.circle:\n            W = self.in_size\n            diagonal = it_size\n            pad = int(torch.tensor(diagonal - W, dtype=torch.float).ceil())\n            new_center = (W + pad) // 2\n            old_center = W // 2\n            pad_before = new_center - old_center\n            pad_width = (pad_before, pad - pad_before)\n            reco = F.pad(reco, (-pad_width[0], -pad_width[1], -pad_width[0], -pad_width[1]))\n\n        if self.circle:\n            reconstruction_circle = (self.xgrid ** 2 + self.ygrid ** 2) <= 1\n            reconstruction_circle = reconstruction_circle.repeat(x.shape[0], ch_size, 1, 1)\n            reco[~reconstruction_circle] = 0.\n\n        reco = reco * PI.item() / (2 * len(self.theta))\n\n        if self.out_size is not None:\n            pad = (self.out_size - self.in_size) // 2\n            reco = F.pad(reco, (pad, pad, pad, pad))\n\n        return reco\n\n    def _create_yxgrid(self, in_size, circle):\n        if not circle:\n            in_size = int((SQRT2 * in_size).ceil())\n        unitrange = torch.linspace(-1, 1, in_size, dtype=self.dtype)\n        return torch.meshgrid(unitrange, unitrange)\n\n    def _XYtoT(self, theta):\n        T = self.xgrid * (deg2rad(theta)).cos() - self.ygrid * (deg2rad(theta)).sin()\n        return T\n\n    def _create_grids(self, angles, grid_size, circle):\n        if not circle:\n            grid_size = int((SQRT2 * grid_size).ceil())\n        all_grids = []\n        for i_theta in range(len(angles)):\n            X = torch.ones(grid_size, dtype=self.dtype).view(-1, 1).repeat(1, grid_size) * i_theta * 2. / (\n                        len(angles) - 1) - 1.\n            Y = self._XYtoT(angles[i_theta])\n            all_grids.append(torch.cat((X.unsqueeze(-1), Y.unsqueeze(-1)), dim=-1).unsqueeze(0))\n        return all_grids\n\n\nif __name__ == '__main__':\n    img_width = 2\n    num_proj = 180\n    device = 'cuda:0'\n    radon = Radon(in_size=img_width, theta=torch.arange(num_proj), circle=False).to(device)\n    iradon = IRadon(in_size=img_width, theta=torch.arange(num_proj), circle=False).to(device)\n\n    img = torch.randn([1, 1, 2, 2]).to(device)\n    sinogram = radon(img)\n    b_img = iradon(sinogram)\n"
  },
  {
    "path": "physics/radon/stackgram.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .utils import SQRT2, deg2rad, affine_grid, grid_sample\n\n'''source: https://github.com/matteo-ronchetti/torch-radon'''\n\nclass Stackgram(nn.Module):\n    def __init__(self, out_size, theta=None, circle=True, mode='nearest', dtype=torch.float):\n        super(Stackgram, self).__init__()\n        self.circle = circle\n        self.theta = theta\n        if theta is None:\n            self.theta = torch.arange(180)\n        self.out_size = out_size\n        self.in_size = in_size = out_size if circle else int((SQRT2*out_size).ceil())\n        self.dtype = dtype\n        self.all_grids = self._create_grids(self.theta, in_size)\n        self.mode = mode\n\n    def forward(self, x):\n        stackgram = torch.zeros(x.shape[0], len(self.theta), self.in_size, self.in_size, device=x.device, dtype=self.dtype)\n\n        for i_theta in range(len(self.theta)):\n            repline = x[...,i_theta]\n            repline = repline.unsqueeze(-1).repeat(1,1,1,repline.shape[2])\n            linogram = grid_sample(repline, self.all_grids[i_theta].repeat(x.shape[0],1,1,1).to(x.device), mode=self.mode)\n            stackgram[:,i_theta] = linogram\n\n        return stackgram\n\n    def _create_grids(self, angles, grid_size):\n        all_grids = []\n        for i_theta in range(len(angles)):\n            t = deg2rad(angles[i_theta])\n            R = torch.tensor([[t.sin(), t.cos(), 0.],[t.cos(), -t.sin(), 0.]], dtype=self.dtype).unsqueeze(0)\n            all_grids.append(affine_grid(R, torch.Size([1,1,grid_size,grid_size])))\n        return all_grids\n\nclass IStackgram(nn.Module):\n    def __init__(self, out_size, theta=None, circle=True, mode='bilinear', dtype=torch.float):\n        super(IStackgram, self).__init__()\n        self.circle = circle\n        self.theta = theta\n        if theta is None:\n            self.theta = torch.arange(180)\n        self.out_size = out_size\n        self.in_size = in_size = out_size if circle else int((SQRT2*out_size).ceil())\n        self.dtype = dtype\n        self.all_grids = self._create_grids(self.theta, in_size)\n        self.mode = mode\n\n    def forward(self, x):\n        sinogram = torch.zeros(x.shape[0], 1, self.in_size, len(self.theta), device=x.device, dtype=self.dtype)\n\n        for i_theta in range(len(self.theta)):\n            linogram = x[:,i_theta].unsqueeze(1)\n            repline = grid_sample(linogram, self.all_grids[i_theta].repeat(x.shape[0],1,1,1).to(x.device), mode=self.mode)\n            repline = repline[...,repline.shape[-1]//2]\n            sinogram[...,i_theta] = repline\n\n        return sinogram\n\n    def _create_grids(self, angles, grid_size):\n        all_grids = []\n        for i_theta in range(len(angles)):\n            t = deg2rad(angles[i_theta])\n            R = torch.tensor([[t.sin(), t.cos(), 0.],[t.cos(), -t.sin(), 0.]], dtype=self.dtype).unsqueeze(0)\n            all_grids.append(affine_grid(R, torch.Size([1,1,grid_size,grid_size])))\n        return all_grids"
  },
  {
    "path": "physics/radon/utils.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n'''source: https://github.com/matteo-ronchetti/torch-radon'''\n\nif torch.__version__>'1.2.0':\n    affine_grid = lambda theta, size: F.affine_grid(theta, size, align_corners=True)\n    grid_sample = lambda input, grid, mode='bilinear': F.grid_sample(input, grid, align_corners=True, mode=mode)\nelse:\n    affine_grid = F.affine_grid\n    grid_sample = F.grid_sample\n\n# constants\nPI = 4*torch.ones(1).atan()\nSQRT2 = (2*torch.ones(1)).sqrt()\n\ndef fftfreq(n):\n    val = 1.0/n\n    results = torch.zeros(n)\n    N = (n-1)//2 + 1\n    p1 = torch.arange(0, N)\n    results[:N] = p1\n    p2 = torch.arange(-(n//2), 0)\n    results[N:] = p2\n    return results*val\n\ndef deg2rad(x):\n    return x*PI/180"
  },
  {
    "path": "run_lib.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"Training and evaluation for score-based generative models. \"\"\"\n\nimport gc\nimport io\nimport os\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport logging\n# Keep the import below for registering all model definitions\nfrom models import ddpm, ncsnv2, ncsnpp, unet\nimport losses\nimport sampling\nfrom models import utils as mutils\nfrom models.ema import ExponentialMovingAverage\nimport datasets\n#import evaluation\nimport likelihood\nimport sde_lib\nfrom absl import flags\nimport torch\nfrom torch import nn\nfrom torch.utils import tensorboard\nfrom torchvision.utils import make_grid, save_image\nfrom utils import save_checkpoint, restore_checkpoint, get_mask, kspace_to_nchw, root_sum_of_squares\n\nFLAGS = flags.FLAGS\n\n\ndef train(config, workdir):\n  \"\"\"Runs the training pipeline.\n\n  Args:\n    config: Configuration to use.\n    workdir: Working directory for checkpoints and TF summaries. If this\n      contains checkpoint training will be resumed from the latest checkpoint.\n  \"\"\"\n\n  # Create directories for experimental logs\n  sample_dir = os.path.join(workdir, \"samples\")\n  Path(sample_dir).mkdir(parents=True, exist_ok=True)\n\n  tb_dir = os.path.join(workdir, \"tensorboard\")\n  Path(tb_dir).mkdir(parents=True, exist_ok=True)\n  writer = tensorboard.SummaryWriter(tb_dir)\n\n  # Initialize model.\n  score_model = mutils.create_model(config)\n  ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)\n  optimizer = losses.get_optimizer(config, score_model.parameters())\n  state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)\n\n  # Create checkpoints directory\n  checkpoint_dir = os.path.join(workdir, \"checkpoints\")\n  checkpoint_meta_dir = os.path.join(workdir, \"checkpoints-meta\")\n  Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)\n  Path(checkpoint_meta_dir).mkdir(parents=True, exist_ok=True)\n\n  # Resume training when intermediate checkpoints are detected\n  state = restore_checkpoint(checkpoint_meta_dir, state, config.device)\n  initial_step = int(state['step'])\n\n  # Build pytorch dataloader for training\n  train_dl, eval_dl = datasets.create_dataloader(config)\n  num_data = len(train_dl.dataset)\n\n  # Create data normalizer and its inverse\n  scaler = datasets.get_data_scaler(config)\n  inverse_scaler = datasets.get_data_inverse_scaler(config)\n\n  # Setup SDEs\n  if config.training.sde.lower() == 'vpsde':\n    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n    sampling_eps = 1e-3\n  elif config.training.sde.lower() == 'subvpsde':\n    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n    sampling_eps = 1e-3\n  elif config.training.sde.lower() == 'vesde':\n    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)\n    sampling_eps = 1e-5\n  else:\n    raise NotImplementedError(f\"SDE {config.training.sde} unknown.\")\n\n  # Build one-step training and evaluation functions\n  optimize_fn = losses.optimization_manager(config)\n  continuous = config.training.continuous\n  reduce_mean = config.training.reduce_mean\n  likelihood_weighting = config.training.likelihood_weighting\n  train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,\n                                     reduce_mean=reduce_mean, continuous=continuous,\n                                     likelihood_weighting=likelihood_weighting)\n  eval_step_fn = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn,\n                                    reduce_mean=reduce_mean, continuous=continuous,\n                                    likelihood_weighting=likelihood_weighting)\n\n  # Building sampling functions\n  if config.training.snapshot_sampling:\n    sampling_shape = (config.training.batch_size, config.data.num_channels,\n                      config.data.image_size, config.data.image_size)\n    sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)\n\n  # In case there are multiple hosts (e.g., TPU pods), only log to host 0\n  logging.info(\"Starting training loop at step %d.\" % (initial_step,))\n\n  for epoch in range(1, config.training.epochs):\n    print('=================================================')\n    print(f'Epoch: {epoch}')\n    print('=================================================')\n\n    for step, batch in enumerate(train_dl, start=1):\n      batch = scaler(batch.to(config.device))\n      # (b, 1, 320, 320, 2) --> (b, 2, 320, 320)\n      # batch = kspace_to_nchw(torch.view_as_real(batch))\n      # Execute one training step\n      loss = train_step_fn(state, batch)\n      if step % config.training.log_freq == 0:\n        logging.info(\"step: %d, training_loss: %.5e\" % (step, loss.item()))\n        global_step = num_data * epoch + step\n        writer.add_scalar(\"training_loss\", scalar_value=loss, global_step=global_step)\n      if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:\n        save_checkpoint(checkpoint_meta_dir, state)\n      # Report the loss on an evaluation dataset periodically\n      # if step % config.training.eval_freq == 0:\n      #   eval_batch = scaler(next(iter(eval_dl)).to(config.device))\n      #   eval_loss = eval_step_fn(state, eval_batch)\n      #   logging.info(\"step: %d, eval_loss: %.5e\" % (step, eval_loss.item()))\n      #   global_step = num_data * epoch + step\n      #   writer.add_scalar(\"eval_loss\", scalar_value=eval_loss.item(), global_step=global_step)\n\n    # Save a checkpoint for every epoch\n    save_checkpoint(checkpoint_dir, state, name=f'checkpoint_{epoch}.pth')\n\n    # Generate and save samples for every epoch\n    if config.training.snapshot_sampling:\n      print('sampling')\n      ema.store(score_model.parameters())\n      ema.copy_to(score_model.parameters())\n      sample, n = sampling_fn(score_model)\n      if config.data.is_complex:\n        sample = root_sum_of_squares(sample, dim=1).unsqueeze(dim=0)\n      ema.restore(score_model.parameters())\n      this_sample_dir = os.path.join(sample_dir, \"iter_{}\".format(epoch))\n      Path(this_sample_dir).mkdir(parents=True, exist_ok=True)\n      nrow = int(np.sqrt(sample.shape[0]))\n      image_grid = make_grid(sample, nrow, padding=2)\n      sample = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)\n      np.save(os.path.join(this_sample_dir, \"sample\"), sample)\n      save_image(image_grid, os.path.join(this_sample_dir, \"sample.png\"))\n\n\ndef evaluate(config,\n             workdir,\n             eval_folder=\"eval\"):\n  \"\"\"Evaluate trained models.\n\n  Args:\n    config: Configuration to use.\n    workdir: Working directory for checkpoints.\n    eval_folder: The subfolder for storing evaluation results. Default to\n      \"eval\".\n  \"\"\"\n  # Create directory to eval_folder\n  eval_dir = os.path.join(workdir, eval_folder)\n  Path(eval_dir).mkdir(parents=True, exist_ok=True)\n\n  # Build pytorch dataloader for training\n  train_dl, eval_dl = datasets.create_dataloader(config)\n  num_data = len(train_dl.dataset)\n\n  # Create data normalizer and its inverse\n  scaler = datasets.get_data_scaler(config)\n  inverse_scaler = datasets.get_data_inverse_scaler(config)\n\n  # Initialize model\n  score_model = mutils.create_model(config)\n  optimizer = losses.get_optimizer(config, score_model.parameters())\n  ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)\n  state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)\n\n  checkpoint_dir = os.path.join(workdir, \"checkpoints\")\n\n  # Setup SDEs\n  if config.training.sde.lower() == 'vpsde':\n    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n    sampling_eps = 1e-3\n  elif config.training.sde.lower() == 'subvpsde':\n    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n    sampling_eps = 1e-3\n  elif config.training.sde.lower() == 'vesde':\n    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)\n    sampling_eps = 1e-5\n  else:\n    raise NotImplementedError(f\"SDE {config.training.sde} unknown.\")\n\n  # Create the one-step evaluation function when loss computation is enabled\n  if config.eval.enable_loss:\n    optimize_fn = losses.optimization_manager(config)\n    continuous = config.training.continuous\n    likelihood_weighting = config.training.likelihood_weighting\n\n    reduce_mean = config.training.reduce_mean\n    eval_step = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn,\n                                   reduce_mean=reduce_mean,\n                                   continuous=continuous,\n                                   likelihood_weighting=likelihood_weighting)\n\n\n  # Create data loaders for likelihood evaluation. Only evaluate on uniformly dequantized data\n  train_ds_bpd, eval_ds_bpd, _ = datasets.get_dataset(config,\n                                                      uniform_dequantization=True, evaluation=True)\n  if config.eval.bpd_dataset.lower() == 'train':\n    ds_bpd = train_ds_bpd\n    bpd_num_repeats = 1\n  elif config.eval.bpd_dataset.lower() == 'test':\n    # Go over the dataset 5 times when computing likelihood on the test dataset\n    ds_bpd = eval_ds_bpd\n    bpd_num_repeats = 5\n  else:\n    raise ValueError(f\"No bpd dataset {config.eval.bpd_dataset} recognized.\")\n\n  # Build the likelihood computation function when likelihood is enabled\n  if config.eval.enable_bpd:\n    likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler)\n\n  # Build the sampling function when sampling is enabled\n  if config.eval.enable_sampling:\n    sampling_shape = (config.eval.batch_size,\n                      config.data.num_channels,\n                      config.data.image_size, config.data.image_size)\n    sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)\n\n  # Use inceptionV3 for images with resolution higher than 256.\n  inceptionv3 = config.data.image_size >= 256\n  inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)\n\n  begin_ckpt = config.eval.begin_ckpt\n  logging.info(\"begin checkpoint: %d\" % (begin_ckpt,))\n  for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):\n    # Wait if the target checkpoint doesn't exist yet\n    waiting_message_printed = False\n    ckpt_filename = os.path.join(checkpoint_dir, \"checkpoint_{}.pth\".format(ckpt))\n    while not tf.io.gfile.exists(ckpt_filename):\n      if not waiting_message_printed:\n        logging.warning(\"Waiting for the arrival of checkpoint_%d\" % (ckpt,))\n        waiting_message_printed = True\n      time.sleep(60)\n\n    # Wait for 2 additional mins in case the file exists but is not ready for reading\n    ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{ckpt}.pth')\n    try:\n      state = restore_checkpoint(ckpt_path, state, device=config.device)\n    except:\n      time.sleep(60)\n      try:\n        state = restore_checkpoint(ckpt_path, state, device=config.device)\n      except:\n        time.sleep(120)\n        state = restore_checkpoint(ckpt_path, state, device=config.device)\n    ema.copy_to(score_model.parameters())\n    # Compute the loss function on the full evaluation dataset if loss computation is enabled\n    if config.eval.enable_loss:\n      all_losses = []\n      eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types\n      for i, batch in enumerate(eval_iter):\n        eval_batch = torch.from_numpy(batch['image']._numpy()).to(config.device).float()\n        eval_batch = eval_batch.permute(0, 3, 1, 2)\n        eval_batch = scaler(eval_batch)\n        eval_loss = eval_step(state, eval_batch)\n        all_losses.append(eval_loss.item())\n        if (i + 1) % 1000 == 0:\n          logging.info(\"Finished %dth step loss evaluation\" % (i + 1))\n\n      # Save loss values to disk or Google Cloud Storage\n      all_losses = np.asarray(all_losses)\n      with tf.io.gfile.GFile(os.path.join(eval_dir, f\"ckpt_{ckpt}_loss.npz\"), \"wb\") as fout:\n        io_buffer = io.BytesIO()\n        np.savez_compressed(io_buffer, all_losses=all_losses, mean_loss=all_losses.mean())\n        fout.write(io_buffer.getvalue())\n\n    # Compute log-likelihoods (bits/dim) if enabled\n    if config.eval.enable_bpd:\n      bpds = []\n      for repeat in range(bpd_num_repeats):\n        bpd_iter = iter(ds_bpd)  # pytype: disable=wrong-arg-types\n        for batch_id in range(len(ds_bpd)):\n          batch = next(bpd_iter)\n          eval_batch = torch.from_numpy(batch['image']._numpy()).to(config.device).float()\n          eval_batch = eval_batch.permute(0, 3, 1, 2)\n          eval_batch = scaler(eval_batch)\n          bpd = likelihood_fn(score_model, eval_batch)[0]\n          bpd = bpd.detach().cpu().numpy().reshape(-1)\n          bpds.extend(bpd)\n          logging.info(\n            \"ckpt: %d, repeat: %d, batch: %d, mean bpd: %6f\" % (ckpt, repeat, batch_id, np.mean(np.asarray(bpds))))\n          bpd_round_id = batch_id + len(ds_bpd) * repeat\n          # Save bits/dim to disk or Google Cloud Storage\n          with tf.io.gfile.GFile(os.path.join(eval_dir,\n                                              f\"{config.eval.bpd_dataset}_ckpt_{ckpt}_bpd_{bpd_round_id}.npz\"),\n                                 \"wb\") as fout:\n            io_buffer = io.BytesIO()\n            np.savez_compressed(io_buffer, bpd)\n            fout.write(io_buffer.getvalue())\n\n    # Generate samples and compute IS/FID/KID when enabled\n    if config.eval.enable_sampling:\n      num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1\n      for r in range(num_sampling_rounds):\n        logging.info(\"sampling -- ckpt: %d, round: %d\" % (ckpt, r))\n\n        # Directory to save samples. Different for each host to avoid writing conflicts\n        this_sample_dir = os.path.join(\n          eval_dir, f\"ckpt_{ckpt}\")\n        tf.io.gfile.makedirs(this_sample_dir)\n        samples, n = sampling_fn(score_model)\n        samples = np.clip(samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)\n        samples = samples.reshape(\n          (-1, config.data.image_size, config.data.image_size, config.data.num_channels))\n        # Write samples to disk or Google Cloud Storage\n        with tf.io.gfile.GFile(\n            os.path.join(this_sample_dir, f\"samples_{r}.npz\"), \"wb\") as fout:\n          io_buffer = io.BytesIO()\n          np.savez_compressed(io_buffer, samples=samples)\n          fout.write(io_buffer.getvalue())\n\n        # Force garbage collection before calling TensorFlow code for Inception network\n        gc.collect()\n        latents = evaluation.run_inception_distributed(samples, inception_model,\n                                                       inceptionv3=inceptionv3)\n        # Force garbage collection again before returning to JAX code\n        gc.collect()\n        # Save latent represents of the Inception network to disk or Google Cloud Storage\n        with tf.io.gfile.GFile(\n            os.path.join(this_sample_dir, f\"statistics_{r}.npz\"), \"wb\") as fout:\n          io_buffer = io.BytesIO()\n          np.savez_compressed(\n            io_buffer, pool_3=latents[\"pool_3\"], logits=latents[\"logits\"])\n          fout.write(io_buffer.getvalue())\n\n      # Compute inception scores, FIDs and KIDs.\n      # Load all statistics that have been previously computed and saved for each host\n      all_logits = []\n      all_pools = []\n      this_sample_dir = os.path.join(eval_dir, f\"ckpt_{ckpt}\")\n      stats = tf.io.gfile.glob(os.path.join(this_sample_dir, \"statistics_*.npz\"))\n      for stat_file in stats:\n        with tf.io.gfile.GFile(stat_file, \"rb\") as fin:\n          stat = np.load(fin)\n          if not inceptionv3:\n            all_logits.append(stat[\"logits\"])\n          all_pools.append(stat[\"pool_3\"])\n\n      if not inceptionv3:\n        all_logits = np.concatenate(all_logits, axis=0)[:config.eval.num_samples]\n      all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]\n\n      # Load pre-computed dataset statistics.\n      data_stats = evaluation.load_dataset_stats(config)\n      data_pools = data_stats[\"pool_3\"]\n\n      # Compute FID/KID/IS on all samples together.\n      if not inceptionv3:\n        inception_score = tfgan.eval.classifier_score_from_logits(all_logits)\n      else:\n        inception_score = -1\n\n      fid = tfgan.eval.frechet_classifier_distance_from_activations(\n        data_pools, all_pools)\n      # Hack to get tfgan KID work for eager execution.\n      tf_data_pools = tf.convert_to_tensor(data_pools)\n      tf_all_pools = tf.convert_to_tensor(all_pools)\n      kid = tfgan.eval.kernel_classifier_distance_from_activations(\n        tf_data_pools, tf_all_pools).numpy()\n      del tf_data_pools, tf_all_pools\n\n      logging.info(\n        \"ckpt-%d --- inception_score: %.6e, FID: %.6e, KID: %.6e\" % (\n          ckpt, inception_score, fid, kid))\n\n      with tf.io.gfile.GFile(os.path.join(eval_dir, f\"report_{ckpt}.npz\"),\n                             \"wb\") as f:\n        io_buffer = io.BytesIO()\n        np.savez_compressed(io_buffer, IS=inception_score, fid=fid, kid=kid)\n        f.write(io_buffer.getvalue())"
  },
  {
    "path": "sampling.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n# pytype: skip-file\n\"\"\"Various sampling methods.\"\"\"\nimport functools\nimport time\n\nimport torch\nimport numpy as np\nimport abc\n\nfrom models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn\nfrom scipy import integrate\nimport sde_lib\nfrom models import utils as mutils\n\n_CORRECTORS = {}\n_PREDICTORS = {}\n\n\ndef register_predictor(cls=None, *, name=None):\n  \"\"\"A decorator for registering predictor classes.\"\"\"\n\n  def _register(cls):\n    if name is None:\n      local_name = cls.__name__\n    else:\n      local_name = name\n    if local_name in _PREDICTORS:\n      raise ValueError(f'Already registered model with name: {local_name}')\n    _PREDICTORS[local_name] = cls\n    return cls\n\n  if cls is None:\n    return _register\n  else:\n    return _register(cls)\n\n\ndef register_corrector(cls=None, *, name=None):\n  \"\"\"A decorator for registering corrector classes.\"\"\"\n\n  def _register(cls):\n    if name is None:\n      local_name = cls.__name__\n    else:\n      local_name = name\n    if local_name in _CORRECTORS:\n      raise ValueError(f'Already registered model with name: {local_name}')\n    _CORRECTORS[local_name] = cls\n    return cls\n\n  if cls is None:\n    return _register\n  else:\n    return _register(cls)\n\n\ndef get_predictor(name):\n  return _PREDICTORS[name]\n\n\ndef get_corrector(name):\n  return _CORRECTORS[name]\n\n\ndef get_sampling_fn(config, sde, shape, inverse_scaler, eps):\n  \"\"\"Create a sampling function.\n\n  Args:\n    config: A `ml_collections.ConfigDict` object that contains all configuration information.\n    sde: A `sde_lib.SDE` object that represents the forward SDE.\n    shape: A sequence of integers representing the expected shape of a single sample.\n    inverse_scaler: The inverse data normalizer function.\n    eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.\n\n  Returns:\n    A function that takes random states and a replicated training state and outputs samples with the\n      trailing dimensions matching `shape`.\n  \"\"\"\n\n  sampler_name = config.sampling.method\n  # Probability flow ODE sampling with black-box ODE solvers\n  if sampler_name.lower() == 'ode':\n    sampling_fn = get_ode_sampler(sde=sde,\n                                  shape=shape,\n                                  inverse_scaler=inverse_scaler,\n                                  denoise=config.sampling.noise_removal,\n                                  eps=eps,\n                                  device=config.device)\n  # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.\n  elif sampler_name.lower() == 'pc':\n    predictor = get_predictor(config.sampling.predictor.lower())\n    corrector = get_corrector(config.sampling.corrector.lower())\n    sampling_fn = get_pc_sampler(sde=sde,\n                                 shape=shape,\n                                 predictor=predictor,\n                                 corrector=corrector,\n                                 inverse_scaler=inverse_scaler,\n                                 snr=config.sampling.snr,\n                                 n_steps=config.sampling.n_steps_each,\n                                 probability_flow=config.sampling.probability_flow,\n                                 continuous=config.training.continuous,\n                                 denoise=config.sampling.noise_removal,\n                                 eps=eps,\n                                 device=config.device)\n  else:\n    raise ValueError(f\"Sampler name {sampler_name} unknown.\")\n\n  return sampling_fn\n\n\nclass Predictor(abc.ABC):\n  \"\"\"The abstract class for a predictor algorithm.\"\"\"\n\n  def __init__(self, sde, score_fn, probability_flow=False):\n    super().__init__()\n    self.sde = sde\n    # Compute the reverse SDE/ODE\n    self.rsde = sde.reverse(score_fn, probability_flow)\n    self.score_fn = score_fn\n\n  @abc.abstractmethod\n  def update_fn(self, x, t):\n    \"\"\"One update of the predictor.\n\n    Args:\n      x: A PyTorch tensor representing the current state\n      t: A Pytorch tensor representing the current time step.\n\n    Returns:\n      x: A PyTorch tensor of the next state.\n      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.\n    \"\"\"\n    pass\n\n\nclass Corrector(abc.ABC):\n  \"\"\"The abstract class for a corrector algorithm.\"\"\"\n\n  def __init__(self, sde, score_fn, snr, n_steps):\n    super().__init__()\n    self.sde = sde\n    self.score_fn = score_fn\n    self.snr = snr\n    self.n_steps = n_steps\n\n  @abc.abstractmethod\n  def update_fn(self, x, t):\n    \"\"\"One update of the corrector.\n\n    Args:\n      x: A PyTorch tensor representing the current state\n      t: A PyTorch tensor representing the current time step.\n\n    Returns:\n      x: A PyTorch tensor of the next state.\n      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.\n    \"\"\"\n    pass\n\n\n@register_predictor(name='euler_maruyama')\nclass EulerMaruyamaPredictor(Predictor):\n  def __init__(self, sde, score_fn, probability_flow=False):\n    super().__init__(sde, score_fn, probability_flow)\n\n  def update_fn(self, x, t):\n    dt = -1. / self.rsde.N\n    z = torch.randn_like(x)\n    drift, diffusion = self.rsde.sde(x, t)\n    x_mean = x + drift * dt\n    x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z\n    return x, x_mean\n\n\n@register_predictor(name='reverse_diffusion')\nclass ReverseDiffusionPredictor(Predictor):\n  def __init__(self, sde, score_fn, probability_flow=False):\n    super().__init__(sde, score_fn, probability_flow)\n\n  def update_fn(self, x, t):\n    f, G = self.rsde.discretize(x, t)\n    z = torch.randn_like(x)\n    x_mean = x - f\n    x = x_mean + G[:, None, None, None] * z\n    return x, x_mean\n\n\n@register_predictor(name='ancestral_sampling')\nclass AncestralSamplingPredictor(Predictor):\n  \"\"\"The ancestral sampling predictor. Currently only supports VE/VP SDEs.\"\"\"\n\n  def __init__(self, sde, score_fn, probability_flow=False):\n    super().__init__(sde, score_fn, probability_flow)\n    if not isinstance(sde, sde_lib.VPSDE) and not isinstance(sde, sde_lib.VESDE):\n      raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n    assert not probability_flow, \"Probability flow not supported by ancestral sampling\"\n\n  def vesde_update_fn(self, x, t):\n    sde = self.sde\n    timestep = (t * (sde.N - 1) / sde.T).long()\n    sigma = sde.discrete_sigmas[timestep]\n    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), sde.discrete_sigmas.to(t.device)[timestep - 1])\n    score = self.score_fn(x, t)\n    x_mean = x + score * (sigma ** 2 - adjacent_sigma ** 2)[:, None, None, None]\n    std = torch.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2))\n    noise = torch.randn_like(x)\n    x = x_mean + std[:, None, None, None] * noise\n    return x, x_mean\n\n  def vpsde_update_fn(self, x, t):\n    sde = self.sde\n    timestep = (t * (sde.N - 1) / sde.T).long()\n    beta = sde.discrete_betas.to(t.device)[timestep]\n    score = self.score_fn(x, t)\n    x_mean = (x + beta[:, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None]\n    noise = torch.randn_like(x)\n    x = x_mean + torch.sqrt(beta)[:, None, None, None] * noise\n    return x, x_mean\n\n  def update_fn(self, x, t):\n    if isinstance(self.sde, sde_lib.VESDE):\n      return self.vesde_update_fn(x, t)\n    elif isinstance(self.sde, sde_lib.VPSDE):\n      return self.vpsde_update_fn(x, t)\n\n\n@register_predictor(name='none')\nclass NonePredictor(Predictor):\n  \"\"\"An empty predictor that does nothing.\"\"\"\n\n  def __init__(self, sde, score_fn, probability_flow=False):\n    pass\n\n  def update_fn(self, x, t):\n    return x, x\n\n\n@register_corrector(name='langevin')\nclass LangevinCorrector(Corrector):\n  def __init__(self, sde, score_fn, snr, n_steps):\n    super().__init__(sde, score_fn, snr, n_steps)\n    if not isinstance(sde, sde_lib.VPSDE) \\\n        and not isinstance(sde, sde_lib.VESDE) \\\n        and not isinstance(sde, sde_lib.subVPSDE):\n      raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n\n  def update_fn(self, x, t):\n    sde = self.sde\n    score_fn = self.score_fn\n    n_steps = self.n_steps\n    target_snr = self.snr\n    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):\n      timestep = (t * (sde.N - 1) / sde.T).long()\n      alpha = sde.alphas.to(t.device)[timestep]\n    else:\n      alpha = torch.ones_like(t)\n\n    for i in range(n_steps):\n      grad = score_fn(x, t)\n      noise = torch.randn_like(x)\n      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()\n      noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()\n      step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha\n      x_mean = x + step_size[:, None, None, None] * grad\n      x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise\n\n    return x, x_mean\n\n\nclass LangevinCorrectorCS(Corrector):\n  \"\"\" Modified Langevin Corrector to solve for p(x|y) \"\"\"\n  def __init__(self, sde, score_fn, snr, n_steps, sigma_min, sigma_max, N):\n    super().__init__(sde, score_fn, snr, n_steps)\n    self.N = N\n    self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), N))\n    if not isinstance(sde, sde_lib.VESDE):\n      raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n\n  def update_fn(self, x, t, y, discrete_sigmas):\n    \"\"\"\n    Args:\n      x: current estimate x_i\n      t: current time step\n      y: measurement in the image domain\n      discrete_sigmas: list of values of \\sigma that are indexable with t\n    \"\"\"\n    sde = self.sde\n    score_fn = self.score_fn\n    n_steps = self.n_steps\n    target_snr = self.snr\n    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):\n      timestep = (t * (sde.N - 1) / sde.T).long()\n      alpha = sde.alphas.to(t.device)[timestep]\n    else:\n      alpha = torch.ones_like(t)\n\n    for i in range(n_steps):\n      timestep = (t * (self.N - 1) / 1).long()\n      sigma = self.discrete_sigmas.to(t.device)[timestep]\n      grad = score_fn(x, t)\n      grad_likelihood = (x - y) / (sigma[0] ** 2)\n      noise = torch.randn_like(x)\n      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()\n      noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()\n      step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha\n      x_mean = x + step_size[:, None, None, None] * (grad + grad_likelihood)\n      x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise\n\n    return x, x_mean\n\n\n@register_corrector(name='ald')\nclass AnnealedLangevinDynamics(Corrector):\n  \"\"\"The original annealed Langevin dynamics predictor in NCSN/NCSNv2.\n\n  We include this corrector only for completeness. It was not directly used in our paper.\n  \"\"\"\n\n  def __init__(self, sde, score_fn, snr, n_steps):\n    super().__init__(sde, score_fn, snr, n_steps)\n    if not isinstance(sde, sde_lib.VPSDE) \\\n        and not isinstance(sde, sde_lib.VESDE) \\\n        and not isinstance(sde, sde_lib.subVPSDE):\n      raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n\n  def update_fn(self, x, t):\n    sde = self.sde\n    score_fn = self.score_fn\n    n_steps = self.n_steps\n    target_snr = self.snr\n    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):\n      timestep = (t * (sde.N - 1) / sde.T).long()\n      alpha = sde.alphas.to(t.device)[timestep]\n    else:\n      alpha = torch.ones_like(t)\n\n    std = self.sde.marginal_prob(x, t)[1]\n\n    for i in range(n_steps):\n      grad = score_fn(x, t)\n      noise = torch.randn_like(x)\n      step_size = (target_snr * std) ** 2 * 2 * alpha\n      x_mean = x + step_size[:, None, None, None] * grad\n      x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]\n\n    return x, x_mean\n\n\n@register_corrector(name='none')\nclass NoneCorrector(Corrector):\n  \"\"\"An empty corrector that does nothing.\"\"\"\n\n  def __init__(self, sde, score_fn, snr, n_steps):\n    pass\n\n  def update_fn(self, x, t):\n    return x, x\n\n\ndef shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous):\n  \"\"\"A wrapper that configures and returns the update function of predictors.\"\"\"\n  score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)\n  if predictor is None:\n    # Corrector-only sampler\n    predictor_obj = NonePredictor(sde, score_fn, probability_flow)\n  else:\n    predictor_obj = predictor(sde, score_fn, probability_flow)\n  return predictor_obj.update_fn(x, t)\n\n\ndef shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps, cs=False,\n                               sigma_min=None, sigma_max=None, N=None, y=None, discrete_sigmas=None):\n  \"\"\"A wrapper tha configures and returns the update function of correctors.\"\"\"\n  score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous)\n  if corrector is None:\n    # Predictor-only sampler\n    corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)\n    fn = corrector_obj.update_fn(x, t)\n  else:\n    if cs:\n      corrector_obj = corrector(sde, score_fn, snr, n_steps, sigma_min, sigma_max, N)\n      fn = corrector_obj.update_fn(x, t, y, discrete_sigmas)\n    else:\n      corrector_obj = corrector(sde, score_fn, snr, n_steps)\n      fn = corrector_obj.update_fn(x, t)\n\n  return fn\n\n\ndef get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,\n                   n_steps=1, probability_flow=False, continuous=False,\n                   denoise=True, eps=1e-3, device='cuda'):\n  \"\"\"Create a Predictor-Corrector (PC) sampler.\n\n  Args:\n    sde: An `sde_lib.SDE` object representing the forward SDE.\n    shape: A sequence of integers. The expected shape of a single sample.\n    predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.\n    corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.\n    inverse_scaler: The inverse data normalizer.\n    snr: A `float` number. The signal-to-noise ratio for configuring correctors.\n    n_steps: An integer. The number of corrector steps per predictor update.\n    probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.\n    continuous: `True` indicates that the score model was continuously trained.\n    denoise: If `True`, add one-step denoising to the final samples.\n    eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.\n    device: PyTorch device.\n\n  Returns:\n    A sampling function that returns samples and the number of function evaluations during sampling.\n  \"\"\"\n  # Create predictor & corrector update functions\n  predictor_update_fn = functools.partial(shared_predictor_update_fn,\n                                          sde=sde,\n                                          predictor=predictor,\n                                          probability_flow=probability_flow,\n                                          continuous=continuous)\n  corrector_update_fn = functools.partial(shared_corrector_update_fn,\n                                          sde=sde,\n                                          corrector=corrector,\n                                          continuous=continuous,\n                                          snr=snr,\n                                          n_steps=n_steps)\n\n  def pc_sampler(model):\n    \"\"\" The PC sampler funciton.\n\n    Args:\n      model: A score model.\n    Returns:\n      Samples, number of function evaluations.\n    \"\"\"\n    with torch.no_grad():\n      # Initial sample\n      x = sde.prior_sampling(shape).to(device)\n      timesteps = torch.linspace(sde.T, eps, sde.N, device=device)\n\n      time_corrector_tot = 0\n      time_predictor_tot = 0\n      for i in range(sde.N):\n        t = timesteps[i]\n        vec_t = torch.ones(shape[0], device=t.device) * t\n        tic_corrector = time.time()\n        x, x_mean = corrector_update_fn(x, vec_t, model=model)\n        time_corrector_tot += time.time() - tic_corrector\n        tic_predictor = time.time()\n        x, x_mean = predictor_update_fn(x, vec_t, model=model)\n        time_predictor_tot += time.time() - tic_predictor\n      print(f'Average time for corrector step: {time_corrector_tot / sde.N} sec.')\n      print(f'Average time for predictor step: {time_predictor_tot / sde.N} sec.')\n\n      return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)\n\n  return pc_sampler\n\n\ndef get_ode_sampler(sde, shape, inverse_scaler,\n                    denoise=False, rtol=1e-5, atol=1e-5,\n                    method='RK45', eps=1e-3, device='cuda'):\n  \"\"\"Probability flow ODE sampler with the black-box ODE solver.\n\n  Args:\n    sde: An `sde_lib.SDE` object that represents the forward SDE.\n    shape: A sequence of integers. The expected shape of a single sample.\n    inverse_scaler: The inverse data normalizer.\n    denoise: If `True`, add one-step denoising to final samples.\n    rtol: A `float` number. The relative tolerance level of the ODE solver.\n    atol: A `float` number. The absolute tolerance level of the ODE solver.\n    method: A `str`. The algorithm used for the black-box ODE solver.\n      See the documentation of `scipy.integrate.solve_ivp`.\n    eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.\n    device: PyTorch device.\n\n  Returns:\n    A sampling function that returns samples and the number of function evaluations during sampling.\n  \"\"\"\n\n  def denoise_update_fn(model, x):\n    score_fn = get_score_fn(sde, model, train=False, continuous=True)\n    # Reverse diffusion predictor for denoising\n    predictor_obj = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)\n    vec_eps = torch.ones(x.shape[0], device=x.device) * eps\n    _, x = predictor_obj.update_fn(x, vec_eps)\n    return x\n\n  def drift_fn(model, x, t):\n    \"\"\"Get the drift function of the reverse-time SDE.\"\"\"\n    score_fn = get_score_fn(sde, model, train=False, continuous=True)\n    rsde = sde.reverse(score_fn, probability_flow=True)\n    return rsde.sde(x, t)[0]  # returns only the drift term because diffusion = 0 for probability_flow\n\n  def ode_sampler(model, z=None):\n    \"\"\"The probability flow ODE sampler with black-box ODE solver.\n\n    Args:\n      model: A score model.\n      z: If present, generate samples from latent code `z`.\n    Returns:\n      samples, number of function evaluations.\n    \"\"\"\n    with torch.no_grad():\n      # Initial sample\n      if z is None:\n        # If not represent, sample the latent code from the prior distibution of the SDE.\n        x = sde.prior_sampling(shape).to(device)\n      else:\n        x = z\n\n      def ode_func(t, x):\n        x = from_flattened_numpy(x, shape).to(device).type(torch.float32)\n        vec_t = torch.ones(shape[0], device=x.device) * t\n        drift = drift_fn(model, x, vec_t)\n        return to_flattened_numpy(drift)\n\n      # Black-box ODE solver for the probability flow ODE\n      solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x),\n                                     rtol=rtol, atol=atol, method=method)\n      nfe = solution.nfev\n      x = torch.tensor(solution.y[:, -1]).reshape(shape).to(device).type(torch.float32)\n\n      # Denoising is equivalent to running one predictor step without adding noise\n      if denoise:\n        x = denoise_update_fn(model, x)\n\n      x = inverse_scaler(x)\n      return x, nfe\n\n  return ode_sampler\n"
  },
  {
    "path": "sde_lib.py",
    "content": "\"\"\"Abstract SDE classes, Reverse SDE, and VE/VP SDEs.\"\"\"\nimport abc\nimport torch\nimport numpy as np\n\n\nclass SDE(abc.ABC):\n  \"\"\"SDE abstract class. Functions are designed for a mini-batch of inputs.\"\"\"\n\n  def __init__(self, N):\n    \"\"\"Construct an SDE.\n\n    Args:\n      N: number of discretization time steps.\n    \"\"\"\n    super().__init__()\n    self.N = N\n\n  @property\n  @abc.abstractmethod\n  def T(self):\n    \"\"\"End time of the SDE.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def sde(self, x, t):\n    pass\n\n  @abc.abstractmethod\n  def marginal_prob(self, x, t):\n    \"\"\"Parameters to determine the marginal distribution of the SDE, $p_t(x)$.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def prior_sampling(self, shape):\n    \"\"\"Generate one sample from the prior distribution, $p_T(x)$.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def prior_logp(self, z):\n    \"\"\"Compute log-density of the prior distribution.\n\n    Useful for computing the log-likelihood via probability flow ODE.\n\n    Args:\n      z: latent code\n    Returns:\n      log probability density\n    \"\"\"\n    pass\n\n  def discretize(self, x, t):\n    \"\"\"Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.\n\n    Useful for reverse diffusion sampling and probabiliy flow sampling.\n    Defaults to Euler-Maruyama discretization.\n\n    Args:\n      x: a torch tensor\n      t: a torch float representing the time step (from 0 to `self.T`)\n\n    Returns:\n      f, G\n    \"\"\"\n    dt = 1 / self.N\n    drift, diffusion = self.sde(x, t)\n    f = drift * dt\n    G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))\n    return f, G\n\n  def reverse(self, score_fn, probability_flow=False):\n    \"\"\"Create the reverse-time SDE/ODE.\n\n    Args:\n      score_fn: A time-dependent score-based model that takes x and t and returns the score.\n      probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.\n    \"\"\"\n    N = self.N\n    T = self.T\n    sde_fn = self.sde\n    discretize_fn = self.discretize\n\n    # Build the class for reverse-time SDE.\n    class RSDE(self.__class__):\n      def __init__(self):\n        self.N = N\n        self.probability_flow = probability_flow\n\n      @property\n      def T(self):\n        return T\n\n      def sde(self, x, t):\n        \"\"\"Create the drift and diffusion functions for the reverse SDE/ODE.\"\"\"\n        drift, diffusion = sde_fn(x, t)\n        score = score_fn(x, t)\n        drift = drift - diffusion[:, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)\n        # Set the diffusion function to zero for ODEs.\n        diffusion = 0. if self.probability_flow else diffusion\n        return drift, diffusion\n\n      def discretize(self, x, t):\n        \"\"\"Create discretized iteration rules for the reverse diffusion sampler.\"\"\"\n        f, G = discretize_fn(x, t)\n        rev_f = f - G[:, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)\n        rev_G = torch.zeros_like(G) if self.probability_flow else G\n        return rev_f, rev_G\n\n    return RSDE()\n\n\nclass VPSDE(SDE):\n  def __init__(self, beta_min=0.1, beta_max=20, N=1000):\n    \"\"\"Construct a Variance Preserving SDE.\n\n    Args:\n      beta_min: value of beta(0)\n      beta_max: value of beta(1)\n      N: number of discretization steps\n    \"\"\"\n    super().__init__(N)\n    self.beta_0 = beta_min\n    self.beta_1 = beta_max\n    self.N = N\n    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)\n    self.alphas = 1. - self.discrete_betas\n    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)\n    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)\n\n  @property\n  def T(self):\n    return 1\n\n  def sde(self, x, t):\n    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)\n    drift = -0.5 * beta_t[:, None, None, None] * x\n    diffusion = torch.sqrt(beta_t)\n    return drift, diffusion\n\n  def marginal_prob(self, x, t):\n    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0\n    mean = torch.exp(log_mean_coeff[:, None, None, None]) * x\n    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))\n    return mean, std\n\n  def prior_sampling(self, shape):\n    return torch.randn(*shape)\n\n  def prior_logp(self, z):\n    shape = z.shape\n    N = np.prod(shape[1:])\n    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.\n    return logps\n\n  def discretize(self, x, t):\n    \"\"\"DDPM discretization.\"\"\"\n    timestep = (t * (self.N - 1) / self.T).long()\n    beta = self.discrete_betas.to(x.device)[timestep]\n    alpha = self.alphas.to(x.device)[timestep]\n    sqrt_beta = torch.sqrt(beta)\n    f = torch.sqrt(alpha)[:, None, None, None] * x - x\n    G = sqrt_beta\n    return f, G\n\n\nclass subVPSDE(SDE):\n  def __init__(self, beta_min=0.1, beta_max=20, N=1000):\n    \"\"\"Construct the sub-VP SDE that excels at likelihoods.\n\n    Args:\n      beta_min: value of beta(0)\n      beta_max: value of beta(1)\n      N: number of discretization steps\n    \"\"\"\n    super().__init__(N)\n    self.beta_0 = beta_min\n    self.beta_1 = beta_max\n    self.N = N\n\n  @property\n  def T(self):\n    return 1\n\n  def sde(self, x, t):\n    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)\n    drift = -0.5 * beta_t[:, None, None, None] * x\n    discount = 1. - torch.exp(-2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2)\n    diffusion = torch.sqrt(beta_t * discount)\n    return drift, diffusion\n\n  def marginal_prob(self, x, t):\n    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0\n    mean = torch.exp(log_mean_coeff)[:, None, None, None] * x\n    std = 1 - torch.exp(2. * log_mean_coeff)\n    return mean, std\n\n  def prior_sampling(self, shape):\n    return torch.randn(*shape)\n\n  def prior_logp(self, z):\n    shape = z.shape\n    N = np.prod(shape[1:])\n    return -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3)) / 2.\n\n\nclass VESDE(SDE):\n  def __init__(self, sigma_min=0.01, sigma_max=50, N=1000):\n    \"\"\"Construct a Variance Exploding SDE.\n\n    Args:\n      sigma_min: smallest sigma.\n      sigma_max: largest sigma.\n      N: number of discretization steps\n    \"\"\"\n    super().__init__(N)\n    self.sigma_min = sigma_min\n    self.sigma_max = sigma_max\n    self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))\n    self.N = N\n\n  @property\n  def T(self):\n    return 1\n\n  def sde(self, x, t):\n    sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t\n    drift = torch.zeros_like(x)\n    diffusion = sigma * torch.sqrt(torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)),\n                                                device=t.device))\n    return drift, diffusion\n\n  def marginal_prob(self, x, t):\n    std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t\n    mean = x\n    return mean, std\n\n  def prior_sampling(self, shape):\n    return torch.randn(*shape) * self.sigma_max\n\n  def prior_logp(self, z):\n    shape = z.shape\n    N = np.prod(shape[1:])\n    return -N / 2. * np.log(2 * np.pi * self.sigma_max ** 2) - torch.sum(z ** 2, dim=(1, 2, 3)) / (2 * self.sigma_max ** 2)\n\n  def discretize(self, x, t):\n    \"\"\"SMLD(NCSN) discretization.\"\"\"\n    timestep = (t * (self.N - 1) / self.T).long()\n    sigma = self.discrete_sigmas.to(t.device)[timestep]\n    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),\n                                 self.discrete_sigmas[timestep - 1].to(t.device))\n    f = torch.zeros_like(x)\n    G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)\n    return f, G"
  },
  {
    "path": "test/test_TV.py",
    "content": "\"\"\"\npython -m pytest\n\"\"\"\nimport sys\n\nimport pytest\nimport torch\nimport matplotlib.pyplot as plt\nimport skimage\n\n\nimport controllable_generation_TV as TV\n\n@pytest.mark.parametrize(\n    [\"A\", \"AT\"],\n    [\n        [TV._Dz, TV._DzT],\n        [TV._Dx, TV._DxT],\n        [TV._Dy, TV._DyT],\n    ]\n)\ndef test_adjoint(A, AT):\n    x = torch.randn(10, 10, 10, 10)\n    y = torch.randn(10, 10, 10, 10)\n\n    torch.testing.assert_allclose(\n        torch.dot(A(x).ravel(), y.ravel()),\n        torch.dot(x.ravel(), AT(y).ravel())\n    )\n\ndef test_prox_l21():\n    prox_val = .75\n\n    Dx = torch.randn(1, 1, 1, 1)\n    Dy = torch.randn(1, 1, 1, 1)\n    Dz = torch.randn(1, 1, 1, 1)\n\n    Dq = torch.cat((Dx, Dy, Dz), dim=1)\n    Dq_norm = torch.linalg.norm(Dq)\n\n    Dq_prox = TV.prox_l21(Dq, prox_val, dim=1)\n    Dq_prox_norm = torch.linalg.norm(Dq_prox)\n\n    torch.testing.assert_allclose(\n        max(Dq_norm, 0) - prox_val,\n        Dq_prox_norm,\n    )\n    torch.testing.assert_allclose(\n        Dq / Dq_norm,\n        Dq_prox / Dq_prox_norm,\n    )\n\n\nclass Identity:\n    @staticmethod\n    def A(x):\n        return x\n\n    @staticmethod    \n    def AT(y):\n        return y\n\ndef test_ADMM_TV_isotropic():\n    x_gt = skimage.data.astronaut().mean(axis=2) / 255\n\n    x_gt = torch.tensor(x_gt).reshape((1, 1) + x_gt.shape)\n    y = x_gt + 0.5 * torch.randn_like(x_gt)\n\n    x0 = torch.zeros_like(y)\n\n    ADMM_TV = TV.get_ADMM_TV_isotropic(\n        radon=Identity(), img_shape=y.shape,\n        lamb_1 = 1e0, rho=1e2)\n\n    x_recon = ADMM_TV(x0, y)\n\n    args = dict(vmin=-0.2, vmax=1.2)\n    \n    fig, ax = plt.subplots()\n    im = ax.imshow(x_gt.squeeze(), **args)\n    fig.colorbar(im)\n    fig.savefig('x_gt.png')\n\n    fig, ax = plt.subplots()\n    im = ax.imshow(y.squeeze(), **args)\n    fig.colorbar(im)\n    fig.savefig('y.png')\n\n    fig, ax = plt.subplots()\n    im = ax.imshow(x_recon.squeeze(), **args)\n    fig.colorbar(im)\n    fig.savefig('x_recon.png')    "
  },
  {
    "path": "train_AAPM256.sh",
    "content": "#!/bin/bash\n\npython main.py \\\n  --config=configs/ve/AAPM_256_ncsnpp_continuous.py \\\n  --eval_folder=eval/AAPM256 \\\n  --mode='train' \\\n  --workdir=workdir/AAPM256"
  },
  {
    "path": "utils.py",
    "content": "from pathlib import Path\n\nimport torch\nimport os\nimport logging\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom fastmri_utils import fft2c_new, ifft2c_new\nfrom statistics import mean, stdev\nfrom skimage.metrics import peak_signal_noise_ratio, structural_similarity\nfrom sporco.metric import gmsd, mse\nfrom scipy.ndimage import gaussian_laplace\nimport functools\n\n\ndef clear_color(x):\n  x = x.detach().cpu().squeeze().numpy()\n  return np.transpose(x, (1, 2, 0))\n\ndef clear(x, normalize=True):\n  x = x.detach().cpu().squeeze().numpy()\n  if normalize:\n    x = normalize_np(x)\n  return x\n\n\ndef restore_checkpoint(ckpt_dir, state, device, skip_sigma=False, skip_optimizer=False):\n  ckpt_dir = Path(ckpt_dir)\n  # import ipdb; ipdb.set_trace()\n  # ckpt = ckpt_dir / \"checkpoint.pth\"\n  if not ckpt_dir.exists():\n    logging.warning(f\"No checkpoint found at {ckpt_dir}. \"\n                  f\"Returned the same state as input\")\n    return state\n  else:\n    loaded_state = torch.load(ckpt_dir, map_location=device)\n    if not skip_optimizer:\n      state['optimizer'].load_state_dict(loaded_state['optimizer'])\n    loaded_model_state = loaded_state['model']\n    if skip_sigma:\n      loaded_model_state.pop('module.sigmas')\n\n    state['model'].load_state_dict(loaded_model_state, strict=False)\n    state['ema'].load_state_dict(loaded_state['ema'])\n    state['step'] = loaded_state['step']\n    print(f'loaded checkpoint dir from {ckpt_dir}')\n    return state\n\n\ndef save_checkpoint(ckpt_dir, state, name=\"checkpoint.pth\"):\n  ckpt_dir = Path(ckpt_dir)\n  saved_state = {\n    'optimizer': state['optimizer'].state_dict(),\n    'model': state['model'].state_dict(),\n    'ema': state['ema'].state_dict(),\n    'step': state['step']\n  }\n  torch.save(saved_state, ckpt_dir / name)\n\n\n\"\"\"\nHelper functions for new types of inverse problems\n\"\"\"\n\ndef fft2(x):\n  \"\"\" FFT with shifting DC to the center of the image\"\"\"\n  return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2])\n\n\ndef ifft2(x):\n  \"\"\" IFFT with shifting DC to the corner of the image prior to transform\"\"\"\n  return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2]))\n\n\ndef fft2_m(x):\n  \"\"\" FFT for multi-coil \"\"\"\n  return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))\n\n\ndef ifft2_m(x):\n  \"\"\" IFFT for multi-coil \"\"\"\n  return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))\n\n\ndef crop_center(img, cropx, cropy):\n  c, y, x = img.shape\n  startx = x // 2 - (cropx // 2)\n  starty = y // 2 - (cropy // 2)\n  return img[:, starty:starty + cropy, startx:startx + cropx]\n\n\ndef normalize(img):\n  \"\"\" Normalize img in arbitrary range to [0, 1] \"\"\"\n  img -= torch.min(img)\n  img /= torch.max(img)\n  return img\n\ndef normalize_np(img):\n  \"\"\" Normalize img in arbitrary range to [0, 1] \"\"\"\n  img -= np.min(img)\n  img /= np.max(img)\n  return img\n\n\ndef normalize_np_kwarg(img, maxv=1.0, minv=0.0):\n  \"\"\" Normalize img in arbitrary range to [0, 1] \"\"\"\n  img -= minv\n  img /= maxv\n  return img\n\n\ndef normalize_complex(img):\n  \"\"\" normalizes the magnitude of complex-valued image to range [0, 1] \"\"\"\n  abs_img = normalize(torch.abs(img))\n  # ang_img = torch.angle(img)\n  ang_img = normalize(torch.angle(img))\n  return abs_img * torch.exp(1j * ang_img)\n\n\ndef batchfy(tensor, batch_size):\n  n = len(tensor)\n  num_batches = n // batch_size + 1\n  return tensor.chunk(num_batches, dim=0)\n\n\ndef img_wise_min_max(img):\n  img_flatten = img.view(img.shape[0], -1)\n  img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)\n  img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)\n\n  return (img - img_min) / (img_max - img_min)\n\n\ndef patient_wise_min_max(img):\n  std_upper = 3\n  img_flatten = img.view(img.shape[0], -1)\n\n  std = torch.std(img)\n  mean = torch.mean(img)\n\n  img_min = torch.min(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)\n  img_max = torch.max(img_flatten, dim=-1)[0].view(-1, 1, 1, 1)\n\n  min_max_scaled = (img - img_min) / (img_max - img_min)\n  min_max_scaled_std = (std - img_min) / (img_max - img_min)\n  min_max_scaled_mean = (mean - img_min) / (img_max - img_min)\n\n  min_max_scaled[min_max_scaled > min_max_scaled_mean + std_upper * min_max_scaled_std] = 1\n\n  return min_max_scaled\n\n\ndef create_sphere(cx, cy, cz, r, resolution=256):\n  '''\n  create sphere with center (cx, cy, cz) and radius r\n  '''\n  phi = np.linspace(0, 2 * np.pi, 2 * resolution)\n  theta = np.linspace(0, np.pi, resolution)\n\n  theta, phi = np.meshgrid(theta, phi)\n\n  r_xy = r * np.sin(theta)\n  x = cx + np.cos(phi) * r_xy\n  y = cy + np.sin(phi) * r_xy\n  z = cz + r * np.cos(theta)\n\n  return np.stack([x, y, z])\n\n\nclass lambda_schedule:\n  def __init__(self, total=2000):\n    self.total = total\n\n  def get_current_lambda(self, i):\n    pass\n\n\nclass lambda_schedule_linear(lambda_schedule):\n  def __init__(self, start_lamb=1.0, end_lamb=0.0):\n    super().__init__()\n    self.start_lamb = start_lamb\n    self.end_lamb = end_lamb\n\n  def get_current_lambda(self, i):\n    return self.start_lamb + (self.end_lamb - self.start_lamb) * (i / self.total)\n\n\nclass lambda_schedule_const(lambda_schedule):\n  def __init__(self, lamb=1.0):\n    super().__init__()\n    self.lamb = lamb\n\n  def get_current_lambda(self, i):\n    return self.lamb\n\n\n\n\ndef image_grid(x, sz=32):\n  size = sz\n  channels = 3\n  img = x.reshape(-1, size, size, channels)\n  w = int(np.sqrt(img.shape[0]))\n  img = img.reshape((w, w, size, size, channels)).transpose((0, 2, 1, 3, 4)).reshape((w * size, w * size, channels))\n  return img\n\n\ndef show_samples(x, sz=32):\n  x = x.permute(0, 2, 3, 1).detach().cpu().numpy()\n  img = image_grid(x, sz)\n  plt.figure(figsize=(8, 8))\n  plt.axis('off')\n  plt.imshow(img)\n  plt.show()\n\n\ndef image_grid_gray(x, size=32):\n  img = x.reshape(-1, size, size)\n  w = int(np.sqrt(img.shape[0]))\n  img = img.reshape((w, w, size, size)).transpose((0, 2, 1, 3)).reshape((w * size, w * size))\n  return img\n\n\ndef show_samples_gray(x, size=32, save=False, save_fname=None):\n  x = x.detach().cpu().numpy()\n  img = image_grid_gray(x, size=size)\n  plt.figure(figsize=(8, 8))\n  plt.axis('off')\n  plt.imshow(img, cmap='gray')\n  plt.show()\n  if save:\n    plt.imsave(save_fname, img, cmap='gray')\n\n\ndef get_mask(img, size, batch_size, type='gaussian2d', acc_factor=8, center_fraction=0.04, fix=False):\n  mux_in = size ** 2\n  if type.endswith('2d'):\n    Nsamp = mux_in // acc_factor\n  elif type.endswith('1d'):\n    Nsamp = size // acc_factor\n  if type == 'gaussian2d':\n    mask = torch.zeros_like(img)\n    cov_factor = size * (1.5 / 128)\n    mean = [size // 2, size // 2]\n    cov = [[size * cov_factor, 0], [0, size * cov_factor]]\n    if fix:\n      samples = np.random.multivariate_normal(mean, cov, int(Nsamp))\n      int_samples = samples.astype(int)\n      int_samples = np.clip(int_samples, 0, size - 1)\n      mask[..., int_samples[:, 0], int_samples[:, 1]] = 1\n    else:\n      for i in range(batch_size):\n        # sample different masks for batch\n        samples = np.random.multivariate_normal(mean, cov, int(Nsamp))\n        int_samples = samples.astype(int)\n        int_samples = np.clip(int_samples, 0, size - 1)\n        mask[i, :, int_samples[:, 0], int_samples[:, 1]] = 1\n  elif type == 'uniformrandom2d':\n    mask = torch.zeros_like(img)\n    if fix:\n      mask_vec = torch.zeros([1, size * size])\n      samples = np.random.choice(size * size, int(Nsamp))\n      mask_vec[:, samples] = 1\n      mask_b = mask_vec.view(size, size)\n      mask[:, ...] = mask_b\n    else:\n      for i in range(batch_size):\n        # sample different masks for batch\n        mask_vec = torch.zeros([1, size * size])\n        samples = np.random.choice(size * size, int(Nsamp))\n        mask_vec[:, samples] = 1\n        mask_b = mask_vec.view(size, size)\n        mask[i, ...] = mask_b\n  elif type == 'gaussian1d':\n    mask = torch.zeros_like(img)\n    mean = size // 2\n    std = size * (15.0 / 128)\n    Nsamp_center = int(size * center_fraction)\n    if fix:\n      samples = np.random.normal(loc=mean, scale=std, size=int(Nsamp * 1.2))\n      int_samples = samples.astype(int)\n      int_samples = np.clip(int_samples, 0, size - 1)\n      mask[... , int_samples] = 1\n      c_from = size // 2 - Nsamp_center // 2\n      mask[... , c_from:c_from + Nsamp_center] = 1\n    else:\n      for i in range(batch_size):\n        samples = np.random.normal(loc=mean, scale=std, size=int(Nsamp*1.2))\n        int_samples = samples.astype(int)\n        int_samples = np.clip(int_samples, 0, size - 1)\n        mask[i, :, :, int_samples] = 1\n        c_from = size // 2 - Nsamp_center // 2\n        mask[i, :, :, c_from:c_from + Nsamp_center] = 1\n  elif type == 'uniform1d':\n    mask = torch.zeros_like(img)\n    if fix:\n      Nsamp_center = int(size * center_fraction)\n      samples = np.random.choice(size, int(Nsamp - Nsamp_center))\n      mask[..., samples] = 1\n      # ACS region\n      c_from = size // 2 - Nsamp_center // 2\n      mask[..., c_from:c_from + Nsamp_center] = 1\n    else:\n      for i in range(batch_size):\n        Nsamp_center = int(size * center_fraction)\n        samples = np.random.choice(size, int(Nsamp - Nsamp_center))\n        mask[i, :, :, samples] = 1\n        # ACS region\n        c_from = size // 2 - Nsamp_center // 2\n        mask[i, :, :, c_from:c_from+Nsamp_center] = 1\n  else:\n    NotImplementedError(f'Mask type {type} is currently not supported.')\n\n  return mask\n\n\ndef kspace_to_nchw(tensor):\n    \"\"\"\n    Convert torch tensor in (Slice, Coil, Height, Width, Complex) 5D format to\n    (N, C, H, W) 4D format for processing by 2D CNNs.\n\n    Complex indicates (real, imag) as 2 channels, the complex data format for Pytorch.\n\n    C is the coils interleaved with real and imaginary values as separate channels.\n    C is therefore always 2 * Coil.\n\n    Singlecoil data is assumed to be in the 5D format with Coil = 1\n\n    Args:\n        tensor (torch.Tensor): Input data in 5D kspace tensor format.\n    Returns:\n        tensor (torch.Tensor): tensor in 4D NCHW format to be fed into a CNN.\n    \"\"\"\n    assert isinstance(tensor, torch.Tensor)\n    assert tensor.dim() == 5\n    s = tensor.shape\n    assert s[-1] == 2\n    tensor = tensor.permute(dims=(0, 1, 4, 2, 3)).reshape(shape=(s[0], 2 * s[1], s[2], s[3]))\n    return tensor\n\n\ndef nchw_to_kspace(tensor):\n  \"\"\"\n  Convert a torch tensor in (N, C, H, W) format to the (Slice, Coil, Height, Width, Complex) format.\n\n  This function assumes that the real and imaginary values of a coil are always adjacent to one another in C.\n  If the coil dimension is not divisible by 2, the function assumes that the input data is 'real' data,\n  and thus pads the imaginary dimension as 0.\n  \"\"\"\n  assert isinstance(tensor, torch.Tensor)\n  assert tensor.dim() == 4\n  s = tensor.shape\n  if tensor.shape[1] == 1:\n    imag_tensor = torch.zeros(s, device=tensor.device)\n    tensor = torch.cat((tensor, imag_tensor), dim=1)\n    s = tensor.shape\n  tensor = tensor.view(size=(s[0], s[1] // 2, 2, s[2], s[3])).permute(dims=(0, 1, 3, 4, 2))\n  return tensor\n\n\ndef root_sum_of_squares(data, dim=0):\n    \"\"\"\n    Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor.\n    Args:\n        data (torch.Tensor): The input tensor\n        dim (int): The dimensions along which to apply the RSS transform\n    Returns:\n        torch.Tensor: The RSS value\n    \"\"\"\n    return torch.sqrt((data ** 2).sum(dim))\n\n\ndef save_data(fname, arr):\n  \"\"\" Save data as .npy and .png \"\"\"\n  np.save(fname + '.npy', arr)\n  plt.imsave(fname + '.png', arr, cmap='gray')\n\ndef mean_std(vals: list):\n  return mean(vals), stdev(vals)\n\ndef cal_metric(comp, label):\n  LoG = functools.partial(gaussian_laplace, sigma=1.5)\n  psnr_val = peak_signal_noise_ratio(comp, label)\n  ssim_val = structural_similarity(comp, label)\n  hfen_val = mse(LoG(comp), LoG(label))\n  gmsd_val = gmsd(label, comp)\n  return psnr_val, ssim_val, hfen_val, gmsd_val\n"
  }
]