Repository: nnaisense/bayesian-flow-networks
Branch: main
Commit: b62568e5d064
Files: 23
Total size: 146.9 KB
Directory structure:
gitextract__0riu0_z/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── cifar10_continuous_16bins.yaml
│ ├── cifar10_continuous_256bins.yaml
│ ├── cifar10_discretized_16bins.yaml
│ ├── cifar10_discretized_256bins.yaml
│ ├── mnist_discrete.yaml
│ └── text8_discrete.yaml
├── data.py
├── env.yml
├── model.py
├── networks/
│ ├── __init__.py
│ ├── adapters.py
│ ├── transformer.py
│ ├── unet_improved.py
│ └── unet_vdm.py
├── probability.py
├── sample.py
├── test.py
├── train.py
├── utils_model.py
└── utils_train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Data, checkpoints, logs
data
checkpoints
.neptune
# Files generated by setuptools_scm
__version.py
# MacOS
.DS_Store
# Visual Studio Code
.vscode/
*.code-workspace
.history/
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# PyCharm
.idea/
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# End of https://www.gitignore.io/api/python
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
================================================
FILE: README.md
================================================
# Bayesian Flow Networks
This is the official code release for [Bayesian Flow Networks](https://arxiv.org/abs/2308.07037) by Alex Graves, Rupesh Kumar Srivastava, Timothy Atkinson and Faustino Gomez.
<img src="bfn.gif" alt="Overview of BFN process" style="width:600px;"/>
## Reading Guide
- `model.py` contains all the main contributions of the paper. These include definitions, for both continuous and discrete data, of Bayesian Flows as well as loss functions for both continuous-time and discrete-time. See comments in the base classes in that file for details.
- `probability.py` defines the probability distributions used by the models.
- `train.py`, `test.py` and `sample.py` are scripts for training, testing and sampling (see below for usage).
- `data.py` contains utilities related to data loading and processing.
- `networks/` contains implementations of the network architectures used by the models.
## Setup
```shell
# Create a new conda env with all dependencies including pytorch and CUDA
conda env create -f env.yml
conda activate bfn
# Or, install additional dependencies into an existing pytorch env
pip install accelerate==0.19.0 matplotlib omegaconf rich
# Optional, if you want to enable logging to neptune.ai
pip install neptune
```
## Training
The models in the paper can be trained using the configs provided in the `configs` dir as follows:
```shell
# mnist experiment on 1 GPU
accelerate launch train.py config_file=configs/mnist_discrete.yaml
# cifar10 experiment on 1 GPU (A100)
accelerate launch train.py config_file=configs/cifar10_discretized_256bins.yaml
# text8 experiment on 8 GPUs (A100)
accelerate launch --multi_gpu --num_processes=8 --num_machines=1 --dynamo_backend=no --mixed_precision=fp16 train.py config_file=configs/text8_discrete.yaml
```
## Testing
> [!NOTE]
> Depending on your GPU, you may wish to adjust the batch size used for testing in `test.py`.
```shell
# Optional: Download pretrained checkpoints (make sure you have git-lfs installed: https://git-lfs.com/)
git clone git@hf.co:rupspace/pretrained-BFNs
# Compute 784-step loss on MNIST
python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000
# Compute 10-step loss on CIFAR-10
python test.py seed=1 config_file=./configs/cifar10_discretized_256bins.yaml load_model=./pretrained-BFNs/cifar10_256d_ema.pt n_steps=10 n_repeats=100
# Compute continuous-time loss on text8
python test.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt n_steps=0 n_repeats=1
```
> [!IMPORTANT]
> All computed results will be in nats-per-data-dimension. To convert to bits, divide by ln(2).
## Sampling
You can sample from a pre-trained model as follows (change options as desired):
```shell
# Sample 4 binarized MNIST images using 100 steps
python sample.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt samples_shape="[4, 28, 28, 1]" n_steps=100 save_file=./samples_mnist.pt
# Sample 4 CIFAR-10 16-bit images modeled as discretized data using 1000 steps
python sample.py seed=1 config_file=./configs/cifar10_discretized_16bins.yaml load_model=./pretrained-BFNs/cifar10_16d_ema.pt samples_shape="[4, 32, 32, 3]" n_steps=1000 save_file=./samples_cifar.pt
# Sample 2 text8 sequences of length 256 using 100 steps
python sample.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt samples_shape="[2, 256]" n_steps=100 save_file=./samples_text8.pt
```
The samples are stored as PyTorch tensors in the `save_file`, and can be visualized by loading them and then using the utilities `batch_to_images` and `batch_to_str` in `data.py`.
For example:
```shell
# batch_to_images returns a matplotlib Figure object
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_mnist.pt')).savefig('mnist.png')"
python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_cifar.pt')).savefig('cifar.png')"
# batch_to_str returns a list of str
python -c "import torch; from data import batch_to_str; print(batch_to_str(torch.load('./samples_text8.pt')))"
```
## Reproducibility
If a high degree of reproducibility is desired (e.g. during sampling), set the following:
```python
torch.set_float32_matmul_precision("highest")
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
```
## Acknowledgements
We are grateful to [@Higgcz](https://github.com/Higgcz) for generous support with the experiment infrastructure and code release.
================================================
FILE: configs/cifar10_continuous_16bins.yaml
================================================
meta:
neptune:
debug: False
data:
dataset: "cifar10"
horizontal_flip: False
num_bins: 16
train_loader:
batch_size: 32
shuffle: True
num_workers: 8
pin_memory: True
drop_last: True
persistent_workers: True
val_loader:
batch_size: 500
shuffle: False
num_workers: 8
pin_memory: True
model:
net:
class_name: "UNetVDM"
parameters:
embedding_dim: 128
n_blocks: 32
n_attention_heads: 1
dropout_prob: 0.1
norm_groups: 32
input_channels: 3
use_fourier_features: True
attention_everywhere: False
image_size: 32
input_adapter:
class_name: "FourierImageInputAdapter"
parameters:
input_channels: 3
input_shape: [32, 32]
output_height: 3
add_pos_feats: False
add_mask: False
output_adapter:
class_name: "OutputAdapter"
parameters:
input_height: 131
output_channels: 3 # (r,g,b)
output_height: 1
bayesian_flow:
class_name: "CtsBayesianFlow"
parameters:
min_variance: 1e-3
loss:
class_name: "CtsBayesianFlowLoss"
parameters:
noise_pred: True
distribution_factory:
class_name: "DeltaFactory"
parameters: {}
optimizer:
lr: 2e-4
betas: [0.9,0.99]
weight_decay: 0.01
eps: 1e-8
training:
checkpoint_interval: 10_000
ema_decay: 0.9999
grad_clip_norm: 5.0
log_interval: 1
n_training_steps: 1_000_000
val_interval: 50_000
val_repeats: 100
================================================
FILE: configs/cifar10_continuous_256bins.yaml
================================================
meta:
neptune:
debug: False
data:
dataset: "cifar10"
horizontal_flip: False
num_bins: 256
train_loader:
batch_size: 32
shuffle: True
num_workers: 8
pin_memory: True
drop_last: True
persistent_workers: True
val_loader:
batch_size: 500
shuffle: False
num_workers: 8
pin_memory: True
model:
net:
class_name: "UNetVDM"
parameters:
embedding_dim: 128
n_blocks: 32
n_attention_heads: 1
dropout_prob: 0.1
norm_groups: 32
input_channels: 3
use_fourier_features: True
attention_everywhere: False
image_size: 32
input_adapter:
class_name: "FourierImageInputAdapter"
parameters:
input_channels: 3
input_shape: [32, 32]
output_height: 3
add_pos_feats: False
add_mask: False
output_adapter:
class_name: "OutputAdapter"
parameters:
input_height: 131
output_channels: 3 # (r,g,b)
output_height: 1
bayesian_flow:
class_name: "CtsBayesianFlow"
parameters:
min_variance: 1e-6
loss:
class_name: "CtsBayesianFlowLoss"
parameters:
noise_pred: True
distribution_factory:
class_name: "DeltaFactory"
parameters: {}
optimizer:
lr: 2e-4
betas: [0.9,0.99]
weight_decay: 0.01
eps: 1e-8
training:
checkpoint_interval: 10_000
ema_decay: 0.9999
grad_clip_norm: 5.0
log_interval: 1
n_training_steps: 1_000_000
val_interval: 50_000
val_repeats: 100
================================================
FILE: configs/cifar10_discretized_16bins.yaml
================================================
meta:
neptune:
debug: False
data:
dataset: "cifar10"
horizontal_flip: False
num_bins: 16
train_loader:
batch_size: 32
shuffle: True
num_workers: 8
pin_memory: True
drop_last: True
persistent_workers: True
val_loader:
batch_size: 1000
shuffle: False
num_workers: 8
pin_memory: True
model:
net:
class_name: "UNetVDM"
parameters:
embedding_dim: 128
n_blocks: 32
n_attention_heads: 1
dropout_prob: 0.1
norm_groups: 32
input_channels: 3
use_fourier_features: True
attention_everywhere: False
image_size: 32
input_adapter:
class_name: "FourierImageInputAdapter"
parameters:
input_channels: 3
input_shape: [32, 32]
output_height: 3
add_pos_feats: False
add_mask: False
output_adapter:
class_name: "OutputAdapter"
parameters:
input_height: 131
output_channels: 3 # (r,g,b)
output_height: 2 # mean, std
bayesian_flow:
class_name: "CtsBayesianFlow"
parameters:
min_variance: 1e-3
loss:
class_name: "CtsBayesianFlowLoss"
parameters:
noise_pred: True
distribution_factory:
class_name: "DiscretizedNormalFactory"
parameters:
num_bins: 16
clip: True
optimizer:
lr: 2e-4
betas: [0.9,0.99]
weight_decay: 0.01
eps: 1e-8
training:
checkpoint_interval: 10_000
ema_decay: 0.9999
grad_clip_norm: 5.0
log_interval: 1
n_training_steps: 1_000_000
val_interval: 50_000
val_repeats: 100
================================================
FILE: configs/cifar10_discretized_256bins.yaml
================================================
meta:
neptune:
debug: False
data:
dataset: "cifar10"
horizontal_flip: False
num_bins: 256
train_loader:
batch_size: 32
shuffle: True
num_workers: 8
pin_memory: True
drop_last: True
persistent_workers: True
val_loader:
batch_size: 1000
shuffle: False
num_workers: 8
pin_memory: True
model:
net:
class_name: "UNetVDM"
parameters:
embedding_dim: 128
n_blocks: 32
n_attention_heads: 1
dropout_prob: 0.1
norm_groups: 32
input_channels: 3
use_fourier_features: True
attention_everywhere: False
image_size: 32
input_adapter:
class_name: "FourierImageInputAdapter"
parameters:
input_channels: 3
input_shape: [32, 32]
output_height: 3
add_pos_feats: False
add_mask: False
output_adapter:
class_name: "OutputAdapter"
parameters:
input_height: 131
output_channels: 3 # (r,g,b)
output_height: 2 # mean, std
bayesian_flow:
class_name: "CtsBayesianFlow"
parameters:
min_variance: 1e-6
loss:
class_name: "CtsBayesianFlowLoss"
parameters:
noise_pred: True
distribution_factory:
class_name: "DiscretizedNormalFactory"
parameters:
num_bins: 256
clip: True
optimizer:
lr: 2e-4
betas: [0.9,0.99]
weight_decay: 0.01
eps: 1e-8
training:
checkpoint_interval: 10_000
ema_decay: 0.9999
grad_clip_norm: 5.0
log_interval: 1
n_training_steps: 1_000_000
val_interval: 50_000
val_repeats: 100
================================================
FILE: configs/mnist_discrete.yaml
================================================
meta:
neptune:
debug: False
data:
dataset: "bin_mnist"
train_loader:
batch_size: 512
shuffle: True
num_workers: 8
pin_memory: True
drop_last: True
val_loader:
batch_size: 1000
shuffle: False
num_workers: 8
pin_memory: True
model:
net:
class_name: "UNetModel"
parameters:
image_size: 28
in_channels: 2
model_channels: 128
out_channels: 128
num_res_blocks: 2
attention_resolutions: [8,16]
dropout: 0.5
channel_mult: [1, 2, 2]
conv_resample: True
dims: 2
num_heads: 4
num_heads_upsample: -1
project_input: True
skip: True
input_adapter:
class_name: "FourierImageInputAdapter"
parameters:
input_channels: 1
input_shape: [28, 28]
output_height: 2
add_pos_feats: False
output_adapter:
class_name: "OutputAdapter"
parameters:
input_height: 256
output_channels: 1
output_height: 1
bayesian_flow:
class_name: "DiscreteBayesianFlow"
parameters:
n_classes: 2
max_sqrt_beta: 3
discretize: False
loss:
class_name: "DiscreteBayesianFlowLoss"
parameters: {}
distribution_factory:
class_name: "BernoulliFactory"
parameters: {}
optimizer:
lr: 1e-4
betas: [0.9,0.98]
training:
checkpoint_interval: 10_000
ema_decay: 0.9999
grad_clip_norm: 5.0
log_interval: 1
n_training_steps: 1_000_000
val_interval: 50_000
val_repeats: 1000
================================================
FILE: configs/text8_discrete.yaml
================================================
meta:
neptune:
debug: False
data:
dataset: "text8"
seq_len: 256
train_loader:
batch_size: 416
shuffle: True
num_workers: 8
pin_memory: True
drop_last: True
val_loader:
batch_size: 200
shuffle: True
num_workers: 8
pin_memory: True
model:
net:
class_name: "GPT"
parameters:
vocab_size: 27
n_layer: 24
n_head: 12
n_embd: 768
dropout: 0.0
skip: True
bias: True
input_adapter:
class_name: "TextInputAdapter"
parameters:
vocab_size: 27
seq_len: 256
output_size: 768
learn_pos_embedding: False
output_adapter: null
bayesian_flow:
class_name: "DiscreteBayesianFlow"
parameters:
n_classes: 27
max_sqrt_beta: 0.75
loss:
class_name: "DiscreteBayesianFlowLoss"
parameters: {}
distribution_factory:
class_name: "CategoricalFactory"
parameters: {}
optimizer:
lr: 1e-4
betas: [0.9, 0.98]
weight_decay: 0.01
training:
accumulate: 1
checkpoint_interval: 10_000
ema_decay: 0.9999
grad_clip_norm: 5
log_interval: 1
max_val_batches: 5_000
n_training_steps: 10_000_000
val_interval: 100_000
val_repeats: 1
================================================
FILE: data.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import pathlib
import pickle
import zipfile
from typing import Union
import numpy as np
import requests
import torch
import torchvision
from matplotlib import pyplot as plt
from omegaconf import DictConfig
from torch.utils.data import Dataset, random_split
from torchvision import transforms
from torchvision.utils import make_grid
from utils_model import quantize
TEXT8_CHARS = list("_abcdefghijklmnopqrstuvwxyz")
def bin_mnist_transform(x):
return torch.bernoulli(x.permute(1, 2, 0).contiguous()).int()
def bin_mnist_cts_transform(x):
return torch.bernoulli(x.permute(1, 2, 0).contiguous()) - 0.5
def rgb_image_transform(x, num_bins=256):
return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous()
class MyLambda(torchvision.transforms.Lambda):
def __init__(self, lambd, arg1):
super().__init__(lambd)
self.arg1 = arg1
def __call__(self, x):
return self.lambd(x, self.arg1)
class CIFAR10(torchvision.datasets.CIFAR10):
def __getitem__(self, idx):
return super().__getitem__(idx)[0]
class MNIST(torchvision.datasets.MNIST):
def __getitem__(self, idx):
return super().__getitem__(idx)[0]
def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:
"""
Mandatory keys: dataset (must be cifar10, mnist, bin_mnist, bin_mnist_cts or text8), data_dir
Optional for vision: num_bins (default 256), val_frac (default 0.01), horizontal_flip (default: False)
Mandatory for text: seq_len
"""
num_bins = cfg.get("num_bins", 256)
if cfg.dataset == "cifar10":
train_transform_list = [transforms.ToTensor()]
if cfg.get("horizontal_flip", False):
train_transform_list.append(transforms.RandomHorizontalFlip())
train_transform_list.append(MyLambda(rgb_image_transform, num_bins))
train_transform = transforms.Compose(train_transform_list)
test_transform = transforms.Compose([transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins)])
train_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transform)
val_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=test_transform)
test_set = CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transform)
elif cfg.dataset == "mnist":
transform = transforms.Compose(
[
transforms.ToTensor(),
MyLambda(rgb_image_transform, num_bins),
]
)
train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)
elif cfg.dataset == "bin_mnist":
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_transform)])
train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)
elif cfg.dataset == "bin_mnist_cts":
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_cts_transform)])
train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)
elif cfg.dataset == "text8":
train_set = Text8Dataset(cfg.data_dir, "train", download=True, seq_len=cfg.seq_len)
val_set = Text8Dataset(cfg.data_dir, "val", download=True, seq_len=cfg.seq_len)
test_set = Text8Dataset(cfg.data_dir, "test", download=True, seq_len=cfg.seq_len)
else:
raise NotImplementedError(cfg.dataset)
if cfg.dataset != "text8":
# For vision datasets we split the train set into train and val
val_frac = cfg.get("val_frac", 0.01)
train_val_split = [1.0 - val_frac, val_frac]
seed = 2147483647
train_set = random_split(train_set, train_val_split, generator=torch.Generator().manual_seed(seed))[0]
val_set = random_split(val_set, train_val_split, generator=torch.Generator().manual_seed(seed))[1]
return train_set, val_set, test_set
def prepare_text8(data_dir: pathlib.Path):
data_dir.mkdir(parents=True, exist_ok=True)
data_url = "http://mattmahoney.net/dc/text8.zip"
with open(data_dir / "text8.zip", "wb") as f:
print("Downloading text8")
f.write(requests.get(data_url).content)
print("Done")
with zipfile.ZipFile(data_dir / "text8.zip") as f:
f.extractall(data_dir)
os.remove(data_dir / "text8.zip")
data = (data_dir / "text8").read_text()
# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", "".join(chars))
print(f"vocab size: {vocab_size:,}")
# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
def encode(s):
return [stoi[c] for c in s] # encoder: take a string, output a list of integers
# encode both to integers
n = len(data)
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) : int(n * 0.95)]
test_data = data[int(n * 0.95) :]
train_ids = encode(train_data)
val_ids = encode(val_data)
test_ids = encode(test_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")
print(f"test has {len(test_ids):,} tokens")
# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
test_ids = np.array(test_ids, dtype=np.uint16)
train_ids.tofile(data_dir / "train.bin")
val_ids.tofile(data_dir / "val.bin")
test_ids.tofile(data_dir / "test.bin")
print(f"Saved to {data_dir / 'train.bin'}, {data_dir / 'val.bin'}, {data_dir / 'test.bin'}")
# save the meta information as well, to help us encode/decode later
meta = {
"vocab_size": vocab_size,
"itos": itos,
"stoi": stoi,
}
with open(os.path.join(data_dir / "meta.pkl"), "wb") as f:
pickle.dump(meta, f)
print(f"text8 dataset downloaded and prepared in dir {data_dir}")
class Text8Dataset(Dataset):
def __init__(self, data_dir: Union[str, pathlib.Path], split: str, download: bool, seq_len: int):
"""
seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context.
context is only used for correct preparation of val/test sets.
"""
self.root_dir = pathlib.Path(data_dir)
self.split = split
self.seq_len = seq_len
fname = {"train": "train.bin", "val": "val.bin", "test": "test.bin"}[self.split]
assert self.split in ["train", "val", "test"]
data_dir = self.root_dir / "text8"
if not os.path.exists(data_dir):
if download:
prepare_text8(data_dir)
else:
raise NotADirectoryError(f"dir {data_dir} does not exist and download is False")
self.data = np.memmap(data_dir / fname, np.uint16, "r")
def __getitem__(self, index) -> torch.Tensor:
seq = torch.from_numpy(self.data[index : index + self.seq_len].astype(np.int64))
return seq
def __len__(self):
return self.data.size - self.seq_len
def char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) -> str:
"""Decode a 1D sequence of character IDs to a string."""
return "".join([TEXT8_CHARS[i] for i in char_ids])
def batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) -> list[str]:
"""Decode a batch of character IDs to a list of strings."""
return [char_ids_to_str(row_char_ids) for row_char_ids in text_batch]
def batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt.Figure:
if ncols is None:
ncols = math.ceil(math.sqrt(len(image_batch)))
if image_batch.size(-1) == 3: # for color images (CIFAR-10)
image_batch = (image_batch + 1) / 2
grid = make_grid(image_batch.permute(0, 3, 1, 2), ncols, pad_value=1).permute(1, 2, 0)
fig = plt.figure(figsize=(grid.size(1) / 30, grid.size(0) / 30))
plt.imshow(grid.cpu().clip(min=0, max=1), interpolation="nearest")
plt.grid(False)
plt.axis("off")
return fig
================================================
FILE: env.yml
================================================
name: bfn
channels:
- pytorch
- nvidia
dependencies:
- python=3.9
- pytorch=2.0.0
- pytorch-cuda=11.8
- torchvision=0.15.0
- pip
- pip:
- accelerate==0.19.0
- matplotlib
- omegaconf
- rich
================================================
FILE: model.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implements the Bayesian Flow and BFN loss for continuous and discrete variables.
Finally it implements the BFN using these objects.
For consistency we use always use a tuple to store input parameters.
It has just one element for discrete data (the probabilities) and two for continuous/discretized (mean & variance).
The probability distributions and network architectures are defined in probability.py and networks dir.
"Cts" is an abbreviation of "Continuous".
"""
import math
from abc import abstractmethod, ABC
from typing import Union, Optional
import torch
import torch.distributions as D
import torch.nn.functional as F
from torch import nn, Tensor
from probability import (
DiscreteDistributionFactory,
CtsDistributionFactory,
PredDistToDataDistFactory,
DiscretizedCtsDistribution,
)
from utils_model import sandwich, float_to_idx
class BayesianFlow(nn.Module, ABC):
def __init__(self):
super().__init__()
@abstractmethod
def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]:
"""Returns the initial input params (for a batch) at t=0. Used during sampling.
For discrete data, the tuple has length 1 and contains the initial class probabilities.
For continuous data, the tuple has length 2 and contains the mean and precision."""
pass
@abstractmethod
def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:
"""Utility method to convert input distribution params to network inputs if needed."""
pass
@abstractmethod
def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:
"""Returns the alpha at step i of total n_steps according to the flow schedule. Used:
a) during sampling, when i and alpha are the same for all samples in the batch.
b) during discrete time loss computation, when i and alpha are different for samples in the batch."""
pass
@abstractmethod
def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
"""Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used:
a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net.
b) during discrete time loss computation when alpha are different for samples in the batch."""
pass
@abstractmethod
def update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]:
"""Updates the distribution parameters using Bayes' theorem in light of noisy sample y.
Used during sampling when alpha is the same for the whole batch."""
pass
@abstractmethod
def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:
"""Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data.
Used during training when t (and thus accuracies) are different for different samples in the batch.
For discrete data, the returned tuple has length 1 and contains the class probabilities.
For continuous data, the returned tuple has length 2 and contains the mean and precision."""
pass
class Loss(nn.Module, ABC):
def __init__(self):
super().__init__()
@abstractmethod
def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor:
"""Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1).
The input params are only used when the network is parameterized to predict the noise for continuous data."""
pass
@abstractmethod
def discrete_time_loss(
self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples: int = 20
) -> Tensor:
"""Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) using
n_samples for Monte Carlo estimation of the discrete loss.
The input params are only used when the network is parameterized to predict the noise for continuous data."""
pass
@abstractmethod
def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
"""Returns the reconstruction loss, i.e. the final cost of transmitting clean data.
The input params are only used when the network is parameterized to predict the noise for continuous data."""
pass
# Continuous or Discretized data
class CtsBayesianFlow(BayesianFlow):
def __init__(
self,
min_variance: float = 1e-6,
):
super().__init__()
self.min_variance = min_variance
@torch.no_grad()
def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:
post_var = torch.pow(self.min_variance, t)
alpha_t = 1 - post_var
mean_mean = alpha_t * data
mean_var = alpha_t * post_var
mean_std_dev = mean_var.sqrt()
noise = torch.randn(mean_mean.shape, device=mean_mean.device)
mean = mean_mean + (mean_std_dev * noise)
# We don't need to compute the variance because it is not needed by the network, so set it to None
input_params = (mean, None)
return input_params
def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
return params[0] # Only the mean is used by the network
def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]:
return torch.zeros(*data_shape, device=device), 1.0
def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
sigma_1 = math.sqrt(self.min_variance)
return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps))
def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
dist = D.Normal(x, 1.0 / alpha**0.5)
return dist
def update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]:
input_mean, input_precision = input_params
new_precision = input_precision + alpha
new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precision
return new_mean, new_precision
class CtsBayesianFlowLoss(Loss):
def __init__(
self,
bayesian_flow: CtsBayesianFlow,
distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],
min_loss_variance: float = -1,
noise_pred: bool = True,
):
super().__init__()
self.bayesian_flow = bayesian_flow
self.distribution_factory = distribution_factory
self.min_loss_variance = min_loss_variance
self.C = -0.5 * math.log(bayesian_flow.min_variance)
self.noise_pred = noise_pred
if self.noise_pred:
self.distribution_factory.log_dev = False
self.distribution_factory = PredDistToDataDistFactory(
self.distribution_factory, self.bayesian_flow.min_variance
)
def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
output_params = sandwich(output_params)
t = t.flatten(start_dim=1).float()
posterior_var = torch.pow(self.bayesian_flow.min_variance, t)
flat_target = data.flatten(start_dim=1)
pred_dist = self.distribution_factory.get_dist(output_params, input_params, t)
pred_mean = pred_dist.mean
mse_loss = (pred_mean - flat_target).square()
if self.min_loss_variance > 0:
posterior_var = posterior_var.clamp(min=self.min_loss_variance)
loss = self.C * mse_loss / posterior_var
return loss
def discrete_time_loss(
self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10
) -> Tensor:
output_params = sandwich(output_params)
t = t.flatten(start_dim=1).float()
output_dist = self.distribution_factory.get_dist(output_params, input_params, t)
if hasattr(output_dist, "probs"): # output distribution is discretized normal
flat_target = data.flatten(start_dim=1)
t = t.flatten(start_dim=1)
i = t * n_steps + 1 # since t = (i - 1) / n
alpha = self.bayesian_flow.get_alpha(i, n_steps)
sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)
receiver_mix_wts = sandwich(output_dist.probs)
receiver_mix_dist = D.Categorical(probs=receiver_mix_wts, validate_args=False)
receiver_components = D.Normal(
output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False
)
receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False)
y = sender_dist.sample(torch.Size([n_samples]))
loss = (
(sender_dist.log_prob(y) - receiver_dist.log_prob(y))
.mean(0)
.flatten(start_dim=1)
.mean(1, keepdims=True)
)
else: # output distribution is normal
pred_mean = output_dist.mean
flat_target = data.flatten(start_dim=1)
mse_loss = (pred_mean - flat_target).square()
i = t * n_steps + 1
alpha = self.bayesian_flow.get_alpha(i, n_steps)
loss = alpha * mse_loss / 2
return n_steps * loss
def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
output_params = sandwich(output_params)
flat_data = data.flatten(start_dim=1)
t = torch.ones_like(data).flatten(start_dim=1).float()
output_dist = self.distribution_factory.get_dist(output_params, input_params, t)
if hasattr(output_dist, "probs"): # output distribution is discretized normal
reconstruction_loss = -output_dist.log_prob(flat_data)
else: # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2)
if self.bayesian_flow.min_variance == 1e-3: # used for 16 bin CIFAR10
noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance)
num_bins = 16
else:
noise_dev = math.sqrt(self.bayesian_flow.min_variance)
num_bins = 256
mean = output_dist.mean.flatten(start_dim=1)
final_dist = D.Normal(mean, noise_dev)
final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1)
reconstruction_loss = -final_dist.log_prob(flat_data)
return reconstruction_loss
# Discrete Data
class DiscreteBayesianFlow(BayesianFlow):
def __init__(
self,
n_classes: int,
min_sqrt_beta: float = 1e-10,
discretize: bool = False,
epsilon: float = 1e-6,
max_sqrt_beta: float = 1,
):
super().__init__()
self.n_classes = n_classes
self.min_sqrt_beta = min_sqrt_beta
self.discretize = discretize
self.epsilon = epsilon
self.max_sqrt_beta = max_sqrt_beta
self.uniform_entropy = math.log(self.n_classes)
def t_to_sqrt_beta(self, t):
return t * self.max_sqrt_beta
def count_dist(self, x, beta=None):
mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1
std_dev = math.sqrt(self.n_classes)
if beta is not None:
mean = mean * beta
std_dev = std_dev * beta.sqrt()
return D.Normal(mean, std_dev, validate_args=False)
def count_sample(self, x, beta):
return self.count_dist(x, beta).rsample()
@torch.no_grad()
def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]:
return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,)
@torch.no_grad()
def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
params = params[0]
if self.n_classes == 2:
params = params * 2 - 1 # We scale-shift here for MNIST instead of in the network like for text
params = params[..., :1]
return params
def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1)
def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
e_x = F.one_hot(x.long(), self.n_classes)
alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alpha
dist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5)
return dist
def update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]:
new_input_params = input_params[0] * y.exp()
new_input_params /= new_input_params.sum(-1, keepdims=True)
return (new_input_params,)
@torch.no_grad()
def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:
if self.discretize:
data = float_to_idx(data, self.n_classes)
sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon))
lo_beta = sqrt_beta < self.min_sqrt_beta
sqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta)
beta = sqrt_beta.square().unsqueeze(-1)
logits = self.count_sample(data, beta)
probs = F.softmax(logits, -1)
probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs)
if self.n_classes == 2:
probs = probs[..., :1]
probs = probs.reshape_as(data)
input_params = (probs,)
return input_params
class DiscreteBayesianFlowLoss(Loss):
def __init__(
self,
bayesian_flow: DiscreteBayesianFlow,
distribution_factory: DiscreteDistributionFactory,
):
super().__init__()
self.bayesian_flow = bayesian_flow
self.distribution_factory = distribution_factory
self.K = self.bayesian_flow.n_classes
def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
flat_output = sandwich(output_params)
pred_probs = self.distribution_factory.get_dist(flat_output).probs
flat_target = data.flatten(start_dim=1)
if self.bayesian_flow.discretize:
flat_target = float_to_idx(flat_target, self.K)
tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K)
kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1)
t = t.flatten(start_dim=1).float()
loss = t * (self.bayesian_flow.max_sqrt_beta**2) * kl
return loss
def discrete_time_loss(
self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10
) -> Tensor:
flat_target = data.flatten(start_dim=1)
if self.bayesian_flow.discretize:
flat_target = float_to_idx(flat_target, self.K)
i = t * n_steps + 1
alpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1)
sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)
flat_output = sandwich(output_params)
receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs
receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2))
classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0)
receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1))
receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components)
y = sender_dist.sample(torch.Size([n_samples]))
loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True)
return loss
def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
flat_outputs = sandwich(output_params)
flat_data = data.flatten(start_dim=1)
output_dist = self.distribution_factory.get_dist(flat_outputs)
return -output_dist.log_prob(flat_data)
class BFN(nn.Module):
def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss):
super().__init__()
self.net = net
self.bayesian_flow = bayesian_flow
self.loss = loss
@staticmethod
@torch.no_grad()
def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:
if n_steps == 0 or n_steps is None:
t = torch.rand(data.size(0), device=data.device).unsqueeze(-1)
else:
t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps
t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data)
return t
def forward(
self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None
) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]:
"""
Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss.
t is sampled randomly if None. If t is not None, expect t.shape == data.shape.
"""
t = self.sample_t(data, n_steps) if t is None else t
# sample input parameter flow
input_params = self.bayesian_flow(data, t)
net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)
# compute output distribution parameters
output_params: Tensor = self.net(net_inputs, t)
# compute KL loss in float32
with torch.autocast(device_type=data.device.type if data.device.type != "mps" else "cpu", enabled=False):
if n_steps == 0 or n_steps is None:
loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t)
else:
loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps)
# loss shape is (batch_size, 1)
return loss.mean()
@torch.inference_mode()
def compute_reconstruction_loss(self, data: Tensor) -> Tensor:
t = torch.ones_like(data).float()
input_params = self.bayesian_flow(data, t)
net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)
output_params: Tensor = self.net(net_inputs, t)
return self.loss.reconstruction_loss(data, output_params, input_params).flatten(start_dim=1).mean()
@torch.inference_mode()
def sample(self, data_shape: tuple, n_steps: int) -> Tensor:
device = next(self.parameters()).device
input_params = self.bayesian_flow.get_prior_input_params(data_shape, device)
distribution_factory = self.loss.distribution_factory
for i in range(1, n_steps + 1):
t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps
output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
output_sample = distribution_factory.get_dist(output_params, input_params, t).sample()
output_sample = output_sample.reshape(*data_shape)
alpha = self.bayesian_flow.get_alpha(i, n_steps)
y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample()
input_params = self.bayesian_flow.update_input_params(input_params, y, alpha)
t = torch.ones(*data_shape, device=device)
output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
output_sample = distribution_factory.get_dist(output_params, input_params, t).mode
output_sample = output_sample.reshape(*data_shape)
return output_sample
================================================
FILE: networks/__init__.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = (
"GPT",
"UNetVDM",
"UNetModel",
"adapters",
)
from .transformer import GPT
from .unet_vdm import UNetVDM
from .unet_improved import UNetModel
from . import adapters
================================================
FILE: networks/adapters.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
import torch
from torch import Tensor
from torch import nn
from utils_model import sandwich, pe_encode, pe_encode_float
class TextInputAdapter(nn.Module):
"""
A module to convert sequences of text class tokens to embedding tokens with learned positional embeddings.
"""
def __init__(
self,
vocab_size: int,
seq_len: int,
output_size: int = 256,
learn_pos_embedding: bool = False,
):
super().__init__()
self.learn_pos_embedding = learn_pos_embedding
if learn_pos_embedding:
self.pos_embedding = nn.Embedding(seq_len, output_size)
else:
self.register_buffer("pos_embedding", pe_encode(seq_len, output_size))
self.inp_embedding = nn.Linear(vocab_size, output_size)
self.t_embedding = nn.Linear(1, output_size)
def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor:
inp_emb = self.inp_embedding(2 * probs - 1)
if self.learn_pos_embedding:
pos_emb = self.pos_embedding(
torch.arange(0, probs.size(1)).to(probs.device)
)
else:
pos_emb = self.pos_embedding
pos_emb = pos_emb.unsqueeze(0).expand(inp_emb.size(0), -1, -1)
t_emb = self.t_embedding((2 * t - 1).unsqueeze(-1))
output = inp_emb + pos_emb + t_emb
return output
class FourierImageInputAdapter(nn.Module):
"""
A module to convert 2D image coordinates into a set of vectors represented as a matrix, with fourier position codes.
"""
def __init__(
self,
input_channels: int = 3,
input_shape: Tuple[int, int] = (224, 224),
n_freq_bands: int = 64,
output_height: int = 256,
value_res: int = -1,
mask_res: int = -1,
add_pos_feats: bool = True,
add_mask: bool = True,
learn_pos_feats: bool = False,
pos_embed_size: int = 32,
init_scale: float = 0.02,
):
super().__init__()
self.input_shape = input_shape
self.n_freq_bands = n_freq_bands
self.value_res = value_res
self.mask_res = mask_res
self.add_pos_feats = add_pos_feats
self.add_mask = add_mask
if learn_pos_feats:
pos_feats = nn.Parameter(
init_scale
* torch.randn(1, input_shape[0] * input_shape[1], pos_embed_size)
)
self.register_parameter("pos_feats", pos_feats)
else:
x = torch.linspace(-1.0, 1.0, steps=input_shape[0])
y = torch.linspace(-1.0, 1.0, steps=input_shape[1])
x_pos, y_pos = torch.meshgrid(x, y, indexing="ij")
pos = torch.stack((x_pos, y_pos), dim=-1)
pos = pos.reshape(-1, 2)
x_bands = torch.linspace(1.0, input_shape[0] / 2, steps=n_freq_bands)
y_bands = torch.linspace(1.0, input_shape[1] / 2, steps=n_freq_bands)
bands = torch.stack((x_bands, y_bands), dim=0)
vals = pos[:, :, None] * bands[None, :, :]
vals = math.pi * vals.reshape(vals.shape[0], -1)
pos_feats = torch.cat([vals.sin(), vals.cos()], dim=-1)
pos_feats = torch.cat([pos_feats, pos], dim=-1)
self.register_buffer("pos_feats", pos_feats)
img_feat_height = input_channels
pos_feat_height = pos_feats.size(-1)
if self.mask_res > 0:
mask_feat_height = (n_freq_bands * 2) + 1
else:
mask_feat_height = 1
all_feat_height = img_feat_height
if add_mask:
all_feat_height += mask_feat_height
if add_pos_feats:
all_feat_height += pos_feat_height
self.output_projection = None
if output_height != all_feat_height:
self.output_projection = nn.Linear(all_feat_height, output_height)
def forward(self, img: Tensor, t: Tensor) -> Tensor:
flat_img = sandwich(img)
flat_t = sandwich(t)
t_feats = (flat_t.float()[..., :1] * 2) - 1
if self.mask_res > 0:
t_feats = torch.cat(
[
t_feats,
pe_encode_float(
t_feats, self.mask_res, self.n_freq_bands * 2
).flatten(start_dim=2),
],
-1,
)
fourier_feats = self.pos_feats.expand(img.size(0), -1, -1)
all_feat_list = [flat_img]
if self.add_mask:
all_feat_list.append(t_feats)
if self.add_pos_feats:
all_feat_list.append(fourier_feats)
all_feats = torch.cat(all_feat_list, dim=-1)
if self.output_projection is None:
output = all_feats
else:
output = self.output_projection(all_feats)
return output
class OutputAdapter(nn.Module):
def __init__(self, input_height: int, output_channels: int, output_height: int):
super().__init__()
self.output_channels = output_channels
self.output_height = output_height
self.output_projection = nn.Linear(
input_height, output_channels * output_height
)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
output = self.output_projection(inp)
return output.reshape(
output.size(0), -1, self.output_channels, self.output_height
)
================================================
FILE: networks/transformer.py
================================================
# Source: https://github.com/karpathy/nanoGPT
#
# MIT License
#
# Copyright (c) 2022 Andrej Karpathy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Modifications:
# - Added data_adapters to GPT to preprocess the inputs and (optionally) postprocess the outputs
# - Added the `skip` option to concat the input and output of the network before the final projection
# - Added time `t` as an input to `forward()`
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def gelu(x):
return F.gelu(x, approximate="tanh")
class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class SelfAttention(nn.Module):
def __init__(self, n_head, n_embd, dropout, bias, is_causal):
super().__init__()
assert n_embd % n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias)
# output projection
self.c_proj = nn.Linear(n_embd, n_embd, bias=bias)
# regularization
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.is_causal = is_causal
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout if self.training else 0, is_causal=self.is_causal
)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
def __init__(self, n_embd, dropout, bias):
super().__init__()
self.c_fc = nn.Linear(n_embd, 4 * n_embd, bias=bias)
self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.c_fc(x)
x = gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, n_head, n_embd, dropout, bias, is_causal):
super().__init__()
self.ln_1 = LayerNorm(n_embd, bias=bias)
self.attn = SelfAttention(n_head, n_embd, dropout, bias, is_causal)
self.ln_2 = LayerNorm(n_embd, bias=bias)
self.mlp = MLP(n_embd, dropout, bias)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
def __init__(
self,
data_adapters: dict,
vocab_size: int,
n_layer: int = 12,
n_head: int = 12,
n_embd: int = 768,
dropout: float = 0.0,
bias: bool = True,
skip: bool = False,
is_causal: bool = False,
):
super().__init__()
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.input_adapter = data_adapters["input_adapter"]
self.output_adapter = data_adapters["output_adapter"]
self.transformer = nn.ModuleDict(
dict(
drop=nn.Dropout(dropout),
h=nn.ModuleList([Block(n_head, n_embd, dropout, bias, is_causal) for _ in range(n_layer)]),
ln_f=LayerNorm(n_embd, bias=bias),
)
)
self.is_causal = is_causal
if self.is_causal:
self.skip = False
else:
self.skip = skip
if skip:
self.lm_head = nn.Linear(2 * n_embd, vocab_size, bias=bias)
else:
self.lm_head = nn.Linear(n_embd, vocab_size, bias=bias)
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("c_proj.weight"):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layer))
# report number of parameters
print(f"number of parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6:.2f}M")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
x_in = self.input_adapter(data, t)
x = self.transformer.drop(x_in)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if self.skip:
x = torch.cat([x, x_in], -1)
logits = self.output_adapter(self.lm_head(x)) if self.output_adapter else self.lm_head(x)
return logits
def get_optim_groups(self, weight_decay: float):
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear,)
blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are recursive
# we will see the same tensors p many many times. but doing it this way
# allows us to know which parent module any tensor p belongs to...
if pn.endswith("bias"):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# We don't use weight tying so comment this out
# decay.remove('lm_head.weight')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
assert (
len(param_dict.keys() - union_params) == 0
), "parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),)
# create the pytorch optimizer groups
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
return optim_groups
================================================
FILE: networks/unet_improved.py
================================================
# Source: https://github.com/openai/improved-diffusion
#
# MIT License
#
# Copyright (c) 2021 OpenAI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Modifications:
# - Added data_adapters to UNetModel to preprocess the inputs and postprocess the outputs
# - Added the `skip` option to concat the input and output of the network before the final projection
# - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps`
from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from utils_model import sandwich
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
"""
Helpers to train with 16-bit precision.
"""
def convert_module_to_f16(module):
"""
Convert primitive modules to float16.
"""
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
module.weight.data = module.weight.data.half()
module.bias.data = module.bias.data.half()
def convert_module_to_f32(module):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
module.weight.data = module.weight.data.float()
module.bias.data = module.bias.data.float()
def make_master_params(model_params):
"""
Copy model parameters into a (differently-shaped) list of full-precision
parameters.
"""
master_params = _flatten_dense_tensors([param.detach().float() for param in model_params])
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]
def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors([param.grad.data.detach().float() for param in model_params])
def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
# Without copying to a list, if a generator is passed, this will
# silently not copy any parameters.
model_params = list(model_params)
for param, master_param in zip(model_params, unflatten_master_params(model_params, master_params)):
param.detach().copy_(master_param)
def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), model_params)
def zero_grad(model_params):
for param in model_params:
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * th.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(th.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with th.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with th.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = th.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2):
super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, channels, channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2):
super().__init__()
self.channels = channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1)
else:
self.op = avg_pool_nd(stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.emb_layers = nn.Sequential(
SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
SiLU(),
nn.Dropout(p=dropout),
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint)
def _forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(self, channels, num_heads=1, use_checkpoint=False):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention()
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
h = self.attention(qkv)
h = h.reshape(b, -1, h.shape[-1])
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention.
"""
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
:return: an [N x C x T] tensor after attention.
"""
ch = qkv.shape[1] // 3
q, k, v = th.split(qkv, ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
return th.einsum("bts,bcs->bct", weight, v)
@staticmethod
def count_flops(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
"""
def __init__(
self,
data_adapters,
image_size=32,
in_channels=3,
model_channels=128,
out_channels=128,
num_res_blocks=3,
attention_resolutions=[8, 16],
dropout=0,
channel_mult=(1, 2, 2, 2),
conv_resample=True,
dims=2,
skip=True,
num_classes=None,
use_checkpoint=False,
num_heads=4,
num_heads_upsample=-1,
use_scale_shift_norm=False,
project_input=False,
):
super().__init__()
self.input_adapter = data_adapters["input_adapter"]
self.output_adapter = data_adapters["output_adapter"]
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.num_heads = num_heads
self.num_heads_upsample = num_heads_upsample
self.skip = skip
self.project_input = project_input
if project_input:
self.input_projection = nn.Linear(self.in_channels, self.model_channels)
in_channels = self.model_channels
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
)
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads))
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
self.input_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)))
input_block_chans.append(ch)
ds *= 2
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
layers = [
ResBlock(
ch + input_block_chans.pop(),
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
)
)
if level and i == num_res_blocks:
layers.append(Upsample(ch, conv_resample, dims=dims))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
normalization(ch),
SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
@property
def inner_dtype(self):
"""
Get the dtype used by the torso of the model.
"""
return next(self.input_blocks.parameters()).dtype
def forward(
self,
data: th.Tensor,
t: th.Tensor,
) -> th.Tensor:
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
y = None
flat_x = self.input_adapter(data, t)
x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.in_channels)
if self.project_input:
x = self.input_projection(x)
x_perm = x.permute(0, 3, 1, 2).contiguous()
timesteps = t.flatten(start_dim=1)[:, 0] * 4000
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x_perm.type(self.inner_dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
for module in self.output_blocks:
cat_in = th.cat([h, hs.pop()], dim=1)
h = module(cat_in, emb)
h = h.type(x.dtype)
out = sandwich(self.out(h).permute(0, 2, 3, 1).contiguous())
if self.skip:
out = th.cat([sandwich(x), out], -1)
out = self.output_adapter(out)
return out
def get_feature_vectors(self, x, timesteps, y=None):
"""
Apply the model and return all of the intermediate tensors.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: a dict with the following keys:
- 'down': a list of hidden state tensors from downsampling.
- 'middle': the tensor of the output of the lowest-resolution
block in the model.
- 'up': a list of hidden state tensors from upsampling.
"""
hs = []
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
result = dict(down=[], up=[])
h = x.type(self.inner_dtype)
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
result["down"].append(h.type(x.dtype))
h = self.middle_block(h, emb)
result["middle"] = h.type(x.dtype)
for module in self.output_blocks:
cat_in = th.cat([h, hs.pop()], dim=1)
h = module(cat_in, emb)
result["up"].append(h.type(x.dtype))
return result
================================================
FILE: networks/unet_vdm.py
================================================
# Source: https://github.com/addtt/variational-diffusion-models
#
# MIT License
#
# Copyright (c) 2022 Andrea Dittadi
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Modifications:
# - Added data_adapters to UNetVDM to preprocess the inputs and postprocess the outputs
# - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps`
# - Added 1/1000 to t before computing timesteps embeddings so t isn't 0
# - Added concatenation of input and output of the network before the final projection
import numpy as np
import torch
from torch import einsum, nn, pi, softmax
from utils_model import sandwich
@torch.no_grad()
def zero_init(module: nn.Module) -> nn.Module:
"""Sets to zero all the parameters of a module, and returns the module."""
for p in module.parameters():
nn.init.zeros_(p.data)
return module
class UNetVDM(nn.Module):
def __init__(
self,
data_adapters,
embedding_dim: int = 128,
n_blocks: int = 32,
n_attention_heads: int = 1,
dropout_prob: float = 0.1,
norm_groups: int = 32,
input_channels: int = 3,
use_fourier_features: bool = True,
attention_everywhere: bool = False,
image_size: int = 32,
):
super().__init__()
self.input_adapter = data_adapters["input_adapter"]
self.output_adapter = data_adapters["output_adapter"]
attention_params = dict(
n_heads=n_attention_heads,
n_channels=embedding_dim,
norm_groups=norm_groups,
)
resnet_params = dict(
ch_in=embedding_dim,
ch_out=embedding_dim,
condition_dim=4 * embedding_dim,
dropout_prob=dropout_prob,
norm_groups=norm_groups,
)
if use_fourier_features:
self.fourier_features = FourierFeatures()
self.embed_conditioning = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim * 4),
nn.SiLU(),
nn.Linear(embedding_dim * 4, embedding_dim * 4),
nn.SiLU(),
)
total_input_ch = input_channels
if use_fourier_features:
total_input_ch *= 1 + self.fourier_features.num_features
self.conv_in = nn.Conv2d(total_input_ch, embedding_dim, 3, padding=1)
# Down path: n_blocks blocks with a resnet block and maybe attention.
self.down_blocks = nn.ModuleList(
UpDownBlock(
resnet_block=ResnetBlock(**resnet_params),
attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,
)
for _ in range(n_blocks)
)
self.mid_resnet_block_1 = ResnetBlock(**resnet_params)
self.mid_attn_block = AttentionBlock(**attention_params)
self.mid_resnet_block_2 = ResnetBlock(**resnet_params)
# Up path: n_blocks+1 blocks with a resnet block and maybe attention.
resnet_params["ch_in"] *= 2 # double input channels due to skip connections
self.up_blocks = nn.ModuleList(
UpDownBlock(
resnet_block=ResnetBlock(**resnet_params),
attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,
)
for _ in range(n_blocks + 1)
)
self.conv_out = nn.Sequential(
nn.GroupNorm(num_groups=norm_groups, num_channels=embedding_dim),
nn.SiLU(),
zero_init(nn.Conv2d(embedding_dim, embedding_dim, 3, padding=1)),
)
self.embedding_dim = embedding_dim
self.input_channels = input_channels
self.image_size = image_size
self.use_fourier_features = use_fourier_features
def forward(
self,
data: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
flat_x = self.input_adapter(data, t)
x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.input_channels)
x_perm = x.permute(0, 3, 1, 2).contiguous()
t = t.float().flatten(start_dim=1)[:, 0]
t_embedding = get_timestep_embedding(t + 0.001, self.embedding_dim)
# We will condition on time embedding.
cond = self.embed_conditioning(t_embedding)
h = self.maybe_concat_fourier(x_perm)
h = self.conv_in(h) # (B, embedding_dim, H, W)
hs = []
for down_block in self.down_blocks: # n_blocks times
hs.append(h)
h = down_block(h, cond)
hs.append(h)
h = self.mid_resnet_block_1(h, cond)
h = self.mid_attn_block(h)
h = self.mid_resnet_block_2(h, cond)
for up_block in self.up_blocks: # n_blocks+1 times
h = torch.cat([h, hs.pop()], dim=1)
h = up_block(h, cond)
out = sandwich(self.conv_out(h).permute(0, 2, 3, 1).contiguous())
out = torch.cat([sandwich(x), out], -1)
out = self.output_adapter(out)
return out
def maybe_concat_fourier(self, z):
if self.use_fourier_features:
return torch.cat([z, self.fourier_features(z)], dim=1)
return z
class ResnetBlock(nn.Module):
def __init__(
self,
ch_in,
ch_out=None,
condition_dim=None,
dropout_prob=0.0,
norm_groups=32,
):
super().__init__()
ch_out = ch_in if ch_out is None else ch_out
self.ch_out = ch_out
self.condition_dim = condition_dim
self.net1 = nn.Sequential(
nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in),
nn.SiLU(),
nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
)
if condition_dim is not None:
self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False))
self.net2 = nn.Sequential(
nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out),
nn.SiLU(),
nn.Dropout(dropout_prob),
zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)),
)
if ch_in != ch_out:
self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)
def forward(self, x, condition):
h = self.net1(x)
if condition is not None:
assert condition.shape == (x.shape[0], self.condition_dim)
condition = self.cond_proj(condition)
condition = condition[:, :, None, None]
h = h + condition
h = self.net2(h)
if x.shape[1] != self.ch_out:
x = self.skip_conv(x)
assert x.shape == h.shape
return x + h
def get_timestep_embedding(
timesteps,
embedding_dim: int,
dtype=torch.float32,
max_timescale=10_000,
min_timescale=1,
):
# Adapted from tensor2tensor and VDM codebase.
assert timesteps.ndim == 1
assert embedding_dim % 2 == 0
timesteps *= 1000.0 # In DDPM the time step is in [0, 1000], here [0, 1]
num_timescales = embedding_dim // 2
inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n))
-np.log10(min_timescale),
-np.log10(max_timescale),
num_timescales,
device=timesteps.device,
)
emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2)
return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D)
class FourierFeatures(nn.Module):
def __init__(self, first=5.0, last=6.0, step=1.0):
super().__init__()
self.freqs_exponent = torch.arange(first, last + 1e-8, step)
@property
def num_features(self):
return len(self.freqs_exponent) * 2
def forward(self, x):
assert len(x.shape) >= 2
# Compute (2pi * 2^n) for n in freqs.
freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device) # (F, )
freqs = 2.0**freqs_exponent * 2 * pi # (F, )
freqs = freqs.view(-1, *([1] * (x.dim() - 1))) # (F, 1, 1, ...)
# Compute (2pi * 2^n * x) for n in freqs.
features = freqs * x.unsqueeze(1) # (B, F, X1, X2, ...)
features = features.flatten(1, 2) # (B, F * C, X1, X2, ...)
# Output features are cos and sin of above. Shape (B, 2 * F * C, H, W).
return torch.cat([features.sin(), features.cos()], dim=1)
def attention_inner_heads(qkv, num_heads):
"""Computes attention with heads inside of qkv in the channel dimension.
Args:
qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:
H = number of heads,
C = number of channels per head.
num_heads: number of heads.
Returns:
Attention output of shape (B, H*C, T).
"""
bs, width, length = qkv.shape
ch = width // (3 * num_heads)
# Split into (q, k, v) of shape (B, H*C, T).
q, k, v = qkv.chunk(3, dim=1)
# Rescale q and k. This makes them contiguous in memory.
scale = ch ** (-1 / 4) # scale with 4th root = scaling output by sqrt
q = q * scale
k = k * scale
# Reshape qkv to (B*H, C, T).
new_shape = (bs * num_heads, ch, length)
q = q.view(*new_shape)
k = k.view(*new_shape)
v = v.reshape(*new_shape)
# Compute attention.
weight = einsum("bct,bcs->bts", q, k) # (B*H, T, T)
weight = softmax(weight.float(), dim=-1).to(weight.dtype) # (B*H, T, T)
out = einsum("bts,bcs->bct", weight, v) # (B*H, C, T)
return out.reshape(bs, num_heads * ch, length) # (B, H*C, T)
class Attention(nn.Module):
"""Based on https://github.com/openai/guided-diffusion."""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
assert qkv.dim() >= 3, qkv.dim()
assert qkv.shape[1] % (3 * self.n_heads) == 0
spatial_dims = qkv.shape[2:]
qkv = qkv.view(*qkv.shape[:2], -1) # (B, 3*H*C, T)
out = attention_inner_heads(qkv, self.n_heads) # (B, H*C, T)
return out.view(*out.shape[:2], *spatial_dims).contiguous()
class AttentionBlock(nn.Module):
"""Self-attention residual block."""
def __init__(self, n_heads, n_channels, norm_groups):
super().__init__()
assert n_channels % n_heads == 0
self.layers = nn.Sequential(
nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),
nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1), # (B, 3 * C, H, W)
Attention(n_heads),
zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),
)
def forward(self, x):
return self.layers(x) + x
class UpDownBlock(nn.Module):
def __init__(self, resnet_block, attention_block=None):
super().__init__()
self.resnet_block = resnet_block
self.attention_block = attention_block
def forward(self, x, cond):
x = self.resnet_block(x, cond)
if self.attention_block is not None:
x = self.attention_block(x)
return x
================================================
FILE: probability.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import functools
from abc import abstractmethod
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical as torch_Categorical
from torch.distributions.bernoulli import Bernoulli as torch_Bernoulli
from torch.distributions.mixture_same_family import MixtureSameFamily
from torch.distributions.uniform import Uniform
from math import log
from utils_model import (
safe_exp,
safe_log,
idx_to_float,
float_to_idx,
quantize, sandwich,
)
class CtsDistribution:
@abstractmethod
def log_prob(self, x):
pass
@abstractmethod
def sample(self):
pass
class DiscreteDistribution:
@property
@abstractmethod
def probs(self):
pass
@functools.cached_property
def log_probs(self):
return safe_log(self.probs)
@functools.cached_property
def mean(self):
pass
@functools.cached_property
def mode(self):
pass
@abstractmethod
def log_prob(self, x):
pass
@abstractmethod
def sample(self):
pass
class DiscretizedDistribution(DiscreteDistribution):
def __init__(self, num_bins, device):
self.num_bins = num_bins
self.bin_width = 2.0 / num_bins
self.half_bin_width = self.bin_width / 2.0
self.device = device
@functools.cached_property
def class_centres(self):
return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device)
@functools.cached_property
def class_boundaries(self):
return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device)
@functools.cached_property
def mean(self):
return (self.probs * self.class_centres).sum(-1)
@functools.cached_property
def mode(self):
mode_idx = self.probs.argmax(-1).flatten()
return self.class_centres[mode_idx].reshape(self.probs.shape[:-1])
class DiscretizedCtsDistribution(DiscretizedDistribution):
def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5):
super().__init__(num_bins, device)
self.cts_dist = cts_dist
self.log_bin_width = log(self.bin_width)
self.batch_dims = batch_dims
self.clip = clip
self.min_prob = min_prob
@functools.cached_property
def probs(self):
bdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims)))
bdry_slice = bdry_cdfs[:1]
if self.clip:
cdf_min = torch.zeros_like(bdry_slice)
cdf_max = torch.ones_like(bdry_slice)
bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)
return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1)
else:
cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1)
cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice))
bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)
cdf_range = cdf_max - cdf_min
cdf_mask = cdf_range < self.min_prob
cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)
probs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range
probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs)
return probs.moveaxis(0, -1)
def prob(self, x):
class_idx = float_to_idx(x, self.num_bins)
centre = idx_to_float(class_idx, self.num_bins)
cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width)
cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width)
if self.clip:
cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo)
cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi)
return cdf_hi - cdf_lo
else:
cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1)
cdf_max = self.cts_dist.cdf(torch.ones_like(centre))
cdf_range = cdf_max - cdf_min
cdf_mask = cdf_range < self.min_prob
cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)
prob = (cdf_hi - cdf_lo) / cdf_range
return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob)
def log_prob(self, x):
prob = self.prob(x)
return torch.where(
prob < self.min_prob,
self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width,
safe_log(prob),
)
def sample(self, sample_shape=torch.Size([])):
if self.clip:
return quantize(self.cts_dist.sample(sample_shape), self.num_bins)
else:
assert hasattr(self.cts_dist, "icdf")
cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1)
cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min))
u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape)
cts_samp = self.cts_dist.icdf(u)
return quantize(cts_samp, self.num_bins)
class GMM(MixtureSameFamily):
def __init__(self, mix_wt_logits, means, std_devs):
mix_wts = torch_Categorical(logits=mix_wt_logits, validate_args=False)
components = Normal(means, std_devs, validate_args=False)
super().__init__(mix_wts, components, validate_args=False)
class DiscretizedGMM(DiscretizedCtsDistribution):
def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
assert params.size(-1) % 3 == 0
if min_std_dev < 0:
min_std_dev = 1.0 / (num_bins * 5)
mix_wt_logits, means, std_devs = params.chunk(3, -1)
if log_dev:
std_devs = safe_exp(std_devs)
std_devs = std_devs.clamp(min=min_std_dev, max=max_std_dev)
super().__init__(
cts_dist=GMM(mix_wt_logits, means, std_devs),
num_bins=num_bins,
device=params.device,
batch_dims=params.ndim - 1,
clip=clip,
min_prob=min_prob,
)
class DiscretizedNormal(DiscretizedCtsDistribution):
def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
assert params.size(-1) == 2
if min_std_dev < 0:
min_std_dev = 1.0 / (num_bins * 5)
mean, std_dev = params.split(1, -1)[:2]
if log_dev:
std_dev = safe_exp(std_dev)
std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev)
super().__init__(
cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False),
num_bins=num_bins,
device=params.device,
batch_dims=params.ndim - 1,
clip=clip,
min_prob=min_prob,
)
class Bernoulli(DiscreteDistribution):
def __init__(self, logits):
self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)
@functools.cached_property
def probs(self):
p = self.bernoulli.probs.unsqueeze(-1)
return torch.cat([1 - p, p], -1)
@functools.cached_property
def mode(self):
return self.bernoulli.mode
def log_prob(self, x):
return self.bernoulli.log_prob(x.float())
def sample(self, sample_shape=torch.Size([])):
return self.bernoulli.sample(sample_shape)
class DiscretizedBernoulli(DiscretizedDistribution):
def __init__(self, logits):
super().__init__(2, logits.device)
self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)
@functools.cached_property
def probs(self):
p = self.bernoulli.probs.unsqueeze(-1)
return torch.cat([1 - p, p], -1)
@functools.cached_property
def mode(self):
return idx_to_float(self.bernoulli.mode, 2)
def log_prob(self, x):
return self.bernoulli.log_prob(float_to_idx(x, 2).float())
def sample(self, sample_shape=torch.Size([])):
return idx_to_float(self.bernoulli.sample(sample_shape), 2)
class DeltaDistribution(CtsDistribution):
def __init__(self, mean, clip_range=1.0):
if clip_range > 0:
mean = mean.clip(min=-clip_range, max=clip_range)
self.mean = mean
@functools.cached_property
def mode(self):
return self.mean
@functools.cached_property
def mean(self):
return self.mean
def sample(self, sample_shape=torch.Size([])):
return self.mean
class Categorical(DiscreteDistribution):
def __init__(self, logits):
self.categorical = torch_Categorical(logits=logits, validate_args=False)
self.n_classes = logits.size(-1)
@functools.cached_property
def probs(self):
return self.categorical.probs
@functools.cached_property
def mode(self):
return self.categorical.mode
def log_prob(self, x):
return self.categorical.log_prob(x)
def sample(self, sample_shape=torch.Size([])):
return self.categorical.sample(sample_shape)
class DiscretizedCategorical(DiscretizedDistribution):
def __init__(self, logits=None, probs=None):
assert (logits is not None) or (probs is not None)
if logits is not None:
super().__init__(logits.size(-1), logits.device)
self.categorical = torch_Categorical(logits=logits, validate_args=False)
else:
super().__init__(probs.size(-1), probs.device)
self.categorical = torch_Categorical(probs=probs, validate_args=False)
@functools.cached_property
def probs(self):
return self.categorical.probs
@functools.cached_property
def mode(self):
return idx_to_float(self.categorical.mode, self.num_bins)
def log_prob(self, x):
return self.categorical.log_prob(float_to_idx(x, self.num_bins))
def sample(self, sample_shape=torch.Size([])):
return idx_to_float(self.categorical.sample(sample_shape), self.num_bins)
class CtsDistributionFactory:
@abstractmethod
def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution:
"""Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory."""
pass
class GMMFactory(CtsDistributionFactory):
def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True):
self.min_std_dev = min_std_dev
self.max_std_dev = max_std_dev
self.log_dev = log_dev
def get_dist(self, params, input_params=None, t=None):
mix_wt_logits, means, std_devs = params.chunk(3, -1)
if self.log_dev:
std_devs = safe_exp(std_devs)
std_devs = std_devs.clamp(min=self.min_std_dev, max=self.max_std_dev)
return GMM(mix_wt_logits, means, std_devs)
class NormalFactory(CtsDistributionFactory):
def __init__(self, min_std_dev=1e-3, max_std_dev=10):
self.min_std_dev = min_std_dev
self.max_std_dev = max_std_dev
def get_dist(self, params, input_params=None, t=None):
mean, log_std_dev = params.split(1, -1)[:2]
std_dev = safe_exp(log_std_dev).clamp(min=self.min_std_dev, max=self.max_std_dev)
return Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False)
class DeltaFactory(CtsDistributionFactory):
def __init__(self, clip_range=1.0):
self.clip_range = clip_range
def get_dist(self, params, input_params=None, t=None):
return DeltaDistribution(params.squeeze(-1), self.clip_range)
class DiscreteDistributionFactory:
@abstractmethod
def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution:
"""Note: input_params and t are only required by PredDistToDataDistFactory."""
pass
class BernoulliFactory(DiscreteDistributionFactory):
def get_dist(self, params, input_params=None, t=None):
return Bernoulli(logits=params.squeeze(-1))
class CategoricalFactory(DiscreteDistributionFactory):
def get_dist(self, params, input_params=None, t=None):
return Categorical(logits=params)
class DiscretizedBernoulliFactory(DiscreteDistributionFactory):
def get_dist(self, params, input_params=None, t=None):
return DiscretizedBernoulli(logits=params.squeeze(-1))
class DiscretizedCategoricalFactory(DiscreteDistributionFactory):
def get_dist(self, params, input_params=None, t=None):
return DiscretizedCategorical(logits=params)
class DiscretizedGMMFactory(DiscreteDistributionFactory):
def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
self.num_bins = num_bins
self.clip = clip
self.min_std_dev = min_std_dev
self.max_std_dev = max_std_dev
self.min_prob = min_prob
self.log_dev = log_dev
def get_dist(self, params, input_params=None, t=None):
return DiscretizedGMM(
params,
num_bins=self.num_bins,
clip=self.clip,
min_std_dev=self.min_std_dev,
max_std_dev=self.max_std_dev,
min_prob=self.min_prob,
log_dev=self.log_dev,
)
class DiscretizedNormalFactory(DiscreteDistributionFactory):
def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
self.num_bins = num_bins
self.clip = clip
self.min_std_dev = min_std_dev
self.max_std_dev = max_std_dev
self.min_prob = min_prob
self.log_dev = log_dev
def get_dist(self, params, input_params=None, t=None):
return DiscretizedNormal(
params,
num_bins=self.num_bins,
clip=self.clip,
min_std_dev=self.min_std_dev,
max_std_dev=self.max_std_dev,
min_prob=self.min_prob,
log_dev=self.log_dev,
)
def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tensor, input_mean: torch.Tensor, t: torch.Tensor, min_variance: float, min_t=1e-6):
"""Convert output parameters that predict the noise added to data, to parameters that predict the data."""
data_shape = list(noise_pred_params.shape)[:-1]
noise_pred_params = sandwich(noise_pred_params)
input_mean = input_mean.flatten(start_dim=1)
if torch.is_tensor(t):
t = t.flatten(start_dim=1)
else:
t = (input_mean * 0) + t
alpha_mask = (t < min_t).unsqueeze(-1)
posterior_var = torch.pow(min_variance, t.clamp(min=min_t))
gamma = 1 - posterior_var
A = (input_mean / gamma).unsqueeze(-1)
B = (posterior_var / gamma).sqrt().unsqueeze(-1)
data_pred_params = []
if noise_pred_params.size(-1) == 1:
noise_pred_mean = noise_pred_params
elif noise_pred_params.size(-1) == 2:
noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1)
else:
assert noise_pred_params.size(-1) % 3 == 0
mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1)
data_pred_params.append(mix_wt_logits)
data_pred_mean = A - (B * noise_pred_mean)
data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean)
data_pred_params.append(data_pred_mean)
if noise_pred_params.size(-1) >= 2:
noise_pred_dev = safe_exp(noise_pred_log_dev)
data_pred_dev = B * noise_pred_dev
data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev)
data_pred_params.append(data_pred_dev)
data_pred_params = torch.cat(data_pred_params, -1)
data_pred_params = data_pred_params.reshape(data_shape + [-1])
return data_pred_params
class PredDistToDataDistFactory(DiscreteDistributionFactory):
def __init__(self, data_dist_factory, min_variance, min_t=1e-6):
self.data_dist_factory = data_dist_factory
self.data_dist_factory.log_dev = False
self.min_variance = min_variance
self.min_t = min_t
def get_dist(self, params, input_params, t):
data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t)
return self.data_dist_factory.get_dist(data_pred_params)
================================================
FILE: sample.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from omegaconf import OmegaConf, DictConfig
from utils_train import seed_everything, make_config, make_bfn
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
def main(cfg: DictConfig) -> torch.Tensor:
"""
Config entries:
seed (int): Optional
config_file (str): Name of config file containing model and data config for a saved checkpoint
load_model (str): Path to a saved checkpoint to be tested
sample_shape (list): Shape of sample batch, e.g.:
(3, 256) for sampling 3 sequences of length 256 from the text8 model.
(2, 32, 32, 3) for sampling 2 images from the CIFAR10 model.
(4, 28, 28, 1) for sampling 4 images from the MNIST model.
n_steps (int): Number of sampling steps (positive integer).
save_file (str): File path to save the generated sample tensor. Skip saving if None.
"""
seed_everything(cfg.seed)
print(f"Seeded everything with seed {cfg.seed}")
# Get model config from the training config file
train_cfg = make_config(cfg.config_file)
bfn = make_bfn(train_cfg.model)
bfn.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu"))
if torch.cuda.is_available():
bfn.to("cuda")
samples = bfn.sample(cfg.samples_shape, cfg.n_steps)
if cfg.save_file is not None:
torch.save(samples.to("cpu"), cfg.save_file)
return samples
if __name__ == "__main__":
main(OmegaConf.from_cli())
================================================
FILE: test.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
import torch
from omegaconf import OmegaConf, DictConfig
from rich import print
from torch import nn
from torch.utils.data import DataLoader
from data import make_datasets
from model import BFN
from utils_train import seed_everything, make_config, make_bfn, worker_init_function, make_progress_bar
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
def setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]:
test_ds = make_datasets(cfg.data)[-1]
test_dl = DataLoader(
dataset=test_ds,
worker_init_fn=worker_init_function,
batch_size=100,
shuffle=False,
num_workers=8,
pin_memory=True,
)
model = make_bfn(cfg.model)
return model, test_dl
@torch.inference_mode()
def test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: int) -> tuple[float, float, float, float]:
if torch.cuda.is_available():
model.to("cuda")
model.eval()
losses, recon_losses = [], []
pbar = make_progress_bar(True, "[red]loss: {task.fields[loss]:.4f} repeat: {task.fields[r]}")
with pbar:
task_id = pbar.add_task("Test", visible=True, total=n_repeats * len(dataloader), loss=math.nan, r=0)
for r in range(n_repeats):
_losses, _recon_losses = [], []
for eval_batch in dataloader:
eval_batch = eval_batch.to("cuda") if torch.cuda.is_available() else eval_batch
loss = model(eval_batch, n_steps=n_steps).item()
recon_loss = model.compute_reconstruction_loss(eval_batch).item()
_losses.append(loss)
_recon_losses.append(recon_loss)
pbar.update(task_id, advance=1, loss=torch.tensor(_losses).mean() + torch.tensor(_recon_losses).mean(), r=r+1)
losses.append(torch.tensor(_losses).mean())
recon_losses.append(torch.tensor(_recon_losses).mean())
losses = torch.stack(losses)
loss_mean, loss_err = losses.mean(), losses.std(correction=0).item() / math.sqrt(len(losses))
recon_losses = torch.stack(recon_losses)
recon_mean, recon_err = recon_losses.mean(), recon_losses.std(correction=0).item() / math.sqrt(len(recon_losses))
return loss_mean, loss_err, recon_mean, recon_err
def main(cfg: DictConfig) -> tuple[float, float, float, float]:
"""
Config entries:
seed (int): Optional
config_file (str): Name of config file containing model and data config for a saved checkpoint
load_model (str): Path to a saved checkpoint to be tested
n_steps (int): Number of Bayesian flow steps. Set to None for continuous time Bayesian flow loss.
n_repeats (int): Number of times to iterate through the dataset.
"""
seed_everything(cfg.seed)
print(f"Seeded everything with seed {cfg.seed}")
# Get model and data config from the training config file
train_cfg = make_config(cfg.config_file)
model, dataloader = setup(train_cfg)
model.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu"))
loss_mean, loss_err, recon_mean, recon_err = test(model, dataloader, cfg.n_steps, cfg.n_repeats)
print(f"For {cfg.n_steps} steps with {cfg.n_repeats} repeats:")
print(f"Loss is {loss_mean:.6f} +- {loss_err:.6f}")
print(f"Reconstruction Loss is {recon_mean:.6f} +- {recon_err:.6f}")
print(f"Total loss mean = {loss_mean + recon_mean}")
return loss_mean, loss_err, recon_mean, recon_err
if __name__ == "__main__":
main(OmegaConf.from_cli())
================================================
FILE: train.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from omegaconf import OmegaConf
from rich.logging import RichHandler
from rich.progress import Progress
from torch import nn, optim
from torch.utils.data import DataLoader
from model import BFN
from utils_train import (
seed_everything, log_cfg,
checkpoint_training_state,
init_checkpointing,
log,
update_ema,
ddict,
make_infinite,
make_progress_bar, make_config, make_dataloaders, make_bfn,
)
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True, show_time=False)],
)
logger = get_logger(__name__)
def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:
"""Create the model, dataloader and optimizer"""
dataloaders = make_dataloaders(cfg)
model = make_bfn(cfg.model)
if "weight_decay" in cfg.optimizer.keys() and hasattr(model.net, "get_optim_groups"):
params = model.net.get_optim_groups(cfg.optimizer.weight_decay)
else:
params = model.net.parameters()
# Instantiate the optimizer using the hyper-parameters in the config
optimizer = optim.AdamW(params=params, **cfg.optimizer)
return model, dataloaders, optimizer
@torch.no_grad()
def validate(
cfg,
model: BFN,
ema_model: nn.Module,
val_dataloader: DataLoader,
step: int,
run: "neptune.Run",
pbar: Optional[Progress],
best_val_loss: float,
checkpoint_root_dir: Optional[Path],
accelerator: Accelerator,
) -> float:
"""Evaluate model on validation data and save checkpoint if loss improves"""
dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[accelerator.mixed_precision]
model_to_eval = ema_model if ema_model is not None else model
model_to_eval.eval()
pbar = pbar or Progress()
max_steps = cfg.max_val_batches if cfg.max_val_batches > 0 else len(val_dataloader)
val_id = pbar.add_task("Validating", visible=True, total=cfg.val_repeats * max_steps, transient=True, loss=math.nan)
loss, count = 0.0, 0
for i in range(cfg.val_repeats):
for idx, eval_batch in enumerate(val_dataloader):
enabled = True if dtype in [torch.float16, torch.bfloat16] else False
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):
loss += model_to_eval(eval_batch.to(accelerator.device)).item()
count += 1
pbar.update(val_id, advance=1, loss=loss / count)
if (idx + 1) >= max_steps:
break
loss /= count
pbar.remove_task(val_id)
log(run["metrics"]["val"]["loss"], loss, step)
if checkpoint_root_dir is not None and (loss < best_val_loss or math.isinf(best_val_loss)):
logger.info(f"loss improved: new value is {loss}")
step_checkpoint_path = checkpoint_root_dir / "best"
run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()
checkpoint_training_state(step_checkpoint_path, accelerator, ema_model, step, run_id)
run["metrics/best/loss/metric"] = loss
run["metrics/best/loss/step"] = step
model.train()
return loss
def train(
cfg,
accelerator: Accelerator,
model: BFN,
ema_model: Optional[nn.Module],
dataloaders: dict,
optimizer: optim.Optimizer,
run: "neptune.Run",
):
is_main = accelerator.is_main_process
pbar = make_progress_bar(is_main)
run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()
train_id = pbar.add_task(f"Training {run_id}", start=cfg.start_step, total=cfg.n_training_steps, loss=math.nan)
checkpoint_root_dir = init_checkpointing(cfg.checkpoint_dir, run_id) if is_main else None
best_val_loss = math.inf
train_iter = make_infinite(dataloaders["train"])
model.train()
with pbar:
for step in range(cfg.start_step, cfg.n_training_steps + 1):
step_loss = 0.0
for _ in range(cfg.accumulate):
with accelerator.accumulate(model):
train_batch = next(train_iter)
loss = model(train_batch)
accelerator.backward(loss)
if accelerator.sync_gradients and cfg.grad_clip_norm > 0:
accelerator.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
step_loss += loss.item()
update_ema(ema_model, model, cfg.ema_decay)
if is_main and (step % cfg.checkpoint_interval == 0):
checkpoint_training_state(checkpoint_root_dir / "last", accelerator, ema_model, step, run_id)
run["checkpoints/last"].track_files(str(checkpoint_root_dir / "last"))
log(run["metrics"]["train"]["loss"], step_loss / cfg.accumulate, step, is_main and step % cfg.log_interval == 0)
log(run["metrics"]["epoch"], step // len(dataloaders["train"]), step, is_main)
if is_main and (step % cfg.val_interval == 0) and "val" in dataloaders:
val_loss = validate(
cfg=cfg,
model=model,
ema_model=ema_model,
val_dataloader=dataloaders["val"],
step=step,
run=run,
pbar=pbar,
best_val_loss=best_val_loss,
checkpoint_root_dir=checkpoint_root_dir,
accelerator=accelerator,
)
best_val_loss = min(val_loss, best_val_loss)
pbar.update(train_id, advance=1, loss=loss.item())
def main(cfg):
acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate)
seed_everything(cfg.training.seed)
logger.info(f"Seeded everything with seed {cfg.training.seed}", main_process_only=True)
with acc.main_process_first():
model, dataloaders, optimizer = setup(cfg)
ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None # EMA on main proc only
model, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"])
run = ddict()
if acc.is_main_process:
ema.to(acc.device)
try:
if cfg.meta.neptune:
import neptune
run = neptune.init_run(project=cfg.meta.neptune, mode="debug" if cfg.meta.debug else None)
run["accelerate"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes)
log_cfg(cfg, run)
except ImportError:
logger.info("Did not find neptune installed. Logging will be disabled.")
train(cfg.training, acc, model, ema, dataloaders, optimizer, run)
if __name__ == "__main__":
cfg_file = OmegaConf.from_cli()['config_file']
main(make_config(cfg_file))
================================================
FILE: utils_model.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
from torch import Tensor
CONST_log_range = 20
CONST_log_min = 1e-10
CONST_summary_rescale = 10
CONST_exp_range = 10
CONST_min_std_dev = math.exp(-CONST_exp_range)
def sandwich(x: Tensor):
return x.reshape(x.size(0), -1, x.size(-1))
def safe_log(data: Tensor):
return data.clamp(min=CONST_log_min).log()
def safe_exp(data: Tensor):
return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp()
def idx_to_float(idx: np.ndarray, num_bins: int):
flt_zero_one = (idx + 0.5) / num_bins
return (2.0 * flt_zero_one) - 1.0
def float_to_idx(flt: np.ndarray, num_bins: int):
flt_zero_one = (flt / 2.0) + 0.5
return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long()
def quantize(flt, num_bins: int):
return idx_to_float(float_to_idx(flt, num_bins), num_bins)
def pe_encode(sequence_length: int, embedding_size: int) -> Tensor:
"""Positional encoding as described in original attention is all you need paper"""
pe = torch.zeros((sequence_length, embedding_size))
pos = torch.arange(sequence_length).unsqueeze(1)
pe[:, 0::2] = torch.sin(
pos / torch.pow(1000, torch.arange(0, embedding_size, 2, dtype=torch.float32) / embedding_size)
)
pe[:, 1::2] = torch.cos(
pos / torch.pow(1000, torch.arange(1, embedding_size, 2, dtype=torch.float32) / embedding_size)
)
return pe
def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> Tensor:
pe = torch.zeros(list(x.shape) + [embedding_size], device=x.device)
pos = (((x + 1) / 2) * max_freq).unsqueeze(-1)
pe[..., 0::2] = torch.sin(
pos
/ torch.pow(10000, torch.arange(0, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)
)
pe[..., 1::2] = torch.cos(
pos
/ torch.pow(10000, torch.arange(1, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)
)
return pe
================================================
FILE: utils_train.py
================================================
# Copyright 2023 NNAISENSE SA
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import random
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Optional, Generator, Union
try:
import neptune
from neptune.utils import stringify_unsupported
except ImportError:
neptune = None
def stringify_unsupported(x):
return x
import numpy as np
import torch
from accelerate.logging import get_logger
from omegaconf import OmegaConf, DictConfig
from rich.progress import Progress, SpinnerColumn, MofNCompleteColumn, TimeElapsedColumn, TextColumn
from torch.utils.data import DataLoader
import model
import networks
import probability
from data import make_datasets
from networks import adapters
logger = get_logger(__name__)
def seed_everything(seed: Optional[int]):
assert seed is not None
seed += torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def worker_init_function(worker_id: int) -> None:
"""https://pytorch.org/docs/stable/notes/randomness.html#dataloader"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: str) -> Optional[Path]:
if checkpoint_dir is None:
return None
checkpoint_dir = Path(checkpoint_dir) / run_id
checkpoint_dir.mkdir(parents=True, exist_ok=True)
last_dir = checkpoint_dir / "last"
last_dir.mkdir(parents=True, exist_ok=True)
best_dir = checkpoint_dir / "best"
best_dir.mkdir(parents=True, exist_ok=True)
return checkpoint_dir
def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, step: int, run_id: str):
if checkpoint_dir is None:
return
logger.info(f"Checkpointing training state to {checkpoint_dir} at step {step}")
accelerator.save_state(checkpoint_dir)
with open(checkpoint_dir / "info.json", "w") as f:
json.dump({"step": step, "run_id": run_id}, f)
if ema_model is not None:
ema_checkpoint_path = checkpoint_dir / "ema_model.pt"
torch.save(ema_model.state_dict(), ema_checkpoint_path)
def log(key_handler, value, step, cond=True):
"""Log series to neptune only if cond is True. Helps with distributed training and conditional logging."""
if not isinstance(key_handler, defaultdict) and cond and math.isfinite(value):
key_handler.log(value, step=step)
def log_cfg(cfg, run: "neptune.Run"):
with tempfile.TemporaryDirectory() as tmpdir:
cfg_temp_filename: Path = Path(tmpdir) / "cfg.yaml"
cfg_temp_filename.write_text(OmegaConf.to_yaml(cfg, resolve=True))
run["cfg"].upload(str(cfg_temp_filename), wait=True)
run["hyperparameters"] = stringify_unsupported(OmegaConf.to_container(cfg, resolve=True))
@torch.no_grad()
def update_ema(ema_model, model, ema_decay):
if ema_model is not None and ema_decay > 0:
for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
ema_param.sub_((1 - ema_decay) * (ema_param - model_param))
def ddict():
"""Infinite default dict to fake neptune run on non-main processes"""
return defaultdict(ddict)
def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]:
while True:
for data in dataloader:
yield data
def make_progress_bar(is_main: bool, text="[red]loss: {task.fields[loss]:.3f}"):
return Progress(
SpinnerColumn(),
MofNCompleteColumn(),
*Progress.get_default_columns(),
TimeElapsedColumn(),
TextColumn(text),
disable=not is_main,
)
def make_dataloaders(cfg: DictConfig):
train_set, val_set, _ = make_datasets(cfg.data)
dataloaders = {
"train": DataLoader(
dataset=train_set,
worker_init_fn=worker_init_function,
**cfg.train_loader,
),
"val": DataLoader(
dataset=val_set,
worker_init_fn=worker_init_function,
**cfg.val_loader,
),
}
return dataloaders
def make_from_cfg(module, cfg, **parameters):
return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else None
def make_bfn(cfg: DictConfig):
data_adapters = {
"input_adapter": make_from_cfg(adapters, cfg.input_adapter),
"output_adapter": make_from_cfg(adapters, cfg.output_adapter),
}
net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters)
bayesian_flow = make_from_cfg(model, cfg.bayesian_flow)
distribution_factory = make_from_cfg(probability, cfg.distribution_factory)
loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory)
bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss)
return bfn
default_train_config = {
"meta": {
"neptune": None,
"debug": False,
"root_dir": ".",
},
"data": {
"dataset": "",
"data_dir": "./data",
},
"train_loader": {
"batch_size": 1,
"shuffle": True,
"num_workers": 0,
"pin_memory": True,
"drop_last": True,
},
"val_loader": {
"batch_size": 1,
"shuffle": False,
"num_workers": 0,
"pin_memory": True,
"drop_last": False,
},
"training": {
"accumulate": 1,
"checkpoint_dir": "./checkpoints",
"checkpoint_interval": None,
"ema_decay": -1,
"grad_clip_norm": -1,
"log_interval": 50,
"max_val_batches": -1,
"seed": 666,
"start_step": 1,
"val_repeats": 1,
},
}
def make_config(cfg_file: str):
cli_conf = OmegaConf.load(cfg_file)
# Start with default config
cfg = OmegaConf.create(default_train_config)
# Merge into default config
cfg = OmegaConf.merge(cfg, cli_conf)
return cfg
gitextract__0riu0_z/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── cifar10_continuous_16bins.yaml │ ├── cifar10_continuous_256bins.yaml │ ├── cifar10_discretized_16bins.yaml │ ├── cifar10_discretized_256bins.yaml │ ├── mnist_discrete.yaml │ └── text8_discrete.yaml ├── data.py ├── env.yml ├── model.py ├── networks/ │ ├── __init__.py │ ├── adapters.py │ ├── transformer.py │ ├── unet_improved.py │ └── unet_vdm.py ├── probability.py ├── sample.py ├── test.py ├── train.py ├── utils_model.py └── utils_train.py
SYMBOL INDEX (288 symbols across 12 files)
FILE: data.py
function bin_mnist_transform (line 37) | def bin_mnist_transform(x):
function bin_mnist_cts_transform (line 41) | def bin_mnist_cts_transform(x):
function rgb_image_transform (line 45) | def rgb_image_transform(x, num_bins=256):
class MyLambda (line 49) | class MyLambda(torchvision.transforms.Lambda):
method __init__ (line 50) | def __init__(self, lambd, arg1):
method __call__ (line 54) | def __call__(self, x):
class CIFAR10 (line 58) | class CIFAR10(torchvision.datasets.CIFAR10):
method __getitem__ (line 59) | def __getitem__(self, idx):
class MNIST (line 63) | class MNIST(torchvision.datasets.MNIST):
method __getitem__ (line 64) | def __getitem__(self, idx):
function make_datasets (line 68) | def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:
function prepare_text8 (line 127) | def prepare_text8(data_dir: pathlib.Path):
class Text8Dataset (line 185) | class Text8Dataset(Dataset):
method __init__ (line 186) | def __init__(self, data_dir: Union[str, pathlib.Path], split: str, dow...
method __getitem__ (line 204) | def __getitem__(self, index) -> torch.Tensor:
method __len__ (line 208) | def __len__(self):
function char_ids_to_str (line 212) | def char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) ...
function batch_to_str (line 217) | def batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) ...
function batch_to_images (line 222) | def batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt...
FILE: model.py
class BayesianFlow (line 42) | class BayesianFlow(nn.Module, ABC):
method __init__ (line 43) | def __init__(self):
method get_prior_input_params (line 47) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi...
method params_to_net_inputs (line 54) | def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:
method get_alpha (line 59) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:
method get_sender_dist (line 66) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap...
method update_input_params (line 73) | def update_input_params(self, input_params: tuple[Tensor, ...], y: Ten...
method forward (line 79) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:
class Loss (line 87) | class Loss(nn.Module, ABC):
method __init__ (line 88) | def __init__(self):
method cts_time_loss (line 92) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par...
method discrete_time_loss (line 98) | def discrete_time_loss(
method reconstruction_loss (line 107) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp...
class CtsBayesianFlow (line 116) | class CtsBayesianFlow(BayesianFlow):
method __init__ (line 117) | def __init__(
method forward (line 125) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:
method params_to_net_inputs (line 137) | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
method get_prior_input_params (line 140) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi...
method get_alpha (line 143) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[floa...
method get_sender_dist (line 147) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap...
method update_input_params (line 151) | def update_input_params(self, input_params: tuple[Tensor, float], y: T...
class CtsBayesianFlowLoss (line 158) | class CtsBayesianFlowLoss(Loss):
method __init__ (line 159) | def __init__(
method cts_time_loss (line 178) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par...
method discrete_time_loss (line 191) | def discrete_time_loss(
method reconstruction_loss (line 225) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp...
class DiscreteBayesianFlow (line 250) | class DiscreteBayesianFlow(BayesianFlow):
method __init__ (line 251) | def __init__(
method t_to_sqrt_beta (line 267) | def t_to_sqrt_beta(self, t):
method count_dist (line 270) | def count_dist(self, x, beta=None):
method count_sample (line 278) | def count_sample(self, x, beta):
method get_prior_input_params (line 282) | def get_prior_input_params(self, data_shape: tuple, device: torch.devi...
method params_to_net_inputs (line 286) | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
method get_alpha (line 293) | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[floa...
method get_sender_dist (line 296) | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shap...
method update_input_params (line 302) | def update_input_params(self, input_params: tuple[Tensor], y: Tensor, ...
method forward (line 308) | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:
class DiscreteBayesianFlowLoss (line 325) | class DiscreteBayesianFlowLoss(Loss):
method __init__ (line 326) | def __init__(
method cts_time_loss (line 336) | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_par...
method discrete_time_loss (line 348) | def discrete_time_loss(
method reconstruction_loss (line 369) | def reconstruction_loss(self, data: Tensor, output_params: Tensor, inp...
class BFN (line 376) | class BFN(nn.Module):
method __init__ (line 377) | def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: ...
method sample_t (line 385) | def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:
method forward (line 393) | def forward(
method compute_reconstruction_loss (line 420) | def compute_reconstruction_loss(self, data: Tensor) -> Tensor:
method sample (line 428) | def sample(self, data_shape: tuple, n_steps: int) -> Tensor:
FILE: networks/adapters.py
class TextInputAdapter (line 25) | class TextInputAdapter(nn.Module):
method __init__ (line 30) | def __init__(
method forward (line 46) | def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor:
class FourierImageInputAdapter (line 61) | class FourierImageInputAdapter(nn.Module):
method __init__ (line 66) | def __init__(
method forward (line 122) | def forward(self, img: Tensor, t: Tensor) -> Tensor:
class OutputAdapter (line 150) | class OutputAdapter(nn.Module):
method __init__ (line 151) | def __init__(self, input_height: int, output_channels: int, output_hei...
method forward (line 159) | def forward(self, inp: torch.Tensor) -> torch.Tensor:
FILE: networks/transformer.py
function gelu (line 37) | def gelu(x):
class LayerNorm (line 41) | class LayerNorm(nn.Module):
method __init__ (line 44) | def __init__(self, ndim, bias):
method forward (line 49) | def forward(self, input):
class SelfAttention (line 53) | class SelfAttention(nn.Module):
method __init__ (line 54) | def __init__(self, n_head, n_embd, dropout, bias, is_causal):
method forward (line 72) | def forward(self, x):
class MLP (line 92) | class MLP(nn.Module):
method __init__ (line 93) | def __init__(self, n_embd, dropout, bias):
method forward (line 99) | def forward(self, x):
class Block (line 107) | class Block(nn.Module):
method __init__ (line 108) | def __init__(self, n_head, n_embd, dropout, bias, is_causal):
method forward (line 115) | def forward(self, x):
class GPT (line 121) | class GPT(nn.Module):
method __init__ (line 122) | def __init__(
method _init_weights (line 169) | def _init_weights(self, module):
method forward (line 177) | def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
method get_optim_groups (line 188) | def get_optim_groups(self, weight_decay: float):
FILE: networks/unet_improved.py
function convert_module_to_f16 (line 48) | def convert_module_to_f16(module):
function convert_module_to_f32 (line 57) | def convert_module_to_f32(module):
function make_master_params (line 66) | def make_master_params(model_params):
function model_grads_to_master_grads (line 77) | def model_grads_to_master_grads(model_params, master_params):
function master_params_to_model_params (line 85) | def master_params_to_model_params(model_params, master_params):
function unflatten_master_params (line 97) | def unflatten_master_params(model_params, master_params):
function zero_grad (line 104) | def zero_grad(model_params):
class SiLU (line 113) | class SiLU(nn.Module):
method forward (line 114) | def forward(self, x):
class GroupNorm32 (line 118) | class GroupNorm32(nn.GroupNorm):
method forward (line 119) | def forward(self, x):
function conv_nd (line 123) | def conv_nd(dims, *args, **kwargs):
function linear (line 136) | def linear(*args, **kwargs):
function avg_pool_nd (line 143) | def avg_pool_nd(dims, *args, **kwargs):
function update_ema (line 156) | def update_ema(target_params, source_params, rate=0.99):
function zero_module (line 169) | def zero_module(module):
function scale_module (line 178) | def scale_module(module, scale):
function mean_flat (line 187) | def mean_flat(tensor):
function normalization (line 194) | def normalization(channels):
function timestep_embedding (line 204) | def timestep_embedding(timesteps, dim, max_period=10000):
function checkpoint (line 225) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 243) | class CheckpointFunction(th.autograd.Function):
method forward (line 245) | def forward(ctx, run_function, length, *args):
method backward (line 254) | def backward(ctx, *output_grads):
class TimestepBlock (line 274) | class TimestepBlock(nn.Module):
method forward (line 280) | def forward(self, x, emb):
class TimestepEmbedSequential (line 286) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 292) | def forward(self, x, emb):
class Upsample (line 301) | class Upsample(nn.Module):
method __init__ (line 311) | def __init__(self, channels, use_conv, dims=2):
method forward (line 319) | def forward(self, x):
class Downsample (line 330) | class Downsample(nn.Module):
method __init__ (line 340) | def __init__(self, channels, use_conv, dims=2):
method forward (line 351) | def forward(self, x):
class ResBlock (line 356) | class ResBlock(TimestepBlock):
method __init__ (line 371) | def __init__(
method forward (line 417) | def forward(self, x, emb):
method _forward (line 427) | def _forward(self, x, emb):
class AttentionBlock (line 443) | class AttentionBlock(nn.Module):
method __init__ (line 451) | def __init__(self, channels, num_heads=1, use_checkpoint=False):
method forward (line 462) | def forward(self, x):
method _forward (line 465) | def _forward(self, x):
class QKVAttention (line 476) | class QKVAttention(nn.Module):
method forward (line 481) | def forward(self, qkv):
method count_flops (line 496) | def count_flops(model, _x, y):
class UNetModel (line 519) | class UNetModel(nn.Module):
method __init__ (line 542) | def __init__(
method convert_to_fp16 (line 682) | def convert_to_fp16(self):
method convert_to_fp32 (line 690) | def convert_to_fp32(self):
method inner_dtype (line 699) | def inner_dtype(self):
method forward (line 705) | def forward(
method get_feature_vectors (line 751) | def get_feature_vectors(self, x, timesteps, y=None):
FILE: networks/unet_vdm.py
function zero_init (line 39) | def zero_init(module: nn.Module) -> nn.Module:
class UNetVDM (line 46) | class UNetVDM(nn.Module):
method __init__ (line 47) | def __init__(
method forward (line 121) | def forward(
method maybe_concat_fourier (line 152) | def maybe_concat_fourier(self, z):
class ResnetBlock (line 158) | class ResnetBlock(nn.Module):
method __init__ (line 159) | def __init__(
method forward (line 187) | def forward(self, x, condition):
function get_timestep_embedding (line 201) | def get_timestep_embedding(
class FourierFeatures (line 223) | class FourierFeatures(nn.Module):
method __init__ (line 224) | def __init__(self, first=5.0, last=6.0, step=1.0):
method num_features (line 229) | def num_features(self):
method forward (line 232) | def forward(self, x):
function attention_inner_heads (line 248) | def attention_inner_heads(qkv, num_heads):
class Attention (line 285) | class Attention(nn.Module):
method __init__ (line 288) | def __init__(self, n_heads):
method forward (line 292) | def forward(self, qkv):
class AttentionBlock (line 301) | class AttentionBlock(nn.Module):
method __init__ (line 304) | def __init__(self, n_heads, n_channels, norm_groups):
method forward (line 314) | def forward(self, x):
class UpDownBlock (line 318) | class UpDownBlock(nn.Module):
method __init__ (line 319) | def __init__(self, resnet_block, attention_block=None):
method forward (line 324) | def forward(self, x, cond):
FILE: probability.py
class CtsDistribution (line 36) | class CtsDistribution:
method log_prob (line 38) | def log_prob(self, x):
method sample (line 42) | def sample(self):
class DiscreteDistribution (line 46) | class DiscreteDistribution:
method probs (line 49) | def probs(self):
method log_probs (line 53) | def log_probs(self):
method mean (line 57) | def mean(self):
method mode (line 61) | def mode(self):
method log_prob (line 65) | def log_prob(self, x):
method sample (line 69) | def sample(self):
class DiscretizedDistribution (line 73) | class DiscretizedDistribution(DiscreteDistribution):
method __init__ (line 74) | def __init__(self, num_bins, device):
method class_centres (line 81) | def class_centres(self):
method class_boundaries (line 85) | def class_boundaries(self):
method mean (line 89) | def mean(self):
method mode (line 93) | def mode(self):
class DiscretizedCtsDistribution (line 98) | class DiscretizedCtsDistribution(DiscretizedDistribution):
method __init__ (line 99) | def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, ...
method probs (line 108) | def probs(self):
method prob (line 127) | def prob(self, x):
method log_prob (line 145) | def log_prob(self, x):
method sample (line 153) | def sample(self, sample_shape=torch.Size([])):
class GMM (line 165) | class GMM(MixtureSameFamily):
method __init__ (line 166) | def __init__(self, mix_wt_logits, means, std_devs):
class DiscretizedGMM (line 172) | class DiscretizedGMM(DiscretizedCtsDistribution):
method __init__ (line 173) | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max...
class DiscretizedNormal (line 191) | class DiscretizedNormal(DiscretizedCtsDistribution):
method __init__ (line 192) | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max...
class Bernoulli (line 210) | class Bernoulli(DiscreteDistribution):
method __init__ (line 211) | def __init__(self, logits):
method probs (line 215) | def probs(self):
method mode (line 220) | def mode(self):
method log_prob (line 223) | def log_prob(self, x):
method sample (line 226) | def sample(self, sample_shape=torch.Size([])):
class DiscretizedBernoulli (line 230) | class DiscretizedBernoulli(DiscretizedDistribution):
method __init__ (line 231) | def __init__(self, logits):
method probs (line 236) | def probs(self):
method mode (line 241) | def mode(self):
method log_prob (line 244) | def log_prob(self, x):
method sample (line 247) | def sample(self, sample_shape=torch.Size([])):
class DeltaDistribution (line 251) | class DeltaDistribution(CtsDistribution):
method __init__ (line 252) | def __init__(self, mean, clip_range=1.0):
method mode (line 258) | def mode(self):
method mean (line 262) | def mean(self):
method sample (line 265) | def sample(self, sample_shape=torch.Size([])):
class Categorical (line 269) | class Categorical(DiscreteDistribution):
method __init__ (line 270) | def __init__(self, logits):
method probs (line 275) | def probs(self):
method mode (line 279) | def mode(self):
method log_prob (line 282) | def log_prob(self, x):
method sample (line 285) | def sample(self, sample_shape=torch.Size([])):
class DiscretizedCategorical (line 289) | class DiscretizedCategorical(DiscretizedDistribution):
method __init__ (line 290) | def __init__(self, logits=None, probs=None):
method probs (line 300) | def probs(self):
method mode (line 304) | def mode(self):
method log_prob (line 307) | def log_prob(self, x):
method sample (line 310) | def sample(self, sample_shape=torch.Size([])):
class CtsDistributionFactory (line 314) | class CtsDistributionFactory:
method get_dist (line 316) | def get_dist(self, params: torch.Tensor, input_params=None, t=None) ->...
class GMMFactory (line 321) | class GMMFactory(CtsDistributionFactory):
method __init__ (line 322) | def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True):
method get_dist (line 327) | def get_dist(self, params, input_params=None, t=None):
class NormalFactory (line 335) | class NormalFactory(CtsDistributionFactory):
method __init__ (line 336) | def __init__(self, min_std_dev=1e-3, max_std_dev=10):
method get_dist (line 340) | def get_dist(self, params, input_params=None, t=None):
class DeltaFactory (line 346) | class DeltaFactory(CtsDistributionFactory):
method __init__ (line 347) | def __init__(self, clip_range=1.0):
method get_dist (line 350) | def get_dist(self, params, input_params=None, t=None):
class DiscreteDistributionFactory (line 354) | class DiscreteDistributionFactory:
method get_dist (line 356) | def get_dist(self, params: torch.Tensor, input_params=None, t=None) ->...
class BernoulliFactory (line 361) | class BernoulliFactory(DiscreteDistributionFactory):
method get_dist (line 362) | def get_dist(self, params, input_params=None, t=None):
class CategoricalFactory (line 366) | class CategoricalFactory(DiscreteDistributionFactory):
method get_dist (line 367) | def get_dist(self, params, input_params=None, t=None):
class DiscretizedBernoulliFactory (line 371) | class DiscretizedBernoulliFactory(DiscreteDistributionFactory):
method get_dist (line 372) | def get_dist(self, params, input_params=None, t=None):
class DiscretizedCategoricalFactory (line 376) | class DiscretizedCategoricalFactory(DiscreteDistributionFactory):
method get_dist (line 377) | def get_dist(self, params, input_params=None, t=None):
class DiscretizedGMMFactory (line 381) | class DiscretizedGMMFactory(DiscreteDistributionFactory):
method __init__ (line 382) | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=...
method get_dist (line 390) | def get_dist(self, params, input_params=None, t=None):
class DiscretizedNormalFactory (line 402) | class DiscretizedNormalFactory(DiscreteDistributionFactory):
method __init__ (line 403) | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=...
method get_dist (line 411) | def get_dist(self, params, input_params=None, t=None):
function noise_pred_params_to_data_pred_params (line 423) | def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tenso...
class PredDistToDataDistFactory (line 459) | class PredDistToDataDistFactory(DiscreteDistributionFactory):
method __init__ (line 460) | def __init__(self, data_dist_factory, min_variance, min_t=1e-6):
method get_dist (line 466) | def get_dist(self, params, input_params, t):
FILE: sample.py
function main (line 24) | def main(cfg: DictConfig) -> torch.Tensor:
FILE: test.py
function setup (line 32) | def setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]:
function test (line 47) | def test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: in...
function main (line 73) | def main(cfg: DictConfig) -> tuple[float, float, float, float]:
FILE: train.py
function setup (line 56) | def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:
function validate (line 70) | def validate(
function train (line 116) | def train(
function main (line 178) | def main(cfg):
FILE: utils_model.py
function sandwich (line 28) | def sandwich(x: Tensor):
function safe_log (line 32) | def safe_log(data: Tensor):
function safe_exp (line 36) | def safe_exp(data: Tensor):
function idx_to_float (line 40) | def idx_to_float(idx: np.ndarray, num_bins: int):
function float_to_idx (line 45) | def float_to_idx(flt: np.ndarray, num_bins: int):
function quantize (line 50) | def quantize(flt, num_bins: int):
function pe_encode (line 54) | def pe_encode(sequence_length: int, embedding_size: int) -> Tensor:
function pe_encode_float (line 69) | def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> ...
FILE: utils_train.py
function stringify_unsupported (line 29) | def stringify_unsupported(x):
function seed_everything (line 49) | def seed_everything(seed: Optional[int]):
function worker_init_function (line 58) | def worker_init_function(worker_id: int) -> None:
function init_checkpointing (line 65) | def init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: s...
function checkpoint_training_state (line 77) | def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, st...
function log (line 89) | def log(key_handler, value, step, cond=True):
function log_cfg (line 95) | def log_cfg(cfg, run: "neptune.Run"):
function update_ema (line 104) | def update_ema(ema_model, model, ema_decay):
function ddict (line 110) | def ddict():
function make_infinite (line 115) | def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]:
function make_progress_bar (line 121) | def make_progress_bar(is_main: bool, text="[red]loss: {task.fields[loss]...
function make_dataloaders (line 132) | def make_dataloaders(cfg: DictConfig):
function make_from_cfg (line 149) | def make_from_cfg(module, cfg, **parameters):
function make_bfn (line 153) | def make_bfn(cfg: DictConfig):
function make_config (line 205) | def make_config(cfg_file: str):
Condensed preview — 23 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (157K chars).
[
{
"path": ".gitignore",
"chars": 2075,
"preview": "# Data, checkpoints, logs\ndata\ncheckpoints\n.neptune\n\n# Files generated by setuptools_scm\n__version.py\n\n# MacOS\n.DS_Store"
},
{
"path": "LICENSE",
"chars": 10173,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 4630,
"preview": "# Bayesian Flow Networks\n\nThis is the official code release for [Bayesian Flow Networks](https://arxiv.org/abs/2308.0703"
},
{
"path": "configs/cifar10_continuous_16bins.yaml",
"chars": 1442,
"preview": "meta:\n neptune: \n debug: False\ndata:\n dataset: \"cifar10\"\n horizontal_flip: False\n num_bins: 16\ntrain_loader:\n batc"
},
{
"path": "configs/cifar10_continuous_256bins.yaml",
"chars": 1443,
"preview": "meta:\n neptune: \n debug: False\ndata:\n dataset: \"cifar10\"\n horizontal_flip: False\n num_bins: 256\ntrain_loader:\n bat"
},
{
"path": "configs/cifar10_discretized_16bins.yaml",
"chars": 1500,
"preview": "meta:\n neptune: \n debug: False\ndata:\n dataset: \"cifar10\"\n horizontal_flip: False\n num_bins: 16\ntrain_loader:\n batc"
},
{
"path": "configs/cifar10_discretized_256bins.yaml",
"chars": 1502,
"preview": "meta:\n neptune: \n debug: False\ndata:\n dataset: \"cifar10\"\n horizontal_flip: False\n num_bins: 256\ntrain_loader:\n bat"
},
{
"path": "configs/mnist_discrete.yaml",
"chars": 1453,
"preview": "meta:\n neptune:\n debug: False\ndata:\n dataset: \"bin_mnist\"\ntrain_loader:\n batch_size: 512\n shuffle: True\n num_worke"
},
{
"path": "configs/text8_discrete.yaml",
"chars": 1164,
"preview": "meta:\n neptune:\n debug: False\ndata:\n dataset: \"text8\"\n seq_len: 256\ntrain_loader:\n batch_size: 416\n shuffle: True\n"
},
{
"path": "data.py",
"chars": 9324,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "env.yml",
"chars": 221,
"preview": "name: bfn\nchannels:\n - pytorch\n - nvidia\ndependencies:\n - python=3.9\n - pytorch=2.0.0\n - pytorch-cuda=11.8\n - torc"
},
{
"path": "model.py",
"chars": 20519,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "networks/__init__.py",
"chars": 770,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "networks/adapters.py",
"chars": 5996,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "networks/transformer.py",
"chars": 9026,
"preview": "# Source: https://github.com/karpathy/nanoGPT\n#\n# MIT License\n#\n# Copyright (c) 2022 Andrej Karpathy\n#\n# Permission is h"
},
{
"path": "networks/unet_improved.py",
"chars": 27147,
"preview": "# Source: https://github.com/openai/improved-diffusion\n#\n# MIT License\n#\n# Copyright (c) 2021 OpenAI\n#\n# Permission is h"
},
{
"path": "networks/unet_vdm.py",
"chars": 11973,
"preview": "# Source: https://github.com/addtt/variational-diffusion-models\n#\n# MIT License\n#\n# Copyright (c) 2022 Andrea Dittadi\n#\n"
},
{
"path": "probability.py",
"chars": 16818,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "sample.py",
"chars": 2096,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "test.py",
"chars": 4149,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "train.py",
"chars": 7840,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "utils_model.py",
"chars": 2565,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
},
{
"path": "utils_train.py",
"chars": 6573,
"preview": "# Copyright 2023 NNAISENSE SA\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this"
}
]
About this extraction
This page contains the full source code of the nnaisense/bayesian-flow-networks GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 23 files (146.9 KB), approximately 37.4k tokens, and a symbol index with 288 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.