main b62568e5d064 cached
23 files
146.9 KB
37.4k tokens
288 symbols
1 requests
Download .txt
Repository: nnaisense/bayesian-flow-networks
Branch: main
Commit: b62568e5d064
Files: 23
Total size: 146.9 KB

Directory structure:
gitextract__0riu0_z/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── cifar10_continuous_16bins.yaml
│   ├── cifar10_continuous_256bins.yaml
│   ├── cifar10_discretized_16bins.yaml
│   ├── cifar10_discretized_256bins.yaml
│   ├── mnist_discrete.yaml
│   └── text8_discrete.yaml
├── data.py
├── env.yml
├── model.py
├── networks/
│   ├── __init__.py
│   ├── adapters.py
│   ├── transformer.py
│   ├── unet_improved.py
│   └── unet_vdm.py
├── probability.py
├── sample.py
├── test.py
├── train.py
├── utils_model.py
└── utils_train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Data, checkpoints, logs
data
checkpoints
.neptune

# Files generated by setuptools_scm
__version.py

# MacOS
.DS_Store

# Visual Studio Code
.vscode/
*.code-workspace
.history/

# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# PyCharm
.idea/

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# End of https://www.gitignore.io/api/python


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS


================================================
FILE: README.md
================================================
# Bayesian Flow Networks

This is the official code release for [Bayesian Flow Networks](https://arxiv.org/abs/2308.07037) by Alex Graves, Rupesh Kumar Srivastava, Timothy Atkinson and Faustino Gomez.

<img src="bfn.gif" alt="Overview of BFN process" style="width:600px;"/>

## Reading Guide

- `model.py` contains all the main contributions of the paper. These include definitions, for both continuous and discrete data, of Bayesian Flows as well as loss functions for both continuous-time and discrete-time. See comments in the base classes in that file for details.
- `probability.py` defines the probability distributions used by the models.
- `train.py`, `test.py` and `sample.py` are scripts for training, testing and sampling (see below for usage).
- `data.py` contains utilities related to data loading and processing.
- `networks/` contains implementations of the network architectures used by the models. 

## Setup

```shell
# Create a new conda env with all dependencies including pytorch and CUDA
conda env create -f env.yml
conda activate bfn

# Or, install additional dependencies into an existing pytorch env
pip install accelerate==0.19.0 matplotlib omegaconf rich

# Optional, if you want to enable logging to neptune.ai
pip install neptune 
```

## Training

The models in the paper can be trained using the configs provided in the `configs` dir as follows:

```shell
# mnist experiment on 1 GPU
accelerate launch train.py config_file=configs/mnist_discrete.yaml
# cifar10 experiment on 1 GPU (A100)
accelerate launch train.py config_file=configs/cifar10_discretized_256bins.yaml
# text8 experiment on 8 GPUs (A100)
accelerate launch --multi_gpu --num_processes=8 --num_machines=1 --dynamo_backend=no --mixed_precision=fp16 train.py config_file=configs/text8_discrete.yaml 
```

## Testing
> [!NOTE]
> Depending on your GPU, you may wish to adjust the batch size used for testing in `test.py`.
```shell
# Optional: Download pretrained checkpoints (make sure you have git-lfs installed: https://git-lfs.com/)
git clone git@hf.co:rupspace/pretrained-BFNs
# Compute 784-step loss on MNIST
python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000
# Compute 10-step loss on CIFAR-10
python test.py seed=1 config_file=./configs/cifar10_discretized_256bins.yaml load_model=./pretrained-BFNs/cifar10_256d_ema.pt n_steps=10 n_repeats=100
# Compute continuous-time loss on text8
python test.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt n_steps=0 n_repeats=1
```
> [!IMPORTANT]
> All computed results will be in nats-per-data-dimension. To convert to bits, divide by ln(2).

## Sampling

You can sample from a pre-trained model as follows (change options as desired):

```shell
# Sample 4 binarized MNIST images using 100 steps
python sample.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt samples_shape="[4, 28, 28, 1]" n_steps=100 save_file=./samples_mnist.pt
# Sample 4 CIFAR-10 16-bit images modeled as discretized data using 1000 steps
python sample.py seed=1 config_file=./configs/cifar10_discretized_16bins.yaml load_model=./pretrained-BFNs/cifar10_16d_ema.pt samples_shape="[4, 32, 32, 3]" n_steps=1000 save_file=./samples_cifar.pt
# Sample 2 text8 sequences of length 256 using 100 steps
python sample.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt samples_shape="[2, 256]" n_steps=100 save_file=./samples_text8.pt
```

The samples are stored as PyTorch tensors in the `save_file`, and can be visualized by loading them and then using the utilities `batch_to_images` and `batch_to_str` in `data.py`.
For example: 
```shell
# batch_to_images returns a matplotlib Figure object
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_mnist.pt')).savefig('mnist.png')"
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_cifar.pt')).savefig('cifar.png')"
# batch_to_str returns a list of str
python -c "import torch; from data import batch_to_str; print(batch_to_str(torch.load('./samples_text8.pt')))"
```

## Reproducibility 

If a high degree of reproducibility is desired (e.g. during sampling), set the following:

```python
torch.set_float32_matmul_precision("highest")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
```

## Acknowledgements

We are grateful to [@Higgcz](https://github.com/Higgcz) for generous support with the experiment infrastructure and code release.


================================================
FILE: configs/cifar10_continuous_16bins.yaml
================================================
meta:
  neptune: 
  debug: False
data:
  dataset: "cifar10"
  horizontal_flip: False
  num_bins: 16
train_loader:
  batch_size: 32
  shuffle: True
  num_workers: 8
  pin_memory: True
  drop_last: True
  persistent_workers: True
val_loader:
  batch_size: 500
  shuffle: False
  num_workers: 8
  pin_memory: True
model:
  net:
    class_name: "UNetVDM"
    parameters:
      embedding_dim: 128
      n_blocks: 32
      n_attention_heads: 1
      dropout_prob: 0.1
      norm_groups: 32
      input_channels: 3
      use_fourier_features: True
      attention_everywhere: False
      image_size: 32
  input_adapter:
    class_name: "FourierImageInputAdapter"
    parameters:
      input_channels: 3
      input_shape: [32, 32]
      output_height: 3
      add_pos_feats: False
      add_mask: False
  output_adapter:
    class_name: "OutputAdapter"
    parameters:
      input_height: 131
      output_channels: 3 # (r,g,b)
      output_height: 1
  bayesian_flow:
    class_name: "CtsBayesianFlow"
    parameters:
      min_variance: 1e-3
  loss:
    class_name: "CtsBayesianFlowLoss"
    parameters:
      noise_pred: True
  distribution_factory:
    class_name: "DeltaFactory"
    parameters: {}
optimizer:
  lr: 2e-4
  betas: [0.9,0.99]
  weight_decay: 0.01
  eps: 1e-8
training:
  checkpoint_interval: 10_000
  ema_decay: 0.9999
  grad_clip_norm: 5.0
  log_interval: 1
  n_training_steps: 1_000_000
  val_interval: 50_000
  val_repeats: 100


================================================
FILE: configs/cifar10_continuous_256bins.yaml
================================================
meta:
  neptune: 
  debug: False
data:
  dataset: "cifar10"
  horizontal_flip: False
  num_bins: 256
train_loader:
  batch_size: 32
  shuffle: True
  num_workers: 8
  pin_memory: True
  drop_last: True
  persistent_workers: True
val_loader:
  batch_size: 500
  shuffle: False
  num_workers: 8
  pin_memory: True
model:
  net:
    class_name: "UNetVDM"
    parameters:
      embedding_dim: 128
      n_blocks: 32
      n_attention_heads: 1
      dropout_prob: 0.1
      norm_groups: 32
      input_channels: 3
      use_fourier_features: True
      attention_everywhere: False
      image_size: 32
  input_adapter:
    class_name: "FourierImageInputAdapter"
    parameters:
      input_channels: 3
      input_shape: [32, 32]
      output_height: 3
      add_pos_feats: False
      add_mask: False
  output_adapter:
    class_name: "OutputAdapter"
    parameters:
      input_height: 131
      output_channels: 3 # (r,g,b)
      output_height: 1
  bayesian_flow:
    class_name: "CtsBayesianFlow"
    parameters:
      min_variance: 1e-6
  loss:
    class_name: "CtsBayesianFlowLoss"
    parameters:
      noise_pred: True
  distribution_factory:
    class_name: "DeltaFactory"
    parameters: {}
optimizer:
  lr: 2e-4
  betas: [0.9,0.99]
  weight_decay: 0.01
  eps: 1e-8
training:
  checkpoint_interval: 10_000
  ema_decay: 0.9999
  grad_clip_norm: 5.0
  log_interval: 1
  n_training_steps: 1_000_000
  val_interval: 50_000
  val_repeats: 100


================================================
FILE: configs/cifar10_discretized_16bins.yaml
================================================
meta:
  neptune: 
  debug: False
data:
  dataset: "cifar10"
  horizontal_flip: False
  num_bins: 16
train_loader:
  batch_size: 32
  shuffle: True
  num_workers: 8
  pin_memory: True
  drop_last: True
  persistent_workers: True
val_loader:
  batch_size: 1000
  shuffle: False
  num_workers: 8
  pin_memory: True
model:
  net:
    class_name: "UNetVDM"
    parameters:
      embedding_dim: 128
      n_blocks: 32
      n_attention_heads: 1
      dropout_prob: 0.1
      norm_groups: 32
      input_channels: 3
      use_fourier_features: True
      attention_everywhere: False
      image_size: 32
  input_adapter:
    class_name: "FourierImageInputAdapter"
    parameters:
      input_channels: 3
      input_shape: [32, 32]
      output_height: 3
      add_pos_feats: False
      add_mask: False
  output_adapter:
    class_name: "OutputAdapter"
    parameters:
      input_height: 131
      output_channels: 3 # (r,g,b)
      output_height: 2 # mean, std
  bayesian_flow:
    class_name: "CtsBayesianFlow"
    parameters:
      min_variance: 1e-3
  loss:
    class_name: "CtsBayesianFlowLoss"
    parameters:
      noise_pred: True
  distribution_factory:
    class_name: "DiscretizedNormalFactory"
    parameters:
      num_bins: 16
      clip: True
optimizer:
  lr: 2e-4
  betas: [0.9,0.99]
  weight_decay: 0.01
  eps: 1e-8
training:
  checkpoint_interval: 10_000
  ema_decay: 0.9999
  grad_clip_norm: 5.0
  log_interval: 1
  n_training_steps: 1_000_000
  val_interval: 50_000
  val_repeats: 100


================================================
FILE: configs/cifar10_discretized_256bins.yaml
================================================
meta:
  neptune: 
  debug: False
data:
  dataset: "cifar10"
  horizontal_flip: False
  num_bins: 256
train_loader:
  batch_size: 32
  shuffle: True
  num_workers: 8
  pin_memory: True
  drop_last: True
  persistent_workers: True
val_loader:
  batch_size: 1000
  shuffle: False
  num_workers: 8
  pin_memory: True
model:
  net:
    class_name: "UNetVDM"
    parameters:
      embedding_dim: 128
      n_blocks: 32
      n_attention_heads: 1
      dropout_prob: 0.1
      norm_groups: 32
      input_channels: 3
      use_fourier_features: True
      attention_everywhere: False
      image_size: 32
  input_adapter:
    class_name: "FourierImageInputAdapter"
    parameters:
      input_channels: 3
      input_shape: [32, 32]
      output_height: 3
      add_pos_feats: False
      add_mask: False
  output_adapter:
    class_name: "OutputAdapter"
    parameters:
      input_height: 131
      output_channels: 3 # (r,g,b)
      output_height: 2 # mean, std
  bayesian_flow:
    class_name: "CtsBayesianFlow"
    parameters:
      min_variance: 1e-6
  loss:
    class_name: "CtsBayesianFlowLoss"
    parameters:
      noise_pred: True
  distribution_factory:
    class_name: "DiscretizedNormalFactory"
    parameters:
      num_bins: 256
      clip: True
optimizer:
  lr: 2e-4
  betas: [0.9,0.99]
  weight_decay: 0.01
  eps: 1e-8
training:
  checkpoint_interval: 10_000
  ema_decay: 0.9999
  grad_clip_norm: 5.0
  log_interval: 1
  n_training_steps: 1_000_000
  val_interval: 50_000
  val_repeats: 100


================================================
FILE: configs/mnist_discrete.yaml
================================================
meta:
  neptune:
  debug: False
data:
  dataset: "bin_mnist"
train_loader:
  batch_size: 512
  shuffle: True
  num_workers: 8
  pin_memory: True
  drop_last: True
val_loader:
  batch_size: 1000
  shuffle: False
  num_workers: 8
  pin_memory: True
model:
  net:
    class_name: "UNetModel"
    parameters:
      image_size: 28
      in_channels: 2
      model_channels: 128
      out_channels: 128
      num_res_blocks: 2
      attention_resolutions: [8,16]
      dropout: 0.5
      channel_mult: [1, 2, 2]
      conv_resample: True
      dims: 2
      num_heads: 4
      num_heads_upsample: -1
      project_input: True
      skip: True
  input_adapter:
    class_name: "FourierImageInputAdapter"
    parameters:
      input_channels: 1
      input_shape: [28, 28]
      output_height: 2
      add_pos_feats: False
  output_adapter:
    class_name: "OutputAdapter"
    parameters:
      input_height: 256
      output_channels: 1
      output_height: 1
  bayesian_flow:
    class_name: "DiscreteBayesianFlow"
    parameters:
      n_classes: 2
      max_sqrt_beta: 3
      discretize: False
  loss:
    class_name: "DiscreteBayesianFlowLoss"
    parameters: {}
  distribution_factory:
    class_name: "BernoulliFactory"
    parameters: {}
optimizer:
  lr: 1e-4
  betas: [0.9,0.98]
training:
  checkpoint_interval: 10_000
  ema_decay: 0.9999
  grad_clip_norm: 5.0
  log_interval: 1
  n_training_steps: 1_000_000
  val_interval: 50_000
  val_repeats: 1000

================================================
FILE: configs/text8_discrete.yaml
================================================
meta:
  neptune:
  debug: False
data:
  dataset: "text8"
  seq_len: 256
train_loader:
  batch_size: 416
  shuffle: True
  num_workers: 8
  pin_memory: True
  drop_last: True
val_loader:
  batch_size: 200
  shuffle: True
  num_workers: 8
  pin_memory: True
model:
  net:
    class_name: "GPT"
    parameters:
      vocab_size: 27
      n_layer: 24
      n_head: 12
      n_embd: 768
      dropout: 0.0
      skip: True
      bias: True
  input_adapter:
    class_name: "TextInputAdapter"
    parameters:
      vocab_size: 27
      seq_len: 256
      output_size: 768
      learn_pos_embedding: False
  output_adapter: null
  bayesian_flow:
    class_name: "DiscreteBayesianFlow"
    parameters:
      n_classes: 27
      max_sqrt_beta: 0.75
  loss:
    class_name: "DiscreteBayesianFlowLoss"
    parameters: {}
  distribution_factory:
    class_name: "CategoricalFactory"
    parameters: {}
optimizer:
  lr: 1e-4
  betas: [0.9, 0.98]
  weight_decay: 0.01
training:
  accumulate: 1
  checkpoint_interval: 10_000
  ema_decay: 0.9999
  grad_clip_norm: 5
  log_interval: 1
  max_val_batches: 5_000
  n_training_steps: 10_000_000
  val_interval: 100_000
  val_repeats: 1

================================================
FILE: data.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
import pathlib
import pickle
import zipfile
from typing import Union

import numpy as np
import requests
import torch
import torchvision
from matplotlib import pyplot as plt
from omegaconf import DictConfig
from torch.utils.data import Dataset, random_split
from torchvision import transforms
from torchvision.utils import make_grid

from utils_model import quantize

TEXT8_CHARS = list("_abcdefghijklmnopqrstuvwxyz")


def bin_mnist_transform(x):
    return torch.bernoulli(x.permute(1, 2, 0).contiguous()).int()


def bin_mnist_cts_transform(x):
    return torch.bernoulli(x.permute(1, 2, 0).contiguous()) - 0.5


def rgb_image_transform(x, num_bins=256):
    return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous()


class MyLambda(torchvision.transforms.Lambda):
    def __init__(self, lambd, arg1):
        super().__init__(lambd)
        self.arg1 = arg1

    def __call__(self, x):
        return self.lambd(x, self.arg1)


class CIFAR10(torchvision.datasets.CIFAR10):
    def __getitem__(self, idx):
        return super().__getitem__(idx)[0]


class MNIST(torchvision.datasets.MNIST):
    def __getitem__(self, idx):
        return super().__getitem__(idx)[0]


def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:
    """
    Mandatory keys: dataset (must be cifar10, mnist, bin_mnist, bin_mnist_cts or text8), data_dir
    Optional for vision: num_bins (default 256), val_frac (default 0.01), horizontal_flip (default: False)
    Mandatory for text: seq_len
    """
    num_bins = cfg.get("num_bins", 256)
    if cfg.dataset == "cifar10":
        train_transform_list = [transforms.ToTensor()]
        if cfg.get("horizontal_flip", False):
            train_transform_list.append(transforms.RandomHorizontalFlip())
        train_transform_list.append(MyLambda(rgb_image_transform, num_bins))
        train_transform = transforms.Compose(train_transform_list)
        test_transform = transforms.Compose([transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins)])
        train_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transform)
        val_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=test_transform)
        test_set = CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transform)

    elif cfg.dataset == "mnist":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                MyLambda(rgb_image_transform, num_bins),
            ]
        )
        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)

    elif cfg.dataset == "bin_mnist":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_transform)])
        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)

    elif cfg.dataset == "bin_mnist_cts":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_cts_transform)])
        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)

    elif cfg.dataset == "text8":
        train_set = Text8Dataset(cfg.data_dir, "train", download=True, seq_len=cfg.seq_len)
        val_set = Text8Dataset(cfg.data_dir, "val", download=True, seq_len=cfg.seq_len)
        test_set = Text8Dataset(cfg.data_dir, "test", download=True, seq_len=cfg.seq_len)
    else:
        raise NotImplementedError(cfg.dataset)

    if cfg.dataset != "text8":
        # For vision datasets we split the train set into train and val
        val_frac = cfg.get("val_frac", 0.01)
        train_val_split = [1.0 - val_frac, val_frac]
        seed = 2147483647
        train_set = random_split(train_set, train_val_split, generator=torch.Generator().manual_seed(seed))[0]
        val_set = random_split(val_set, train_val_split, generator=torch.Generator().manual_seed(seed))[1]

    return train_set, val_set, test_set


def prepare_text8(data_dir: pathlib.Path):
    data_dir.mkdir(parents=True, exist_ok=True)
    data_url = "http://mattmahoney.net/dc/text8.zip"
    with open(data_dir / "text8.zip", "wb") as f:
        print("Downloading text8")
        f.write(requests.get(data_url).content)
        print("Done")
    with zipfile.ZipFile(data_dir / "text8.zip") as f:
        f.extractall(data_dir)
    os.remove(data_dir / "text8.zip")
    data = (data_dir / "text8").read_text()

    # get all the unique characters that occur in this text
    chars = sorted(list(set(data)))
    vocab_size = len(chars)
    print("all the unique characters:", "".join(chars))
    print(f"vocab size: {vocab_size:,}")

    # create a mapping from characters to integers
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}

    def encode(s):
        return [stoi[c] for c in s]  # encoder: take a string, output a list of integers

    # encode both to integers
    n = len(data)
    train_data = data[: int(n * 0.9)]
    val_data = data[int(n * 0.9) : int(n * 0.95)]
    test_data = data[int(n * 0.95) :]
    train_ids = encode(train_data)
    val_ids = encode(val_data)
    test_ids = encode(test_data)
    print(f"train has {len(train_ids):,} tokens")
    print(f"val has {len(val_ids):,} tokens")
    print(f"test has {len(test_ids):,} tokens")

    # export to bin files
    train_ids = np.array(train_ids, dtype=np.uint16)
    val_ids = np.array(val_ids, dtype=np.uint16)
    test_ids = np.array(test_ids, dtype=np.uint16)
    train_ids.tofile(data_dir / "train.bin")
    val_ids.tofile(data_dir / "val.bin")
    test_ids.tofile(data_dir / "test.bin")
    print(f"Saved to {data_dir / 'train.bin'}, {data_dir / 'val.bin'}, {data_dir / 'test.bin'}")

    # save the meta information as well, to help us encode/decode later
    meta = {
        "vocab_size": vocab_size,
        "itos": itos,
        "stoi": stoi,
    }
    with open(os.path.join(data_dir / "meta.pkl"), "wb") as f:
        pickle.dump(meta, f)

    print(f"text8 dataset downloaded and prepared in dir {data_dir}")


class Text8Dataset(Dataset):
    def __init__(self, data_dir: Union[str, pathlib.Path], split: str, download: bool, seq_len: int):
        """
        seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context.
        context is only used for correct preparation of val/test sets.
        """
        self.root_dir = pathlib.Path(data_dir)
        self.split = split
        self.seq_len = seq_len
        fname = {"train": "train.bin", "val": "val.bin", "test": "test.bin"}[self.split]
        assert self.split in ["train", "val", "test"]
        data_dir = self.root_dir / "text8"
        if not os.path.exists(data_dir):
            if download:
                prepare_text8(data_dir)
            else:
                raise NotADirectoryError(f"dir {data_dir} does not exist and download is False")
        self.data = np.memmap(data_dir / fname, np.uint16, "r")

    def __getitem__(self, index) -> torch.Tensor:
        seq = torch.from_numpy(self.data[index : index + self.seq_len].astype(np.int64))
        return seq

    def __len__(self):
        return self.data.size - self.seq_len


def char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) -> str:
    """Decode a 1D sequence of character IDs to a string."""
    return "".join([TEXT8_CHARS[i] for i in char_ids])


def batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) -> list[str]:
    """Decode a batch of character IDs to a list of strings."""
    return [char_ids_to_str(row_char_ids) for row_char_ids in text_batch]


def batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt.Figure:
    if ncols is None:
        ncols = math.ceil(math.sqrt(len(image_batch)))
    if image_batch.size(-1) == 3:  # for color images (CIFAR-10)
        image_batch = (image_batch + 1) / 2
    grid = make_grid(image_batch.permute(0, 3, 1, 2), ncols, pad_value=1).permute(1, 2, 0)
    fig = plt.figure(figsize=(grid.size(1) / 30, grid.size(0) / 30))
    plt.imshow(grid.cpu().clip(min=0, max=1), interpolation="nearest")
    plt.grid(False)
    plt.axis("off")
    return fig


================================================
FILE: env.yml
================================================
name: bfn
channels:
  - pytorch
  - nvidia
dependencies:
  - python=3.9
  - pytorch=2.0.0
  - pytorch-cuda=11.8
  - torchvision=0.15.0
  - pip
  - pip:
    - accelerate==0.19.0
    - matplotlib
    - omegaconf
    - rich


================================================
FILE: model.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file implements the Bayesian Flow and BFN loss for continuous and discrete variables.
Finally it implements the BFN using these objects.
For consistency we use always use a tuple to store input parameters.
It has just one element for discrete data (the probabilities) and two for continuous/discretized (mean & variance).
The probability distributions and network architectures are defined in probability.py and networks dir.
"Cts" is an abbreviation of "Continuous".
"""

import math
from abc import abstractmethod, ABC
from typing import Union, Optional

import torch
import torch.distributions as D
import torch.nn.functional as F
from torch import nn, Tensor

from probability import (
    DiscreteDistributionFactory,
    CtsDistributionFactory,
    PredDistToDataDistFactory,
    DiscretizedCtsDistribution,
)
from utils_model import sandwich, float_to_idx


class BayesianFlow(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]:
        """Returns the initial input params (for a batch) at t=0. Used during sampling.
        For discrete data, the tuple has length 1 and contains the initial class probabilities.
        For continuous data, the tuple has length 2 and contains the mean and precision."""
        pass

    @abstractmethod
    def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:
        """Utility method to convert input distribution params to network inputs if needed."""
        pass

    @abstractmethod
    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:
        """Returns the alpha at step i of total n_steps according to the flow schedule. Used:
        a) during sampling, when i and alpha are the same for all samples in the batch.
        b) during discrete time loss computation, when i and alpha are different for samples in the batch."""
        pass

    @abstractmethod
    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        """Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used:
        a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net.
        b) during discrete time loss computation when alpha are different for samples in the batch."""
        pass

    @abstractmethod
    def update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]:
        """Updates the distribution parameters using Bayes' theorem in light of noisy sample y.
        Used during sampling when alpha is the same for the whole batch."""
        pass

    @abstractmethod
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:
        """Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data.
        Used during training when t (and thus accuracies) are different for different samples in the batch.
        For discrete data, the returned tuple has length 1 and contains the class probabilities.
        For continuous data, the returned tuple has length 2 and contains the mean and precision."""
        pass


class Loss(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor:
        """Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1).
        The input params are only used when the network is parameterized to predict the noise for continuous data."""
        pass

    @abstractmethod
    def discrete_time_loss(
        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples: int = 20
    ) -> Tensor:
        """Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) using
        n_samples for Monte Carlo estimation of the discrete loss.
        The input params are only used when the network is parameterized to predict the noise for continuous data."""
        pass

    @abstractmethod
    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        """Returns the reconstruction loss, i.e. the final cost of transmitting clean data.
        The input params are only used when the network is parameterized to predict the noise for continuous data."""
        pass


# Continuous or Discretized data


class CtsBayesianFlow(BayesianFlow):
    def __init__(
        self,
        min_variance: float = 1e-6,
    ):
        super().__init__()
        self.min_variance = min_variance

    @torch.no_grad()
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:
        post_var = torch.pow(self.min_variance, t)
        alpha_t = 1 - post_var
        mean_mean = alpha_t * data
        mean_var = alpha_t * post_var
        mean_std_dev = mean_var.sqrt()
        noise = torch.randn(mean_mean.shape, device=mean_mean.device)
        mean = mean_mean + (mean_std_dev * noise)
        # We don't need to compute the variance because it is not needed by the network, so set it to None
        input_params = (mean, None)
        return input_params

    def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
        return params[0]  # Only the mean is used by the network

    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]:
        return torch.zeros(*data_shape, device=device), 1.0

    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
        sigma_1 = math.sqrt(self.min_variance)
        return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps))

    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        dist = D.Normal(x, 1.0 / alpha**0.5)
        return dist

    def update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]:
        input_mean, input_precision = input_params
        new_precision = input_precision + alpha
        new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precision
        return new_mean, new_precision


class CtsBayesianFlowLoss(Loss):
    def __init__(
        self,
        bayesian_flow: CtsBayesianFlow,
        distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],
        min_loss_variance: float = -1,
        noise_pred: bool = True,
    ):
        super().__init__()
        self.bayesian_flow = bayesian_flow
        self.distribution_factory = distribution_factory
        self.min_loss_variance = min_loss_variance
        self.C = -0.5 * math.log(bayesian_flow.min_variance)
        self.noise_pred = noise_pred
        if self.noise_pred:
            self.distribution_factory.log_dev = False
            self.distribution_factory = PredDistToDataDistFactory(
                self.distribution_factory, self.bayesian_flow.min_variance
            )

    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
        output_params = sandwich(output_params)
        t = t.flatten(start_dim=1).float()
        posterior_var = torch.pow(self.bayesian_flow.min_variance, t)
        flat_target = data.flatten(start_dim=1)
        pred_dist = self.distribution_factory.get_dist(output_params, input_params, t)
        pred_mean = pred_dist.mean
        mse_loss = (pred_mean - flat_target).square()
        if self.min_loss_variance > 0:
            posterior_var = posterior_var.clamp(min=self.min_loss_variance)
        loss = self.C * mse_loss / posterior_var
        return loss

    def discrete_time_loss(
        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10
    ) -> Tensor:
        output_params = sandwich(output_params)
        t = t.flatten(start_dim=1).float()
        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)
        if hasattr(output_dist, "probs"):  # output distribution is discretized normal
            flat_target = data.flatten(start_dim=1)
            t = t.flatten(start_dim=1)
            i = t * n_steps + 1  # since t = (i - 1) / n
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)
            receiver_mix_wts = sandwich(output_dist.probs)
            receiver_mix_dist = D.Categorical(probs=receiver_mix_wts, validate_args=False)
            receiver_components = D.Normal(
                output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False
            )
            receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False)
            y = sender_dist.sample(torch.Size([n_samples]))
            loss = (
                (sender_dist.log_prob(y) - receiver_dist.log_prob(y))
                .mean(0)
                .flatten(start_dim=1)
                .mean(1, keepdims=True)
            )
        else:  # output distribution is normal
            pred_mean = output_dist.mean
            flat_target = data.flatten(start_dim=1)
            mse_loss = (pred_mean - flat_target).square()
            i = t * n_steps + 1
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            loss = alpha * mse_loss / 2
        return n_steps * loss

    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        output_params = sandwich(output_params)
        flat_data = data.flatten(start_dim=1)
        t = torch.ones_like(data).flatten(start_dim=1).float()
        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)

        if hasattr(output_dist, "probs"):  # output distribution is discretized normal
            reconstruction_loss = -output_dist.log_prob(flat_data)
        else:  # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2)
            if self.bayesian_flow.min_variance == 1e-3:  # used for 16 bin CIFAR10
                noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance)
                num_bins = 16
            else:
                noise_dev = math.sqrt(self.bayesian_flow.min_variance)
                num_bins = 256
            mean = output_dist.mean.flatten(start_dim=1)
            final_dist = D.Normal(mean, noise_dev)
            final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1)
            reconstruction_loss = -final_dist.log_prob(flat_data)
        return reconstruction_loss


# Discrete Data


class DiscreteBayesianFlow(BayesianFlow):
    def __init__(
        self,
        n_classes: int,
        min_sqrt_beta: float = 1e-10,
        discretize: bool = False,
        epsilon: float = 1e-6,
        max_sqrt_beta: float = 1,
    ):
        super().__init__()
        self.n_classes = n_classes
        self.min_sqrt_beta = min_sqrt_beta
        self.discretize = discretize
        self.epsilon = epsilon
        self.max_sqrt_beta = max_sqrt_beta
        self.uniform_entropy = math.log(self.n_classes)

    def t_to_sqrt_beta(self, t):
        return t * self.max_sqrt_beta

    def count_dist(self, x, beta=None):
        mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1
        std_dev = math.sqrt(self.n_classes)
        if beta is not None:
            mean = mean * beta
            std_dev = std_dev * beta.sqrt()
        return D.Normal(mean, std_dev, validate_args=False)

    def count_sample(self, x, beta):
        return self.count_dist(x, beta).rsample()

    @torch.no_grad()
    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]:
        return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,)

    @torch.no_grad()
    def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
        params = params[0]
        if self.n_classes == 2:
            params = params * 2 - 1  # We scale-shift here for MNIST instead of in the network like for text
            params = params[..., :1]
        return params

    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
        return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1)

    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        e_x = F.one_hot(x.long(), self.n_classes)
        alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alpha
        dist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5)
        return dist

    def update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]:
        new_input_params = input_params[0] * y.exp()
        new_input_params /= new_input_params.sum(-1, keepdims=True)
        return (new_input_params,)

    @torch.no_grad()
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:
        if self.discretize:
            data = float_to_idx(data, self.n_classes)
        sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon))
        lo_beta = sqrt_beta < self.min_sqrt_beta
        sqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta)
        beta = sqrt_beta.square().unsqueeze(-1)
        logits = self.count_sample(data, beta)
        probs = F.softmax(logits, -1)
        probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs)
        if self.n_classes == 2:
            probs = probs[..., :1]
            probs = probs.reshape_as(data)
        input_params = (probs,)
        return input_params


class DiscreteBayesianFlowLoss(Loss):
    def __init__(
        self,
        bayesian_flow: DiscreteBayesianFlow,
        distribution_factory: DiscreteDistributionFactory,
    ):
        super().__init__()
        self.bayesian_flow = bayesian_flow
        self.distribution_factory = distribution_factory
        self.K = self.bayesian_flow.n_classes

    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
        flat_output = sandwich(output_params)
        pred_probs = self.distribution_factory.get_dist(flat_output).probs
        flat_target = data.flatten(start_dim=1)
        if self.bayesian_flow.discretize:
            flat_target = float_to_idx(flat_target, self.K)
        tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K)
        kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1)
        t = t.flatten(start_dim=1).float()
        loss = t * (self.bayesian_flow.max_sqrt_beta**2) * kl
        return loss

    def discrete_time_loss(
        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10
    ) -> Tensor:
        flat_target = data.flatten(start_dim=1)
        if self.bayesian_flow.discretize:
            flat_target = float_to_idx(flat_target, self.K)
        i = t * n_steps + 1
        alpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1)
        sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)

        flat_output = sandwich(output_params)
        receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs
        receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2))
        classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0)
        receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1))
        receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components)

        y = sender_dist.sample(torch.Size([n_samples]))
        loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True)
        return loss

    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        flat_outputs = sandwich(output_params)
        flat_data = data.flatten(start_dim=1)
        output_dist = self.distribution_factory.get_dist(flat_outputs)
        return -output_dist.log_prob(flat_data)


class BFN(nn.Module):
    def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss):
        super().__init__()
        self.net = net
        self.bayesian_flow = bayesian_flow
        self.loss = loss

    @staticmethod
    @torch.no_grad()
    def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:
        if n_steps == 0 or n_steps is None:
            t = torch.rand(data.size(0), device=data.device).unsqueeze(-1)
        else:
            t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps
        t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data)
        return t

    def forward(
        self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None
    ) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]:
        """
        Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss.
        t is sampled randomly if None. If t is not None, expect t.shape == data.shape.
        """

        t = self.sample_t(data, n_steps) if t is None else t
        # sample input parameter flow
        input_params = self.bayesian_flow(data, t)
        net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)

        # compute output distribution parameters
        output_params: Tensor = self.net(net_inputs, t)

        # compute KL loss in float32
        with torch.autocast(device_type=data.device.type if data.device.type != "mps" else "cpu", enabled=False):
            if n_steps == 0 or n_steps is None:
                loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t)
            else:
                loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps)

        # loss shape is (batch_size, 1)
        return loss.mean()

    @torch.inference_mode()
    def compute_reconstruction_loss(self, data: Tensor) -> Tensor:
        t = torch.ones_like(data).float()
        input_params = self.bayesian_flow(data, t)
        net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)
        output_params: Tensor = self.net(net_inputs, t)
        return self.loss.reconstruction_loss(data, output_params, input_params).flatten(start_dim=1).mean()

    @torch.inference_mode()
    def sample(self, data_shape: tuple, n_steps: int) -> Tensor:
        device = next(self.parameters()).device
        input_params = self.bayesian_flow.get_prior_input_params(data_shape, device)
        distribution_factory = self.loss.distribution_factory

        for i in range(1, n_steps + 1):
            t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps
            output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
            output_sample = distribution_factory.get_dist(output_params, input_params, t).sample()
            output_sample = output_sample.reshape(*data_shape)
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample()
            input_params = self.bayesian_flow.update_input_params(input_params, y, alpha)

        t = torch.ones(*data_shape, device=device)
        output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
        output_sample = distribution_factory.get_dist(output_params, input_params, t).mode
        output_sample = output_sample.reshape(*data_shape)
        return output_sample


================================================
FILE: networks/__init__.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = (
    "GPT",
    "UNetVDM",
    "UNetModel",
    "adapters",
)

from .transformer import GPT
from .unet_vdm import UNetVDM
from .unet_improved import UNetModel
from . import adapters


================================================
FILE: networks/adapters.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Tuple

import torch
from torch import Tensor
from torch import nn

from utils_model import sandwich, pe_encode, pe_encode_float


class TextInputAdapter(nn.Module):
    """
    A module to convert sequences of text class tokens to embedding tokens with learned positional embeddings.
    """

    def __init__(
        self,
        vocab_size: int,
        seq_len: int,
        output_size: int = 256,
        learn_pos_embedding: bool = False,
    ):
        super().__init__()
        self.learn_pos_embedding = learn_pos_embedding
        if learn_pos_embedding:
            self.pos_embedding = nn.Embedding(seq_len, output_size)
        else:
            self.register_buffer("pos_embedding", pe_encode(seq_len, output_size))
        self.inp_embedding = nn.Linear(vocab_size, output_size)
        self.t_embedding = nn.Linear(1, output_size)

    def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor:
        inp_emb = self.inp_embedding(2 * probs - 1)
        if self.learn_pos_embedding:
            pos_emb = self.pos_embedding(
                torch.arange(0, probs.size(1)).to(probs.device)
            )
        else:
            pos_emb = self.pos_embedding
        pos_emb = pos_emb.unsqueeze(0).expand(inp_emb.size(0), -1, -1)
        t_emb = self.t_embedding((2 * t - 1).unsqueeze(-1))
        output = inp_emb + pos_emb + t_emb

        return output


class FourierImageInputAdapter(nn.Module):
    """
    A module to convert 2D image coordinates into a set of vectors represented as a matrix, with fourier position codes.
    """

    def __init__(
        self,
        input_channels: int = 3,
        input_shape: Tuple[int, int] = (224, 224),
        n_freq_bands: int = 64,
        output_height: int = 256,
        value_res: int = -1,
        mask_res: int = -1,
        add_pos_feats: bool = True,
        add_mask: bool = True,
        learn_pos_feats: bool = False,
        pos_embed_size: int = 32,
        init_scale: float = 0.02,
    ):
        super().__init__()
        self.input_shape = input_shape
        self.n_freq_bands = n_freq_bands
        self.value_res = value_res
        self.mask_res = mask_res
        self.add_pos_feats = add_pos_feats
        self.add_mask = add_mask
        if learn_pos_feats:
            pos_feats = nn.Parameter(
                init_scale
                * torch.randn(1, input_shape[0] * input_shape[1], pos_embed_size)
            )
            self.register_parameter("pos_feats", pos_feats)
        else:
            x = torch.linspace(-1.0, 1.0, steps=input_shape[0])
            y = torch.linspace(-1.0, 1.0, steps=input_shape[1])
            x_pos, y_pos = torch.meshgrid(x, y, indexing="ij")
            pos = torch.stack((x_pos, y_pos), dim=-1)
            pos = pos.reshape(-1, 2)
            x_bands = torch.linspace(1.0, input_shape[0] / 2, steps=n_freq_bands)
            y_bands = torch.linspace(1.0, input_shape[1] / 2, steps=n_freq_bands)
            bands = torch.stack((x_bands, y_bands), dim=0)
            vals = pos[:, :, None] * bands[None, :, :]
            vals = math.pi * vals.reshape(vals.shape[0], -1)
            pos_feats = torch.cat([vals.sin(), vals.cos()], dim=-1)
            pos_feats = torch.cat([pos_feats, pos], dim=-1)
            self.register_buffer("pos_feats", pos_feats)
        img_feat_height = input_channels
        pos_feat_height = pos_feats.size(-1)
        if self.mask_res > 0:
            mask_feat_height = (n_freq_bands * 2) + 1
        else:
            mask_feat_height = 1
        all_feat_height = img_feat_height
        if add_mask:
            all_feat_height += mask_feat_height
        if add_pos_feats:
            all_feat_height += pos_feat_height
        self.output_projection = None
        if output_height != all_feat_height:
            self.output_projection = nn.Linear(all_feat_height, output_height)

    def forward(self, img: Tensor, t: Tensor) -> Tensor:
        flat_img = sandwich(img)
        flat_t = sandwich(t)
        t_feats = (flat_t.float()[..., :1] * 2) - 1
        if self.mask_res > 0:
            t_feats = torch.cat(
                [
                    t_feats,
                    pe_encode_float(
                        t_feats, self.mask_res, self.n_freq_bands * 2
                    ).flatten(start_dim=2),
                ],
                -1,
            )
        fourier_feats = self.pos_feats.expand(img.size(0), -1, -1)
        all_feat_list = [flat_img]
        if self.add_mask:
            all_feat_list.append(t_feats)
        if self.add_pos_feats:
            all_feat_list.append(fourier_feats)
        all_feats = torch.cat(all_feat_list, dim=-1)
        if self.output_projection is None:
            output = all_feats
        else:
            output = self.output_projection(all_feats)
        return output


class OutputAdapter(nn.Module):
    def __init__(self, input_height: int, output_channels: int, output_height: int):
        super().__init__()
        self.output_channels = output_channels
        self.output_height = output_height
        self.output_projection = nn.Linear(
            input_height, output_channels * output_height
        )

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        output = self.output_projection(inp)
        return output.reshape(
            output.size(0), -1, self.output_channels, self.output_height
        )


================================================
FILE: networks/transformer.py
================================================
# Source: https://github.com/karpathy/nanoGPT
#
# MIT License
#
# Copyright (c) 2022 Andrej Karpathy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Modifications:
# - Added data_adapters to GPT to preprocess the inputs and (optionally) postprocess the outputs
# - Added the `skip` option to concat the input and output of the network before the final projection
# - Added time `t` as an input to `forward()`

import math

import torch
import torch.nn as nn
import torch.nn.functional as F


def gelu(x):
    return F.gelu(x, approximate="tanh")


class LayerNorm(nn.Module):
    """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class SelfAttention(nn.Module):
    def __init__(self, n_head, n_embd, dropout, bias, is_causal):
        super().__init__()
        assert n_embd % n_head == 0

        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)

        # output projection
        self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)

        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.is_causal = is_causal

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        y = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout if self.training else 0, is_causal=self.is_causal
        )
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    def __init__(self, n_embd, dropout, bias):
        super().__init__()
        self.c_fc = nn.Linear(n_embd, 4 * n_embd, bias=bias)
        self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self, n_head, n_embd, dropout, bias, is_causal):
        super().__init__()
        self.ln_1 = LayerNorm(n_embd, bias=bias)
        self.attn = SelfAttention(n_head, n_embd, dropout, bias, is_causal)
        self.ln_2 = LayerNorm(n_embd, bias=bias)
        self.mlp = MLP(n_embd, dropout, bias)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GPT(nn.Module):
    def __init__(
        self,
        data_adapters: dict,
        vocab_size: int,
        n_layer: int = 12,
        n_head: int = 12,
        n_embd: int = 768,
        dropout: float = 0.0,
        bias: bool = True,
        skip: bool = False,
        is_causal: bool = False,
    ):
        super().__init__()
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd

        self.input_adapter = data_adapters["input_adapter"]
        self.output_adapter = data_adapters["output_adapter"]
        self.transformer = nn.ModuleDict(
            dict(
                drop=nn.Dropout(dropout),
                h=nn.ModuleList([Block(n_head, n_embd, dropout, bias, is_causal) for _ in range(n_layer)]),
                ln_f=LayerNorm(n_embd, bias=bias),
            )
        )
        self.is_causal = is_causal
        if self.is_causal:
            self.skip = False
        else:
            self.skip = skip
        if skip:
            self.lm_head = nn.Linear(2 * n_embd, vocab_size, bias=bias)
        else:
            self.lm_head = nn.Linear(n_embd, vocab_size, bias=bias)

        # init all weights
        self.apply(self._init_weights)

        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layer))

        # report number of parameters
        print(f"number of parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6:.2f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        x_in = self.input_adapter(data, t)
        x = self.transformer.drop(x_in)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        if self.skip:
            x = torch.cat([x, x_in], -1)
        logits = self.output_adapter(self.lm_head(x)) if self.output_adapter else self.lm_head(x)
        return logits

    def get_optim_groups(self, weight_decay: float):
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear,)
        blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = "%s.%s" % (mn, pn) if mn else pn  # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to...
                if pn.endswith("bias"):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # We don't use weight tying so comment this out
        # decay.remove('lm_head.weight')

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
        assert (
            len(param_dict.keys() - union_params) == 0
        ), "parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)

        # create the pytorch optimizer groups
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        return optim_groups


================================================
FILE: networks/unet_improved.py
================================================
# Source: https://github.com/openai/improved-diffusion
#
# MIT License
#
# Copyright (c) 2021 OpenAI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Modifications:
# - Added data_adapters to UNetModel to preprocess the inputs and postprocess the outputs
# - Added the `skip` option to concat the input and output of the network before the final projection
# - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps`

from abc import abstractmethod

import math

import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F

from utils_model import sandwich

from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

"""
Helpers to train with 16-bit precision.
"""


def convert_module_to_f16(module):
    """
    Convert primitive modules to float16.
    """
    if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        module.weight.data = module.weight.data.half()
        module.bias.data = module.bias.data.half()


def convert_module_to_f32(module):
    """
    Convert primitive modules to float32, undoing convert_module_to_f16().
    """
    if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        module.weight.data = module.weight.data.float()
        module.bias.data = module.bias.data.float()


def make_master_params(model_params):
    """
    Copy model parameters into a (differently-shaped) list of full-precision
    parameters.
    """
    master_params = _flatten_dense_tensors([param.detach().float() for param in model_params])
    master_params = nn.Parameter(master_params)
    master_params.requires_grad = True
    return [master_params]


def model_grads_to_master_grads(model_params, master_params):
    """
    Copy the gradients from the model parameters into the master parameters
    from make_master_params().
    """
    master_params[0].grad = _flatten_dense_tensors([param.grad.data.detach().float() for param in model_params])


def master_params_to_model_params(model_params, master_params):
    """
    Copy the master parameter data back into the model parameters.
    """
    # Without copying to a list, if a generator is passed, this will
    # silently not copy any parameters.
    model_params = list(model_params)

    for param, master_param in zip(model_params, unflatten_master_params(model_params, master_params)):
        param.detach().copy_(master_param)


def unflatten_master_params(model_params, master_params):
    """
    Unflatten the master parameters to look like model_params.
    """
    return _unflatten_dense_tensors(master_params[0].detach(), model_params)


def zero_grad(model_params):
    for param in model_params:
        # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
        if param.grad is not None:
            param.grad.detach_()
            param.grad.zero_()


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
    def forward(self, x):
        return x * th.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)


def conv_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D convolution module.
    """
    if dims == 1:
        return nn.Conv1d(*args, **kwargs)
    elif dims == 2:
        return nn.Conv2d(*args, **kwargs)
    elif dims == 3:
        return nn.Conv3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
    """
    Create a linear module.
    """
    return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
    """
    Create a 1D, 2D, or 3D average pooling module.
    """
    if dims == 1:
        return nn.AvgPool1d(*args, **kwargs)
    elif dims == 2:
        return nn.AvgPool2d(*args, **kwargs)
    elif dims == 3:
        return nn.AvgPool3d(*args, **kwargs)
    raise ValueError(f"unsupported dimensions: {dims}")


def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def scale_module(module, scale):
    """
    Scale the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().mul_(scale)
    return module


def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))


def normalization(channels):
    """
    Make a standard normalization layer.

    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)


def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(
        device=timesteps.device
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.

    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(th.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        with th.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with th.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = th.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads


class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x


class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 upsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2):
        super().__init__()
        self.channels = channels
        self.use_conv = use_conv
        self.dims = dims
        if use_conv:
            self.conv = conv_nd(dims, channels, channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
        else:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x


class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.

    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
                 downsampling occurs in the inner-two dimensions.
    """

    def __init__(self, channels, use_conv, dims=2):
        super().__init__()
        self.channels = channels
        self.use_conv = use_conv
        self.dims = dims
        stride = 2 if dims != 3 else (1, 2, 2)
        if use_conv:
            self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1)
        else:
            self.op = avg_pool_nd(stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)


class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.

    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )
        self.emb_layers = nn.Sequential(
            SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            SiLU(),
            nn.Dropout(p=dropout),
            zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
        else:
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.

        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)

    def _forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h


class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.

    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(self, channels, num_heads=1, use_checkpoint=False):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.use_checkpoint = use_checkpoint

        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        self.attention = QKVAttention()
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)


class QKVAttention(nn.Module):
    """
    A module which performs QKV attention.
    """

    def forward(self, qkv):
        """
        Apply QKV attention.

        :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x C x T] tensor after attention.
        """
        ch = qkv.shape[1] // 3
        q, k, v = th.split(qkv, ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum("bct,bcs->bts", q * scale, k * scale)  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        return th.einsum("bts,bcs->bct", weight, v)

    @staticmethod
    def count_flops(model, _x, y):
        """
        A counter for the `thop` package to count the operations in an
        attention operation.

        Meant to be used like:

            macs, params = thop.profile(
                model,
                inputs=(inputs, timestamps),
                custom_ops={QKVAttention: QKVAttention.count_flops},
            )

        """
        b, c, *spatial = y[0].shape
        num_spatial = int(np.prod(spatial))
        # We perform two matmuls with the same number of ops.
        # The first computes the weight matrix, the second computes
        # the combination of the value vectors.
        matmul_ops = 2 * b * (num_spatial**2) * c
        model.total_ops += th.DoubleTensor([matmul_ops])


class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.

    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    """

    def __init__(
        self,
        data_adapters,
        image_size=32,
        in_channels=3,
        model_channels=128,
        out_channels=128,
        num_res_blocks=3,
        attention_resolutions=[8, 16],
        dropout=0,
        channel_mult=(1, 2, 2, 2),
        conv_resample=True,
        dims=2,
        skip=True,
        num_classes=None,
        use_checkpoint=False,
        num_heads=4,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        project_input=False,
    ):
        super().__init__()
        self.input_adapter = data_adapters["input_adapter"]
        self.output_adapter = data_adapters["output_adapter"]

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.num_heads = num_heads
        self.num_heads_upsample = num_heads_upsample
        self.skip = skip
        self.project_input = project_input
        if project_input:
            self.input_projection = nn.Linear(self.in_channels, self.model_channels)
            in_channels = self.model_channels

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)

        self.input_blocks = nn.ModuleList(
            [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
        )
        input_block_chans = [model_channels]
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = mult * model_channels
                if ds in attention_resolutions:
                    layers.append(AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads))
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                self.input_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)))
                input_block_chans.append(ch)
                ds *= 2

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [
                    ResBlock(
                        ch + input_block_chans.pop(),
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                        )
                    )
                if level and i == num_res_blocks:
                    layers.append(Upsample(ch, conv_resample, dims=dims))
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))

        self.out = nn.Sequential(
            normalization(ch),
            SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    @property
    def inner_dtype(self):
        """
        Get the dtype used by the torso of the model.
        """
        return next(self.input_blocks.parameters()).dtype

    def forward(
        self,
        data: th.Tensor,
        t: th.Tensor,
    ) -> th.Tensor:
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        y = None
        flat_x = self.input_adapter(data, t)
        x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.in_channels)
        if self.project_input:
            x = self.input_projection(x)
        x_perm = x.permute(0, 3, 1, 2).contiguous()
        timesteps = t.flatten(start_dim=1)[:, 0] * 4000
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        h = x_perm.type(self.inner_dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)
        for module in self.output_blocks:
            cat_in = th.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)
        h = h.type(x.dtype)
        out = sandwich(self.out(h).permute(0, 2, 3, 1).contiguous())
        if self.skip:
            out = th.cat([sandwich(x), out], -1)
        out = self.output_adapter(out)
        return out

    def get_feature_vectors(self, x, timesteps, y=None):
        """
        Apply the model and return all of the intermediate tensors.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: a dict with the following keys:
                 - 'down': a list of hidden state tensors from downsampling.
                 - 'middle': the tensor of the output of the lowest-resolution
                             block in the model.
                 - 'up': a list of hidden state tensors from upsampling.
        """
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)
        result = dict(down=[], up=[])
        h = x.type(self.inner_dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
            result["down"].append(h.type(x.dtype))
        h = self.middle_block(h, emb)
        result["middle"] = h.type(x.dtype)
        for module in self.output_blocks:
            cat_in = th.cat([h, hs.pop()], dim=1)
            h = module(cat_in, emb)
            result["up"].append(h.type(x.dtype))
        return result


================================================
FILE: networks/unet_vdm.py
================================================
# Source: https://github.com/addtt/variational-diffusion-models
#
# MIT License
#
# Copyright (c) 2022 Andrea Dittadi
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Modifications:
# - Added data_adapters to UNetVDM to preprocess the inputs and postprocess the outputs
# - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps`
# - Added 1/1000 to t before computing timesteps embeddings so t isn't 0
# - Added concatenation of input and output of the network before the final projection

import numpy as np
import torch
from torch import einsum, nn, pi, softmax

from utils_model import sandwich


@torch.no_grad()
def zero_init(module: nn.Module) -> nn.Module:
    """Sets to zero all the parameters of a module, and returns the module."""
    for p in module.parameters():
        nn.init.zeros_(p.data)
    return module


class UNetVDM(nn.Module):
    def __init__(
        self,
        data_adapters,
        embedding_dim: int = 128,
        n_blocks: int = 32,
        n_attention_heads: int = 1,
        dropout_prob: float = 0.1,
        norm_groups: int = 32,
        input_channels: int = 3,
        use_fourier_features: bool = True,
        attention_everywhere: bool = False,
        image_size: int = 32,
    ):
        super().__init__()
        self.input_adapter = data_adapters["input_adapter"]
        self.output_adapter = data_adapters["output_adapter"]
        attention_params = dict(
            n_heads=n_attention_heads,
            n_channels=embedding_dim,
            norm_groups=norm_groups,
        )
        resnet_params = dict(
            ch_in=embedding_dim,
            ch_out=embedding_dim,
            condition_dim=4 * embedding_dim,
            dropout_prob=dropout_prob,
            norm_groups=norm_groups,
        )
        if use_fourier_features:
            self.fourier_features = FourierFeatures()
        self.embed_conditioning = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.SiLU(),
            nn.Linear(embedding_dim * 4, embedding_dim * 4),
            nn.SiLU(),
        )
        total_input_ch = input_channels
        if use_fourier_features:
            total_input_ch *= 1 + self.fourier_features.num_features
        self.conv_in = nn.Conv2d(total_input_ch, embedding_dim, 3, padding=1)

        # Down path: n_blocks blocks with a resnet block and maybe attention.
        self.down_blocks = nn.ModuleList(
            UpDownBlock(
                resnet_block=ResnetBlock(**resnet_params),
                attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,
            )
            for _ in range(n_blocks)
        )

        self.mid_resnet_block_1 = ResnetBlock(**resnet_params)
        self.mid_attn_block = AttentionBlock(**attention_params)
        self.mid_resnet_block_2 = ResnetBlock(**resnet_params)

        # Up path: n_blocks+1 blocks with a resnet block and maybe attention.
        resnet_params["ch_in"] *= 2  # double input channels due to skip connections
        self.up_blocks = nn.ModuleList(
            UpDownBlock(
                resnet_block=ResnetBlock(**resnet_params),
                attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,
            )
            for _ in range(n_blocks + 1)
        )

        self.conv_out = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=embedding_dim),
            nn.SiLU(),
            zero_init(nn.Conv2d(embedding_dim, embedding_dim, 3, padding=1)),
        )
        self.embedding_dim = embedding_dim
        self.input_channels = input_channels
        self.image_size = image_size
        self.use_fourier_features = use_fourier_features

    def forward(
        self,
        data: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        flat_x = self.input_adapter(data, t)
        x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.input_channels)
        x_perm = x.permute(0, 3, 1, 2).contiguous()
        t = t.float().flatten(start_dim=1)[:, 0]
        t_embedding = get_timestep_embedding(t + 0.001, self.embedding_dim)
        # We will condition on time embedding.
        cond = self.embed_conditioning(t_embedding)

        h = self.maybe_concat_fourier(x_perm)
        h = self.conv_in(h)  # (B, embedding_dim, H, W)
        hs = []
        for down_block in self.down_blocks:  # n_blocks times
            hs.append(h)
            h = down_block(h, cond)
        hs.append(h)
        h = self.mid_resnet_block_1(h, cond)
        h = self.mid_attn_block(h)
        h = self.mid_resnet_block_2(h, cond)
        for up_block in self.up_blocks:  # n_blocks+1 times
            h = torch.cat([h, hs.pop()], dim=1)
            h = up_block(h, cond)
        out = sandwich(self.conv_out(h).permute(0, 2, 3, 1).contiguous())
        out = torch.cat([sandwich(x), out], -1)
        out = self.output_adapter(out)
        return out

    def maybe_concat_fourier(self, z):
        if self.use_fourier_features:
            return torch.cat([z, self.fourier_features(z)], dim=1)
        return z


class ResnetBlock(nn.Module):
    def __init__(
        self,
        ch_in,
        ch_out=None,
        condition_dim=None,
        dropout_prob=0.0,
        norm_groups=32,
    ):
        super().__init__()
        ch_out = ch_in if ch_out is None else ch_out
        self.ch_out = ch_out
        self.condition_dim = condition_dim
        self.net1 = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in),
            nn.SiLU(),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
        )
        if condition_dim is not None:
            self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False))
        self.net2 = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out),
            nn.SiLU(),
            nn.Dropout(dropout_prob),
            zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)),
        )
        if ch_in != ch_out:
            self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)

    def forward(self, x, condition):
        h = self.net1(x)
        if condition is not None:
            assert condition.shape == (x.shape[0], self.condition_dim)
            condition = self.cond_proj(condition)
            condition = condition[:, :, None, None]
            h = h + condition
        h = self.net2(h)
        if x.shape[1] != self.ch_out:
            x = self.skip_conv(x)
        assert x.shape == h.shape
        return x + h


def get_timestep_embedding(
    timesteps,
    embedding_dim: int,
    dtype=torch.float32,
    max_timescale=10_000,
    min_timescale=1,
):
    # Adapted from tensor2tensor and VDM codebase.
    assert timesteps.ndim == 1
    assert embedding_dim % 2 == 0
    timesteps *= 1000.0  # In DDPM the time step is in [0, 1000], here [0, 1]
    num_timescales = embedding_dim // 2
    inv_timescales = torch.logspace(  # or exp(-linspace(log(min), log(max), n))
        -np.log10(min_timescale),
        -np.log10(max_timescale),
        num_timescales,
        device=timesteps.device,
    )
    emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :]  # (T, D/2)
    return torch.cat([emb.sin(), emb.cos()], dim=1)  # (T, D)


class FourierFeatures(nn.Module):
    def __init__(self, first=5.0, last=6.0, step=1.0):
        super().__init__()
        self.freqs_exponent = torch.arange(first, last + 1e-8, step)

    @property
    def num_features(self):
        return len(self.freqs_exponent) * 2

    def forward(self, x):
        assert len(x.shape) >= 2

        # Compute (2pi * 2^n) for n in freqs.
        freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device)  # (F, )
        freqs = 2.0**freqs_exponent * 2 * pi  # (F, )
        freqs = freqs.view(-1, *([1] * (x.dim() - 1)))  # (F, 1, 1, ...)

        # Compute (2pi * 2^n * x) for n in freqs.
        features = freqs * x.unsqueeze(1)  # (B, F, X1, X2, ...)
        features = features.flatten(1, 2)  # (B, F * C, X1, X2, ...)

        # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W).
        return torch.cat([features.sin(), features.cos()], dim=1)


def attention_inner_heads(qkv, num_heads):
    """Computes attention with heads inside of qkv in the channel dimension.

    Args:
        qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:
            H = number of heads,
            C = number of channels per head.
        num_heads: number of heads.

    Returns:
        Attention output of shape (B, H*C, T).
    """

    bs, width, length = qkv.shape
    ch = width // (3 * num_heads)

    # Split into (q, k, v) of shape (B, H*C, T).
    q, k, v = qkv.chunk(3, dim=1)

    # Rescale q and k. This makes them contiguous in memory.
    scale = ch ** (-1 / 4)  # scale with 4th root = scaling output by sqrt
    q = q * scale
    k = k * scale

    # Reshape qkv to (B*H, C, T).
    new_shape = (bs * num_heads, ch, length)
    q = q.view(*new_shape)
    k = k.view(*new_shape)
    v = v.reshape(*new_shape)

    # Compute attention.
    weight = einsum("bct,bcs->bts", q, k)  # (B*H, T, T)
    weight = softmax(weight.float(), dim=-1).to(weight.dtype)  # (B*H, T, T)
    out = einsum("bts,bcs->bct", weight, v)  # (B*H, C, T)
    return out.reshape(bs, num_heads * ch, length)  # (B, H*C, T)


class Attention(nn.Module):
    """Based on https://github.com/openai/guided-diffusion."""

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        assert qkv.dim() >= 3, qkv.dim()
        assert qkv.shape[1] % (3 * self.n_heads) == 0
        spatial_dims = qkv.shape[2:]
        qkv = qkv.view(*qkv.shape[:2], -1)  # (B, 3*H*C, T)
        out = attention_inner_heads(qkv, self.n_heads)  # (B, H*C, T)
        return out.view(*out.shape[:2], *spatial_dims).contiguous()


class AttentionBlock(nn.Module):
    """Self-attention residual block."""

    def __init__(self, n_heads, n_channels, norm_groups):
        super().__init__()
        assert n_channels % n_heads == 0
        self.layers = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),
            nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1),  # (B, 3 * C, H, W)
            Attention(n_heads),
            zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),
        )

    def forward(self, x):
        return self.layers(x) + x


class UpDownBlock(nn.Module):
    def __init__(self, resnet_block, attention_block=None):
        super().__init__()
        self.resnet_block = resnet_block
        self.attention_block = attention_block

    def forward(self, x, cond):
        x = self.resnet_block(x, cond)
        if self.attention_block is not None:
            x = self.attention_block(x)
        return x


================================================
FILE: probability.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import functools
from abc import abstractmethod

from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical as torch_Categorical
from torch.distributions.bernoulli import Bernoulli as torch_Bernoulli
from torch.distributions.mixture_same_family import MixtureSameFamily
from torch.distributions.uniform import Uniform

from math import log

from utils_model import (
    safe_exp,
    safe_log,
    idx_to_float,
    float_to_idx,
    quantize, sandwich,
)


class CtsDistribution:
    @abstractmethod
    def log_prob(self, x):
        pass

    @abstractmethod
    def sample(self):
        pass


class DiscreteDistribution:
    @property
    @abstractmethod
    def probs(self):
        pass

    @functools.cached_property
    def log_probs(self):
        return safe_log(self.probs)

    @functools.cached_property
    def mean(self):
        pass

    @functools.cached_property
    def mode(self):
        pass

    @abstractmethod
    def log_prob(self, x):
        pass

    @abstractmethod
    def sample(self):
        pass


class DiscretizedDistribution(DiscreteDistribution):
    def __init__(self, num_bins, device):
        self.num_bins = num_bins
        self.bin_width = 2.0 / num_bins
        self.half_bin_width = self.bin_width / 2.0
        self.device = device

    @functools.cached_property
    def class_centres(self):
        return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device)

    @functools.cached_property
    def class_boundaries(self):
        return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device)

    @functools.cached_property
    def mean(self):
        return (self.probs * self.class_centres).sum(-1)

    @functools.cached_property
    def mode(self):
        mode_idx = self.probs.argmax(-1).flatten()
        return self.class_centres[mode_idx].reshape(self.probs.shape[:-1])


class DiscretizedCtsDistribution(DiscretizedDistribution):
    def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5):
        super().__init__(num_bins, device)
        self.cts_dist = cts_dist
        self.log_bin_width = log(self.bin_width)
        self.batch_dims = batch_dims
        self.clip = clip
        self.min_prob = min_prob

    @functools.cached_property
    def probs(self):
        bdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims)))
        bdry_slice = bdry_cdfs[:1]
        if self.clip:
            cdf_min = torch.zeros_like(bdry_slice)
            cdf_max = torch.ones_like(bdry_slice)
            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)
            return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1)
        else:
            cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice))
            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)
            cdf_range = cdf_max - cdf_min
            cdf_mask = cdf_range < self.min_prob
            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)
            probs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range
            probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs)
            return probs.moveaxis(0, -1)

    def prob(self, x):
        class_idx = float_to_idx(x, self.num_bins)
        centre = idx_to_float(class_idx, self.num_bins)
        cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width)
        cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width)
        if self.clip:
            cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo)
            cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi)
            return cdf_hi - cdf_lo
        else:
            cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(centre))
            cdf_range = cdf_max - cdf_min
            cdf_mask = cdf_range < self.min_prob
            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)
            prob = (cdf_hi - cdf_lo) / cdf_range
            return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob)

    def log_prob(self, x):
        prob = self.prob(x)
        return torch.where(
            prob < self.min_prob,
            self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width,
            safe_log(prob),
        )

    def sample(self, sample_shape=torch.Size([])):
        if self.clip:
            return quantize(self.cts_dist.sample(sample_shape), self.num_bins)
        else:
            assert hasattr(self.cts_dist, "icdf")
            cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min))
            u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape)
            cts_samp = self.cts_dist.icdf(u)
            return quantize(cts_samp, self.num_bins)


class GMM(MixtureSameFamily):
    def __init__(self, mix_wt_logits, means, std_devs):
        mix_wts = torch_Categorical(logits=mix_wt_logits, validate_args=False)
        components = Normal(means, std_devs, validate_args=False)
        super().__init__(mix_wts, components, validate_args=False)


class DiscretizedGMM(DiscretizedCtsDistribution):
    def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        assert params.size(-1) % 3 == 0
        if min_std_dev < 0:
            min_std_dev = 1.0 / (num_bins * 5)
        mix_wt_logits, means, std_devs = params.chunk(3, -1)
        if log_dev:
            std_devs = safe_exp(std_devs)
        std_devs = std_devs.clamp(min=min_std_dev, max=max_std_dev)
        super().__init__(
            cts_dist=GMM(mix_wt_logits, means, std_devs),
            num_bins=num_bins,
            device=params.device,
            batch_dims=params.ndim - 1,
            clip=clip,
            min_prob=min_prob,
        )


class DiscretizedNormal(DiscretizedCtsDistribution):
    def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        assert params.size(-1) == 2
        if min_std_dev < 0:
            min_std_dev = 1.0 / (num_bins * 5)
        mean, std_dev = params.split(1, -1)[:2]
        if log_dev:
            std_dev = safe_exp(std_dev)
        std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev)
        super().__init__(
            cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False),
            num_bins=num_bins,
            device=params.device,
            batch_dims=params.ndim - 1,
            clip=clip,
            min_prob=min_prob,
        )


class Bernoulli(DiscreteDistribution):
    def __init__(self, logits):
        self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)

    @functools.cached_property
    def probs(self):
        p = self.bernoulli.probs.unsqueeze(-1)
        return torch.cat([1 - p, p], -1)

    @functools.cached_property
    def mode(self):
        return self.bernoulli.mode

    def log_prob(self, x):
        return self.bernoulli.log_prob(x.float())

    def sample(self, sample_shape=torch.Size([])):
        return self.bernoulli.sample(sample_shape)


class DiscretizedBernoulli(DiscretizedDistribution):
    def __init__(self, logits):
        super().__init__(2, logits.device)
        self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)

    @functools.cached_property
    def probs(self):
        p = self.bernoulli.probs.unsqueeze(-1)
        return torch.cat([1 - p, p], -1)

    @functools.cached_property
    def mode(self):
        return idx_to_float(self.bernoulli.mode, 2)

    def log_prob(self, x):
        return self.bernoulli.log_prob(float_to_idx(x, 2).float())

    def sample(self, sample_shape=torch.Size([])):
        return idx_to_float(self.bernoulli.sample(sample_shape), 2)


class DeltaDistribution(CtsDistribution):
    def __init__(self, mean, clip_range=1.0):
        if clip_range > 0:
            mean = mean.clip(min=-clip_range, max=clip_range)
        self.mean = mean

    @functools.cached_property
    def mode(self):
        return self.mean

    @functools.cached_property
    def mean(self):
        return self.mean

    def sample(self, sample_shape=torch.Size([])):
        return self.mean


class Categorical(DiscreteDistribution):
    def __init__(self, logits):
        self.categorical = torch_Categorical(logits=logits, validate_args=False)
        self.n_classes = logits.size(-1)

    @functools.cached_property
    def probs(self):
        return self.categorical.probs

    @functools.cached_property
    def mode(self):
        return self.categorical.mode

    def log_prob(self, x):
        return self.categorical.log_prob(x)

    def sample(self, sample_shape=torch.Size([])):
        return self.categorical.sample(sample_shape)


class DiscretizedCategorical(DiscretizedDistribution):
    def __init__(self, logits=None, probs=None):
        assert (logits is not None) or (probs is not None)
        if logits is not None:
            super().__init__(logits.size(-1), logits.device)
            self.categorical = torch_Categorical(logits=logits, validate_args=False)
        else:
            super().__init__(probs.size(-1), probs.device)
            self.categorical = torch_Categorical(probs=probs, validate_args=False)

    @functools.cached_property
    def probs(self):
        return self.categorical.probs

    @functools.cached_property
    def mode(self):
        return idx_to_float(self.categorical.mode, self.num_bins)

    def log_prob(self, x):
        return self.categorical.log_prob(float_to_idx(x, self.num_bins))

    def sample(self, sample_shape=torch.Size([])):
        return idx_to_float(self.categorical.sample(sample_shape), self.num_bins)


class CtsDistributionFactory:
    @abstractmethod
    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution:
        """Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory."""
        pass


class GMMFactory(CtsDistributionFactory):
    def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True):
        self.min_std_dev = min_std_dev
        self.max_std_dev = max_std_dev
        self.log_dev = log_dev

    def get_dist(self, params, input_params=None, t=None):
        mix_wt_logits, means, std_devs = params.chunk(3, -1)
        if self.log_dev:
            std_devs = safe_exp(std_devs)
        std_devs = std_devs.clamp(min=self.min_std_dev, max=self.max_std_dev)
        return GMM(mix_wt_logits, means, std_devs)


class NormalFactory(CtsDistributionFactory):
    def __init__(self, min_std_dev=1e-3, max_std_dev=10):
        self.min_std_dev = min_std_dev
        self.max_std_dev = max_std_dev

    def get_dist(self, params, input_params=None, t=None):
        mean, log_std_dev = params.split(1, -1)[:2]
        std_dev = safe_exp(log_std_dev).clamp(min=self.min_std_dev, max=self.max_std_dev)
        return Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False)


class DeltaFactory(CtsDistributionFactory):
    def __init__(self, clip_range=1.0):
        self.clip_range = clip_range

    def get_dist(self, params, input_params=None, t=None):
        return DeltaDistribution(params.squeeze(-1), self.clip_range)


class DiscreteDistributionFactory:
    @abstractmethod
    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution:
        """Note: input_params and t are only required by PredDistToDataDistFactory."""
        pass


class BernoulliFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return Bernoulli(logits=params.squeeze(-1))


class CategoricalFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return Categorical(logits=params)


class DiscretizedBernoulliFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return DiscretizedBernoulli(logits=params.squeeze(-1))


class DiscretizedCategoricalFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return DiscretizedCategorical(logits=params)


class DiscretizedGMMFactory(DiscreteDistributionFactory):
    def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        self.num_bins = num_bins
        self.clip = clip
        self.min_std_dev = min_std_dev
        self.max_std_dev = max_std_dev
        self.min_prob = min_prob
        self.log_dev = log_dev

    def get_dist(self, params, input_params=None, t=None):
        return DiscretizedGMM(
            params,
            num_bins=self.num_bins,
            clip=self.clip,
            min_std_dev=self.min_std_dev,
            max_std_dev=self.max_std_dev,
            min_prob=self.min_prob,
            log_dev=self.log_dev,
        )


class DiscretizedNormalFactory(DiscreteDistributionFactory):
    def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        self.num_bins = num_bins
        self.clip = clip
        self.min_std_dev = min_std_dev
        self.max_std_dev = max_std_dev
        self.min_prob = min_prob
        self.log_dev = log_dev

    def get_dist(self, params, input_params=None, t=None):
        return DiscretizedNormal(
            params,
            num_bins=self.num_bins,
            clip=self.clip,
            min_std_dev=self.min_std_dev,
            max_std_dev=self.max_std_dev,
            min_prob=self.min_prob,
            log_dev=self.log_dev,
        )


def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tensor, input_mean: torch.Tensor, t: torch.Tensor, min_variance: float, min_t=1e-6):
    """Convert output parameters that predict the noise added to data, to parameters that predict the data."""
    data_shape = list(noise_pred_params.shape)[:-1]
    noise_pred_params = sandwich(noise_pred_params)
    input_mean = input_mean.flatten(start_dim=1)
    if torch.is_tensor(t):
        t = t.flatten(start_dim=1)
    else:
        t = (input_mean * 0) + t
    alpha_mask = (t < min_t).unsqueeze(-1)
    posterior_var = torch.pow(min_variance, t.clamp(min=min_t))
    gamma = 1 - posterior_var
    A = (input_mean / gamma).unsqueeze(-1)
    B = (posterior_var / gamma).sqrt().unsqueeze(-1)
    data_pred_params = []
    if noise_pred_params.size(-1) == 1:
        noise_pred_mean = noise_pred_params
    elif noise_pred_params.size(-1) == 2:
        noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1)
    else:
        assert noise_pred_params.size(-1) % 3 == 0
        mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1)
        data_pred_params.append(mix_wt_logits)
    data_pred_mean = A - (B * noise_pred_mean)
    data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean)
    data_pred_params.append(data_pred_mean)
    if noise_pred_params.size(-1) >= 2:
        noise_pred_dev = safe_exp(noise_pred_log_dev)
        data_pred_dev = B * noise_pred_dev
        data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev)
        data_pred_params.append(data_pred_dev)
    data_pred_params = torch.cat(data_pred_params, -1)
    data_pred_params = data_pred_params.reshape(data_shape + [-1])
    return data_pred_params


class PredDistToDataDistFactory(DiscreteDistributionFactory):
    def __init__(self, data_dist_factory, min_variance, min_t=1e-6):
        self.data_dist_factory = data_dist_factory
        self.data_dist_factory.log_dev = False
        self.min_variance = min_variance
        self.min_t = min_t

    def get_dist(self, params, input_params, t):
        data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t)
        return self.data_dist_factory.get_dist(data_pred_params)


================================================
FILE: sample.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from omegaconf import OmegaConf, DictConfig

from utils_train import seed_everything, make_config, make_bfn

torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True


def main(cfg: DictConfig) -> torch.Tensor:
    """
    Config entries:
        seed (int): Optional
        config_file (str): Name of config file containing model and data config for a saved checkpoint
        load_model (str): Path to a saved checkpoint to be tested
        sample_shape (list): Shape of sample batch, e.g.:
            (3, 256) for sampling 3 sequences of length 256 from the text8 model.
            (2, 32, 32, 3) for sampling 2 images from the CIFAR10 model.
            (4, 28, 28, 1) for sampling 4 images from the MNIST model.
        n_steps (int): Number of sampling steps (positive integer).
        save_file (str): File path to save the generated sample tensor. Skip saving if None.
    """
    seed_everything(cfg.seed)
    print(f"Seeded everything with seed {cfg.seed}")

    # Get model config from the training config file
    train_cfg = make_config(cfg.config_file)
    bfn = make_bfn(train_cfg.model)

    bfn.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu"))
    if torch.cuda.is_available():
        bfn.to("cuda")
    samples = bfn.sample(cfg.samples_shape, cfg.n_steps)

    if cfg.save_file is not None:
        torch.save(samples.to("cpu"), cfg.save_file)

    return samples


if __name__ == "__main__":
    main(OmegaConf.from_cli())


================================================
FILE: test.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Tuple

import torch
from omegaconf import OmegaConf, DictConfig
from rich import print
from torch import nn
from torch.utils.data import DataLoader

from data import make_datasets
from model import BFN
from utils_train import seed_everything, make_config, make_bfn, worker_init_function, make_progress_bar

torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True


def setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]:
    test_ds = make_datasets(cfg.data)[-1]
    test_dl = DataLoader(
        dataset=test_ds,
        worker_init_fn=worker_init_function,
        batch_size=100,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    model = make_bfn(cfg.model)
    return model, test_dl


@torch.inference_mode()
def test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: int) -> tuple[float, float, float, float]:
    if torch.cuda.is_available():
        model.to("cuda")
    model.eval()
    losses, recon_losses = [], []
    pbar = make_progress_bar(True, "[red]loss: {task.fields[loss]:.4f} repeat: {task.fields[r]}")
    with pbar:
        task_id = pbar.add_task("Test", visible=True, total=n_repeats * len(dataloader), loss=math.nan, r=0)
        for r in range(n_repeats):
            _losses, _recon_losses = [], []
            for eval_batch in dataloader:
                eval_batch = eval_batch.to("cuda") if torch.cuda.is_available() else eval_batch
                loss = model(eval_batch, n_steps=n_steps).item()
                recon_loss = model.compute_reconstruction_loss(eval_batch).item()
                _losses.append(loss)
                _recon_losses.append(recon_loss)
                pbar.update(task_id, advance=1, loss=torch.tensor(_losses).mean() + torch.tensor(_recon_losses).mean(), r=r+1)
            losses.append(torch.tensor(_losses).mean())
            recon_losses.append(torch.tensor(_recon_losses).mean())
    losses = torch.stack(losses)
    loss_mean, loss_err = losses.mean(), losses.std(correction=0).item() / math.sqrt(len(losses))
    recon_losses = torch.stack(recon_losses)
    recon_mean, recon_err = recon_losses.mean(), recon_losses.std(correction=0).item() / math.sqrt(len(recon_losses))
    return loss_mean, loss_err, recon_mean, recon_err


def main(cfg: DictConfig) -> tuple[float, float, float, float]:
    """
    Config entries:
        seed (int): Optional
        config_file (str): Name of config file containing model and data config for a saved checkpoint
        load_model (str): Path to a saved checkpoint to be tested
        n_steps (int): Number of Bayesian flow steps. Set to None for continuous time Bayesian flow loss.
        n_repeats (int): Number of times to iterate through the dataset.
    """
    seed_everything(cfg.seed)
    print(f"Seeded everything with seed {cfg.seed}")

    # Get model and data config from the training config file
    train_cfg = make_config(cfg.config_file)
    model, dataloader = setup(train_cfg)

    model.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu"))
    loss_mean, loss_err, recon_mean, recon_err = test(model, dataloader, cfg.n_steps, cfg.n_repeats)
    print(f"For {cfg.n_steps} steps with {cfg.n_repeats} repeats:")
    print(f"Loss is {loss_mean:.6f} +- {loss_err:.6f}")
    print(f"Reconstruction Loss is {recon_mean:.6f} +- {recon_err:.6f}")
    print(f"Total loss mean = {loss_mean + recon_mean}")
    return loss_mean, loss_err, recon_mean, recon_err


if __name__ == "__main__":
    main(OmegaConf.from_cli())


================================================
FILE: train.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple

import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from omegaconf import OmegaConf
from rich.logging import RichHandler
from rich.progress import Progress
from torch import nn, optim
from torch.utils.data import DataLoader

from model import BFN
from utils_train import (
    seed_everything, log_cfg,
    checkpoint_training_state,
    init_checkpointing,
    log,
    update_ema,
    ddict,
    make_infinite,
    make_progress_bar, make_config, make_dataloaders, make_bfn,
)

torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True

logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True, show_time=False)],
)

logger = get_logger(__name__)


def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:
    """Create the model, dataloader and optimizer"""
    dataloaders = make_dataloaders(cfg)
    model = make_bfn(cfg.model)
    if "weight_decay" in cfg.optimizer.keys() and hasattr(model.net, "get_optim_groups"):
        params = model.net.get_optim_groups(cfg.optimizer.weight_decay)
    else:
        params = model.net.parameters()
    # Instantiate the optimizer using the hyper-parameters in the config
    optimizer = optim.AdamW(params=params, **cfg.optimizer)
    return model, dataloaders, optimizer


@torch.no_grad()
def validate(
        cfg,
        model: BFN,
        ema_model: nn.Module,
        val_dataloader: DataLoader,
        step: int,
        run: "neptune.Run",
        pbar: Optional[Progress],
        best_val_loss: float,
        checkpoint_root_dir: Optional[Path],
        accelerator: Accelerator,
) -> float:
    """Evaluate model on validation data and save checkpoint if loss improves"""
    dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[accelerator.mixed_precision]
    model_to_eval = ema_model if ema_model is not None else model
    model_to_eval.eval()
    pbar = pbar or Progress()
    max_steps = cfg.max_val_batches if cfg.max_val_batches > 0 else len(val_dataloader)
    val_id = pbar.add_task("Validating", visible=True, total=cfg.val_repeats * max_steps, transient=True, loss=math.nan)

    loss, count = 0.0, 0
    for i in range(cfg.val_repeats):
        for idx, eval_batch in enumerate(val_dataloader):
            enabled = True if dtype in [torch.float16, torch.bfloat16] else False
            with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):
                loss += model_to_eval(eval_batch.to(accelerator.device)).item()
                count += 1
            pbar.update(val_id, advance=1, loss=loss / count)
            if (idx + 1) >= max_steps:
                break
    loss /= count
    pbar.remove_task(val_id)
    log(run["metrics"]["val"]["loss"], loss, step)

    if checkpoint_root_dir is not None and (loss < best_val_loss or math.isinf(best_val_loss)):
        logger.info(f"loss improved: new value is {loss}")
        step_checkpoint_path = checkpoint_root_dir / "best"
        run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()
        checkpoint_training_state(step_checkpoint_path, accelerator, ema_model, step, run_id)
        run["metrics/best/loss/metric"] = loss
        run["metrics/best/loss/step"] = step

    model.train()
    return loss


def train(
        cfg,
        accelerator: Accelerator,
        model: BFN,
        ema_model: Optional[nn.Module],
        dataloaders: dict,
        optimizer: optim.Optimizer,
        run: "neptune.Run",
):
    is_main = accelerator.is_main_process
    pbar = make_progress_bar(is_main)
    run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()
    train_id = pbar.add_task(f"Training {run_id}", start=cfg.start_step, total=cfg.n_training_steps, loss=math.nan)
    checkpoint_root_dir = init_checkpointing(cfg.checkpoint_dir, run_id) if is_main else None
    best_val_loss = math.inf

    train_iter = make_infinite(dataloaders["train"])
    model.train()
    with pbar:
        for step in range(cfg.start_step, cfg.n_training_steps + 1):
            step_loss = 0.0
            for _ in range(cfg.accumulate):
                with accelerator.accumulate(model):
                    train_batch = next(train_iter)

                    loss = model(train_batch)
                    accelerator.backward(loss)

                    if accelerator.sync_gradients and cfg.grad_clip_norm > 0:
                        accelerator.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)

                step_loss += loss.item()

            update_ema(ema_model, model, cfg.ema_decay)

            if is_main and (step % cfg.checkpoint_interval == 0):
                checkpoint_training_state(checkpoint_root_dir / "last", accelerator, ema_model, step, run_id)
                run["checkpoints/last"].track_files(str(checkpoint_root_dir / "last"))

            log(run["metrics"]["train"]["loss"], step_loss / cfg.accumulate, step, is_main and step % cfg.log_interval == 0)
            log(run["metrics"]["epoch"], step // len(dataloaders["train"]), step, is_main)

            if is_main and (step % cfg.val_interval == 0) and "val" in dataloaders:
                val_loss = validate(
                    cfg=cfg,
                    model=model,
                    ema_model=ema_model,
                    val_dataloader=dataloaders["val"],
                    step=step,
                    run=run,
                    pbar=pbar,
                    best_val_loss=best_val_loss,
                    checkpoint_root_dir=checkpoint_root_dir,
                    accelerator=accelerator,
                )
                best_val_loss = min(val_loss, best_val_loss)

            pbar.update(train_id, advance=1, loss=loss.item())


def main(cfg):
    acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate)

    seed_everything(cfg.training.seed)
    logger.info(f"Seeded everything with seed {cfg.training.seed}", main_process_only=True)

    with acc.main_process_first():
        model, dataloaders, optimizer = setup(cfg)
    ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None  # EMA on main proc only
    model, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"])
    run = ddict()
    if acc.is_main_process:
        ema.to(acc.device)
        try:
            if cfg.meta.neptune:
                import neptune
                run = neptune.init_run(project=cfg.meta.neptune, mode="debug" if cfg.meta.debug else None)
                run["accelerate"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes)
                log_cfg(cfg, run)
        except ImportError:
            logger.info("Did not find neptune installed. Logging will be disabled.")

    train(cfg.training, acc, model, ema, dataloaders, optimizer, run)


if __name__ == "__main__":
    cfg_file = OmegaConf.from_cli()['config_file']
    main(make_config(cfg_file))


================================================
FILE: utils_model.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import torch
from torch import Tensor

CONST_log_range = 20
CONST_log_min = 1e-10
CONST_summary_rescale = 10
CONST_exp_range = 10
CONST_min_std_dev = math.exp(-CONST_exp_range)


def sandwich(x: Tensor):
    return x.reshape(x.size(0), -1, x.size(-1))


def safe_log(data: Tensor):
    return data.clamp(min=CONST_log_min).log()


def safe_exp(data: Tensor):
    return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp()


def idx_to_float(idx: np.ndarray, num_bins: int):
    flt_zero_one = (idx + 0.5) / num_bins
    return (2.0 * flt_zero_one) - 1.0


def float_to_idx(flt: np.ndarray, num_bins: int):
    flt_zero_one = (flt / 2.0) + 0.5
    return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long()


def quantize(flt, num_bins: int):
    return idx_to_float(float_to_idx(flt, num_bins), num_bins)


def pe_encode(sequence_length: int, embedding_size: int) -> Tensor:
    """Positional encoding as described in original attention is all you need paper"""

    pe = torch.zeros((sequence_length, embedding_size))
    pos = torch.arange(sequence_length).unsqueeze(1)
    pe[:, 0::2] = torch.sin(
        pos / torch.pow(1000, torch.arange(0, embedding_size, 2, dtype=torch.float32) / embedding_size)
    )
    pe[:, 1::2] = torch.cos(
        pos / torch.pow(1000, torch.arange(1, embedding_size, 2, dtype=torch.float32) / embedding_size)
    )

    return pe


def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> Tensor:
    pe = torch.zeros(list(x.shape) + [embedding_size], device=x.device)
    pos = (((x + 1) / 2) * max_freq).unsqueeze(-1)
    pe[..., 0::2] = torch.sin(
        pos
        / torch.pow(10000, torch.arange(0, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)
    )
    pe[..., 1::2] = torch.cos(
        pos
        / torch.pow(10000, torch.arange(1, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)
    )
    return pe


================================================
FILE: utils_train.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
import random
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Optional, Generator, Union

try:
    import neptune
    from neptune.utils import stringify_unsupported
except ImportError:
    neptune = None

    def stringify_unsupported(x):
        return x


import numpy as np
import torch
from accelerate.logging import get_logger
from omegaconf import OmegaConf, DictConfig
from rich.progress import Progress, SpinnerColumn, MofNCompleteColumn, TimeElapsedColumn, TextColumn
from torch.utils.data import DataLoader

import model
import networks
import probability
from data import make_datasets
from networks import adapters

logger = get_logger(__name__)


def seed_everything(seed: Optional[int]):
    assert seed is not None
    seed += torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def worker_init_function(worker_id: int) -> None:
    """https://pytorch.org/docs/stable/notes/randomness.html#dataloader"""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: str) -> Optional[Path]:
    if checkpoint_dir is None:
        return None
    checkpoint_dir = Path(checkpoint_dir) / run_id
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    last_dir = checkpoint_dir / "last"
    last_dir.mkdir(parents=True, exist_ok=True)
    best_dir = checkpoint_dir / "best"
    best_dir.mkdir(parents=True, exist_ok=True)
    return checkpoint_dir


def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, step: int, run_id: str):
    if checkpoint_dir is None:
        return
    logger.info(f"Checkpointing training state to {checkpoint_dir} at step {step}")
    accelerator.save_state(checkpoint_dir)
    with open(checkpoint_dir / "info.json", "w") as f:
        json.dump({"step": step, "run_id": run_id}, f)
    if ema_model is not None:
        ema_checkpoint_path = checkpoint_dir / "ema_model.pt"
        torch.save(ema_model.state_dict(), ema_checkpoint_path)


def log(key_handler, value, step, cond=True):
    """Log series to neptune only if cond is True. Helps with distributed training and conditional logging."""
    if not isinstance(key_handler, defaultdict) and cond and math.isfinite(value):
        key_handler.log(value, step=step)


def log_cfg(cfg, run: "neptune.Run"):
    with tempfile.TemporaryDirectory() as tmpdir:
        cfg_temp_filename: Path = Path(tmpdir) / "cfg.yaml"
        cfg_temp_filename.write_text(OmegaConf.to_yaml(cfg, resolve=True))
        run["cfg"].upload(str(cfg_temp_filename), wait=True)
    run["hyperparameters"] = stringify_unsupported(OmegaConf.to_container(cfg, resolve=True))


@torch.no_grad()
def update_ema(ema_model, model, ema_decay):
    if ema_model is not None and ema_decay > 0:
        for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
            ema_param.sub_((1 - ema_decay) * (ema_param - model_param))


def ddict():
    """Infinite default dict to fake neptune run on non-main processes"""
    return defaultdict(ddict)


def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]:
    while True:
        for data in dataloader:
            yield data


def make_progress_bar(is_main: bool, text="[red]loss: {task.fields[loss]:.3f}"):
    return Progress(
        SpinnerColumn(),
        MofNCompleteColumn(),
        *Progress.get_default_columns(),
        TimeElapsedColumn(),
        TextColumn(text),
        disable=not is_main,
    )


def make_dataloaders(cfg: DictConfig):
    train_set, val_set, _ = make_datasets(cfg.data)
    dataloaders = {
        "train": DataLoader(
            dataset=train_set,
            worker_init_fn=worker_init_function,
            **cfg.train_loader,
        ),
        "val": DataLoader(
            dataset=val_set,
            worker_init_fn=worker_init_function,
            **cfg.val_loader,
        ),
    }
    return dataloaders


def make_from_cfg(module, cfg, **parameters):
    return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else None


def make_bfn(cfg: DictConfig):
    data_adapters = {
        "input_adapter": make_from_cfg(adapters, cfg.input_adapter),
        "output_adapter": make_from_cfg(adapters, cfg.output_adapter),
    }
    net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters)
    bayesian_flow = make_from_cfg(model, cfg.bayesian_flow)
    distribution_factory = make_from_cfg(probability, cfg.distribution_factory)
    loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory)
    bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss)
    return bfn


default_train_config = {
    "meta": {
        "neptune": None,
        "debug": False,
        "root_dir": ".",
    },
    "data": {
        "dataset": "",
        "data_dir": "./data",
    },
    "train_loader": {
        "batch_size": 1,
        "shuffle": True,
        "num_workers": 0,
        "pin_memory": True,
        "drop_last": True,
    },
    "val_loader": {
        "batch_size": 1,
        "shuffle": False,
        "num_workers": 0,
        "pin_memory": True,
        "drop_last": False,
    },
    "training": {
        "accumulate": 1,
        "checkpoint_dir": "./checkpoints",
        "checkpoint_interval": None,
        "ema_decay": -1,
        "grad_clip_norm": -1,
        "log_interval": 50,
        "max_val_batches": -1,
        "seed": 666,
        "start_step": 1,
        "val_repeats": 1,
    },
}


def make_config(cfg_file: str):
    cli_conf = OmegaConf.load(cfg_file)
    # Start with default config
    cfg = OmegaConf.create(default_train_config)
    # Merge into default config
    cfg = OmegaConf.merge(cfg, cli_conf)
    return cfg
Download .txt
gitextract__0riu0_z/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── cifar10_continuous_16bins.yaml
│   ├── cifar10_continuous_256bins.yaml
│   ├── cifar10_discretized_16bins.yaml
│   ├── cifar10_discretized_256bins.yaml
│   ├── mnist_discrete.yaml
│   └── text8_discrete.yaml
├── data.py
├── env.yml
├── model.py
├── networks/
│   ├── __init__.py
│   ├── adapters.py
│   ├── transformer.py
│   ├── unet_improved.py
│   └── unet_vdm.py
├── probability.py
├── sample.py
├── test.py
├── train.py
├── utils_model.py
└── utils_train.py
Download .txt
SYMBOL INDEX (288 symbols across 12 files)

FILE: data.py
  function bin_mnist_transform (line 37) | def bin_mnist_transform(x):
  function bin_mnist_cts_transform (line 41) | def bin_mnist_cts_transform(x):
  function rgb_image_transform (line 45) | def rgb_image_transform(x, num_bins=256):
  class MyLambda (line 49) | class MyLambda(torchvision.transforms.Lambda):
    method __init__ (line 50) | def __init__(self, lambd, arg1):
    method __call__ (line 54) | def __call__(self, x):
  class CIFAR10 (line 58) | class CIFAR10(torchvision.datasets.CIFAR10):
    method __getitem__ (line 59) | def __getitem__(self, idx):
  class MNIST (line 63) | class MNIST(torchvision.datasets.MNIST):
    method __getitem__ (line 64) | def __getitem__(self, idx):
  function make_datasets (line 68) | def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:
  function prepare_text8 (line 127) | def prepare_text8(data_dir: pathlib.Path):
  class Text8Dataset (line 185) | class Text8Dataset(Dataset):
    method __init__ (line 186) | def __init__(self, data_dir: Union[str, pathlib.Path], split: str, dow...
    method __getitem__ (line 204) | def __getitem__(self, index) -> torch.Tensor:
    method __len__ (line 208) | def __len__(self):
  function char_ids_to_str (line 212) | def char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) ...
  function batch_to_str (line 217) | def batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) ...
  function batch_to_images (line 222) | def batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt...

FILE: model.py
  class BayesianFlow (line 42) | class BayesianFlow(nn.Module, ABC):
    method __init__ (line 43) | def __init__(self):
    method get_prior_input_params (line 47) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi...
    method params_to_net_inputs (line 54) | def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:
    method get_alpha (line 59) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:
    method get_sender_dist (line 66) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap...
    method update_input_params (line 73) | def update_input_params(self, input_params: tuple[Tensor, ...], y: Ten...
    method forward (line 79) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:
  class Loss (line 87) | class Loss(nn.Module, ABC):
    method __init__ (line 88) | def __init__(self):
    method cts_time_loss (line 92) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par...
    method discrete_time_loss (line 98) | def discrete_time_loss(
    method reconstruction_loss (line 107) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp...
  class CtsBayesianFlow (line 116) | class CtsBayesianFlow(BayesianFlow):
    method __init__ (line 117) | def __init__(
    method forward (line 125) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:
    method params_to_net_inputs (line 137) | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
    method get_prior_input_params (line 140) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi...
    method get_alpha (line 143) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[floa...
    method get_sender_dist (line 147) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap...
    method update_input_params (line 151) | def update_input_params(self, input_params: tuple[Tensor, float], y: T...
  class CtsBayesianFlowLoss (line 158) | class CtsBayesianFlowLoss(Loss):
    method __init__ (line 159) | def __init__(
    method cts_time_loss (line 178) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par...
    method discrete_time_loss (line 191) | def discrete_time_loss(
    method reconstruction_loss (line 225) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp...
  class DiscreteBayesianFlow (line 250) | class DiscreteBayesianFlow(BayesianFlow):
    method __init__ (line 251) | def __init__(
    method t_to_sqrt_beta (line 267) | def t_to_sqrt_beta(self, t):
    method count_dist (line 270) | def count_dist(self, x, beta=None):
    method count_sample (line 278) | def count_sample(self, x, beta):
    method get_prior_input_params (line 282) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi...
    method params_to_net_inputs (line 286) | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
    method get_alpha (line 293) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[floa...
    method get_sender_dist (line 296) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap...
    method update_input_params (line 302) | def update_input_params(self, input_params: tuple[Tensor], y: Tensor, ...
    method forward (line 308) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:
  class DiscreteBayesianFlowLoss (line 325) | class DiscreteBayesianFlowLoss(Loss):
    method __init__ (line 326) | def __init__(
    method cts_time_loss (line 336) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par...
    method discrete_time_loss (line 348) | def discrete_time_loss(
    method reconstruction_loss (line 369) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp...
  class BFN (line 376) | class BFN(nn.Module):
    method __init__ (line 377) | def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: ...
    method sample_t (line 385) | def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:
    method forward (line 393) | def forward(
    method compute_reconstruction_loss (line 420) | def compute_reconstruction_loss(self, data: Tensor) -> Tensor:
    method sample (line 428) | def sample(self, data_shape: tuple, n_steps: int) -> Tensor:

FILE: networks/adapters.py
  class TextInputAdapter (line 25) | class TextInputAdapter(nn.Module):
    method __init__ (line 30) | def __init__(
    method forward (line 46) | def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor:
  class FourierImageInputAdapter (line 61) | class FourierImageInputAdapter(nn.Module):
    method __init__ (line 66) | def __init__(
    method forward (line 122) | def forward(self, img: Tensor, t: Tensor) -> Tensor:
  class OutputAdapter (line 150) | class OutputAdapter(nn.Module):
    method __init__ (line 151) | def __init__(self, input_height: int, output_channels: int, output_hei...
    method forward (line 159) | def forward(self, inp: torch.Tensor) -> torch.Tensor:

FILE: networks/transformer.py
  function gelu (line 37) | def gelu(x):
  class LayerNorm (line 41) | class LayerNorm(nn.Module):
    method __init__ (line 44) | def __init__(self, ndim, bias):
    method forward (line 49) | def forward(self, input):
  class SelfAttention (line 53) | class SelfAttention(nn.Module):
    method __init__ (line 54) | def __init__(self, n_head, n_embd, dropout, bias, is_causal):
    method forward (line 72) | def forward(self, x):
  class MLP (line 92) | class MLP(nn.Module):
    method __init__ (line 93) | def __init__(self, n_embd, dropout, bias):
    method forward (line 99) | def forward(self, x):
  class Block (line 107) | class Block(nn.Module):
    method __init__ (line 108) | def __init__(self, n_head, n_embd, dropout, bias, is_causal):
    method forward (line 115) | def forward(self, x):
  class GPT (line 121) | class GPT(nn.Module):
    method __init__ (line 122) | def __init__(
    method _init_weights (line 169) | def _init_weights(self, module):
    method forward (line 177) | def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    method get_optim_groups (line 188) | def get_optim_groups(self, weight_decay: float):

FILE: networks/unet_improved.py
  function convert_module_to_f16 (line 48) | def convert_module_to_f16(module):
  function convert_module_to_f32 (line 57) | def convert_module_to_f32(module):
  function make_master_params (line 66) | def make_master_params(model_params):
  function model_grads_to_master_grads (line 77) | def model_grads_to_master_grads(model_params, master_params):
  function master_params_to_model_params (line 85) | def master_params_to_model_params(model_params, master_params):
  function unflatten_master_params (line 97) | def unflatten_master_params(model_params, master_params):
  function zero_grad (line 104) | def zero_grad(model_params):
  class SiLU (line 113) | class SiLU(nn.Module):
    method forward (line 114) | def forward(self, x):
  class GroupNorm32 (line 118) | class GroupNorm32(nn.GroupNorm):
    method forward (line 119) | def forward(self, x):
  function conv_nd (line 123) | def conv_nd(dims, *args, **kwargs):
  function linear (line 136) | def linear(*args, **kwargs):
  function avg_pool_nd (line 143) | def avg_pool_nd(dims, *args, **kwargs):
  function update_ema (line 156) | def update_ema(target_params, source_params, rate=0.99):
  function zero_module (line 169) | def zero_module(module):
  function scale_module (line 178) | def scale_module(module, scale):
  function mean_flat (line 187) | def mean_flat(tensor):
  function normalization (line 194) | def normalization(channels):
  function timestep_embedding (line 204) | def timestep_embedding(timesteps, dim, max_period=10000):
  function checkpoint (line 225) | def checkpoint(func, inputs, params, flag):
  class CheckpointFunction (line 243) | class CheckpointFunction(th.autograd.Function):
    method forward (line 245) | def forward(ctx, run_function, length, *args):
    method backward (line 254) | def backward(ctx, *output_grads):
  class TimestepBlock (line 274) | class TimestepBlock(nn.Module):
    method forward (line 280) | def forward(self, x, emb):
  class TimestepEmbedSequential (line 286) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    method forward (line 292) | def forward(self, x, emb):
  class Upsample (line 301) | class Upsample(nn.Module):
    method __init__ (line 311) | def __init__(self, channels, use_conv, dims=2):
    method forward (line 319) | def forward(self, x):
  class Downsample (line 330) | class Downsample(nn.Module):
    method __init__ (line 340) | def __init__(self, channels, use_conv, dims=2):
    method forward (line 351) | def forward(self, x):
  class ResBlock (line 356) | class ResBlock(TimestepBlock):
    method __init__ (line 371) | def __init__(
    method forward (line 417) | def forward(self, x, emb):
    method _forward (line 427) | def _forward(self, x, emb):
  class AttentionBlock (line 443) | class AttentionBlock(nn.Module):
    method __init__ (line 451) | def __init__(self, channels, num_heads=1, use_checkpoint=False):
    method forward (line 462) | def forward(self, x):
    method _forward (line 465) | def _forward(self, x):
  class QKVAttention (line 476) | class QKVAttention(nn.Module):
    method forward (line 481) | def forward(self, qkv):
    method count_flops (line 496) | def count_flops(model, _x, y):
  class UNetModel (line 519) | class UNetModel(nn.Module):
    method __init__ (line 542) | def __init__(
    method convert_to_fp16 (line 682) | def convert_to_fp16(self):
    method convert_to_fp32 (line 690) | def convert_to_fp32(self):
    method inner_dtype (line 699) | def inner_dtype(self):
    method forward (line 705) | def forward(
    method get_feature_vectors (line 751) | def get_feature_vectors(self, x, timesteps, y=None):

FILE: networks/unet_vdm.py
  function zero_init (line 39) | def zero_init(module: nn.Module) -> nn.Module:
  class UNetVDM (line 46) | class UNetVDM(nn.Module):
    method __init__ (line 47) | def __init__(
    method forward (line 121) | def forward(
    method maybe_concat_fourier (line 152) | def maybe_concat_fourier(self, z):
  class ResnetBlock (line 158) | class ResnetBlock(nn.Module):
    method __init__ (line 159) | def __init__(
    method forward (line 187) | def forward(self, x, condition):
  function get_timestep_embedding (line 201) | def get_timestep_embedding(
  class FourierFeatures (line 223) | class FourierFeatures(nn.Module):
    method __init__ (line 224) | def __init__(self, first=5.0, last=6.0, step=1.0):
    method num_features (line 229) | def num_features(self):
    method forward (line 232) | def forward(self, x):
  function attention_inner_heads (line 248) | def attention_inner_heads(qkv, num_heads):
  class Attention (line 285) | class Attention(nn.Module):
    method __init__ (line 288) | def __init__(self, n_heads):
    method forward (line 292) | def forward(self, qkv):
  class AttentionBlock (line 301) | class AttentionBlock(nn.Module):
    method __init__ (line 304) | def __init__(self, n_heads, n_channels, norm_groups):
    method forward (line 314) | def forward(self, x):
  class UpDownBlock (line 318) | class UpDownBlock(nn.Module):
    method __init__ (line 319) | def __init__(self, resnet_block, attention_block=None):
    method forward (line 324) | def forward(self, x, cond):

FILE: probability.py
  class CtsDistribution (line 36) | class CtsDistribution:
    method log_prob (line 38) | def log_prob(self, x):
    method sample (line 42) | def sample(self):
  class DiscreteDistribution (line 46) | class DiscreteDistribution:
    method probs (line 49) | def probs(self):
    method log_probs (line 53) | def log_probs(self):
    method mean (line 57) | def mean(self):
    method mode (line 61) | def mode(self):
    method log_prob (line 65) | def log_prob(self, x):
    method sample (line 69) | def sample(self):
  class DiscretizedDistribution (line 73) | class DiscretizedDistribution(DiscreteDistribution):
    method __init__ (line 74) | def __init__(self, num_bins, device):
    method class_centres (line 81) | def class_centres(self):
    method class_boundaries (line 85) | def class_boundaries(self):
    method mean (line 89) | def mean(self):
    method mode (line 93) | def mode(self):
  class DiscretizedCtsDistribution (line 98) | class DiscretizedCtsDistribution(DiscretizedDistribution):
    method __init__ (line 99) | def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, ...
    method probs (line 108) | def probs(self):
    method prob (line 127) | def prob(self, x):
    method log_prob (line 145) | def log_prob(self, x):
    method sample (line 153) | def sample(self, sample_shape=torch.Size([])):
  class GMM (line 165) | class GMM(MixtureSameFamily):
    method __init__ (line 166) | def __init__(self, mix_wt_logits, means, std_devs):
  class DiscretizedGMM (line 172) | class DiscretizedGMM(DiscretizedCtsDistribution):
    method __init__ (line 173) | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max...
  class DiscretizedNormal (line 191) | class DiscretizedNormal(DiscretizedCtsDistribution):
    method __init__ (line 192) | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max...
  class Bernoulli (line 210) | class Bernoulli(DiscreteDistribution):
    method __init__ (line 211) | def __init__(self, logits):
    method probs (line 215) | def probs(self):
    method mode (line 220) | def mode(self):
    method log_prob (line 223) | def log_prob(self, x):
    method sample (line 226) | def sample(self, sample_shape=torch.Size([])):
  class DiscretizedBernoulli (line 230) | class DiscretizedBernoulli(DiscretizedDistribution):
    method __init__ (line 231) | def __init__(self, logits):
    method probs (line 236) | def probs(self):
    method mode (line 241) | def mode(self):
    method log_prob (line 244) | def log_prob(self, x):
    method sample (line 247) | def sample(self, sample_shape=torch.Size([])):
  class DeltaDistribution (line 251) | class DeltaDistribution(CtsDistribution):
    method __init__ (line 252) | def __init__(self, mean, clip_range=1.0):
    method mode (line 258) | def mode(self):
    method mean (line 262) | def mean(self):
    method sample (line 265) | def sample(self, sample_shape=torch.Size([])):
  class Categorical (line 269) | class Categorical(DiscreteDistribution):
    method __init__ (line 270) | def __init__(self, logits):
    method probs (line 275) | def probs(self):
    method mode (line 279) | def mode(self):
    method log_prob (line 282) | def log_prob(self, x):
    method sample (line 285) | def sample(self, sample_shape=torch.Size([])):
  class DiscretizedCategorical (line 289) | class DiscretizedCategorical(DiscretizedDistribution):
    method __init__ (line 290) | def __init__(self, logits=None, probs=None):
    method probs (line 300) | def probs(self):
    method mode (line 304) | def mode(self):
    method log_prob (line 307) | def log_prob(self, x):
    method sample (line 310) | def sample(self, sample_shape=torch.Size([])):
  class CtsDistributionFactory (line 314) | class CtsDistributionFactory:
    method get_dist (line 316) | def get_dist(self, params: torch.Tensor, input_params=None, t=None) ->...
  class GMMFactory (line 321) | class GMMFactory(CtsDistributionFactory):
    method __init__ (line 322) | def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True):
    method get_dist (line 327) | def get_dist(self, params, input_params=None, t=None):
  class NormalFactory (line 335) | class NormalFactory(CtsDistributionFactory):
    method __init__ (line 336) | def __init__(self, min_std_dev=1e-3, max_std_dev=10):
    method get_dist (line 340) | def get_dist(self, params, input_params=None, t=None):
  class DeltaFactory (line 346) | class DeltaFactory(CtsDistributionFactory):
    method __init__ (line 347) | def __init__(self, clip_range=1.0):
    method get_dist (line 350) | def get_dist(self, params, input_params=None, t=None):
  class DiscreteDistributionFactory (line 354) | class DiscreteDistributionFactory:
    method get_dist (line 356) | def get_dist(self, params: torch.Tensor, input_params=None, t=None) ->...
  class BernoulliFactory (line 361) | class BernoulliFactory(DiscreteDistributionFactory):
    method get_dist (line 362) | def get_dist(self, params, input_params=None, t=None):
  class CategoricalFactory (line 366) | class CategoricalFactory(DiscreteDistributionFactory):
    method get_dist (line 367) | def get_dist(self, params, input_params=None, t=None):
  class DiscretizedBernoulliFactory (line 371) | class DiscretizedBernoulliFactory(DiscreteDistributionFactory):
    method get_dist (line 372) | def get_dist(self, params, input_params=None, t=None):
  class DiscretizedCategoricalFactory (line 376) | class DiscretizedCategoricalFactory(DiscreteDistributionFactory):
    method get_dist (line 377) | def get_dist(self, params, input_params=None, t=None):
  class DiscretizedGMMFactory (line 381) | class DiscretizedGMMFactory(DiscreteDistributionFactory):
    method __init__ (line 382) | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=...
    method get_dist (line 390) | def get_dist(self, params, input_params=None, t=None):
  class DiscretizedNormalFactory (line 402) | class DiscretizedNormalFactory(DiscreteDistributionFactory):
    method __init__ (line 403) | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=...
    method get_dist (line 411) | def get_dist(self, params, input_params=None, t=None):
  function noise_pred_params_to_data_pred_params (line 423) | def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tenso...
  class PredDistToDataDistFactory (line 459) | class PredDistToDataDistFactory(DiscreteDistributionFactory):
    method __init__ (line 460) | def __init__(self, data_dist_factory, min_variance, min_t=1e-6):
    method get_dist (line 466) | def get_dist(self, params, input_params, t):

FILE: sample.py
  function main (line 24) | def main(cfg: DictConfig) -> torch.Tensor:

FILE: test.py
  function setup (line 32) | def setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]:
  function test (line 47) | def test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: in...
  function main (line 73) | def main(cfg: DictConfig) -> tuple[float, float, float, float]:

FILE: train.py
  function setup (line 56) | def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:
  function validate (line 70) | def validate(
  function train (line 116) | def train(
  function main (line 178) | def main(cfg):

FILE: utils_model.py
  function sandwich (line 28) | def sandwich(x: Tensor):
  function safe_log (line 32) | def safe_log(data: Tensor):
  function safe_exp (line 36) | def safe_exp(data: Tensor):
  function idx_to_float (line 40) | def idx_to_float(idx: np.ndarray, num_bins: int):
  function float_to_idx (line 45) | def float_to_idx(flt: np.ndarray, num_bins: int):
  function quantize (line 50) | def quantize(flt, num_bins: int):
  function pe_encode (line 54) | def pe_encode(sequence_length: int, embedding_size: int) -> Tensor:
  function pe_encode_float (line 69) | def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> ...

FILE: utils_train.py
  function stringify_unsupported (line 29) | def stringify_unsupported(x):
  function seed_everything (line 49) | def seed_everything(seed: Optional[int]):
  function worker_init_function (line 58) | def worker_init_function(worker_id: int) -> None:
  function init_checkpointing (line 65) | def init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: s...
  function checkpoint_training_state (line 77) | def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, st...
  function log (line 89) | def log(key_handler, value, step, cond=True):
  function log_cfg (line 95) | def log_cfg(cfg, run: "neptune.Run"):
  function update_ema (line 104) | def update_ema(ema_model, model, ema_decay):
  function ddict (line 110) | def ddict():
  function make_infinite (line 115) | def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]:
  function make_progress_bar (line 121) | def make_progress_bar(is_main: bool, text="[red]loss: {task.fields[loss]...
  function make_dataloaders (line 132) | def make_dataloaders(cfg: DictConfig):
  function make_from_cfg (line 149) | def make_from_cfg(module, cfg, **parameters):
  function make_bfn (line 153) | def make_bfn(cfg: DictConfig):
  function make_config (line 205) | def make_config(cfg_file: str):
Condensed preview — 23 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (157K chars).
[
  {
    "path": ".gitignore",
    "chars": 2075,
    "preview": "# Data, checkpoints, logs\ndata\ncheckpoints\n.neptune\n\n# Files generated by setuptools_scm\n__version.py\n\n# MacOS\n.DS_Store"
  },
  {
    "path": "LICENSE",
    "chars": 10173,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 4630,
    "preview": "# Bayesian Flow Networks\n\nThis is the official code release for [Bayesian Flow Networks](https://arxiv.org/abs/2308.0703"
  },
  {
    "path": "configs/cifar10_continuous_16bins.yaml",
    "chars": 1442,
    "preview": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 16\ntrain_loader:\n  batc"
  },
  {
    "path": "configs/cifar10_continuous_256bins.yaml",
    "chars": 1443,
    "preview": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 256\ntrain_loader:\n  bat"
  },
  {
    "path": "configs/cifar10_discretized_16bins.yaml",
    "chars": 1500,
    "preview": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 16\ntrain_loader:\n  batc"
  },
  {
    "path": "configs/cifar10_discretized_256bins.yaml",
    "chars": 1502,
    "preview": "meta:\n  neptune: \n  debug: False\ndata:\n  dataset: \"cifar10\"\n  horizontal_flip: False\n  num_bins: 256\ntrain_loader:\n  bat"
  },
  {
    "path": "configs/mnist_discrete.yaml",
    "chars": 1453,
    "preview": "meta:\n  neptune:\n  debug: False\ndata:\n  dataset: \"bin_mnist\"\ntrain_loader:\n  batch_size: 512\n  shuffle: True\n  num_worke"
  },
  {
    "path": "configs/text8_discrete.yaml",
    "chars": 1164,
    "preview": "meta:\n  neptune:\n  debug: False\ndata:\n  dataset: \"text8\"\n  seq_len: 256\ntrain_loader:\n  batch_size: 416\n  shuffle: True\n"
  },
  {
    "path": "data.py",
    "chars": 9324,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "env.yml",
    "chars": 221,
    "preview": "name: bfn\nchannels:\n  - pytorch\n  - nvidia\ndependencies:\n  - python=3.9\n  - pytorch=2.0.0\n  - pytorch-cuda=11.8\n  - torc"
  },
  {
    "path": "model.py",
    "chars": 20519,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "networks/__init__.py",
    "chars": 770,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "networks/adapters.py",
    "chars": 5996,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "networks/transformer.py",
    "chars": 9026,
    "preview": "# Source: https://github.com/karpathy/nanoGPT\n#\n# MIT License\n#\n# Copyright (c) 2022 Andrej Karpathy\n#\n# Permission is h"
  },
  {
    "path": "networks/unet_improved.py",
    "chars": 27147,
    "preview": "# Source: https://github.com/openai/improved-diffusion\n#\n# MIT License\n#\n# Copyright (c) 2021 OpenAI\n#\n# Permission is h"
  },
  {
    "path": "networks/unet_vdm.py",
    "chars": 11973,
    "preview": "# Source: https://github.com/addtt/variational-diffusion-models\n#\n# MIT License\n#\n# Copyright (c) 2022 Andrea Dittadi\n#\n"
  },
  {
    "path": "probability.py",
    "chars": 16818,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "sample.py",
    "chars": 2096,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "test.py",
    "chars": 4149,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "train.py",
    "chars": 7840,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "utils_model.py",
    "chars": 2565,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  },
  {
    "path": "utils_train.py",
    "chars": 6573,
    "preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
  }
]

About this extraction

This page contains the full source code of the nnaisense/bayesian-flow-networks GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 23 files (146.9 KB), approximately 37.4k tokens, and a symbol index with 288 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!