[
  {
    "path": ".gitignore",
    "content": "# Data, checkpoints, logs\ndata\ncheckpoints\n.neptune\n\n# Files generated by setuptools_scm\n__version.py\n\n# MacOS\n.DS_Store\n\n# Visual Studio Code\n.vscode/\n*.code-workspace\n.history/\n\n# Created by https://www.gitignore.io/api/python\n# Edit at https://www.gitignore.io/?templates=python\n\n### Python ###\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# PyCharm\n.idea/\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# End of https://www.gitignore.io/api/python\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n"
  },
  {
    "path": "README.md",
    "content": "# Bayesian Flow Networks\n\nThis is the official code release for [Bayesian Flow Networks](https://arxiv.org/abs/2308.07037) by Alex Graves, Rupesh Kumar Srivastava, Timothy Atkinson and Faustino Gomez.\n\n<img src=\"bfn.gif\" alt=\"Overview of BFN process\" style=\"width:600px;\"/>\n\n## Reading Guide\n\n- `model.py` contains all the main contributions of the paper. These include definitions, for both continuous and discrete data, of Bayesian Flows as well as loss functions for both continuous-time and discrete-time. See comments in the base classes in that file for details.\n- `probability.py` defines the probability distributions used by the models.\n- `train.py`, `test.py` and `sample.py` are scripts for training, testing and sampling (see below for usage).\n- `data.py` contains utilities related to data loading and processing.\n- `networks/` contains implementations of the network architectures used by the models. \n\n## Setup\n\n```shell\n# Create a new conda env with all dependencies including pytorch and CUDA\nconda env create -f env.yml\nconda activate bfn\n\n# Or, install additional dependencies into an existing pytorch env\npip install accelerate==0.19.0 matplotlib omegaconf rich\n\n# Optional, if you want to enable logging to neptune.ai\npip install neptune \n```\n\n## Training\n\nThe models in the paper can be trained using the configs provided in the `configs` dir as follows:\n\n```shell\n# mnist experiment on 1 GPU\naccelerate launch train.py config_file=configs/mnist_discrete.yaml\n# cifar10 experiment on 1 GPU (A100)\naccelerate launch train.py config_file=configs/cifar10_discretized_256bins.yaml\n# text8 experiment on 8 GPUs (A100)\naccelerate launch --multi_gpu --num_processes=8 --num_machines=1 --dynamo_backend=no --mixed_precision=fp16 train.py config_file=configs/text8_discrete.yaml \n```\n\n## Testing\n> [!NOTE]\n> Depending on your GPU, you may wish to adjust the batch size used for testing in `test.py`.\n```shell\n# Optional: Download pretrained checkpoints (make sure you have git-lfs installed: https://git-lfs.com/)\ngit clone git@hf.co:rupspace/pretrained-BFNs\n# Compute 784-step loss on MNIST\npython test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000\n# Compute 10-step loss on CIFAR-10\npython test.py seed=1 config_file=./configs/cifar10_discretized_256bins.yaml load_model=./pretrained-BFNs/cifar10_256d_ema.pt n_steps=10 n_repeats=100\n# Compute continuous-time loss on text8\npython test.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt n_steps=0 n_repeats=1\n```\n> [!IMPORTANT]\n> All computed results will be in nats-per-data-dimension. To convert to bits, divide by ln(2).\n\n## Sampling\n\nYou can sample from a pre-trained model as follows (change options as desired):\n\n```shell\n# Sample 4 binarized MNIST images using 100 steps\npython sample.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt samples_shape=\"[4, 28, 28, 1]\" n_steps=100 save_file=./samples_mnist.pt\n# Sample 4 CIFAR-10 16-bit images modeled as discretized data using 1000 steps\npython sample.py seed=1 config_file=./configs/cifar10_discretized_16bins.yaml load_model=./pretrained-BFNs/cifar10_16d_ema.pt samples_shape=\"[4, 32, 32, 3]\" n_steps=1000 save_file=./samples_cifar.pt\n# Sample 2 text8 sequences of length 256 using 100 steps\npython sample.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt samples_shape=\"[2, 256]\" n_steps=100 save_file=./samples_text8.pt\n```\n\nThe samples are stored as PyTorch tensors in the `save_file`, and can be visualized by loading them and then using the utilities `batch_to_images` and `batch_to_str` in `data.py`.\nFor example: \n```shell\n# batch_to_images returns a matplotlib Figure object\npython -c \"import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_mnist.pt')).savefig('mnist.png')\"\npython -c \"import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_cifar.pt')).savefig('cifar.png')\"\n# batch_to_str returns a list of str\npython -c \"import torch; from data import batch_to_str; print(batch_to_str(torch.load('./samples_text8.pt')))\"\n```\n\n## Reproducibility \n\nIf a high degree of reproducibility is desired (e.g. during sampling), set the following:\n\n```python\ntorch.set_float32_matmul_precision(\"highest\")\ntorch.use_deterministic_algorithms(True)\ntorch.backends.cudnn.benchmark = False\n```\n\n## Acknowledgements\n\nWe are grateful to [@Higgcz](https://github.com/Higgcz) for generous support with the experiment infrastructure and code release.\n"
  },
  {
    "path": "configs/cifar10_continuous_16bins.yaml",
    "content": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 16\ntrain_loader:\n  batch_size: 32\n  shuffle: True\n  num_workers: 8\n  pin_memory: True\n  drop_last: True\n  persistent_workers: True\nval_loader:\n  batch_size: 500\n  shuffle: False\n  num_workers: 8\n  pin_memory: True\nmodel:\n  net:\n    class_name: \"UNetVDM\"\n    parameters:\n      embedding_dim: 128\n      n_blocks: 32\n      n_attention_heads: 1\n      dropout_prob: 0.1\n      norm_groups: 32\n      input_channels: 3\n      use_fourier_features: True\n      attention_everywhere: False\n      image_size: 32\n  input_adapter:\n    class_name: \"FourierImageInputAdapter\"\n    parameters:\n      input_channels: 3\n      input_shape: [32, 32]\n      output_height: 3\n      add_pos_feats: False\n      add_mask: False\n  output_adapter:\n    class_name: \"OutputAdapter\"\n    parameters:\n      input_height: 131\n      output_channels: 3 # (r,g,b)\n      output_height: 1\n  bayesian_flow:\n    class_name: \"CtsBayesianFlow\"\n    parameters:\n      min_variance: 1e-3\n  loss:\n    class_name: \"CtsBayesianFlowLoss\"\n    parameters:\n      noise_pred: True\n  distribution_factory:\n    class_name: \"DeltaFactory\"\n    parameters: {}\noptimizer:\n  lr: 2e-4\n  betas: [0.9,0.99]\n  weight_decay: 0.01\n  eps: 1e-8\ntraining:\n  checkpoint_interval: 10_000\n  ema_decay: 0.9999\n  grad_clip_norm: 5.0\n  log_interval: 1\n  n_training_steps: 1_000_000\n  val_interval: 50_000\n  val_repeats: 100\n"
  },
  {
    "path": "configs/cifar10_continuous_256bins.yaml",
    "content": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 256\ntrain_loader:\n  batch_size: 32\n  shuffle: True\n  num_workers: 8\n  pin_memory: True\n  drop_last: True\n  persistent_workers: True\nval_loader:\n  batch_size: 500\n  shuffle: False\n  num_workers: 8\n  pin_memory: True\nmodel:\n  net:\n    class_name: \"UNetVDM\"\n    parameters:\n      embedding_dim: 128\n      n_blocks: 32\n      n_attention_heads: 1\n      dropout_prob: 0.1\n      norm_groups: 32\n      input_channels: 3\n      use_fourier_features: True\n      attention_everywhere: False\n      image_size: 32\n  input_adapter:\n    class_name: \"FourierImageInputAdapter\"\n    parameters:\n      input_channels: 3\n      input_shape: [32, 32]\n      output_height: 3\n      add_pos_feats: False\n      add_mask: False\n  output_adapter:\n    class_name: \"OutputAdapter\"\n    parameters:\n      input_height: 131\n      output_channels: 3 # (r,g,b)\n      output_height: 1\n  bayesian_flow:\n    class_name: \"CtsBayesianFlow\"\n    parameters:\n      min_variance: 1e-6\n  loss:\n    class_name: \"CtsBayesianFlowLoss\"\n    parameters:\n      noise_pred: True\n  distribution_factory:\n    class_name: \"DeltaFactory\"\n    parameters: {}\noptimizer:\n  lr: 2e-4\n  betas: [0.9,0.99]\n  weight_decay: 0.01\n  eps: 1e-8\ntraining:\n  checkpoint_interval: 10_000\n  ema_decay: 0.9999\n  grad_clip_norm: 5.0\n  log_interval: 1\n  n_training_steps: 1_000_000\n  val_interval: 50_000\n  val_repeats: 100\n"
  },
  {
    "path": "configs/cifar10_discretized_16bins.yaml",
    "content": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 16\ntrain_loader:\n  batch_size: 32\n  shuffle: True\n  num_workers: 8\n  pin_memory: True\n  drop_last: True\n  persistent_workers: True\nval_loader:\n  batch_size: 1000\n  shuffle: False\n  num_workers: 8\n  pin_memory: True\nmodel:\n  net:\n    class_name: \"UNetVDM\"\n    parameters:\n      embedding_dim: 128\n      n_blocks: 32\n      n_attention_heads: 1\n      dropout_prob: 0.1\n      norm_groups: 32\n      input_channels: 3\n      use_fourier_features: True\n      attention_everywhere: False\n      image_size: 32\n  input_adapter:\n    class_name: \"FourierImageInputAdapter\"\n    parameters:\n      input_channels: 3\n      input_shape: [32, 32]\n      output_height: 3\n      add_pos_feats: False\n      add_mask: False\n  output_adapter:\n    class_name: \"OutputAdapter\"\n    parameters:\n      input_height: 131\n      output_channels: 3 # (r,g,b)\n      output_height: 2 # mean, std\n  bayesian_flow:\n    class_name: \"CtsBayesianFlow\"\n    parameters:\n      min_variance: 1e-3\n  loss:\n    class_name: \"CtsBayesianFlowLoss\"\n    parameters:\n      noise_pred: True\n  distribution_factory:\n    class_name: \"DiscretizedNormalFactory\"\n    parameters:\n      num_bins: 16\n      clip: True\noptimizer:\n  lr: 2e-4\n  betas: [0.9,0.99]\n  weight_decay: 0.01\n  eps: 1e-8\ntraining:\n  checkpoint_interval: 10_000\n  ema_decay: 0.9999\n  grad_clip_norm: 5.0\n  log_interval: 1\n  n_training_steps: 1_000_000\n  val_interval: 50_000\n  val_repeats: 100\n"
  },
  {
    "path": "configs/cifar10_discretized_256bins.yaml",
    "content": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 256\ntrain_loader:\n  batch_size: 32\n  shuffle: True\n  num_workers: 8\n  pin_memory: True\n  drop_last: True\n  persistent_workers: True\nval_loader:\n  batch_size: 1000\n  shuffle: False\n  num_workers: 8\n  pin_memory: True\nmodel:\n  net:\n    class_name: \"UNetVDM\"\n    parameters:\n      embedding_dim: 128\n      n_blocks: 32\n      n_attention_heads: 1\n      dropout_prob: 0.1\n      norm_groups: 32\n      input_channels: 3\n      use_fourier_features: True\n      attention_everywhere: False\n      image_size: 32\n  input_adapter:\n    class_name: \"FourierImageInputAdapter\"\n    parameters:\n      input_channels: 3\n      input_shape: [32, 32]\n      output_height: 3\n      add_pos_feats: False\n      add_mask: False\n  output_adapter:\n    class_name: \"OutputAdapter\"\n    parameters:\n      input_height: 131\n      output_channels: 3 # (r,g,b)\n      output_height: 2 # mean, std\n  bayesian_flow:\n    class_name: \"CtsBayesianFlow\"\n    parameters:\n      min_variance: 1e-6\n  loss:\n    class_name: \"CtsBayesianFlowLoss\"\n    parameters:\n      noise_pred: True\n  distribution_factory:\n    class_name: \"DiscretizedNormalFactory\"\n    parameters:\n      num_bins: 256\n      clip: True\noptimizer:\n  lr: 2e-4\n  betas: [0.9,0.99]\n  weight_decay: 0.01\n  eps: 1e-8\ntraining:\n  checkpoint_interval: 10_000\n  ema_decay: 0.9999\n  grad_clip_norm: 5.0\n  log_interval: 1\n  n_training_steps: 1_000_000\n  val_interval: 50_000\n  val_repeats: 100\n"
  },
  {
    "path": "configs/mnist_discrete.yaml",
    "content": "meta:\n  neptune:\n  debug: False\ndata:\n  dataset: \"bin_mnist\"\ntrain_loader:\n  batch_size: 512\n  shuffle: True\n  num_workers: 8\n  pin_memory: True\n  drop_last: True\nval_loader:\n  batch_size: 1000\n  shuffle: False\n  num_workers: 8\n  pin_memory: True\nmodel:\n  net:\n    class_name: \"UNetModel\"\n    parameters:\n      image_size: 28\n      in_channels: 2\n      model_channels: 128\n      out_channels: 128\n      num_res_blocks: 2\n      attention_resolutions: [8,16]\n      dropout: 0.5\n      channel_mult: [1, 2, 2]\n      conv_resample: True\n      dims: 2\n      num_heads: 4\n      num_heads_upsample: -1\n      project_input: True\n      skip: True\n  input_adapter:\n    class_name: \"FourierImageInputAdapter\"\n    parameters:\n      input_channels: 1\n      input_shape: [28, 28]\n      output_height: 2\n      add_pos_feats: False\n  output_adapter:\n    class_name: \"OutputAdapter\"\n    parameters:\n      input_height: 256\n      output_channels: 1\n      output_height: 1\n  bayesian_flow:\n    class_name: \"DiscreteBayesianFlow\"\n    parameters:\n      n_classes: 2\n      max_sqrt_beta: 3\n      discretize: False\n  loss:\n    class_name: \"DiscreteBayesianFlowLoss\"\n    parameters: {}\n  distribution_factory:\n    class_name: \"BernoulliFactory\"\n    parameters: {}\noptimizer:\n  lr: 1e-4\n  betas: [0.9,0.98]\ntraining:\n  checkpoint_interval: 10_000\n  ema_decay: 0.9999\n  grad_clip_norm: 5.0\n  log_interval: 1\n  n_training_steps: 1_000_000\n  val_interval: 50_000\n  val_repeats: 1000"
  },
  {
    "path": "configs/text8_discrete.yaml",
    "content": "meta:\n  neptune:\n  debug: False\ndata:\n  dataset: \"text8\"\n  seq_len: 256\ntrain_loader:\n  batch_size: 416\n  shuffle: True\n  num_workers: 8\n  pin_memory: True\n  drop_last: True\nval_loader:\n  batch_size: 200\n  shuffle: True\n  num_workers: 8\n  pin_memory: True\nmodel:\n  net:\n    class_name: \"GPT\"\n    parameters:\n      vocab_size: 27\n      n_layer: 24\n      n_head: 12\n      n_embd: 768\n      dropout: 0.0\n      skip: True\n      bias: True\n  input_adapter:\n    class_name: \"TextInputAdapter\"\n    parameters:\n      vocab_size: 27\n      seq_len: 256\n      output_size: 768\n      learn_pos_embedding: False\n  output_adapter: null\n  bayesian_flow:\n    class_name: \"DiscreteBayesianFlow\"\n    parameters:\n      n_classes: 27\n      max_sqrt_beta: 0.75\n  loss:\n    class_name: \"DiscreteBayesianFlowLoss\"\n    parameters: {}\n  distribution_factory:\n    class_name: \"CategoricalFactory\"\n    parameters: {}\noptimizer:\n  lr: 1e-4\n  betas: [0.9, 0.98]\n  weight_decay: 0.01\ntraining:\n  accumulate: 1\n  checkpoint_interval: 10_000\n  ema_decay: 0.9999\n  grad_clip_norm: 5\n  log_interval: 1\n  max_val_batches: 5_000\n  n_training_steps: 10_000_000\n  val_interval: 100_000\n  val_repeats: 1"
  },
  {
    "path": "data.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nimport os\nimport pathlib\nimport pickle\nimport zipfile\nfrom typing import Union\n\nimport numpy as np\nimport requests\nimport torch\nimport torchvision\nfrom matplotlib import pyplot as plt\nfrom omegaconf import DictConfig\nfrom torch.utils.data import Dataset, random_split\nfrom torchvision import transforms\nfrom torchvision.utils import make_grid\n\nfrom utils_model import quantize\n\nTEXT8_CHARS = list(\"_abcdefghijklmnopqrstuvwxyz\")\n\n\ndef bin_mnist_transform(x):\n    return torch.bernoulli(x.permute(1, 2, 0).contiguous()).int()\n\n\ndef bin_mnist_cts_transform(x):\n    return torch.bernoulli(x.permute(1, 2, 0).contiguous()) - 0.5\n\n\ndef rgb_image_transform(x, num_bins=256):\n    return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous()\n\n\nclass MyLambda(torchvision.transforms.Lambda):\n    def __init__(self, lambd, arg1):\n        super().__init__(lambd)\n        self.arg1 = arg1\n\n    def __call__(self, x):\n        return self.lambd(x, self.arg1)\n\n\nclass CIFAR10(torchvision.datasets.CIFAR10):\n    def __getitem__(self, idx):\n        return super().__getitem__(idx)[0]\n\n\nclass MNIST(torchvision.datasets.MNIST):\n    def __getitem__(self, idx):\n        return super().__getitem__(idx)[0]\n\n\ndef make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:\n    \"\"\"\n    Mandatory keys: dataset (must be cifar10, mnist, bin_mnist, bin_mnist_cts or text8), data_dir\n    Optional for vision: num_bins (default 256), val_frac (default 0.01), horizontal_flip (default: False)\n    Mandatory for text: seq_len\n    \"\"\"\n    num_bins = cfg.get(\"num_bins\", 256)\n    if cfg.dataset == \"cifar10\":\n        train_transform_list = [transforms.ToTensor()]\n        if cfg.get(\"horizontal_flip\", False):\n            train_transform_list.append(transforms.RandomHorizontalFlip())\n        train_transform_list.append(MyLambda(rgb_image_transform, num_bins))\n        train_transform = transforms.Compose(train_transform_list)\n        test_transform = transforms.Compose([transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins)])\n        train_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transform)\n        val_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=test_transform)\n        test_set = CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transform)\n\n    elif cfg.dataset == \"mnist\":\n        transform = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                MyLambda(rgb_image_transform, num_bins),\n            ]\n        )\n        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)\n        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)\n        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)\n\n    elif cfg.dataset == \"bin_mnist\":\n        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_transform)])\n        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)\n        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)\n        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)\n\n    elif cfg.dataset == \"bin_mnist_cts\":\n        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_cts_transform)])\n        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)\n        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)\n        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)\n\n    elif cfg.dataset == \"text8\":\n        train_set = Text8Dataset(cfg.data_dir, \"train\", download=True, seq_len=cfg.seq_len)\n        val_set = Text8Dataset(cfg.data_dir, \"val\", download=True, seq_len=cfg.seq_len)\n        test_set = Text8Dataset(cfg.data_dir, \"test\", download=True, seq_len=cfg.seq_len)\n    else:\n        raise NotImplementedError(cfg.dataset)\n\n    if cfg.dataset != \"text8\":\n        # For vision datasets we split the train set into train and val\n        val_frac = cfg.get(\"val_frac\", 0.01)\n        train_val_split = [1.0 - val_frac, val_frac]\n        seed = 2147483647\n        train_set = random_split(train_set, train_val_split, generator=torch.Generator().manual_seed(seed))[0]\n        val_set = random_split(val_set, train_val_split, generator=torch.Generator().manual_seed(seed))[1]\n\n    return train_set, val_set, test_set\n\n\ndef prepare_text8(data_dir: pathlib.Path):\n    data_dir.mkdir(parents=True, exist_ok=True)\n    data_url = \"http://mattmahoney.net/dc/text8.zip\"\n    with open(data_dir / \"text8.zip\", \"wb\") as f:\n        print(\"Downloading text8\")\n        f.write(requests.get(data_url).content)\n        print(\"Done\")\n    with zipfile.ZipFile(data_dir / \"text8.zip\") as f:\n        f.extractall(data_dir)\n    os.remove(data_dir / \"text8.zip\")\n    data = (data_dir / \"text8\").read_text()\n\n    # get all the unique characters that occur in this text\n    chars = sorted(list(set(data)))\n    vocab_size = len(chars)\n    print(\"all the unique characters:\", \"\".join(chars))\n    print(f\"vocab size: {vocab_size:,}\")\n\n    # create a mapping from characters to integers\n    stoi = {ch: i for i, ch in enumerate(chars)}\n    itos = {i: ch for i, ch in enumerate(chars)}\n\n    def encode(s):\n        return [stoi[c] for c in s]  # encoder: take a string, output a list of integers\n\n    # encode both to integers\n    n = len(data)\n    train_data = data[: int(n * 0.9)]\n    val_data = data[int(n * 0.9) : int(n * 0.95)]\n    test_data = data[int(n * 0.95) :]\n    train_ids = encode(train_data)\n    val_ids = encode(val_data)\n    test_ids = encode(test_data)\n    print(f\"train has {len(train_ids):,} tokens\")\n    print(f\"val has {len(val_ids):,} tokens\")\n    print(f\"test has {len(test_ids):,} tokens\")\n\n    # export to bin files\n    train_ids = np.array(train_ids, dtype=np.uint16)\n    val_ids = np.array(val_ids, dtype=np.uint16)\n    test_ids = np.array(test_ids, dtype=np.uint16)\n    train_ids.tofile(data_dir / \"train.bin\")\n    val_ids.tofile(data_dir / \"val.bin\")\n    test_ids.tofile(data_dir / \"test.bin\")\n    print(f\"Saved to {data_dir / 'train.bin'}, {data_dir / 'val.bin'}, {data_dir / 'test.bin'}\")\n\n    # save the meta information as well, to help us encode/decode later\n    meta = {\n        \"vocab_size\": vocab_size,\n        \"itos\": itos,\n        \"stoi\": stoi,\n    }\n    with open(os.path.join(data_dir / \"meta.pkl\"), \"wb\") as f:\n        pickle.dump(meta, f)\n\n    print(f\"text8 dataset downloaded and prepared in dir {data_dir}\")\n\n\nclass Text8Dataset(Dataset):\n    def __init__(self, data_dir: Union[str, pathlib.Path], split: str, download: bool, seq_len: int):\n        \"\"\"\n        seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context.\n        context is only used for correct preparation of val/test sets.\n        \"\"\"\n        self.root_dir = pathlib.Path(data_dir)\n        self.split = split\n        self.seq_len = seq_len\n        fname = {\"train\": \"train.bin\", \"val\": \"val.bin\", \"test\": \"test.bin\"}[self.split]\n        assert self.split in [\"train\", \"val\", \"test\"]\n        data_dir = self.root_dir / \"text8\"\n        if not os.path.exists(data_dir):\n            if download:\n                prepare_text8(data_dir)\n            else:\n                raise NotADirectoryError(f\"dir {data_dir} does not exist and download is False\")\n        self.data = np.memmap(data_dir / fname, np.uint16, \"r\")\n\n    def __getitem__(self, index) -> torch.Tensor:\n        seq = torch.from_numpy(self.data[index : index + self.seq_len].astype(np.int64))\n        return seq\n\n    def __len__(self):\n        return self.data.size - self.seq_len\n\n\ndef char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) -> str:\n    \"\"\"Decode a 1D sequence of character IDs to a string.\"\"\"\n    return \"\".join([TEXT8_CHARS[i] for i in char_ids])\n\n\ndef batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) -> list[str]:\n    \"\"\"Decode a batch of character IDs to a list of strings.\"\"\"\n    return [char_ids_to_str(row_char_ids) for row_char_ids in text_batch]\n\n\ndef batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt.Figure:\n    if ncols is None:\n        ncols = math.ceil(math.sqrt(len(image_batch)))\n    if image_batch.size(-1) == 3:  # for color images (CIFAR-10)\n        image_batch = (image_batch + 1) / 2\n    grid = make_grid(image_batch.permute(0, 3, 1, 2), ncols, pad_value=1).permute(1, 2, 0)\n    fig = plt.figure(figsize=(grid.size(1) / 30, grid.size(0) / 30))\n    plt.imshow(grid.cpu().clip(min=0, max=1), interpolation=\"nearest\")\n    plt.grid(False)\n    plt.axis(\"off\")\n    return fig\n"
  },
  {
    "path": "env.yml",
    "content": "name: bfn\nchannels:\n  - pytorch\n  - nvidia\ndependencies:\n  - python=3.9\n  - pytorch=2.0.0\n  - pytorch-cuda=11.8\n  - torchvision=0.15.0\n  - pip\n  - pip:\n    - accelerate==0.19.0\n    - matplotlib\n    - omegaconf\n    - rich\n"
  },
  {
    "path": "model.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"\nThis file implements the Bayesian Flow and BFN loss for continuous and discrete variables.\nFinally it implements the BFN using these objects.\nFor consistency we use always use a tuple to store input parameters.\nIt has just one element for discrete data (the probabilities) and two for continuous/discretized (mean & variance).\nThe probability distributions and network architectures are defined in probability.py and networks dir.\n\"Cts\" is an abbreviation of \"Continuous\".\n\"\"\"\n\nimport math\nfrom abc import abstractmethod, ABC\nfrom typing import Union, Optional\n\nimport torch\nimport torch.distributions as D\nimport torch.nn.functional as F\nfrom torch import nn, Tensor\n\nfrom probability import (\n    DiscreteDistributionFactory,\n    CtsDistributionFactory,\n    PredDistToDataDistFactory,\n    DiscretizedCtsDistribution,\n)\nfrom utils_model import sandwich, float_to_idx\n\n\nclass BayesianFlow(nn.Module, ABC):\n    def __init__(self):\n        super().__init__()\n\n    @abstractmethod\n    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]:\n        \"\"\"Returns the initial input params (for a batch) at t=0. Used during sampling.\n        For discrete data, the tuple has length 1 and contains the initial class probabilities.\n        For continuous data, the tuple has length 2 and contains the mean and precision.\"\"\"\n        pass\n\n    @abstractmethod\n    def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:\n        \"\"\"Utility method to convert input distribution params to network inputs if needed.\"\"\"\n        pass\n\n    @abstractmethod\n    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:\n        \"\"\"Returns the alpha at step i of total n_steps according to the flow schedule. Used:\n        a) during sampling, when i and alpha are the same for all samples in the batch.\n        b) during discrete time loss computation, when i and alpha are different for samples in the batch.\"\"\"\n        pass\n\n    @abstractmethod\n    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:\n        \"\"\"Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used:\n        a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net.\n        b) during discrete time loss computation when alpha are different for samples in the batch.\"\"\"\n        pass\n\n    @abstractmethod\n    def update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]:\n        \"\"\"Updates the distribution parameters using Bayes' theorem in light of noisy sample y.\n        Used during sampling when alpha is the same for the whole batch.\"\"\"\n        pass\n\n    @abstractmethod\n    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:\n        \"\"\"Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data.\n        Used during training when t (and thus accuracies) are different for different samples in the batch.\n        For discrete data, the returned tuple has length 1 and contains the class probabilities.\n        For continuous data, the returned tuple has length 2 and contains the mean and precision.\"\"\"\n        pass\n\n\nclass Loss(nn.Module, ABC):\n    def __init__(self):\n        super().__init__()\n\n    @abstractmethod\n    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor:\n        \"\"\"Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1).\n        The input params are only used when the network is parameterized to predict the noise for continuous data.\"\"\"\n        pass\n\n    @abstractmethod\n    def discrete_time_loss(\n        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples: int = 20\n    ) -> Tensor:\n        \"\"\"Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) using\n        n_samples for Monte Carlo estimation of the discrete loss.\n        The input params are only used when the network is parameterized to predict the noise for continuous data.\"\"\"\n        pass\n\n    @abstractmethod\n    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:\n        \"\"\"Returns the reconstruction loss, i.e. the final cost of transmitting clean data.\n        The input params are only used when the network is parameterized to predict the noise for continuous data.\"\"\"\n        pass\n\n\n# Continuous or Discretized data\n\n\nclass CtsBayesianFlow(BayesianFlow):\n    def __init__(\n        self,\n        min_variance: float = 1e-6,\n    ):\n        super().__init__()\n        self.min_variance = min_variance\n\n    @torch.no_grad()\n    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:\n        post_var = torch.pow(self.min_variance, t)\n        alpha_t = 1 - post_var\n        mean_mean = alpha_t * data\n        mean_var = alpha_t * post_var\n        mean_std_dev = mean_var.sqrt()\n        noise = torch.randn(mean_mean.shape, device=mean_mean.device)\n        mean = mean_mean + (mean_std_dev * noise)\n        # We don't need to compute the variance because it is not needed by the network, so set it to None\n        input_params = (mean, None)\n        return input_params\n\n    def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:\n        return params[0]  # Only the mean is used by the network\n\n    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]:\n        return torch.zeros(*data_shape, device=device), 1.0\n\n    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:\n        sigma_1 = math.sqrt(self.min_variance)\n        return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps))\n\n    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:\n        dist = D.Normal(x, 1.0 / alpha**0.5)\n        return dist\n\n    def update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]:\n        input_mean, input_precision = input_params\n        new_precision = input_precision + alpha\n        new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precision\n        return new_mean, new_precision\n\n\nclass CtsBayesianFlowLoss(Loss):\n    def __init__(\n        self,\n        bayesian_flow: CtsBayesianFlow,\n        distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],\n        min_loss_variance: float = -1,\n        noise_pred: bool = True,\n    ):\n        super().__init__()\n        self.bayesian_flow = bayesian_flow\n        self.distribution_factory = distribution_factory\n        self.min_loss_variance = min_loss_variance\n        self.C = -0.5 * math.log(bayesian_flow.min_variance)\n        self.noise_pred = noise_pred\n        if self.noise_pred:\n            self.distribution_factory.log_dev = False\n            self.distribution_factory = PredDistToDataDistFactory(\n                self.distribution_factory, self.bayesian_flow.min_variance\n            )\n\n    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:\n        output_params = sandwich(output_params)\n        t = t.flatten(start_dim=1).float()\n        posterior_var = torch.pow(self.bayesian_flow.min_variance, t)\n        flat_target = data.flatten(start_dim=1)\n        pred_dist = self.distribution_factory.get_dist(output_params, input_params, t)\n        pred_mean = pred_dist.mean\n        mse_loss = (pred_mean - flat_target).square()\n        if self.min_loss_variance > 0:\n            posterior_var = posterior_var.clamp(min=self.min_loss_variance)\n        loss = self.C * mse_loss / posterior_var\n        return loss\n\n    def discrete_time_loss(\n        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10\n    ) -> Tensor:\n        output_params = sandwich(output_params)\n        t = t.flatten(start_dim=1).float()\n        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)\n        if hasattr(output_dist, \"probs\"):  # output distribution is discretized normal\n            flat_target = data.flatten(start_dim=1)\n            t = t.flatten(start_dim=1)\n            i = t * n_steps + 1  # since t = (i - 1) / n\n            alpha = self.bayesian_flow.get_alpha(i, n_steps)\n            sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)\n            receiver_mix_wts = sandwich(output_dist.probs)\n            receiver_mix_dist = D.Categorical(probs=receiver_mix_wts, validate_args=False)\n            receiver_components = D.Normal(\n                output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False\n            )\n            receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False)\n            y = sender_dist.sample(torch.Size([n_samples]))\n            loss = (\n                (sender_dist.log_prob(y) - receiver_dist.log_prob(y))\n                .mean(0)\n                .flatten(start_dim=1)\n                .mean(1, keepdims=True)\n            )\n        else:  # output distribution is normal\n            pred_mean = output_dist.mean\n            flat_target = data.flatten(start_dim=1)\n            mse_loss = (pred_mean - flat_target).square()\n            i = t * n_steps + 1\n            alpha = self.bayesian_flow.get_alpha(i, n_steps)\n            loss = alpha * mse_loss / 2\n        return n_steps * loss\n\n    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:\n        output_params = sandwich(output_params)\n        flat_data = data.flatten(start_dim=1)\n        t = torch.ones_like(data).flatten(start_dim=1).float()\n        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)\n\n        if hasattr(output_dist, \"probs\"):  # output distribution is discretized normal\n            reconstruction_loss = -output_dist.log_prob(flat_data)\n        else:  # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2)\n            if self.bayesian_flow.min_variance == 1e-3:  # used for 16 bin CIFAR10\n                noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance)\n                num_bins = 16\n            else:\n                noise_dev = math.sqrt(self.bayesian_flow.min_variance)\n                num_bins = 256\n            mean = output_dist.mean.flatten(start_dim=1)\n            final_dist = D.Normal(mean, noise_dev)\n            final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1)\n            reconstruction_loss = -final_dist.log_prob(flat_data)\n        return reconstruction_loss\n\n\n# Discrete Data\n\n\nclass DiscreteBayesianFlow(BayesianFlow):\n    def __init__(\n        self,\n        n_classes: int,\n        min_sqrt_beta: float = 1e-10,\n        discretize: bool = False,\n        epsilon: float = 1e-6,\n        max_sqrt_beta: float = 1,\n    ):\n        super().__init__()\n        self.n_classes = n_classes\n        self.min_sqrt_beta = min_sqrt_beta\n        self.discretize = discretize\n        self.epsilon = epsilon\n        self.max_sqrt_beta = max_sqrt_beta\n        self.uniform_entropy = math.log(self.n_classes)\n\n    def t_to_sqrt_beta(self, t):\n        return t * self.max_sqrt_beta\n\n    def count_dist(self, x, beta=None):\n        mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1\n        std_dev = math.sqrt(self.n_classes)\n        if beta is not None:\n            mean = mean * beta\n            std_dev = std_dev * beta.sqrt()\n        return D.Normal(mean, std_dev, validate_args=False)\n\n    def count_sample(self, x, beta):\n        return self.count_dist(x, beta).rsample()\n\n    @torch.no_grad()\n    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]:\n        return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,)\n\n    @torch.no_grad()\n    def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:\n        params = params[0]\n        if self.n_classes == 2:\n            params = params * 2 - 1  # We scale-shift here for MNIST instead of in the network like for text\n            params = params[..., :1]\n        return params\n\n    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:\n        return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1)\n\n    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:\n        e_x = F.one_hot(x.long(), self.n_classes)\n        alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alpha\n        dist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5)\n        return dist\n\n    def update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]:\n        new_input_params = input_params[0] * y.exp()\n        new_input_params /= new_input_params.sum(-1, keepdims=True)\n        return (new_input_params,)\n\n    @torch.no_grad()\n    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:\n        if self.discretize:\n            data = float_to_idx(data, self.n_classes)\n        sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon))\n        lo_beta = sqrt_beta < self.min_sqrt_beta\n        sqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta)\n        beta = sqrt_beta.square().unsqueeze(-1)\n        logits = self.count_sample(data, beta)\n        probs = F.softmax(logits, -1)\n        probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs)\n        if self.n_classes == 2:\n            probs = probs[..., :1]\n            probs = probs.reshape_as(data)\n        input_params = (probs,)\n        return input_params\n\n\nclass DiscreteBayesianFlowLoss(Loss):\n    def __init__(\n        self,\n        bayesian_flow: DiscreteBayesianFlow,\n        distribution_factory: DiscreteDistributionFactory,\n    ):\n        super().__init__()\n        self.bayesian_flow = bayesian_flow\n        self.distribution_factory = distribution_factory\n        self.K = self.bayesian_flow.n_classes\n\n    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:\n        flat_output = sandwich(output_params)\n        pred_probs = self.distribution_factory.get_dist(flat_output).probs\n        flat_target = data.flatten(start_dim=1)\n        if self.bayesian_flow.discretize:\n            flat_target = float_to_idx(flat_target, self.K)\n        tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K)\n        kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1)\n        t = t.flatten(start_dim=1).float()\n        loss = t * (self.bayesian_flow.max_sqrt_beta**2) * kl\n        return loss\n\n    def discrete_time_loss(\n        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10\n    ) -> Tensor:\n        flat_target = data.flatten(start_dim=1)\n        if self.bayesian_flow.discretize:\n            flat_target = float_to_idx(flat_target, self.K)\n        i = t * n_steps + 1\n        alpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1)\n        sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)\n\n        flat_output = sandwich(output_params)\n        receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs\n        receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2))\n        classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0)\n        receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1))\n        receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components)\n\n        y = sender_dist.sample(torch.Size([n_samples]))\n        loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True)\n        return loss\n\n    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:\n        flat_outputs = sandwich(output_params)\n        flat_data = data.flatten(start_dim=1)\n        output_dist = self.distribution_factory.get_dist(flat_outputs)\n        return -output_dist.log_prob(flat_data)\n\n\nclass BFN(nn.Module):\n    def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss):\n        super().__init__()\n        self.net = net\n        self.bayesian_flow = bayesian_flow\n        self.loss = loss\n\n    @staticmethod\n    @torch.no_grad()\n    def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:\n        if n_steps == 0 or n_steps is None:\n            t = torch.rand(data.size(0), device=data.device).unsqueeze(-1)\n        else:\n            t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps\n        t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data)\n        return t\n\n    def forward(\n        self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None\n    ) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]:\n        \"\"\"\n        Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss.\n        t is sampled randomly if None. If t is not None, expect t.shape == data.shape.\n        \"\"\"\n\n        t = self.sample_t(data, n_steps) if t is None else t\n        # sample input parameter flow\n        input_params = self.bayesian_flow(data, t)\n        net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)\n\n        # compute output distribution parameters\n        output_params: Tensor = self.net(net_inputs, t)\n\n        # compute KL loss in float32\n        with torch.autocast(device_type=data.device.type if data.device.type != \"mps\" else \"cpu\", enabled=False):\n            if n_steps == 0 or n_steps is None:\n                loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t)\n            else:\n                loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps)\n\n        # loss shape is (batch_size, 1)\n        return loss.mean()\n\n    @torch.inference_mode()\n    def compute_reconstruction_loss(self, data: Tensor) -> Tensor:\n        t = torch.ones_like(data).float()\n        input_params = self.bayesian_flow(data, t)\n        net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)\n        output_params: Tensor = self.net(net_inputs, t)\n        return self.loss.reconstruction_loss(data, output_params, input_params).flatten(start_dim=1).mean()\n\n    @torch.inference_mode()\n    def sample(self, data_shape: tuple, n_steps: int) -> Tensor:\n        device = next(self.parameters()).device\n        input_params = self.bayesian_flow.get_prior_input_params(data_shape, device)\n        distribution_factory = self.loss.distribution_factory\n\n        for i in range(1, n_steps + 1):\n            t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps\n            output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)\n            output_sample = distribution_factory.get_dist(output_params, input_params, t).sample()\n            output_sample = output_sample.reshape(*data_shape)\n            alpha = self.bayesian_flow.get_alpha(i, n_steps)\n            y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample()\n            input_params = self.bayesian_flow.update_input_params(input_params, y, alpha)\n\n        t = torch.ones(*data_shape, device=device)\n        output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)\n        output_sample = distribution_factory.get_dist(output_params, input_params, t).mode\n        output_sample = output_sample.reshape(*data_shape)\n        return output_sample\n"
  },
  {
    "path": "networks/__init__.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__all__ = (\n    \"GPT\",\n    \"UNetVDM\",\n    \"UNetModel\",\n    \"adapters\",\n)\n\nfrom .transformer import GPT\nfrom .unet_vdm import UNetVDM\nfrom .unet_improved import UNetModel\nfrom . import adapters\n"
  },
  {
    "path": "networks/adapters.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch import nn\n\nfrom utils_model import sandwich, pe_encode, pe_encode_float\n\n\nclass TextInputAdapter(nn.Module):\n    \"\"\"\n    A module to convert sequences of text class tokens to embedding tokens with learned positional embeddings.\n    \"\"\"\n\n    def __init__(\n        self,\n        vocab_size: int,\n        seq_len: int,\n        output_size: int = 256,\n        learn_pos_embedding: bool = False,\n    ):\n        super().__init__()\n        self.learn_pos_embedding = learn_pos_embedding\n        if learn_pos_embedding:\n            self.pos_embedding = nn.Embedding(seq_len, output_size)\n        else:\n            self.register_buffer(\"pos_embedding\", pe_encode(seq_len, output_size))\n        self.inp_embedding = nn.Linear(vocab_size, output_size)\n        self.t_embedding = nn.Linear(1, output_size)\n\n    def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor:\n        inp_emb = self.inp_embedding(2 * probs - 1)\n        if self.learn_pos_embedding:\n            pos_emb = self.pos_embedding(\n                torch.arange(0, probs.size(1)).to(probs.device)\n            )\n        else:\n            pos_emb = self.pos_embedding\n        pos_emb = pos_emb.unsqueeze(0).expand(inp_emb.size(0), -1, -1)\n        t_emb = self.t_embedding((2 * t - 1).unsqueeze(-1))\n        output = inp_emb + pos_emb + t_emb\n\n        return output\n\n\nclass FourierImageInputAdapter(nn.Module):\n    \"\"\"\n    A module to convert 2D image coordinates into a set of vectors represented as a matrix, with fourier position codes.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_channels: int = 3,\n        input_shape: Tuple[int, int] = (224, 224),\n        n_freq_bands: int = 64,\n        output_height: int = 256,\n        value_res: int = -1,\n        mask_res: int = -1,\n        add_pos_feats: bool = True,\n        add_mask: bool = True,\n        learn_pos_feats: bool = False,\n        pos_embed_size: int = 32,\n        init_scale: float = 0.02,\n    ):\n        super().__init__()\n        self.input_shape = input_shape\n        self.n_freq_bands = n_freq_bands\n        self.value_res = value_res\n        self.mask_res = mask_res\n        self.add_pos_feats = add_pos_feats\n        self.add_mask = add_mask\n        if learn_pos_feats:\n            pos_feats = nn.Parameter(\n                init_scale\n                * torch.randn(1, input_shape[0] * input_shape[1], pos_embed_size)\n            )\n            self.register_parameter(\"pos_feats\", pos_feats)\n        else:\n            x = torch.linspace(-1.0, 1.0, steps=input_shape[0])\n            y = torch.linspace(-1.0, 1.0, steps=input_shape[1])\n            x_pos, y_pos = torch.meshgrid(x, y, indexing=\"ij\")\n            pos = torch.stack((x_pos, y_pos), dim=-1)\n            pos = pos.reshape(-1, 2)\n            x_bands = torch.linspace(1.0, input_shape[0] / 2, steps=n_freq_bands)\n            y_bands = torch.linspace(1.0, input_shape[1] / 2, steps=n_freq_bands)\n            bands = torch.stack((x_bands, y_bands), dim=0)\n            vals = pos[:, :, None] * bands[None, :, :]\n            vals = math.pi * vals.reshape(vals.shape[0], -1)\n            pos_feats = torch.cat([vals.sin(), vals.cos()], dim=-1)\n            pos_feats = torch.cat([pos_feats, pos], dim=-1)\n            self.register_buffer(\"pos_feats\", pos_feats)\n        img_feat_height = input_channels\n        pos_feat_height = pos_feats.size(-1)\n        if self.mask_res > 0:\n            mask_feat_height = (n_freq_bands * 2) + 1\n        else:\n            mask_feat_height = 1\n        all_feat_height = img_feat_height\n        if add_mask:\n            all_feat_height += mask_feat_height\n        if add_pos_feats:\n            all_feat_height += pos_feat_height\n        self.output_projection = None\n        if output_height != all_feat_height:\n            self.output_projection = nn.Linear(all_feat_height, output_height)\n\n    def forward(self, img: Tensor, t: Tensor) -> Tensor:\n        flat_img = sandwich(img)\n        flat_t = sandwich(t)\n        t_feats = (flat_t.float()[..., :1] * 2) - 1\n        if self.mask_res > 0:\n            t_feats = torch.cat(\n                [\n                    t_feats,\n                    pe_encode_float(\n                        t_feats, self.mask_res, self.n_freq_bands * 2\n                    ).flatten(start_dim=2),\n                ],\n                -1,\n            )\n        fourier_feats = self.pos_feats.expand(img.size(0), -1, -1)\n        all_feat_list = [flat_img]\n        if self.add_mask:\n            all_feat_list.append(t_feats)\n        if self.add_pos_feats:\n            all_feat_list.append(fourier_feats)\n        all_feats = torch.cat(all_feat_list, dim=-1)\n        if self.output_projection is None:\n            output = all_feats\n        else:\n            output = self.output_projection(all_feats)\n        return output\n\n\nclass OutputAdapter(nn.Module):\n    def __init__(self, input_height: int, output_channels: int, output_height: int):\n        super().__init__()\n        self.output_channels = output_channels\n        self.output_height = output_height\n        self.output_projection = nn.Linear(\n            input_height, output_channels * output_height\n        )\n\n    def forward(self, inp: torch.Tensor) -> torch.Tensor:\n        output = self.output_projection(inp)\n        return output.reshape(\n            output.size(0), -1, self.output_channels, self.output_height\n        )\n"
  },
  {
    "path": "networks/transformer.py",
    "content": "# Source: https://github.com/karpathy/nanoGPT\n#\n# MIT License\n#\n# Copyright (c) 2022 Andrej Karpathy\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n#\n# Modifications:\n# - Added data_adapters to GPT to preprocess the inputs and (optionally) postprocess the outputs\n# - Added the `skip` option to concat the input and output of the network before the final projection\n# - Added time `t` as an input to `forward()`\n\nimport math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef gelu(x):\n    return F.gelu(x, approximate=\"tanh\")\n\n\nclass LayerNorm(nn.Module):\n    \"\"\"LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False\"\"\"\n\n    def __init__(self, ndim, bias):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(ndim))\n        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None\n\n    def forward(self, input):\n        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)\n\n\nclass SelfAttention(nn.Module):\n    def __init__(self, n_head, n_embd, dropout, bias, is_causal):\n        super().__init__()\n        assert n_embd % n_head == 0\n\n        # key, query, value projections for all heads, but in a batch\n        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)\n\n        # output projection\n        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)\n\n        # regularization\n        self.attn_dropout = nn.Dropout(dropout)\n        self.resid_dropout = nn.Dropout(dropout)\n        self.n_head = n_head\n        self.n_embd = n_embd\n        self.dropout = dropout\n        self.is_causal = is_causal\n\n    def forward(self, x):\n        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)\n\n        # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)\n        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)\n        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)\n        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)\n\n        # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)\n        y = torch.nn.functional.scaled_dot_product_attention(\n            q, k, v, dropout_p=self.dropout if self.training else 0, is_causal=self.is_causal\n        )\n        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side\n\n        # output projection\n        y = self.resid_dropout(self.c_proj(y))\n        return y\n\n\nclass MLP(nn.Module):\n    def __init__(self, n_embd, dropout, bias):\n        super().__init__()\n        self.c_fc = nn.Linear(n_embd, 4 * n_embd, bias=bias)\n        self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias)\n        self.dropout = nn.Dropout(dropout)\n\n    def forward(self, x):\n        x = self.c_fc(x)\n        x = gelu(x)\n        x = self.c_proj(x)\n        x = self.dropout(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, n_head, n_embd, dropout, bias, is_causal):\n        super().__init__()\n        self.ln_1 = LayerNorm(n_embd, bias=bias)\n        self.attn = SelfAttention(n_head, n_embd, dropout, bias, is_causal)\n        self.ln_2 = LayerNorm(n_embd, bias=bias)\n        self.mlp = MLP(n_embd, dropout, bias)\n\n    def forward(self, x):\n        x = x + self.attn(self.ln_1(x))\n        x = x + self.mlp(self.ln_2(x))\n        return x\n\n\nclass GPT(nn.Module):\n    def __init__(\n        self,\n        data_adapters: dict,\n        vocab_size: int,\n        n_layer: int = 12,\n        n_head: int = 12,\n        n_embd: int = 768,\n        dropout: float = 0.0,\n        bias: bool = True,\n        skip: bool = False,\n        is_causal: bool = False,\n    ):\n        super().__init__()\n        self.n_layer = n_layer\n        self.n_head = n_head\n        self.n_embd = n_embd\n\n        self.input_adapter = data_adapters[\"input_adapter\"]\n        self.output_adapter = data_adapters[\"output_adapter\"]\n        self.transformer = nn.ModuleDict(\n            dict(\n                drop=nn.Dropout(dropout),\n                h=nn.ModuleList([Block(n_head, n_embd, dropout, bias, is_causal) for _ in range(n_layer)]),\n                ln_f=LayerNorm(n_embd, bias=bias),\n            )\n        )\n        self.is_causal = is_causal\n        if self.is_causal:\n            self.skip = False\n        else:\n            self.skip = skip\n        if skip:\n            self.lm_head = nn.Linear(2 * n_embd, vocab_size, bias=bias)\n        else:\n            self.lm_head = nn.Linear(n_embd, vocab_size, bias=bias)\n\n        # init all weights\n        self.apply(self._init_weights)\n\n        # apply special scaled init to the residual projections, per GPT-2 paper\n        for pn, p in self.named_parameters():\n            if pn.endswith(\"c_proj.weight\"):\n                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layer))\n\n        # report number of parameters\n        print(f\"number of parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6:.2f}M\")\n\n    def _init_weights(self, module):\n        if isinstance(module, nn.Linear):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n\n    def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n        x_in = self.input_adapter(data, t)\n        x = self.transformer.drop(x_in)\n        for block in self.transformer.h:\n            x = block(x)\n        x = self.transformer.ln_f(x)\n        if self.skip:\n            x = torch.cat([x, x_in], -1)\n        logits = self.output_adapter(self.lm_head(x)) if self.output_adapter else self.lm_head(x)\n        return logits\n\n    def get_optim_groups(self, weight_decay: float):\n        decay = set()\n        no_decay = set()\n        whitelist_weight_modules = (torch.nn.Linear,)\n        blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)\n        for mn, m in self.named_modules():\n            for pn, p in m.named_parameters():\n                fpn = \"%s.%s\" % (mn, pn) if mn else pn  # full param name\n                # random note: because named_modules and named_parameters are recursive\n                # we will see the same tensors p many many times. but doing it this way\n                # allows us to know which parent module any tensor p belongs to...\n                if pn.endswith(\"bias\"):\n                    # all biases will not be decayed\n                    no_decay.add(fpn)\n                elif pn.endswith(\"weight\") and isinstance(m, whitelist_weight_modules):\n                    # weights of whitelist modules will be weight decayed\n                    decay.add(fpn)\n                elif pn.endswith(\"weight\") and isinstance(m, blacklist_weight_modules):\n                    # weights of blacklist modules will NOT be weight decayed\n                    no_decay.add(fpn)\n\n        # We don't use weight tying so comment this out\n        # decay.remove('lm_head.weight')\n\n        # validate that we considered every parameter\n        param_dict = {pn: p for pn, p in self.named_parameters()}\n        inter_params = decay & no_decay\n        union_params = decay | no_decay\n        assert len(inter_params) == 0, \"parameters %s made it into both decay/no_decay sets!\" % (str(inter_params),)\n        assert (\n            len(param_dict.keys() - union_params) == 0\n        ), \"parameters %s were not separated into either decay/no_decay set!\" % (str(param_dict.keys() - union_params),)\n\n        # create the pytorch optimizer groups\n        optim_groups = [\n            {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": weight_decay},\n            {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n        ]\n        return optim_groups\n"
  },
  {
    "path": "networks/unet_improved.py",
    "content": "# Source: https://github.com/openai/improved-diffusion\n#\n# MIT License\n#\n# Copyright (c) 2021 OpenAI\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n#\n# Modifications:\n# - Added data_adapters to UNetModel to preprocess the inputs and postprocess the outputs\n# - Added the `skip` option to concat the input and output of the network before the final projection\n# - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps`\n\nfrom abc import abstractmethod\n\nimport math\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom utils_model import sandwich\n\nfrom torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors\n\n\"\"\"\nHelpers to train with 16-bit precision.\n\"\"\"\n\n\ndef convert_module_to_f16(module):\n    \"\"\"\n    Convert primitive modules to float16.\n    \"\"\"\n    if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n        module.weight.data = module.weight.data.half()\n        module.bias.data = module.bias.data.half()\n\n\ndef convert_module_to_f32(module):\n    \"\"\"\n    Convert primitive modules to float32, undoing convert_module_to_f16().\n    \"\"\"\n    if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n        module.weight.data = module.weight.data.float()\n        module.bias.data = module.bias.data.float()\n\n\ndef make_master_params(model_params):\n    \"\"\"\n    Copy model parameters into a (differently-shaped) list of full-precision\n    parameters.\n    \"\"\"\n    master_params = _flatten_dense_tensors([param.detach().float() for param in model_params])\n    master_params = nn.Parameter(master_params)\n    master_params.requires_grad = True\n    return [master_params]\n\n\ndef model_grads_to_master_grads(model_params, master_params):\n    \"\"\"\n    Copy the gradients from the model parameters into the master parameters\n    from make_master_params().\n    \"\"\"\n    master_params[0].grad = _flatten_dense_tensors([param.grad.data.detach().float() for param in model_params])\n\n\ndef master_params_to_model_params(model_params, master_params):\n    \"\"\"\n    Copy the master parameter data back into the model parameters.\n    \"\"\"\n    # Without copying to a list, if a generator is passed, this will\n    # silently not copy any parameters.\n    model_params = list(model_params)\n\n    for param, master_param in zip(model_params, unflatten_master_params(model_params, master_params)):\n        param.detach().copy_(master_param)\n\n\ndef unflatten_master_params(model_params, master_params):\n    \"\"\"\n    Unflatten the master parameters to look like model_params.\n    \"\"\"\n    return _unflatten_dense_tensors(master_params[0].detach(), model_params)\n\n\ndef zero_grad(model_params):\n    for param in model_params:\n        # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group\n        if param.grad is not None:\n            param.grad.detach_()\n            param.grad.zero_()\n\n\n# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.\nclass SiLU(nn.Module):\n    def forward(self, x):\n        return x * th.sigmoid(x)\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def forward(self, x):\n        return super().forward(x.float()).type(x.dtype)\n\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef update_ema(target_params, source_params, rate=0.99):\n    \"\"\"\n    Update target parameters to be closer to those of source parameters using\n    an exponential moving average.\n\n    :param target_params: the target parameter sequence.\n    :param source_params: the source parameter sequence.\n    :param rate: the EMA rate (closer to 1 means slower).\n    \"\"\"\n    for targ, src in zip(target_params, source_params):\n        targ.detach().mul_(rate).add_(src, alpha=1 - rate)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef normalization(channels):\n    \"\"\"\n    Make a standard normalization layer.\n\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(32, channels)\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    half = dim // 2\n    freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(\n        device=timesteps.device\n    )\n    args = timesteps[:, None].float() * freqs[None]\n    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)\n    if dim % 2:\n        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)\n    return embedding\n\n\ndef checkpoint(func, inputs, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if flag:\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n        with th.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]\n        with th.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = th.autograd.grad(\n            output_tensors,\n            ctx.input_tensors + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (None, None) + input_grads\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb):\n        for layer in self:\n            if isinstance(layer, TimestepBlock):\n                x = layer(x, emb)\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2):\n        super().__init__()\n        self.channels = channels\n        self.use_conv = use_conv\n        self.dims = dims\n        if use_conv:\n            self.conv = conv_nd(dims, channels, channels, 3, padding=1)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        if self.dims == 3:\n            x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode=\"nearest\")\n        else:\n            x = F.interpolate(x, scale_factor=2, mode=\"nearest\")\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2):\n        super().__init__()\n        self.channels = channels\n        self.use_conv = use_conv\n        self.dims = dims\n        stride = 2 if dims != 3 else (1, 2, 2)\n        if use_conv:\n            self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1)\n        else:\n            self.op = avg_pool_nd(stride)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        return self.op(x)\n\n\nclass ResBlock(TimestepBlock):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: if specified, the number of out channels.\n    :param use_conv: if True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param use_checkpoint: if True, use gradient checkpointing on this module.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        emb_channels,\n        dropout,\n        out_channels=None,\n        use_conv=False,\n        use_scale_shift_norm=False,\n        dims=2,\n        use_checkpoint=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_checkpoint = use_checkpoint\n        self.use_scale_shift_norm = use_scale_shift_norm\n\n        self.in_layers = nn.Sequential(\n            normalization(channels),\n            SiLU(),\n            conv_nd(dims, channels, self.out_channels, 3, padding=1),\n        )\n        self.emb_layers = nn.Sequential(\n            SiLU(),\n            linear(\n                emb_channels,\n                2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n            ),\n        )\n        self.out_layers = nn.Sequential(\n            normalization(self.out_channels),\n            SiLU(),\n            nn.Dropout(p=dropout),\n            zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)\n        else:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)\n\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the block to a Tensor, conditioned on a timestep embedding.\n\n        :param x: an [N x C x ...] Tensor of features.\n        :param emb: an [N x emb_channels] Tensor of timestep embeddings.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)\n\n    def _forward(self, x, emb):\n        h = self.in_layers(x)\n        emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = th.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            h = h + emb_out\n            h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(self, channels, num_heads=1, use_checkpoint=False):\n        super().__init__()\n        self.channels = channels\n        self.num_heads = num_heads\n        self.use_checkpoint = use_checkpoint\n\n        self.norm = normalization(channels)\n        self.qkv = conv_nd(1, channels, channels * 3, 1)\n        self.attention = QKVAttention()\n        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n\n    def forward(self, x):\n        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)\n\n    def _forward(self, x):\n        b, c, *spatial = x.shape\n        x = x.reshape(b, c, -1)\n        qkv = self.qkv(self.norm(x))\n        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])\n        h = self.attention(qkv)\n        h = h.reshape(b, -1, h.shape[-1])\n        h = self.proj_out(h)\n        return (x + h).reshape(b, c, *spatial)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention.\n    \"\"\"\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n\n        :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x C x T] tensor after attention.\n        \"\"\"\n        ch = qkv.shape[1] // 3\n        q, k, v = th.split(qkv, ch, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\"bct,bcs->bts\", q * scale, k * scale)  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        return th.einsum(\"bts,bcs->bct\", weight, v)\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        \"\"\"\n        A counter for the `thop` package to count the operations in an\n        attention operation.\n\n        Meant to be used like:\n\n            macs, params = thop.profile(\n                model,\n                inputs=(inputs, timestamps),\n                custom_ops={QKVAttention: QKVAttention.count_flops},\n            )\n\n        \"\"\"\n        b, c, *spatial = y[0].shape\n        num_spatial = int(np.prod(spatial))\n        # We perform two matmuls with the same number of ops.\n        # The first computes the weight matrix, the second computes\n        # the combination of the value vectors.\n        matmul_ops = 2 * b * (num_spatial**2) * c\n        model.total_ops += th.DoubleTensor([matmul_ops])\n\n\nclass UNetModel(nn.Module):\n    \"\"\"\n    The full UNet model with attention and timestep embedding.\n\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, if this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: if True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param num_classes: if specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param use_checkpoint: use gradient checkpointing to reduce memory usage.\n    :param num_heads: the number of attention heads in each attention layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_adapters,\n        image_size=32,\n        in_channels=3,\n        model_channels=128,\n        out_channels=128,\n        num_res_blocks=3,\n        attention_resolutions=[8, 16],\n        dropout=0,\n        channel_mult=(1, 2, 2, 2),\n        conv_resample=True,\n        dims=2,\n        skip=True,\n        num_classes=None,\n        use_checkpoint=False,\n        num_heads=4,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        project_input=False,\n    ):\n        super().__init__()\n        self.input_adapter = data_adapters[\"input_adapter\"]\n        self.output_adapter = data_adapters[\"output_adapter\"]\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        self.image_size = image_size\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n        self.num_heads = num_heads\n        self.num_heads_upsample = num_heads_upsample\n        self.skip = skip\n        self.project_input = project_input\n        if project_input:\n            self.input_projection = nn.Linear(self.in_channels, self.model_channels)\n            in_channels = self.model_channels\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        if self.num_classes is not None:\n            self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n\n        self.input_blocks = nn.ModuleList(\n            [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]\n        )\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    layers.append(AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads))\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                self.input_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)))\n                input_block_chans.append(ch)\n                ds *= 2\n\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(num_res_blocks + 1):\n                layers = [\n                    ResBlock(\n                        ch + input_block_chans.pop(),\n                        time_embed_dim,\n                        dropout,\n                        out_channels=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads_upsample,\n                        )\n                    )\n                if level and i == num_res_blocks:\n                    layers.append(Upsample(ch, conv_resample, dims=dims))\n                    ds //= 2\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            SiLU(),\n            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),\n        )\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n        self.output_blocks.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n        self.output_blocks.apply(convert_module_to_f32)\n\n    @property\n    def inner_dtype(self):\n        \"\"\"\n        Get the dtype used by the torso of the model.\n        \"\"\"\n        return next(self.input_blocks.parameters()).dtype\n\n    def forward(\n        self,\n        data: th.Tensor,\n        t: th.Tensor,\n    ) -> th.Tensor:\n        \"\"\"\n        Apply the model to an input batch.\n\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        y = None\n        flat_x = self.input_adapter(data, t)\n        x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.in_channels)\n        if self.project_input:\n            x = self.input_projection(x)\n        x_perm = x.permute(0, 3, 1, 2).contiguous()\n        timesteps = t.flatten(start_dim=1)[:, 0] * 4000\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional\"\n\n        hs = []\n        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))\n\n        if self.num_classes is not None:\n            assert y.shape == (x.shape[0],)\n            emb = emb + self.label_emb(y)\n\n        h = x_perm.type(self.inner_dtype)\n        for module in self.input_blocks:\n            h = module(h, emb)\n            hs.append(h)\n        h = self.middle_block(h, emb)\n        for module in self.output_blocks:\n            cat_in = th.cat([h, hs.pop()], dim=1)\n            h = module(cat_in, emb)\n        h = h.type(x.dtype)\n        out = sandwich(self.out(h).permute(0, 2, 3, 1).contiguous())\n        if self.skip:\n            out = th.cat([sandwich(x), out], -1)\n        out = self.output_adapter(out)\n        return out\n\n    def get_feature_vectors(self, x, timesteps, y=None):\n        \"\"\"\n        Apply the model and return all of the intermediate tensors.\n\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :return: a dict with the following keys:\n                 - 'down': a list of hidden state tensors from downsampling.\n                 - 'middle': the tensor of the output of the lowest-resolution\n                             block in the model.\n                 - 'up': a list of hidden state tensors from upsampling.\n        \"\"\"\n        hs = []\n        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))\n        if self.num_classes is not None:\n            assert y.shape == (x.shape[0],)\n            emb = emb + self.label_emb(y)\n        result = dict(down=[], up=[])\n        h = x.type(self.inner_dtype)\n        for module in self.input_blocks:\n            h = module(h, emb)\n            hs.append(h)\n            result[\"down\"].append(h.type(x.dtype))\n        h = self.middle_block(h, emb)\n        result[\"middle\"] = h.type(x.dtype)\n        for module in self.output_blocks:\n            cat_in = th.cat([h, hs.pop()], dim=1)\n            h = module(cat_in, emb)\n            result[\"up\"].append(h.type(x.dtype))\n        return result\n"
  },
  {
    "path": "networks/unet_vdm.py",
    "content": "# Source: https://github.com/addtt/variational-diffusion-models\n#\n# MIT License\n#\n# Copyright (c) 2022 Andrea Dittadi\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n#\n# Modifications:\n# - Added data_adapters to UNetVDM to preprocess the inputs and postprocess the outputs\n# - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps`\n# - Added 1/1000 to t before computing timesteps embeddings so t isn't 0\n# - Added concatenation of input and output of the network before the final projection\n\nimport numpy as np\nimport torch\nfrom torch import einsum, nn, pi, softmax\n\nfrom utils_model import sandwich\n\n\n@torch.no_grad()\ndef zero_init(module: nn.Module) -> nn.Module:\n    \"\"\"Sets to zero all the parameters of a module, and returns the module.\"\"\"\n    for p in module.parameters():\n        nn.init.zeros_(p.data)\n    return module\n\n\nclass UNetVDM(nn.Module):\n    def __init__(\n        self,\n        data_adapters,\n        embedding_dim: int = 128,\n        n_blocks: int = 32,\n        n_attention_heads: int = 1,\n        dropout_prob: float = 0.1,\n        norm_groups: int = 32,\n        input_channels: int = 3,\n        use_fourier_features: bool = True,\n        attention_everywhere: bool = False,\n        image_size: int = 32,\n    ):\n        super().__init__()\n        self.input_adapter = data_adapters[\"input_adapter\"]\n        self.output_adapter = data_adapters[\"output_adapter\"]\n        attention_params = dict(\n            n_heads=n_attention_heads,\n            n_channels=embedding_dim,\n            norm_groups=norm_groups,\n        )\n        resnet_params = dict(\n            ch_in=embedding_dim,\n            ch_out=embedding_dim,\n            condition_dim=4 * embedding_dim,\n            dropout_prob=dropout_prob,\n            norm_groups=norm_groups,\n        )\n        if use_fourier_features:\n            self.fourier_features = FourierFeatures()\n        self.embed_conditioning = nn.Sequential(\n            nn.Linear(embedding_dim, embedding_dim * 4),\n            nn.SiLU(),\n            nn.Linear(embedding_dim * 4, embedding_dim * 4),\n            nn.SiLU(),\n        )\n        total_input_ch = input_channels\n        if use_fourier_features:\n            total_input_ch *= 1 + self.fourier_features.num_features\n        self.conv_in = nn.Conv2d(total_input_ch, embedding_dim, 3, padding=1)\n\n        # Down path: n_blocks blocks with a resnet block and maybe attention.\n        self.down_blocks = nn.ModuleList(\n            UpDownBlock(\n                resnet_block=ResnetBlock(**resnet_params),\n                attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,\n            )\n            for _ in range(n_blocks)\n        )\n\n        self.mid_resnet_block_1 = ResnetBlock(**resnet_params)\n        self.mid_attn_block = AttentionBlock(**attention_params)\n        self.mid_resnet_block_2 = ResnetBlock(**resnet_params)\n\n        # Up path: n_blocks+1 blocks with a resnet block and maybe attention.\n        resnet_params[\"ch_in\"] *= 2  # double input channels due to skip connections\n        self.up_blocks = nn.ModuleList(\n            UpDownBlock(\n                resnet_block=ResnetBlock(**resnet_params),\n                attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,\n            )\n            for _ in range(n_blocks + 1)\n        )\n\n        self.conv_out = nn.Sequential(\n            nn.GroupNorm(num_groups=norm_groups, num_channels=embedding_dim),\n            nn.SiLU(),\n            zero_init(nn.Conv2d(embedding_dim, embedding_dim, 3, padding=1)),\n        )\n        self.embedding_dim = embedding_dim\n        self.input_channels = input_channels\n        self.image_size = image_size\n        self.use_fourier_features = use_fourier_features\n\n    def forward(\n        self,\n        data: torch.Tensor,\n        t: torch.Tensor,\n    ) -> torch.Tensor:\n        flat_x = self.input_adapter(data, t)\n        x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.input_channels)\n        x_perm = x.permute(0, 3, 1, 2).contiguous()\n        t = t.float().flatten(start_dim=1)[:, 0]\n        t_embedding = get_timestep_embedding(t + 0.001, self.embedding_dim)\n        # We will condition on time embedding.\n        cond = self.embed_conditioning(t_embedding)\n\n        h = self.maybe_concat_fourier(x_perm)\n        h = self.conv_in(h)  # (B, embedding_dim, H, W)\n        hs = []\n        for down_block in self.down_blocks:  # n_blocks times\n            hs.append(h)\n            h = down_block(h, cond)\n        hs.append(h)\n        h = self.mid_resnet_block_1(h, cond)\n        h = self.mid_attn_block(h)\n        h = self.mid_resnet_block_2(h, cond)\n        for up_block in self.up_blocks:  # n_blocks+1 times\n            h = torch.cat([h, hs.pop()], dim=1)\n            h = up_block(h, cond)\n        out = sandwich(self.conv_out(h).permute(0, 2, 3, 1).contiguous())\n        out = torch.cat([sandwich(x), out], -1)\n        out = self.output_adapter(out)\n        return out\n\n    def maybe_concat_fourier(self, z):\n        if self.use_fourier_features:\n            return torch.cat([z, self.fourier_features(z)], dim=1)\n        return z\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(\n        self,\n        ch_in,\n        ch_out=None,\n        condition_dim=None,\n        dropout_prob=0.0,\n        norm_groups=32,\n    ):\n        super().__init__()\n        ch_out = ch_in if ch_out is None else ch_out\n        self.ch_out = ch_out\n        self.condition_dim = condition_dim\n        self.net1 = nn.Sequential(\n            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in),\n            nn.SiLU(),\n            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),\n        )\n        if condition_dim is not None:\n            self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False))\n        self.net2 = nn.Sequential(\n            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out),\n            nn.SiLU(),\n            nn.Dropout(dropout_prob),\n            zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)),\n        )\n        if ch_in != ch_out:\n            self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)\n\n    def forward(self, x, condition):\n        h = self.net1(x)\n        if condition is not None:\n            assert condition.shape == (x.shape[0], self.condition_dim)\n            condition = self.cond_proj(condition)\n            condition = condition[:, :, None, None]\n            h = h + condition\n        h = self.net2(h)\n        if x.shape[1] != self.ch_out:\n            x = self.skip_conv(x)\n        assert x.shape == h.shape\n        return x + h\n\n\ndef get_timestep_embedding(\n    timesteps,\n    embedding_dim: int,\n    dtype=torch.float32,\n    max_timescale=10_000,\n    min_timescale=1,\n):\n    # Adapted from tensor2tensor and VDM codebase.\n    assert timesteps.ndim == 1\n    assert embedding_dim % 2 == 0\n    timesteps *= 1000.0  # In DDPM the time step is in [0, 1000], here [0, 1]\n    num_timescales = embedding_dim // 2\n    inv_timescales = torch.logspace(  # or exp(-linspace(log(min), log(max), n))\n        -np.log10(min_timescale),\n        -np.log10(max_timescale),\n        num_timescales,\n        device=timesteps.device,\n    )\n    emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :]  # (T, D/2)\n    return torch.cat([emb.sin(), emb.cos()], dim=1)  # (T, D)\n\n\nclass FourierFeatures(nn.Module):\n    def __init__(self, first=5.0, last=6.0, step=1.0):\n        super().__init__()\n        self.freqs_exponent = torch.arange(first, last + 1e-8, step)\n\n    @property\n    def num_features(self):\n        return len(self.freqs_exponent) * 2\n\n    def forward(self, x):\n        assert len(x.shape) >= 2\n\n        # Compute (2pi * 2^n) for n in freqs.\n        freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device)  # (F, )\n        freqs = 2.0**freqs_exponent * 2 * pi  # (F, )\n        freqs = freqs.view(-1, *([1] * (x.dim() - 1)))  # (F, 1, 1, ...)\n\n        # Compute (2pi * 2^n * x) for n in freqs.\n        features = freqs * x.unsqueeze(1)  # (B, F, X1, X2, ...)\n        features = features.flatten(1, 2)  # (B, F * C, X1, X2, ...)\n\n        # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W).\n        return torch.cat([features.sin(), features.cos()], dim=1)\n\n\ndef attention_inner_heads(qkv, num_heads):\n    \"\"\"Computes attention with heads inside of qkv in the channel dimension.\n\n    Args:\n        qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:\n            H = number of heads,\n            C = number of channels per head.\n        num_heads: number of heads.\n\n    Returns:\n        Attention output of shape (B, H*C, T).\n    \"\"\"\n\n    bs, width, length = qkv.shape\n    ch = width // (3 * num_heads)\n\n    # Split into (q, k, v) of shape (B, H*C, T).\n    q, k, v = qkv.chunk(3, dim=1)\n\n    # Rescale q and k. This makes them contiguous in memory.\n    scale = ch ** (-1 / 4)  # scale with 4th root = scaling output by sqrt\n    q = q * scale\n    k = k * scale\n\n    # Reshape qkv to (B*H, C, T).\n    new_shape = (bs * num_heads, ch, length)\n    q = q.view(*new_shape)\n    k = k.view(*new_shape)\n    v = v.reshape(*new_shape)\n\n    # Compute attention.\n    weight = einsum(\"bct,bcs->bts\", q, k)  # (B*H, T, T)\n    weight = softmax(weight.float(), dim=-1).to(weight.dtype)  # (B*H, T, T)\n    out = einsum(\"bts,bcs->bct\", weight, v)  # (B*H, C, T)\n    return out.reshape(bs, num_heads * ch, length)  # (B, H*C, T)\n\n\nclass Attention(nn.Module):\n    \"\"\"Based on https://github.com/openai/guided-diffusion.\"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        assert qkv.dim() >= 3, qkv.dim()\n        assert qkv.shape[1] % (3 * self.n_heads) == 0\n        spatial_dims = qkv.shape[2:]\n        qkv = qkv.view(*qkv.shape[:2], -1)  # (B, 3*H*C, T)\n        out = attention_inner_heads(qkv, self.n_heads)  # (B, H*C, T)\n        return out.view(*out.shape[:2], *spatial_dims).contiguous()\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"Self-attention residual block.\"\"\"\n\n    def __init__(self, n_heads, n_channels, norm_groups):\n        super().__init__()\n        assert n_channels % n_heads == 0\n        self.layers = nn.Sequential(\n            nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),\n            nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1),  # (B, 3 * C, H, W)\n            Attention(n_heads),\n            zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),\n        )\n\n    def forward(self, x):\n        return self.layers(x) + x\n\n\nclass UpDownBlock(nn.Module):\n    def __init__(self, resnet_block, attention_block=None):\n        super().__init__()\n        self.resnet_block = resnet_block\n        self.attention_block = attention_block\n\n    def forward(self, x, cond):\n        x = self.resnet_block(x, cond)\n        if self.attention_block is not None:\n            x = self.attention_block(x)\n        return x\n"
  },
  {
    "path": "probability.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport functools\nfrom abc import abstractmethod\n\nfrom torch.distributions.normal import Normal\nfrom torch.distributions.categorical import Categorical as torch_Categorical\nfrom torch.distributions.bernoulli import Bernoulli as torch_Bernoulli\nfrom torch.distributions.mixture_same_family import MixtureSameFamily\nfrom torch.distributions.uniform import Uniform\n\nfrom math import log\n\nfrom utils_model import (\n    safe_exp,\n    safe_log,\n    idx_to_float,\n    float_to_idx,\n    quantize, sandwich,\n)\n\n\nclass CtsDistribution:\n    @abstractmethod\n    def log_prob(self, x):\n        pass\n\n    @abstractmethod\n    def sample(self):\n        pass\n\n\nclass DiscreteDistribution:\n    @property\n    @abstractmethod\n    def probs(self):\n        pass\n\n    @functools.cached_property\n    def log_probs(self):\n        return safe_log(self.probs)\n\n    @functools.cached_property\n    def mean(self):\n        pass\n\n    @functools.cached_property\n    def mode(self):\n        pass\n\n    @abstractmethod\n    def log_prob(self, x):\n        pass\n\n    @abstractmethod\n    def sample(self):\n        pass\n\n\nclass DiscretizedDistribution(DiscreteDistribution):\n    def __init__(self, num_bins, device):\n        self.num_bins = num_bins\n        self.bin_width = 2.0 / num_bins\n        self.half_bin_width = self.bin_width / 2.0\n        self.device = device\n\n    @functools.cached_property\n    def class_centres(self):\n        return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device)\n\n    @functools.cached_property\n    def class_boundaries(self):\n        return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device)\n\n    @functools.cached_property\n    def mean(self):\n        return (self.probs * self.class_centres).sum(-1)\n\n    @functools.cached_property\n    def mode(self):\n        mode_idx = self.probs.argmax(-1).flatten()\n        return self.class_centres[mode_idx].reshape(self.probs.shape[:-1])\n\n\nclass DiscretizedCtsDistribution(DiscretizedDistribution):\n    def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5):\n        super().__init__(num_bins, device)\n        self.cts_dist = cts_dist\n        self.log_bin_width = log(self.bin_width)\n        self.batch_dims = batch_dims\n        self.clip = clip\n        self.min_prob = min_prob\n\n    @functools.cached_property\n    def probs(self):\n        bdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims)))\n        bdry_slice = bdry_cdfs[:1]\n        if self.clip:\n            cdf_min = torch.zeros_like(bdry_slice)\n            cdf_max = torch.ones_like(bdry_slice)\n            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)\n            return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1)\n        else:\n            cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1)\n            cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice))\n            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)\n            cdf_range = cdf_max - cdf_min\n            cdf_mask = cdf_range < self.min_prob\n            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)\n            probs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range\n            probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs)\n            return probs.moveaxis(0, -1)\n\n    def prob(self, x):\n        class_idx = float_to_idx(x, self.num_bins)\n        centre = idx_to_float(class_idx, self.num_bins)\n        cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width)\n        cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width)\n        if self.clip:\n            cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo)\n            cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi)\n            return cdf_hi - cdf_lo\n        else:\n            cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1)\n            cdf_max = self.cts_dist.cdf(torch.ones_like(centre))\n            cdf_range = cdf_max - cdf_min\n            cdf_mask = cdf_range < self.min_prob\n            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)\n            prob = (cdf_hi - cdf_lo) / cdf_range\n            return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob)\n\n    def log_prob(self, x):\n        prob = self.prob(x)\n        return torch.where(\n            prob < self.min_prob,\n            self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width,\n            safe_log(prob),\n        )\n\n    def sample(self, sample_shape=torch.Size([])):\n        if self.clip:\n            return quantize(self.cts_dist.sample(sample_shape), self.num_bins)\n        else:\n            assert hasattr(self.cts_dist, \"icdf\")\n            cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1)\n            cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min))\n            u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape)\n            cts_samp = self.cts_dist.icdf(u)\n            return quantize(cts_samp, self.num_bins)\n\n\nclass GMM(MixtureSameFamily):\n    def __init__(self, mix_wt_logits, means, std_devs):\n        mix_wts = torch_Categorical(logits=mix_wt_logits, validate_args=False)\n        components = Normal(means, std_devs, validate_args=False)\n        super().__init__(mix_wts, components, validate_args=False)\n\n\nclass DiscretizedGMM(DiscretizedCtsDistribution):\n    def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):\n        assert params.size(-1) % 3 == 0\n        if min_std_dev < 0:\n            min_std_dev = 1.0 / (num_bins * 5)\n        mix_wt_logits, means, std_devs = params.chunk(3, -1)\n        if log_dev:\n            std_devs = safe_exp(std_devs)\n        std_devs = std_devs.clamp(min=min_std_dev, max=max_std_dev)\n        super().__init__(\n            cts_dist=GMM(mix_wt_logits, means, std_devs),\n            num_bins=num_bins,\n            device=params.device,\n            batch_dims=params.ndim - 1,\n            clip=clip,\n            min_prob=min_prob,\n        )\n\n\nclass DiscretizedNormal(DiscretizedCtsDistribution):\n    def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):\n        assert params.size(-1) == 2\n        if min_std_dev < 0:\n            min_std_dev = 1.0 / (num_bins * 5)\n        mean, std_dev = params.split(1, -1)[:2]\n        if log_dev:\n            std_dev = safe_exp(std_dev)\n        std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev)\n        super().__init__(\n            cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False),\n            num_bins=num_bins,\n            device=params.device,\n            batch_dims=params.ndim - 1,\n            clip=clip,\n            min_prob=min_prob,\n        )\n\n\nclass Bernoulli(DiscreteDistribution):\n    def __init__(self, logits):\n        self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)\n\n    @functools.cached_property\n    def probs(self):\n        p = self.bernoulli.probs.unsqueeze(-1)\n        return torch.cat([1 - p, p], -1)\n\n    @functools.cached_property\n    def mode(self):\n        return self.bernoulli.mode\n\n    def log_prob(self, x):\n        return self.bernoulli.log_prob(x.float())\n\n    def sample(self, sample_shape=torch.Size([])):\n        return self.bernoulli.sample(sample_shape)\n\n\nclass DiscretizedBernoulli(DiscretizedDistribution):\n    def __init__(self, logits):\n        super().__init__(2, logits.device)\n        self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)\n\n    @functools.cached_property\n    def probs(self):\n        p = self.bernoulli.probs.unsqueeze(-1)\n        return torch.cat([1 - p, p], -1)\n\n    @functools.cached_property\n    def mode(self):\n        return idx_to_float(self.bernoulli.mode, 2)\n\n    def log_prob(self, x):\n        return self.bernoulli.log_prob(float_to_idx(x, 2).float())\n\n    def sample(self, sample_shape=torch.Size([])):\n        return idx_to_float(self.bernoulli.sample(sample_shape), 2)\n\n\nclass DeltaDistribution(CtsDistribution):\n    def __init__(self, mean, clip_range=1.0):\n        if clip_range > 0:\n            mean = mean.clip(min=-clip_range, max=clip_range)\n        self.mean = mean\n\n    @functools.cached_property\n    def mode(self):\n        return self.mean\n\n    @functools.cached_property\n    def mean(self):\n        return self.mean\n\n    def sample(self, sample_shape=torch.Size([])):\n        return self.mean\n\n\nclass Categorical(DiscreteDistribution):\n    def __init__(self, logits):\n        self.categorical = torch_Categorical(logits=logits, validate_args=False)\n        self.n_classes = logits.size(-1)\n\n    @functools.cached_property\n    def probs(self):\n        return self.categorical.probs\n\n    @functools.cached_property\n    def mode(self):\n        return self.categorical.mode\n\n    def log_prob(self, x):\n        return self.categorical.log_prob(x)\n\n    def sample(self, sample_shape=torch.Size([])):\n        return self.categorical.sample(sample_shape)\n\n\nclass DiscretizedCategorical(DiscretizedDistribution):\n    def __init__(self, logits=None, probs=None):\n        assert (logits is not None) or (probs is not None)\n        if logits is not None:\n            super().__init__(logits.size(-1), logits.device)\n            self.categorical = torch_Categorical(logits=logits, validate_args=False)\n        else:\n            super().__init__(probs.size(-1), probs.device)\n            self.categorical = torch_Categorical(probs=probs, validate_args=False)\n\n    @functools.cached_property\n    def probs(self):\n        return self.categorical.probs\n\n    @functools.cached_property\n    def mode(self):\n        return idx_to_float(self.categorical.mode, self.num_bins)\n\n    def log_prob(self, x):\n        return self.categorical.log_prob(float_to_idx(x, self.num_bins))\n\n    def sample(self, sample_shape=torch.Size([])):\n        return idx_to_float(self.categorical.sample(sample_shape), self.num_bins)\n\n\nclass CtsDistributionFactory:\n    @abstractmethod\n    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution:\n        \"\"\"Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory.\"\"\"\n        pass\n\n\nclass GMMFactory(CtsDistributionFactory):\n    def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True):\n        self.min_std_dev = min_std_dev\n        self.max_std_dev = max_std_dev\n        self.log_dev = log_dev\n\n    def get_dist(self, params, input_params=None, t=None):\n        mix_wt_logits, means, std_devs = params.chunk(3, -1)\n        if self.log_dev:\n            std_devs = safe_exp(std_devs)\n        std_devs = std_devs.clamp(min=self.min_std_dev, max=self.max_std_dev)\n        return GMM(mix_wt_logits, means, std_devs)\n\n\nclass NormalFactory(CtsDistributionFactory):\n    def __init__(self, min_std_dev=1e-3, max_std_dev=10):\n        self.min_std_dev = min_std_dev\n        self.max_std_dev = max_std_dev\n\n    def get_dist(self, params, input_params=None, t=None):\n        mean, log_std_dev = params.split(1, -1)[:2]\n        std_dev = safe_exp(log_std_dev).clamp(min=self.min_std_dev, max=self.max_std_dev)\n        return Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False)\n\n\nclass DeltaFactory(CtsDistributionFactory):\n    def __init__(self, clip_range=1.0):\n        self.clip_range = clip_range\n\n    def get_dist(self, params, input_params=None, t=None):\n        return DeltaDistribution(params.squeeze(-1), self.clip_range)\n\n\nclass DiscreteDistributionFactory:\n    @abstractmethod\n    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution:\n        \"\"\"Note: input_params and t are only required by PredDistToDataDistFactory.\"\"\"\n        pass\n\n\nclass BernoulliFactory(DiscreteDistributionFactory):\n    def get_dist(self, params, input_params=None, t=None):\n        return Bernoulli(logits=params.squeeze(-1))\n\n\nclass CategoricalFactory(DiscreteDistributionFactory):\n    def get_dist(self, params, input_params=None, t=None):\n        return Categorical(logits=params)\n\n\nclass DiscretizedBernoulliFactory(DiscreteDistributionFactory):\n    def get_dist(self, params, input_params=None, t=None):\n        return DiscretizedBernoulli(logits=params.squeeze(-1))\n\n\nclass DiscretizedCategoricalFactory(DiscreteDistributionFactory):\n    def get_dist(self, params, input_params=None, t=None):\n        return DiscretizedCategorical(logits=params)\n\n\nclass DiscretizedGMMFactory(DiscreteDistributionFactory):\n    def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):\n        self.num_bins = num_bins\n        self.clip = clip\n        self.min_std_dev = min_std_dev\n        self.max_std_dev = max_std_dev\n        self.min_prob = min_prob\n        self.log_dev = log_dev\n\n    def get_dist(self, params, input_params=None, t=None):\n        return DiscretizedGMM(\n            params,\n            num_bins=self.num_bins,\n            clip=self.clip,\n            min_std_dev=self.min_std_dev,\n            max_std_dev=self.max_std_dev,\n            min_prob=self.min_prob,\n            log_dev=self.log_dev,\n        )\n\n\nclass DiscretizedNormalFactory(DiscreteDistributionFactory):\n    def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):\n        self.num_bins = num_bins\n        self.clip = clip\n        self.min_std_dev = min_std_dev\n        self.max_std_dev = max_std_dev\n        self.min_prob = min_prob\n        self.log_dev = log_dev\n\n    def get_dist(self, params, input_params=None, t=None):\n        return DiscretizedNormal(\n            params,\n            num_bins=self.num_bins,\n            clip=self.clip,\n            min_std_dev=self.min_std_dev,\n            max_std_dev=self.max_std_dev,\n            min_prob=self.min_prob,\n            log_dev=self.log_dev,\n        )\n\n\ndef noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tensor, input_mean: torch.Tensor, t: torch.Tensor, min_variance: float, min_t=1e-6):\n    \"\"\"Convert output parameters that predict the noise added to data, to parameters that predict the data.\"\"\"\n    data_shape = list(noise_pred_params.shape)[:-1]\n    noise_pred_params = sandwich(noise_pred_params)\n    input_mean = input_mean.flatten(start_dim=1)\n    if torch.is_tensor(t):\n        t = t.flatten(start_dim=1)\n    else:\n        t = (input_mean * 0) + t\n    alpha_mask = (t < min_t).unsqueeze(-1)\n    posterior_var = torch.pow(min_variance, t.clamp(min=min_t))\n    gamma = 1 - posterior_var\n    A = (input_mean / gamma).unsqueeze(-1)\n    B = (posterior_var / gamma).sqrt().unsqueeze(-1)\n    data_pred_params = []\n    if noise_pred_params.size(-1) == 1:\n        noise_pred_mean = noise_pred_params\n    elif noise_pred_params.size(-1) == 2:\n        noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1)\n    else:\n        assert noise_pred_params.size(-1) % 3 == 0\n        mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1)\n        data_pred_params.append(mix_wt_logits)\n    data_pred_mean = A - (B * noise_pred_mean)\n    data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean)\n    data_pred_params.append(data_pred_mean)\n    if noise_pred_params.size(-1) >= 2:\n        noise_pred_dev = safe_exp(noise_pred_log_dev)\n        data_pred_dev = B * noise_pred_dev\n        data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev)\n        data_pred_params.append(data_pred_dev)\n    data_pred_params = torch.cat(data_pred_params, -1)\n    data_pred_params = data_pred_params.reshape(data_shape + [-1])\n    return data_pred_params\n\n\nclass PredDistToDataDistFactory(DiscreteDistributionFactory):\n    def __init__(self, data_dist_factory, min_variance, min_t=1e-6):\n        self.data_dist_factory = data_dist_factory\n        self.data_dist_factory.log_dev = False\n        self.min_variance = min_variance\n        self.min_t = min_t\n\n    def get_dist(self, params, input_params, t):\n        data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t)\n        return self.data_dist_factory.get_dist(data_pred_params)\n"
  },
  {
    "path": "sample.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom omegaconf import OmegaConf, DictConfig\n\nfrom utils_train import seed_everything, make_config, make_bfn\n\ntorch.set_float32_matmul_precision(\"high\")\ntorch.backends.cudnn.benchmark = True\n\n\ndef main(cfg: DictConfig) -> torch.Tensor:\n    \"\"\"\n    Config entries:\n        seed (int): Optional\n        config_file (str): Name of config file containing model and data config for a saved checkpoint\n        load_model (str): Path to a saved checkpoint to be tested\n        sample_shape (list): Shape of sample batch, e.g.:\n            (3, 256) for sampling 3 sequences of length 256 from the text8 model.\n            (2, 32, 32, 3) for sampling 2 images from the CIFAR10 model.\n            (4, 28, 28, 1) for sampling 4 images from the MNIST model.\n        n_steps (int): Number of sampling steps (positive integer).\n        save_file (str): File path to save the generated sample tensor. Skip saving if None.\n    \"\"\"\n    seed_everything(cfg.seed)\n    print(f\"Seeded everything with seed {cfg.seed}\")\n\n    # Get model config from the training config file\n    train_cfg = make_config(cfg.config_file)\n    bfn = make_bfn(train_cfg.model)\n\n    bfn.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location=\"cpu\"))\n    if torch.cuda.is_available():\n        bfn.to(\"cuda\")\n    samples = bfn.sample(cfg.samples_shape, cfg.n_steps)\n\n    if cfg.save_file is not None:\n        torch.save(samples.to(\"cpu\"), cfg.save_file)\n\n    return samples\n\n\nif __name__ == \"__main__\":\n    main(OmegaConf.from_cli())\n"
  },
  {
    "path": "test.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Tuple\n\nimport torch\nfrom omegaconf import OmegaConf, DictConfig\nfrom rich import print\nfrom torch import nn\nfrom torch.utils.data import DataLoader\n\nfrom data import make_datasets\nfrom model import BFN\nfrom utils_train import seed_everything, make_config, make_bfn, worker_init_function, make_progress_bar\n\ntorch.set_float32_matmul_precision(\"high\")\ntorch.backends.cudnn.benchmark = True\n\n\ndef setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]:\n    test_ds = make_datasets(cfg.data)[-1]\n    test_dl = DataLoader(\n        dataset=test_ds,\n        worker_init_fn=worker_init_function,\n        batch_size=100,\n        shuffle=False,\n        num_workers=8,\n        pin_memory=True,\n    )\n    model = make_bfn(cfg.model)\n    return model, test_dl\n\n\n@torch.inference_mode()\ndef test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: int) -> tuple[float, float, float, float]:\n    if torch.cuda.is_available():\n        model.to(\"cuda\")\n    model.eval()\n    losses, recon_losses = [], []\n    pbar = make_progress_bar(True, \"[red]loss: {task.fields[loss]:.4f} repeat: {task.fields[r]}\")\n    with pbar:\n        task_id = pbar.add_task(\"Test\", visible=True, total=n_repeats * len(dataloader), loss=math.nan, r=0)\n        for r in range(n_repeats):\n            _losses, _recon_losses = [], []\n            for eval_batch in dataloader:\n                eval_batch = eval_batch.to(\"cuda\") if torch.cuda.is_available() else eval_batch\n                loss = model(eval_batch, n_steps=n_steps).item()\n                recon_loss = model.compute_reconstruction_loss(eval_batch).item()\n                _losses.append(loss)\n                _recon_losses.append(recon_loss)\n                pbar.update(task_id, advance=1, loss=torch.tensor(_losses).mean() + torch.tensor(_recon_losses).mean(), r=r+1)\n            losses.append(torch.tensor(_losses).mean())\n            recon_losses.append(torch.tensor(_recon_losses).mean())\n    losses = torch.stack(losses)\n    loss_mean, loss_err = losses.mean(), losses.std(correction=0).item() / math.sqrt(len(losses))\n    recon_losses = torch.stack(recon_losses)\n    recon_mean, recon_err = recon_losses.mean(), recon_losses.std(correction=0).item() / math.sqrt(len(recon_losses))\n    return loss_mean, loss_err, recon_mean, recon_err\n\n\ndef main(cfg: DictConfig) -> tuple[float, float, float, float]:\n    \"\"\"\n    Config entries:\n        seed (int): Optional\n        config_file (str): Name of config file containing model and data config for a saved checkpoint\n        load_model (str): Path to a saved checkpoint to be tested\n        n_steps (int): Number of Bayesian flow steps. Set to None for continuous time Bayesian flow loss.\n        n_repeats (int): Number of times to iterate through the dataset.\n    \"\"\"\n    seed_everything(cfg.seed)\n    print(f\"Seeded everything with seed {cfg.seed}\")\n\n    # Get model and data config from the training config file\n    train_cfg = make_config(cfg.config_file)\n    model, dataloader = setup(train_cfg)\n\n    model.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location=\"cpu\"))\n    loss_mean, loss_err, recon_mean, recon_err = test(model, dataloader, cfg.n_steps, cfg.n_repeats)\n    print(f\"For {cfg.n_steps} steps with {cfg.n_repeats} repeats:\")\n    print(f\"Loss is {loss_mean:.6f} +- {loss_err:.6f}\")\n    print(f\"Reconstruction Loss is {recon_mean:.6f} +- {recon_err:.6f}\")\n    print(f\"Total loss mean = {loss_mean + recon_mean}\")\n    return loss_mean, loss_err, recon_mean, recon_err\n\n\nif __name__ == \"__main__\":\n    main(OmegaConf.from_cli())\n"
  },
  {
    "path": "train.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport copy\nimport logging\nimport math\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Optional, Tuple\n\nimport torch\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom omegaconf import OmegaConf\nfrom rich.logging import RichHandler\nfrom rich.progress import Progress\nfrom torch import nn, optim\nfrom torch.utils.data import DataLoader\n\nfrom model import BFN\nfrom utils_train import (\n    seed_everything, log_cfg,\n    checkpoint_training_state,\n    init_checkpointing,\n    log,\n    update_ema,\n    ddict,\n    make_infinite,\n    make_progress_bar, make_config, make_dataloaders, make_bfn,\n)\n\ntorch.set_float32_matmul_precision(\"high\")\ntorch.backends.cudnn.benchmark = True\n\nlogging.basicConfig(\n    level=logging.INFO,\n    format=\"%(message)s\",\n    datefmt=\"[%X]\",\n    handlers=[RichHandler(rich_tracebacks=True, show_time=False)],\n)\n\nlogger = get_logger(__name__)\n\n\ndef setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:\n    \"\"\"Create the model, dataloader and optimizer\"\"\"\n    dataloaders = make_dataloaders(cfg)\n    model = make_bfn(cfg.model)\n    if \"weight_decay\" in cfg.optimizer.keys() and hasattr(model.net, \"get_optim_groups\"):\n        params = model.net.get_optim_groups(cfg.optimizer.weight_decay)\n    else:\n        params = model.net.parameters()\n    # Instantiate the optimizer using the hyper-parameters in the config\n    optimizer = optim.AdamW(params=params, **cfg.optimizer)\n    return model, dataloaders, optimizer\n\n\n@torch.no_grad()\ndef validate(\n        cfg,\n        model: BFN,\n        ema_model: nn.Module,\n        val_dataloader: DataLoader,\n        step: int,\n        run: \"neptune.Run\",\n        pbar: Optional[Progress],\n        best_val_loss: float,\n        checkpoint_root_dir: Optional[Path],\n        accelerator: Accelerator,\n) -> float:\n    \"\"\"Evaluate model on validation data and save checkpoint if loss improves\"\"\"\n    dtype = {\"no\": torch.float32, \"fp16\": torch.float16, \"bf16\": torch.bfloat16}[accelerator.mixed_precision]\n    model_to_eval = ema_model if ema_model is not None else model\n    model_to_eval.eval()\n    pbar = pbar or Progress()\n    max_steps = cfg.max_val_batches if cfg.max_val_batches > 0 else len(val_dataloader)\n    val_id = pbar.add_task(\"Validating\", visible=True, total=cfg.val_repeats * max_steps, transient=True, loss=math.nan)\n\n    loss, count = 0.0, 0\n    for i in range(cfg.val_repeats):\n        for idx, eval_batch in enumerate(val_dataloader):\n            enabled = True if dtype in [torch.float16, torch.bfloat16] else False\n            with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):\n                loss += model_to_eval(eval_batch.to(accelerator.device)).item()\n                count += 1\n            pbar.update(val_id, advance=1, loss=loss / count)\n            if (idx + 1) >= max_steps:\n                break\n    loss /= count\n    pbar.remove_task(val_id)\n    log(run[\"metrics\"][\"val\"][\"loss\"], loss, step)\n\n    if checkpoint_root_dir is not None and (loss < best_val_loss or math.isinf(best_val_loss)):\n        logger.info(f\"loss improved: new value is {loss}\")\n        step_checkpoint_path = checkpoint_root_dir / \"best\"\n        run_id = \"BFN\" if isinstance(run, defaultdict) else run[\"sys\"][\"id\"].fetch()\n        checkpoint_training_state(step_checkpoint_path, accelerator, ema_model, step, run_id)\n        run[\"metrics/best/loss/metric\"] = loss\n        run[\"metrics/best/loss/step\"] = step\n\n    model.train()\n    return loss\n\n\ndef train(\n        cfg,\n        accelerator: Accelerator,\n        model: BFN,\n        ema_model: Optional[nn.Module],\n        dataloaders: dict,\n        optimizer: optim.Optimizer,\n        run: \"neptune.Run\",\n):\n    is_main = accelerator.is_main_process\n    pbar = make_progress_bar(is_main)\n    run_id = \"BFN\" if isinstance(run, defaultdict) else run[\"sys\"][\"id\"].fetch()\n    train_id = pbar.add_task(f\"Training {run_id}\", start=cfg.start_step, total=cfg.n_training_steps, loss=math.nan)\n    checkpoint_root_dir = init_checkpointing(cfg.checkpoint_dir, run_id) if is_main else None\n    best_val_loss = math.inf\n\n    train_iter = make_infinite(dataloaders[\"train\"])\n    model.train()\n    with pbar:\n        for step in range(cfg.start_step, cfg.n_training_steps + 1):\n            step_loss = 0.0\n            for _ in range(cfg.accumulate):\n                with accelerator.accumulate(model):\n                    train_batch = next(train_iter)\n\n                    loss = model(train_batch)\n                    accelerator.backward(loss)\n\n                    if accelerator.sync_gradients and cfg.grad_clip_norm > 0:\n                        accelerator.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)\n                    optimizer.step()\n                    optimizer.zero_grad(set_to_none=True)\n\n                step_loss += loss.item()\n\n            update_ema(ema_model, model, cfg.ema_decay)\n\n            if is_main and (step % cfg.checkpoint_interval == 0):\n                checkpoint_training_state(checkpoint_root_dir / \"last\", accelerator, ema_model, step, run_id)\n                run[\"checkpoints/last\"].track_files(str(checkpoint_root_dir / \"last\"))\n\n            log(run[\"metrics\"][\"train\"][\"loss\"], step_loss / cfg.accumulate, step, is_main and step % cfg.log_interval == 0)\n            log(run[\"metrics\"][\"epoch\"], step // len(dataloaders[\"train\"]), step, is_main)\n\n            if is_main and (step % cfg.val_interval == 0) and \"val\" in dataloaders:\n                val_loss = validate(\n                    cfg=cfg,\n                    model=model,\n                    ema_model=ema_model,\n                    val_dataloader=dataloaders[\"val\"],\n                    step=step,\n                    run=run,\n                    pbar=pbar,\n                    best_val_loss=best_val_loss,\n                    checkpoint_root_dir=checkpoint_root_dir,\n                    accelerator=accelerator,\n                )\n                best_val_loss = min(val_loss, best_val_loss)\n\n            pbar.update(train_id, advance=1, loss=loss.item())\n\n\ndef main(cfg):\n    acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate)\n\n    seed_everything(cfg.training.seed)\n    logger.info(f\"Seeded everything with seed {cfg.training.seed}\", main_process_only=True)\n\n    with acc.main_process_first():\n        model, dataloaders, optimizer = setup(cfg)\n    ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None  # EMA on main proc only\n    model, optimizer, dataloaders[\"train\"] = acc.prepare(model, optimizer, dataloaders[\"train\"])\n    run = ddict()\n    if acc.is_main_process:\n        ema.to(acc.device)\n        try:\n            if cfg.meta.neptune:\n                import neptune\n                run = neptune.init_run(project=cfg.meta.neptune, mode=\"debug\" if cfg.meta.debug else None)\n                run[\"accelerate\"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes)\n                log_cfg(cfg, run)\n        except ImportError:\n            logger.info(\"Did not find neptune installed. Logging will be disabled.\")\n\n    train(cfg.training, acc, model, ema, dataloaders, optimizer, run)\n\n\nif __name__ == \"__main__\":\n    cfg_file = OmegaConf.from_cli()['config_file']\n    main(make_config(cfg_file))\n"
  },
  {
    "path": "utils_model.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nCONST_log_range = 20\nCONST_log_min = 1e-10\nCONST_summary_rescale = 10\nCONST_exp_range = 10\nCONST_min_std_dev = math.exp(-CONST_exp_range)\n\n\ndef sandwich(x: Tensor):\n    return x.reshape(x.size(0), -1, x.size(-1))\n\n\ndef safe_log(data: Tensor):\n    return data.clamp(min=CONST_log_min).log()\n\n\ndef safe_exp(data: Tensor):\n    return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp()\n\n\ndef idx_to_float(idx: np.ndarray, num_bins: int):\n    flt_zero_one = (idx + 0.5) / num_bins\n    return (2.0 * flt_zero_one) - 1.0\n\n\ndef float_to_idx(flt: np.ndarray, num_bins: int):\n    flt_zero_one = (flt / 2.0) + 0.5\n    return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long()\n\n\ndef quantize(flt, num_bins: int):\n    return idx_to_float(float_to_idx(flt, num_bins), num_bins)\n\n\ndef pe_encode(sequence_length: int, embedding_size: int) -> Tensor:\n    \"\"\"Positional encoding as described in original attention is all you need paper\"\"\"\n\n    pe = torch.zeros((sequence_length, embedding_size))\n    pos = torch.arange(sequence_length).unsqueeze(1)\n    pe[:, 0::2] = torch.sin(\n        pos / torch.pow(1000, torch.arange(0, embedding_size, 2, dtype=torch.float32) / embedding_size)\n    )\n    pe[:, 1::2] = torch.cos(\n        pos / torch.pow(1000, torch.arange(1, embedding_size, 2, dtype=torch.float32) / embedding_size)\n    )\n\n    return pe\n\n\ndef pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> Tensor:\n    pe = torch.zeros(list(x.shape) + [embedding_size], device=x.device)\n    pos = (((x + 1) / 2) * max_freq).unsqueeze(-1)\n    pe[..., 0::2] = torch.sin(\n        pos\n        / torch.pow(10000, torch.arange(0, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)\n    )\n    pe[..., 1::2] = torch.cos(\n        pos\n        / torch.pow(10000, torch.arange(1, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)\n    )\n    return pe\n"
  },
  {
    "path": "utils_train.py",
    "content": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport math\nimport random\nimport tempfile\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Optional, Generator, Union\n\ntry:\n    import neptune\n    from neptune.utils import stringify_unsupported\nexcept ImportError:\n    neptune = None\n\n    def stringify_unsupported(x):\n        return x\n\n\nimport numpy as np\nimport torch\nfrom accelerate.logging import get_logger\nfrom omegaconf import OmegaConf, DictConfig\nfrom rich.progress import Progress, SpinnerColumn, MofNCompleteColumn, TimeElapsedColumn, TextColumn\nfrom torch.utils.data import DataLoader\n\nimport model\nimport networks\nimport probability\nfrom data import make_datasets\nfrom networks import adapters\n\nlogger = get_logger(__name__)\n\n\ndef seed_everything(seed: Optional[int]):\n    assert seed is not None\n    seed += torch.distributed.get_rank() if torch.distributed.is_initialized() else 0\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef worker_init_function(worker_id: int) -> None:\n    \"\"\"https://pytorch.org/docs/stable/notes/randomness.html#dataloader\"\"\"\n    worker_seed = torch.initial_seed() % 2**32\n    np.random.seed(worker_seed)\n    random.seed(worker_seed)\n\n\ndef init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: str) -> Optional[Path]:\n    if checkpoint_dir is None:\n        return None\n    checkpoint_dir = Path(checkpoint_dir) / run_id\n    checkpoint_dir.mkdir(parents=True, exist_ok=True)\n    last_dir = checkpoint_dir / \"last\"\n    last_dir.mkdir(parents=True, exist_ok=True)\n    best_dir = checkpoint_dir / \"best\"\n    best_dir.mkdir(parents=True, exist_ok=True)\n    return checkpoint_dir\n\n\ndef checkpoint_training_state(checkpoint_dir, accelerator, ema_model, step: int, run_id: str):\n    if checkpoint_dir is None:\n        return\n    logger.info(f\"Checkpointing training state to {checkpoint_dir} at step {step}\")\n    accelerator.save_state(checkpoint_dir)\n    with open(checkpoint_dir / \"info.json\", \"w\") as f:\n        json.dump({\"step\": step, \"run_id\": run_id}, f)\n    if ema_model is not None:\n        ema_checkpoint_path = checkpoint_dir / \"ema_model.pt\"\n        torch.save(ema_model.state_dict(), ema_checkpoint_path)\n\n\ndef log(key_handler, value, step, cond=True):\n    \"\"\"Log series to neptune only if cond is True. Helps with distributed training and conditional logging.\"\"\"\n    if not isinstance(key_handler, defaultdict) and cond and math.isfinite(value):\n        key_handler.log(value, step=step)\n\n\ndef log_cfg(cfg, run: \"neptune.Run\"):\n    with tempfile.TemporaryDirectory() as tmpdir:\n        cfg_temp_filename: Path = Path(tmpdir) / \"cfg.yaml\"\n        cfg_temp_filename.write_text(OmegaConf.to_yaml(cfg, resolve=True))\n        run[\"cfg\"].upload(str(cfg_temp_filename), wait=True)\n    run[\"hyperparameters\"] = stringify_unsupported(OmegaConf.to_container(cfg, resolve=True))\n\n\n@torch.no_grad()\ndef update_ema(ema_model, model, ema_decay):\n    if ema_model is not None and ema_decay > 0:\n        for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):\n            ema_param.sub_((1 - ema_decay) * (ema_param - model_param))\n\n\ndef ddict():\n    \"\"\"Infinite default dict to fake neptune run on non-main processes\"\"\"\n    return defaultdict(ddict)\n\n\ndef make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]:\n    while True:\n        for data in dataloader:\n            yield data\n\n\ndef make_progress_bar(is_main: bool, text=\"[red]loss: {task.fields[loss]:.3f}\"):\n    return Progress(\n        SpinnerColumn(),\n        MofNCompleteColumn(),\n        *Progress.get_default_columns(),\n        TimeElapsedColumn(),\n        TextColumn(text),\n        disable=not is_main,\n    )\n\n\ndef make_dataloaders(cfg: DictConfig):\n    train_set, val_set, _ = make_datasets(cfg.data)\n    dataloaders = {\n        \"train\": DataLoader(\n            dataset=train_set,\n            worker_init_fn=worker_init_function,\n            **cfg.train_loader,\n        ),\n        \"val\": DataLoader(\n            dataset=val_set,\n            worker_init_fn=worker_init_function,\n            **cfg.val_loader,\n        ),\n    }\n    return dataloaders\n\n\ndef make_from_cfg(module, cfg, **parameters):\n    return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else None\n\n\ndef make_bfn(cfg: DictConfig):\n    data_adapters = {\n        \"input_adapter\": make_from_cfg(adapters, cfg.input_adapter),\n        \"output_adapter\": make_from_cfg(adapters, cfg.output_adapter),\n    }\n    net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters)\n    bayesian_flow = make_from_cfg(model, cfg.bayesian_flow)\n    distribution_factory = make_from_cfg(probability, cfg.distribution_factory)\n    loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory)\n    bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss)\n    return bfn\n\n\ndefault_train_config = {\n    \"meta\": {\n        \"neptune\": None,\n        \"debug\": False,\n        \"root_dir\": \".\",\n    },\n    \"data\": {\n        \"dataset\": \"\",\n        \"data_dir\": \"./data\",\n    },\n    \"train_loader\": {\n        \"batch_size\": 1,\n        \"shuffle\": True,\n        \"num_workers\": 0,\n        \"pin_memory\": True,\n        \"drop_last\": True,\n    },\n    \"val_loader\": {\n        \"batch_size\": 1,\n        \"shuffle\": False,\n        \"num_workers\": 0,\n        \"pin_memory\": True,\n        \"drop_last\": False,\n    },\n    \"training\": {\n        \"accumulate\": 1,\n        \"checkpoint_dir\": \"./checkpoints\",\n        \"checkpoint_interval\": None,\n        \"ema_decay\": -1,\n        \"grad_clip_norm\": -1,\n        \"log_interval\": 50,\n        \"max_val_batches\": -1,\n        \"seed\": 666,\n        \"start_step\": 1,\n        \"val_repeats\": 1,\n    },\n}\n\n\ndef make_config(cfg_file: str):\n    cli_conf = OmegaConf.load(cfg_file)\n    # Start with default config\n    cfg = OmegaConf.create(default_train_config)\n    # Merge into default config\n    cfg = OmegaConf.merge(cfg, cli_conf)\n    return cfg\n"
  }
]