Repository: nnaisense/bayesian-flow-networks Branch: main Commit: b62568e5d064 Files: 23 Total size: 146.9 KB Directory structure: gitextract__0riu0_z/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── cifar10_continuous_16bins.yaml │ ├── cifar10_continuous_256bins.yaml │ ├── cifar10_discretized_16bins.yaml │ ├── cifar10_discretized_256bins.yaml │ ├── mnist_discrete.yaml │ └── text8_discrete.yaml ├── data.py ├── env.yml ├── model.py ├── networks/ │ ├── __init__.py │ ├── adapters.py │ ├── transformer.py │ ├── unet_improved.py │ └── unet_vdm.py ├── probability.py ├── sample.py ├── test.py ├── train.py ├── utils_model.py └── utils_train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Data, checkpoints, logs data checkpoints .neptune # Files generated by setuptools_scm __version.py # MacOS .DS_Store # Visual Studio Code .vscode/ *.code-workspace .history/ # Created by https://www.gitignore.io/api/python # Edit at https://www.gitignore.io/?templates=python ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # PyCharm .idea/ # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # End of https://www.gitignore.io/api/python ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS ================================================ FILE: README.md ================================================ # Bayesian Flow Networks This is the official code release for [Bayesian Flow Networks](https://arxiv.org/abs/2308.07037) by Alex Graves, Rupesh Kumar Srivastava, Timothy Atkinson and Faustino Gomez. Overview of BFN process ## Reading Guide - `model.py` contains all the main contributions of the paper. These include definitions, for both continuous and discrete data, of Bayesian Flows as well as loss functions for both continuous-time and discrete-time. See comments in the base classes in that file for details. - `probability.py` defines the probability distributions used by the models. - `train.py`, `test.py` and `sample.py` are scripts for training, testing and sampling (see below for usage). - `data.py` contains utilities related to data loading and processing. - `networks/` contains implementations of the network architectures used by the models. ## Setup ```shell # Create a new conda env with all dependencies including pytorch and CUDA conda env create -f env.yml conda activate bfn # Or, install additional dependencies into an existing pytorch env pip install accelerate==0.19.0 matplotlib omegaconf rich # Optional, if you want to enable logging to neptune.ai pip install neptune ``` ## Training The models in the paper can be trained using the configs provided in the `configs` dir as follows: ```shell # mnist experiment on 1 GPU accelerate launch train.py config_file=configs/mnist_discrete.yaml # cifar10 experiment on 1 GPU (A100) accelerate launch train.py config_file=configs/cifar10_discretized_256bins.yaml # text8 experiment on 8 GPUs (A100) accelerate launch --multi_gpu --num_processes=8 --num_machines=1 --dynamo_backend=no --mixed_precision=fp16 train.py config_file=configs/text8_discrete.yaml ``` ## Testing > [!NOTE] > Depending on your GPU, you may wish to adjust the batch size used for testing in `test.py`. ```shell # Optional: Download pretrained checkpoints (make sure you have git-lfs installed: https://git-lfs.com/) git clone git@hf.co:rupspace/pretrained-BFNs # Compute 784-step loss on MNIST python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000 # Compute 10-step loss on CIFAR-10 python test.py seed=1 config_file=./configs/cifar10_discretized_256bins.yaml load_model=./pretrained-BFNs/cifar10_256d_ema.pt n_steps=10 n_repeats=100 # Compute continuous-time loss on text8 python test.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt n_steps=0 n_repeats=1 ``` > [!IMPORTANT] > All computed results will be in nats-per-data-dimension. To convert to bits, divide by ln(2). ## Sampling You can sample from a pre-trained model as follows (change options as desired): ```shell # Sample 4 binarized MNIST images using 100 steps python sample.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt samples_shape="[4, 28, 28, 1]" n_steps=100 save_file=./samples_mnist.pt # Sample 4 CIFAR-10 16-bit images modeled as discretized data using 1000 steps python sample.py seed=1 config_file=./configs/cifar10_discretized_16bins.yaml load_model=./pretrained-BFNs/cifar10_16d_ema.pt samples_shape="[4, 32, 32, 3]" n_steps=1000 save_file=./samples_cifar.pt # Sample 2 text8 sequences of length 256 using 100 steps python sample.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt samples_shape="[2, 256]" n_steps=100 save_file=./samples_text8.pt ``` The samples are stored as PyTorch tensors in the `save_file`, and can be visualized by loading them and then using the utilities `batch_to_images` and `batch_to_str` in `data.py`. For example: ```shell # batch_to_images returns a matplotlib Figure object python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_mnist.pt')).savefig('mnist.png')" python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_cifar.pt')).savefig('cifar.png')" # batch_to_str returns a list of str python -c "import torch; from data import batch_to_str; print(batch_to_str(torch.load('./samples_text8.pt')))" ``` ## Reproducibility If a high degree of reproducibility is desired (e.g. during sampling), set the following: ```python torch.set_float32_matmul_precision("highest") torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False ``` ## Acknowledgements We are grateful to [@Higgcz](https://github.com/Higgcz) for generous support with the experiment infrastructure and code release. ================================================ FILE: configs/cifar10_continuous_16bins.yaml ================================================ meta: neptune: debug: False data: dataset: "cifar10" horizontal_flip: False num_bins: 16 train_loader: batch_size: 32 shuffle: True num_workers: 8 pin_memory: True drop_last: True persistent_workers: True val_loader: batch_size: 500 shuffle: False num_workers: 8 pin_memory: True model: net: class_name: "UNetVDM" parameters: embedding_dim: 128 n_blocks: 32 n_attention_heads: 1 dropout_prob: 0.1 norm_groups: 32 input_channels: 3 use_fourier_features: True attention_everywhere: False image_size: 32 input_adapter: class_name: "FourierImageInputAdapter" parameters: input_channels: 3 input_shape: [32, 32] output_height: 3 add_pos_feats: False add_mask: False output_adapter: class_name: "OutputAdapter" parameters: input_height: 131 output_channels: 3 # (r,g,b) output_height: 1 bayesian_flow: class_name: "CtsBayesianFlow" parameters: min_variance: 1e-3 loss: class_name: "CtsBayesianFlowLoss" parameters: noise_pred: True distribution_factory: class_name: "DeltaFactory" parameters: {} optimizer: lr: 2e-4 betas: [0.9,0.99] weight_decay: 0.01 eps: 1e-8 training: checkpoint_interval: 10_000 ema_decay: 0.9999 grad_clip_norm: 5.0 log_interval: 1 n_training_steps: 1_000_000 val_interval: 50_000 val_repeats: 100 ================================================ FILE: configs/cifar10_continuous_256bins.yaml ================================================ meta: neptune: debug: False data: dataset: "cifar10" horizontal_flip: False num_bins: 256 train_loader: batch_size: 32 shuffle: True num_workers: 8 pin_memory: True drop_last: True persistent_workers: True val_loader: batch_size: 500 shuffle: False num_workers: 8 pin_memory: True model: net: class_name: "UNetVDM" parameters: embedding_dim: 128 n_blocks: 32 n_attention_heads: 1 dropout_prob: 0.1 norm_groups: 32 input_channels: 3 use_fourier_features: True attention_everywhere: False image_size: 32 input_adapter: class_name: "FourierImageInputAdapter" parameters: input_channels: 3 input_shape: [32, 32] output_height: 3 add_pos_feats: False add_mask: False output_adapter: class_name: "OutputAdapter" parameters: input_height: 131 output_channels: 3 # (r,g,b) output_height: 1 bayesian_flow: class_name: "CtsBayesianFlow" parameters: min_variance: 1e-6 loss: class_name: "CtsBayesianFlowLoss" parameters: noise_pred: True distribution_factory: class_name: "DeltaFactory" parameters: {} optimizer: lr: 2e-4 betas: [0.9,0.99] weight_decay: 0.01 eps: 1e-8 training: checkpoint_interval: 10_000 ema_decay: 0.9999 grad_clip_norm: 5.0 log_interval: 1 n_training_steps: 1_000_000 val_interval: 50_000 val_repeats: 100 ================================================ FILE: configs/cifar10_discretized_16bins.yaml ================================================ meta: neptune: debug: False data: dataset: "cifar10" horizontal_flip: False num_bins: 16 train_loader: batch_size: 32 shuffle: True num_workers: 8 pin_memory: True drop_last: True persistent_workers: True val_loader: batch_size: 1000 shuffle: False num_workers: 8 pin_memory: True model: net: class_name: "UNetVDM" parameters: embedding_dim: 128 n_blocks: 32 n_attention_heads: 1 dropout_prob: 0.1 norm_groups: 32 input_channels: 3 use_fourier_features: True attention_everywhere: False image_size: 32 input_adapter: class_name: "FourierImageInputAdapter" parameters: input_channels: 3 input_shape: [32, 32] output_height: 3 add_pos_feats: False add_mask: False output_adapter: class_name: "OutputAdapter" parameters: input_height: 131 output_channels: 3 # (r,g,b) output_height: 2 # mean, std bayesian_flow: class_name: "CtsBayesianFlow" parameters: min_variance: 1e-3 loss: class_name: "CtsBayesianFlowLoss" parameters: noise_pred: True distribution_factory: class_name: "DiscretizedNormalFactory" parameters: num_bins: 16 clip: True optimizer: lr: 2e-4 betas: [0.9,0.99] weight_decay: 0.01 eps: 1e-8 training: checkpoint_interval: 10_000 ema_decay: 0.9999 grad_clip_norm: 5.0 log_interval: 1 n_training_steps: 1_000_000 val_interval: 50_000 val_repeats: 100 ================================================ FILE: configs/cifar10_discretized_256bins.yaml ================================================ meta: neptune: debug: False data: dataset: "cifar10" horizontal_flip: False num_bins: 256 train_loader: batch_size: 32 shuffle: True num_workers: 8 pin_memory: True drop_last: True persistent_workers: True val_loader: batch_size: 1000 shuffle: False num_workers: 8 pin_memory: True model: net: class_name: "UNetVDM" parameters: embedding_dim: 128 n_blocks: 32 n_attention_heads: 1 dropout_prob: 0.1 norm_groups: 32 input_channels: 3 use_fourier_features: True attention_everywhere: False image_size: 32 input_adapter: class_name: "FourierImageInputAdapter" parameters: input_channels: 3 input_shape: [32, 32] output_height: 3 add_pos_feats: False add_mask: False output_adapter: class_name: "OutputAdapter" parameters: input_height: 131 output_channels: 3 # (r,g,b) output_height: 2 # mean, std bayesian_flow: class_name: "CtsBayesianFlow" parameters: min_variance: 1e-6 loss: class_name: "CtsBayesianFlowLoss" parameters: noise_pred: True distribution_factory: class_name: "DiscretizedNormalFactory" parameters: num_bins: 256 clip: True optimizer: lr: 2e-4 betas: [0.9,0.99] weight_decay: 0.01 eps: 1e-8 training: checkpoint_interval: 10_000 ema_decay: 0.9999 grad_clip_norm: 5.0 log_interval: 1 n_training_steps: 1_000_000 val_interval: 50_000 val_repeats: 100 ================================================ FILE: configs/mnist_discrete.yaml ================================================ meta: neptune: debug: False data: dataset: "bin_mnist" train_loader: batch_size: 512 shuffle: True num_workers: 8 pin_memory: True drop_last: True val_loader: batch_size: 1000 shuffle: False num_workers: 8 pin_memory: True model: net: class_name: "UNetModel" parameters: image_size: 28 in_channels: 2 model_channels: 128 out_channels: 128 num_res_blocks: 2 attention_resolutions: [8,16] dropout: 0.5 channel_mult: [1, 2, 2] conv_resample: True dims: 2 num_heads: 4 num_heads_upsample: -1 project_input: True skip: True input_adapter: class_name: "FourierImageInputAdapter" parameters: input_channels: 1 input_shape: [28, 28] output_height: 2 add_pos_feats: False output_adapter: class_name: "OutputAdapter" parameters: input_height: 256 output_channels: 1 output_height: 1 bayesian_flow: class_name: "DiscreteBayesianFlow" parameters: n_classes: 2 max_sqrt_beta: 3 discretize: False loss: class_name: "DiscreteBayesianFlowLoss" parameters: {} distribution_factory: class_name: "BernoulliFactory" parameters: {} optimizer: lr: 1e-4 betas: [0.9,0.98] training: checkpoint_interval: 10_000 ema_decay: 0.9999 grad_clip_norm: 5.0 log_interval: 1 n_training_steps: 1_000_000 val_interval: 50_000 val_repeats: 1000 ================================================ FILE: configs/text8_discrete.yaml ================================================ meta: neptune: debug: False data: dataset: "text8" seq_len: 256 train_loader: batch_size: 416 shuffle: True num_workers: 8 pin_memory: True drop_last: True val_loader: batch_size: 200 shuffle: True num_workers: 8 pin_memory: True model: net: class_name: "GPT" parameters: vocab_size: 27 n_layer: 24 n_head: 12 n_embd: 768 dropout: 0.0 skip: True bias: True input_adapter: class_name: "TextInputAdapter" parameters: vocab_size: 27 seq_len: 256 output_size: 768 learn_pos_embedding: False output_adapter: null bayesian_flow: class_name: "DiscreteBayesianFlow" parameters: n_classes: 27 max_sqrt_beta: 0.75 loss: class_name: "DiscreteBayesianFlowLoss" parameters: {} distribution_factory: class_name: "CategoricalFactory" parameters: {} optimizer: lr: 1e-4 betas: [0.9, 0.98] weight_decay: 0.01 training: accumulate: 1 checkpoint_interval: 10_000 ema_decay: 0.9999 grad_clip_norm: 5 log_interval: 1 max_val_batches: 5_000 n_training_steps: 10_000_000 val_interval: 100_000 val_repeats: 1 ================================================ FILE: data.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import pathlib import pickle import zipfile from typing import Union import numpy as np import requests import torch import torchvision from matplotlib import pyplot as plt from omegaconf import DictConfig from torch.utils.data import Dataset, random_split from torchvision import transforms from torchvision.utils import make_grid from utils_model import quantize TEXT8_CHARS = list("_abcdefghijklmnopqrstuvwxyz") def bin_mnist_transform(x): return torch.bernoulli(x.permute(1, 2, 0).contiguous()).int() def bin_mnist_cts_transform(x): return torch.bernoulli(x.permute(1, 2, 0).contiguous()) - 0.5 def rgb_image_transform(x, num_bins=256): return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous() class MyLambda(torchvision.transforms.Lambda): def __init__(self, lambd, arg1): super().__init__(lambd) self.arg1 = arg1 def __call__(self, x): return self.lambd(x, self.arg1) class CIFAR10(torchvision.datasets.CIFAR10): def __getitem__(self, idx): return super().__getitem__(idx)[0] class MNIST(torchvision.datasets.MNIST): def __getitem__(self, idx): return super().__getitem__(idx)[0] def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]: """ Mandatory keys: dataset (must be cifar10, mnist, bin_mnist, bin_mnist_cts or text8), data_dir Optional for vision: num_bins (default 256), val_frac (default 0.01), horizontal_flip (default: False) Mandatory for text: seq_len """ num_bins = cfg.get("num_bins", 256) if cfg.dataset == "cifar10": train_transform_list = [transforms.ToTensor()] if cfg.get("horizontal_flip", False): train_transform_list.append(transforms.RandomHorizontalFlip()) train_transform_list.append(MyLambda(rgb_image_transform, num_bins)) train_transform = transforms.Compose(train_transform_list) test_transform = transforms.Compose([transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins)]) train_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transform) val_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=test_transform) test_set = CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transform) elif cfg.dataset == "mnist": transform = transforms.Compose( [ transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins), ] ) train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform) elif cfg.dataset == "bin_mnist": transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_transform)]) train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform) elif cfg.dataset == "bin_mnist_cts": transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_cts_transform)]) train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform) elif cfg.dataset == "text8": train_set = Text8Dataset(cfg.data_dir, "train", download=True, seq_len=cfg.seq_len) val_set = Text8Dataset(cfg.data_dir, "val", download=True, seq_len=cfg.seq_len) test_set = Text8Dataset(cfg.data_dir, "test", download=True, seq_len=cfg.seq_len) else: raise NotImplementedError(cfg.dataset) if cfg.dataset != "text8": # For vision datasets we split the train set into train and val val_frac = cfg.get("val_frac", 0.01) train_val_split = [1.0 - val_frac, val_frac] seed = 2147483647 train_set = random_split(train_set, train_val_split, generator=torch.Generator().manual_seed(seed))[0] val_set = random_split(val_set, train_val_split, generator=torch.Generator().manual_seed(seed))[1] return train_set, val_set, test_set def prepare_text8(data_dir: pathlib.Path): data_dir.mkdir(parents=True, exist_ok=True) data_url = "http://mattmahoney.net/dc/text8.zip" with open(data_dir / "text8.zip", "wb") as f: print("Downloading text8") f.write(requests.get(data_url).content) print("Done") with zipfile.ZipFile(data_dir / "text8.zip") as f: f.extractall(data_dir) os.remove(data_dir / "text8.zip") data = (data_dir / "text8").read_text() # get all the unique characters that occur in this text chars = sorted(list(set(data))) vocab_size = len(chars) print("all the unique characters:", "".join(chars)) print(f"vocab size: {vocab_size:,}") # create a mapping from characters to integers stoi = {ch: i for i, ch in enumerate(chars)} itos = {i: ch for i, ch in enumerate(chars)} def encode(s): return [stoi[c] for c in s] # encoder: take a string, output a list of integers # encode both to integers n = len(data) train_data = data[: int(n * 0.9)] val_data = data[int(n * 0.9) : int(n * 0.95)] test_data = data[int(n * 0.95) :] train_ids = encode(train_data) val_ids = encode(val_data) test_ids = encode(test_data) print(f"train has {len(train_ids):,} tokens") print(f"val has {len(val_ids):,} tokens") print(f"test has {len(test_ids):,} tokens") # export to bin files train_ids = np.array(train_ids, dtype=np.uint16) val_ids = np.array(val_ids, dtype=np.uint16) test_ids = np.array(test_ids, dtype=np.uint16) train_ids.tofile(data_dir / "train.bin") val_ids.tofile(data_dir / "val.bin") test_ids.tofile(data_dir / "test.bin") print(f"Saved to {data_dir / 'train.bin'}, {data_dir / 'val.bin'}, {data_dir / 'test.bin'}") # save the meta information as well, to help us encode/decode later meta = { "vocab_size": vocab_size, "itos": itos, "stoi": stoi, } with open(os.path.join(data_dir / "meta.pkl"), "wb") as f: pickle.dump(meta, f) print(f"text8 dataset downloaded and prepared in dir {data_dir}") class Text8Dataset(Dataset): def __init__(self, data_dir: Union[str, pathlib.Path], split: str, download: bool, seq_len: int): """ seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context. context is only used for correct preparation of val/test sets. """ self.root_dir = pathlib.Path(data_dir) self.split = split self.seq_len = seq_len fname = {"train": "train.bin", "val": "val.bin", "test": "test.bin"}[self.split] assert self.split in ["train", "val", "test"] data_dir = self.root_dir / "text8" if not os.path.exists(data_dir): if download: prepare_text8(data_dir) else: raise NotADirectoryError(f"dir {data_dir} does not exist and download is False") self.data = np.memmap(data_dir / fname, np.uint16, "r") def __getitem__(self, index) -> torch.Tensor: seq = torch.from_numpy(self.data[index : index + self.seq_len].astype(np.int64)) return seq def __len__(self): return self.data.size - self.seq_len def char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) -> str: """Decode a 1D sequence of character IDs to a string.""" return "".join([TEXT8_CHARS[i] for i in char_ids]) def batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) -> list[str]: """Decode a batch of character IDs to a list of strings.""" return [char_ids_to_str(row_char_ids) for row_char_ids in text_batch] def batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt.Figure: if ncols is None: ncols = math.ceil(math.sqrt(len(image_batch))) if image_batch.size(-1) == 3: # for color images (CIFAR-10) image_batch = (image_batch + 1) / 2 grid = make_grid(image_batch.permute(0, 3, 1, 2), ncols, pad_value=1).permute(1, 2, 0) fig = plt.figure(figsize=(grid.size(1) / 30, grid.size(0) / 30)) plt.imshow(grid.cpu().clip(min=0, max=1), interpolation="nearest") plt.grid(False) plt.axis("off") return fig ================================================ FILE: env.yml ================================================ name: bfn channels: - pytorch - nvidia dependencies: - python=3.9 - pytorch=2.0.0 - pytorch-cuda=11.8 - torchvision=0.15.0 - pip - pip: - accelerate==0.19.0 - matplotlib - omegaconf - rich ================================================ FILE: model.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This file implements the Bayesian Flow and BFN loss for continuous and discrete variables. Finally it implements the BFN using these objects. For consistency we use always use a tuple to store input parameters. It has just one element for discrete data (the probabilities) and two for continuous/discretized (mean & variance). The probability distributions and network architectures are defined in probability.py and networks dir. "Cts" is an abbreviation of "Continuous". """ import math from abc import abstractmethod, ABC from typing import Union, Optional import torch import torch.distributions as D import torch.nn.functional as F from torch import nn, Tensor from probability import ( DiscreteDistributionFactory, CtsDistributionFactory, PredDistToDataDistFactory, DiscretizedCtsDistribution, ) from utils_model import sandwich, float_to_idx class BayesianFlow(nn.Module, ABC): def __init__(self): super().__init__() @abstractmethod def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]: """Returns the initial input params (for a batch) at t=0. Used during sampling. For discrete data, the tuple has length 1 and contains the initial class probabilities. For continuous data, the tuple has length 2 and contains the mean and precision.""" pass @abstractmethod def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor: """Utility method to convert input distribution params to network inputs if needed.""" pass @abstractmethod def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float: """Returns the alpha at step i of total n_steps according to the flow schedule. Used: a) during sampling, when i and alpha are the same for all samples in the batch. b) during discrete time loss computation, when i and alpha are different for samples in the batch.""" pass @abstractmethod def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution: """Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used: a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net. b) during discrete time loss computation when alpha are different for samples in the batch.""" pass @abstractmethod def update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]: """Updates the distribution parameters using Bayes' theorem in light of noisy sample y. Used during sampling when alpha is the same for the whole batch.""" pass @abstractmethod def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]: """Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data. Used during training when t (and thus accuracies) are different for different samples in the batch. For discrete data, the returned tuple has length 1 and contains the class probabilities. For continuous data, the returned tuple has length 2 and contains the mean and precision.""" pass class Loss(nn.Module, ABC): def __init__(self): super().__init__() @abstractmethod def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor: """Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1). The input params are only used when the network is parameterized to predict the noise for continuous data.""" pass @abstractmethod def discrete_time_loss( self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples: int = 20 ) -> Tensor: """Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) using n_samples for Monte Carlo estimation of the discrete loss. The input params are only used when the network is parameterized to predict the noise for continuous data.""" pass @abstractmethod def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor: """Returns the reconstruction loss, i.e. the final cost of transmitting clean data. The input params are only used when the network is parameterized to predict the noise for continuous data.""" pass # Continuous or Discretized data class CtsBayesianFlow(BayesianFlow): def __init__( self, min_variance: float = 1e-6, ): super().__init__() self.min_variance = min_variance @torch.no_grad() def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]: post_var = torch.pow(self.min_variance, t) alpha_t = 1 - post_var mean_mean = alpha_t * data mean_var = alpha_t * post_var mean_std_dev = mean_var.sqrt() noise = torch.randn(mean_mean.shape, device=mean_mean.device) mean = mean_mean + (mean_std_dev * noise) # We don't need to compute the variance because it is not needed by the network, so set it to None input_params = (mean, None) return input_params def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor: return params[0] # Only the mean is used by the network def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]: return torch.zeros(*data_shape, device=device), 1.0 def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]: sigma_1 = math.sqrt(self.min_variance) return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps)) def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution: dist = D.Normal(x, 1.0 / alpha**0.5) return dist def update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]: input_mean, input_precision = input_params new_precision = input_precision + alpha new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precision return new_mean, new_precision class CtsBayesianFlowLoss(Loss): def __init__( self, bayesian_flow: CtsBayesianFlow, distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory], min_loss_variance: float = -1, noise_pred: bool = True, ): super().__init__() self.bayesian_flow = bayesian_flow self.distribution_factory = distribution_factory self.min_loss_variance = min_loss_variance self.C = -0.5 * math.log(bayesian_flow.min_variance) self.noise_pred = noise_pred if self.noise_pred: self.distribution_factory.log_dev = False self.distribution_factory = PredDistToDataDistFactory( self.distribution_factory, self.bayesian_flow.min_variance ) def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor: output_params = sandwich(output_params) t = t.flatten(start_dim=1).float() posterior_var = torch.pow(self.bayesian_flow.min_variance, t) flat_target = data.flatten(start_dim=1) pred_dist = self.distribution_factory.get_dist(output_params, input_params, t) pred_mean = pred_dist.mean mse_loss = (pred_mean - flat_target).square() if self.min_loss_variance > 0: posterior_var = posterior_var.clamp(min=self.min_loss_variance) loss = self.C * mse_loss / posterior_var return loss def discrete_time_loss( self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10 ) -> Tensor: output_params = sandwich(output_params) t = t.flatten(start_dim=1).float() output_dist = self.distribution_factory.get_dist(output_params, input_params, t) if hasattr(output_dist, "probs"): # output distribution is discretized normal flat_target = data.flatten(start_dim=1) t = t.flatten(start_dim=1) i = t * n_steps + 1 # since t = (i - 1) / n alpha = self.bayesian_flow.get_alpha(i, n_steps) sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha) receiver_mix_wts = sandwich(output_dist.probs) receiver_mix_dist = D.Categorical(probs=receiver_mix_wts, validate_args=False) receiver_components = D.Normal( output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False ) receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False) y = sender_dist.sample(torch.Size([n_samples])) loss = ( (sender_dist.log_prob(y) - receiver_dist.log_prob(y)) .mean(0) .flatten(start_dim=1) .mean(1, keepdims=True) ) else: # output distribution is normal pred_mean = output_dist.mean flat_target = data.flatten(start_dim=1) mse_loss = (pred_mean - flat_target).square() i = t * n_steps + 1 alpha = self.bayesian_flow.get_alpha(i, n_steps) loss = alpha * mse_loss / 2 return n_steps * loss def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor: output_params = sandwich(output_params) flat_data = data.flatten(start_dim=1) t = torch.ones_like(data).flatten(start_dim=1).float() output_dist = self.distribution_factory.get_dist(output_params, input_params, t) if hasattr(output_dist, "probs"): # output distribution is discretized normal reconstruction_loss = -output_dist.log_prob(flat_data) else: # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2) if self.bayesian_flow.min_variance == 1e-3: # used for 16 bin CIFAR10 noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance) num_bins = 16 else: noise_dev = math.sqrt(self.bayesian_flow.min_variance) num_bins = 256 mean = output_dist.mean.flatten(start_dim=1) final_dist = D.Normal(mean, noise_dev) final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1) reconstruction_loss = -final_dist.log_prob(flat_data) return reconstruction_loss # Discrete Data class DiscreteBayesianFlow(BayesianFlow): def __init__( self, n_classes: int, min_sqrt_beta: float = 1e-10, discretize: bool = False, epsilon: float = 1e-6, max_sqrt_beta: float = 1, ): super().__init__() self.n_classes = n_classes self.min_sqrt_beta = min_sqrt_beta self.discretize = discretize self.epsilon = epsilon self.max_sqrt_beta = max_sqrt_beta self.uniform_entropy = math.log(self.n_classes) def t_to_sqrt_beta(self, t): return t * self.max_sqrt_beta def count_dist(self, x, beta=None): mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1 std_dev = math.sqrt(self.n_classes) if beta is not None: mean = mean * beta std_dev = std_dev * beta.sqrt() return D.Normal(mean, std_dev, validate_args=False) def count_sample(self, x, beta): return self.count_dist(x, beta).rsample() @torch.no_grad() def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]: return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,) @torch.no_grad() def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor: params = params[0] if self.n_classes == 2: params = params * 2 - 1 # We scale-shift here for MNIST instead of in the network like for text params = params[..., :1] return params def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]: return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1) def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution: e_x = F.one_hot(x.long(), self.n_classes) alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alpha dist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5) return dist def update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]: new_input_params = input_params[0] * y.exp() new_input_params /= new_input_params.sum(-1, keepdims=True) return (new_input_params,) @torch.no_grad() def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]: if self.discretize: data = float_to_idx(data, self.n_classes) sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon)) lo_beta = sqrt_beta < self.min_sqrt_beta sqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta) beta = sqrt_beta.square().unsqueeze(-1) logits = self.count_sample(data, beta) probs = F.softmax(logits, -1) probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs) if self.n_classes == 2: probs = probs[..., :1] probs = probs.reshape_as(data) input_params = (probs,) return input_params class DiscreteBayesianFlowLoss(Loss): def __init__( self, bayesian_flow: DiscreteBayesianFlow, distribution_factory: DiscreteDistributionFactory, ): super().__init__() self.bayesian_flow = bayesian_flow self.distribution_factory = distribution_factory self.K = self.bayesian_flow.n_classes def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor: flat_output = sandwich(output_params) pred_probs = self.distribution_factory.get_dist(flat_output).probs flat_target = data.flatten(start_dim=1) if self.bayesian_flow.discretize: flat_target = float_to_idx(flat_target, self.K) tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K) kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1) t = t.flatten(start_dim=1).float() loss = t * (self.bayesian_flow.max_sqrt_beta**2) * kl return loss def discrete_time_loss( self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10 ) -> Tensor: flat_target = data.flatten(start_dim=1) if self.bayesian_flow.discretize: flat_target = float_to_idx(flat_target, self.K) i = t * n_steps + 1 alpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1) sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha) flat_output = sandwich(output_params) receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2)) classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0) receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1)) receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components) y = sender_dist.sample(torch.Size([n_samples])) loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True) return loss def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor: flat_outputs = sandwich(output_params) flat_data = data.flatten(start_dim=1) output_dist = self.distribution_factory.get_dist(flat_outputs) return -output_dist.log_prob(flat_data) class BFN(nn.Module): def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss): super().__init__() self.net = net self.bayesian_flow = bayesian_flow self.loss = loss @staticmethod @torch.no_grad() def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor: if n_steps == 0 or n_steps is None: t = torch.rand(data.size(0), device=data.device).unsqueeze(-1) else: t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data) return t def forward( self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None ) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]: """ Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss. t is sampled randomly if None. If t is not None, expect t.shape == data.shape. """ t = self.sample_t(data, n_steps) if t is None else t # sample input parameter flow input_params = self.bayesian_flow(data, t) net_inputs = self.bayesian_flow.params_to_net_inputs(input_params) # compute output distribution parameters output_params: Tensor = self.net(net_inputs, t) # compute KL loss in float32 with torch.autocast(device_type=data.device.type if data.device.type != "mps" else "cpu", enabled=False): if n_steps == 0 or n_steps is None: loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t) else: loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps) # loss shape is (batch_size, 1) return loss.mean() @torch.inference_mode() def compute_reconstruction_loss(self, data: Tensor) -> Tensor: t = torch.ones_like(data).float() input_params = self.bayesian_flow(data, t) net_inputs = self.bayesian_flow.params_to_net_inputs(input_params) output_params: Tensor = self.net(net_inputs, t) return self.loss.reconstruction_loss(data, output_params, input_params).flatten(start_dim=1).mean() @torch.inference_mode() def sample(self, data_shape: tuple, n_steps: int) -> Tensor: device = next(self.parameters()).device input_params = self.bayesian_flow.get_prior_input_params(data_shape, device) distribution_factory = self.loss.distribution_factory for i in range(1, n_steps + 1): t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t) output_sample = distribution_factory.get_dist(output_params, input_params, t).sample() output_sample = output_sample.reshape(*data_shape) alpha = self.bayesian_flow.get_alpha(i, n_steps) y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample() input_params = self.bayesian_flow.update_input_params(input_params, y, alpha) t = torch.ones(*data_shape, device=device) output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t) output_sample = distribution_factory.get_dist(output_params, input_params, t).mode output_sample = output_sample.reshape(*data_shape) return output_sample ================================================ FILE: networks/__init__.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. __all__ = ( "GPT", "UNetVDM", "UNetModel", "adapters", ) from .transformer import GPT from .unet_vdm import UNetVDM from .unet_improved import UNetModel from . import adapters ================================================ FILE: networks/adapters.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Tuple import torch from torch import Tensor from torch import nn from utils_model import sandwich, pe_encode, pe_encode_float class TextInputAdapter(nn.Module): """ A module to convert sequences of text class tokens to embedding tokens with learned positional embeddings. """ def __init__( self, vocab_size: int, seq_len: int, output_size: int = 256, learn_pos_embedding: bool = False, ): super().__init__() self.learn_pos_embedding = learn_pos_embedding if learn_pos_embedding: self.pos_embedding = nn.Embedding(seq_len, output_size) else: self.register_buffer("pos_embedding", pe_encode(seq_len, output_size)) self.inp_embedding = nn.Linear(vocab_size, output_size) self.t_embedding = nn.Linear(1, output_size) def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor: inp_emb = self.inp_embedding(2 * probs - 1) if self.learn_pos_embedding: pos_emb = self.pos_embedding( torch.arange(0, probs.size(1)).to(probs.device) ) else: pos_emb = self.pos_embedding pos_emb = pos_emb.unsqueeze(0).expand(inp_emb.size(0), -1, -1) t_emb = self.t_embedding((2 * t - 1).unsqueeze(-1)) output = inp_emb + pos_emb + t_emb return output class FourierImageInputAdapter(nn.Module): """ A module to convert 2D image coordinates into a set of vectors represented as a matrix, with fourier position codes. """ def __init__( self, input_channels: int = 3, input_shape: Tuple[int, int] = (224, 224), n_freq_bands: int = 64, output_height: int = 256, value_res: int = -1, mask_res: int = -1, add_pos_feats: bool = True, add_mask: bool = True, learn_pos_feats: bool = False, pos_embed_size: int = 32, init_scale: float = 0.02, ): super().__init__() self.input_shape = input_shape self.n_freq_bands = n_freq_bands self.value_res = value_res self.mask_res = mask_res self.add_pos_feats = add_pos_feats self.add_mask = add_mask if learn_pos_feats: pos_feats = nn.Parameter( init_scale * torch.randn(1, input_shape[0] * input_shape[1], pos_embed_size) ) self.register_parameter("pos_feats", pos_feats) else: x = torch.linspace(-1.0, 1.0, steps=input_shape[0]) y = torch.linspace(-1.0, 1.0, steps=input_shape[1]) x_pos, y_pos = torch.meshgrid(x, y, indexing="ij") pos = torch.stack((x_pos, y_pos), dim=-1) pos = pos.reshape(-1, 2) x_bands = torch.linspace(1.0, input_shape[0] / 2, steps=n_freq_bands) y_bands = torch.linspace(1.0, input_shape[1] / 2, steps=n_freq_bands) bands = torch.stack((x_bands, y_bands), dim=0) vals = pos[:, :, None] * bands[None, :, :] vals = math.pi * vals.reshape(vals.shape[0], -1) pos_feats = torch.cat([vals.sin(), vals.cos()], dim=-1) pos_feats = torch.cat([pos_feats, pos], dim=-1) self.register_buffer("pos_feats", pos_feats) img_feat_height = input_channels pos_feat_height = pos_feats.size(-1) if self.mask_res > 0: mask_feat_height = (n_freq_bands * 2) + 1 else: mask_feat_height = 1 all_feat_height = img_feat_height if add_mask: all_feat_height += mask_feat_height if add_pos_feats: all_feat_height += pos_feat_height self.output_projection = None if output_height != all_feat_height: self.output_projection = nn.Linear(all_feat_height, output_height) def forward(self, img: Tensor, t: Tensor) -> Tensor: flat_img = sandwich(img) flat_t = sandwich(t) t_feats = (flat_t.float()[..., :1] * 2) - 1 if self.mask_res > 0: t_feats = torch.cat( [ t_feats, pe_encode_float( t_feats, self.mask_res, self.n_freq_bands * 2 ).flatten(start_dim=2), ], -1, ) fourier_feats = self.pos_feats.expand(img.size(0), -1, -1) all_feat_list = [flat_img] if self.add_mask: all_feat_list.append(t_feats) if self.add_pos_feats: all_feat_list.append(fourier_feats) all_feats = torch.cat(all_feat_list, dim=-1) if self.output_projection is None: output = all_feats else: output = self.output_projection(all_feats) return output class OutputAdapter(nn.Module): def __init__(self, input_height: int, output_channels: int, output_height: int): super().__init__() self.output_channels = output_channels self.output_height = output_height self.output_projection = nn.Linear( input_height, output_channels * output_height ) def forward(self, inp: torch.Tensor) -> torch.Tensor: output = self.output_projection(inp) return output.reshape( output.size(0), -1, self.output_channels, self.output_height ) ================================================ FILE: networks/transformer.py ================================================ # Source: https://github.com/karpathy/nanoGPT # # MIT License # # Copyright (c) 2022 Andrej Karpathy # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # Modifications: # - Added data_adapters to GPT to preprocess the inputs and (optionally) postprocess the outputs # - Added the `skip` option to concat the input and output of the network before the final projection # - Added time `t` as an input to `forward()` import math import torch import torch.nn as nn import torch.nn.functional as F def gelu(x): return F.gelu(x, approximate="tanh") class LayerNorm(nn.Module): """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) class SelfAttention(nn.Module): def __init__(self, n_head, n_embd, dropout, bias, is_causal): super().__init__() assert n_embd % n_head == 0 # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias) # output projection self.c_proj = nn.Linear(n_embd, n_embd, bias=bias) # regularization self.attn_dropout = nn.Dropout(dropout) self.resid_dropout = nn.Dropout(dropout) self.n_head = n_head self.n_embd = n_embd self.dropout = dropout self.is_causal = is_causal def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.n_embd, dim=2) k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) y = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=self.dropout if self.training else 0, is_causal=self.is_causal ) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) return y class MLP(nn.Module): def __init__(self, n_embd, dropout, bias): super().__init__() self.c_fc = nn.Linear(n_embd, 4 * n_embd, bias=bias) self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.c_fc(x) x = gelu(x) x = self.c_proj(x) x = self.dropout(x) return x class Block(nn.Module): def __init__(self, n_head, n_embd, dropout, bias, is_causal): super().__init__() self.ln_1 = LayerNorm(n_embd, bias=bias) self.attn = SelfAttention(n_head, n_embd, dropout, bias, is_causal) self.ln_2 = LayerNorm(n_embd, bias=bias) self.mlp = MLP(n_embd, dropout, bias) def forward(self, x): x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) return x class GPT(nn.Module): def __init__( self, data_adapters: dict, vocab_size: int, n_layer: int = 12, n_head: int = 12, n_embd: int = 768, dropout: float = 0.0, bias: bool = True, skip: bool = False, is_causal: bool = False, ): super().__init__() self.n_layer = n_layer self.n_head = n_head self.n_embd = n_embd self.input_adapter = data_adapters["input_adapter"] self.output_adapter = data_adapters["output_adapter"] self.transformer = nn.ModuleDict( dict( drop=nn.Dropout(dropout), h=nn.ModuleList([Block(n_head, n_embd, dropout, bias, is_causal) for _ in range(n_layer)]), ln_f=LayerNorm(n_embd, bias=bias), ) ) self.is_causal = is_causal if self.is_causal: self.skip = False else: self.skip = skip if skip: self.lm_head = nn.Linear(2 * n_embd, vocab_size, bias=bias) else: self.lm_head = nn.Linear(n_embd, vocab_size, bias=bias) # init all weights self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layer)) # report number of parameters print(f"number of parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6:.2f}M") def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor: x_in = self.input_adapter(data, t) x = self.transformer.drop(x_in) for block in self.transformer.h: x = block(x) x = self.transformer.ln_f(x) if self.skip: x = torch.cat([x, x_in], -1) logits = self.output_adapter(self.lm_head(x)) if self.output_adapter else self.lm_head(x) return logits def get_optim_groups(self, weight_decay: float): decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear,) blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): fpn = "%s.%s" % (mn, pn) if mn else pn # full param name # random note: because named_modules and named_parameters are recursive # we will see the same tensors p many many times. but doing it this way # allows us to know which parent module any tensor p belongs to... if pn.endswith("bias"): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) # We don't use weight tying so comment this out # decay.remove('lm_head.weight') # validate that we considered every parameter param_dict = {pn: p for pn, p in self.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) assert ( len(param_dict.keys() - union_params) == 0 ), "parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),) # create the pytorch optimizer groups optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] return optim_groups ================================================ FILE: networks/unet_improved.py ================================================ # Source: https://github.com/openai/improved-diffusion # # MIT License # # Copyright (c) 2021 OpenAI # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # Modifications: # - Added data_adapters to UNetModel to preprocess the inputs and postprocess the outputs # - Added the `skip` option to concat the input and output of the network before the final projection # - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps` from abc import abstractmethod import math import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from utils_model import sandwich from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors """ Helpers to train with 16-bit precision. """ def convert_module_to_f16(module): """ Convert primitive modules to float16. """ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): module.weight.data = module.weight.data.half() module.bias.data = module.bias.data.half() def convert_module_to_f32(module): """ Convert primitive modules to float32, undoing convert_module_to_f16(). """ if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): module.weight.data = module.weight.data.float() module.bias.data = module.bias.data.float() def make_master_params(model_params): """ Copy model parameters into a (differently-shaped) list of full-precision parameters. """ master_params = _flatten_dense_tensors([param.detach().float() for param in model_params]) master_params = nn.Parameter(master_params) master_params.requires_grad = True return [master_params] def model_grads_to_master_grads(model_params, master_params): """ Copy the gradients from the model parameters into the master parameters from make_master_params(). """ master_params[0].grad = _flatten_dense_tensors([param.grad.data.detach().float() for param in model_params]) def master_params_to_model_params(model_params, master_params): """ Copy the master parameter data back into the model parameters. """ # Without copying to a list, if a generator is passed, this will # silently not copy any parameters. model_params = list(model_params) for param, master_param in zip(model_params, unflatten_master_params(model_params, master_params)): param.detach().copy_(master_param) def unflatten_master_params(model_params, master_params): """ Unflatten the master parameters to look like model_params. """ return _unflatten_dense_tensors(master_params[0].detach(), model_params) def zero_grad(model_params): for param in model_params: # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group if param.grad is not None: param.grad.detach_() param.grad.zero_() # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): def forward(self, x): return x * th.sigmoid(x) class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def linear(*args, **kwargs): """ Create a linear module. """ return nn.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def update_ema(target_params, source_params, rate=0.99): """ Update target parameters to be closer to those of source parameters using an exponential moving average. :param target_params: the target parameter sequence. :param source_params: the source parameter sequence. :param rate: the EMA rate (closer to 1 means slower). """ for targ, src in zip(target_params, source_params): targ.detach().mul_(rate).add_(src, alpha=1 - rate) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def scale_module(module, scale): """ Scale the parameters of a module and return it. """ for p in module.parameters(): p.detach().mul_(scale) return module def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def normalization(channels): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(32, channels) def timestep_embedding(timesteps, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to( device=timesteps.device ) args = timesteps[:, None].float() * freqs[None] embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) if dim % 2: embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) return embedding def checkpoint(func, inputs, params, flag): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(th.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with th.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] with th.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in ctx.input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = th.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return (None, None) + input_grads class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. """ @abstractmethod def forward(self, x, emb): """ Apply the module to `x` given `emb` timestep embeddings. """ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) else: x = layer(x) return x class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2): super().__init__() self.channels = channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = conv_nd(dims, channels, channels, 3, padding=1) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2): super().__init__() self.channels = channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) else: self.op = avg_pool_nd(stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. """ def __init__( self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, use_checkpoint=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( normalization(channels), SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.emb_layers = nn.Sequential( SiLU(), linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), SiLU(), nn.Dropout(p=dropout), zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x, emb): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) def _forward(self, x, emb): h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ def __init__(self, channels, num_heads=1, use_checkpoint=False): super().__init__() self.channels = channels self.num_heads = num_heads self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = conv_nd(1, channels, channels * 3, 1) self.attention = QKVAttention() self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) def _forward(self, x): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) h = self.attention(qkv) h = h.reshape(b, -1, h.shape[-1]) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) class QKVAttention(nn.Module): """ A module which performs QKV attention. """ def forward(self, qkv): """ Apply QKV attention. :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. :return: an [N x C x T] tensor after attention. """ ch = qkv.shape[1] // 3 q, k, v = th.split(qkv, ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) return th.einsum("bts,bcs->bct", weight, v) @staticmethod def count_flops(model, _x, y): """ A counter for the `thop` package to count the operations in an attention operation. Meant to be used like: macs, params = thop.profile( model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops}, ) """ b, c, *spatial = y[0].shape num_spatial = int(np.prod(spatial)) # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this model will be class-conditional with `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. """ def __init__( self, data_adapters, image_size=32, in_channels=3, model_channels=128, out_channels=128, num_res_blocks=3, attention_resolutions=[8, 16], dropout=0, channel_mult=(1, 2, 2, 2), conv_resample=True, dims=2, skip=True, num_classes=None, use_checkpoint=False, num_heads=4, num_heads_upsample=-1, use_scale_shift_norm=False, project_input=False, ): super().__init__() self.input_adapter = data_adapters["input_adapter"] self.output_adapter = data_adapters["output_adapter"] if num_heads_upsample == -1: num_heads_upsample = num_heads self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.num_heads = num_heads self.num_heads_upsample = num_heads_upsample self.skip = skip self.project_input = project_input if project_input: self.input_projection = nn.Linear(self.in_channels, self.model_channels) in_channels = self.model_channels time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.input_blocks = nn.ModuleList( [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] ) input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: layers.append(AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads)) self.input_blocks.append(TimestepEmbedSequential(*layers)) input_block_chans.append(ch) if level != len(channel_mult) - 1: self.input_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))) input_block_chans.append(ch) ds *= 2 self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): layers = [ ResBlock( ch + input_block_chans.pop(), time_embed_dim, dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, ) ) if level and i == num_res_blocks: layers.append(Upsample(ch, conv_resample, dims=dims)) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self.out = nn.Sequential( normalization(ch), SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) def convert_to_fp16(self): """ Convert the torso of the model to float16. """ self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) self.output_blocks.apply(convert_module_to_f16) def convert_to_fp32(self): """ Convert the torso of the model to float32. """ self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) @property def inner_dtype(self): """ Get the dtype used by the torso of the model. """ return next(self.input_blocks.parameters()).dtype def forward( self, data: th.Tensor, t: th.Tensor, ) -> th.Tensor: """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ y = None flat_x = self.input_adapter(data, t) x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.in_channels) if self.project_input: x = self.input_projection(x) x_perm = x.permute(0, 3, 1, 2).contiguous() timesteps = t.flatten(start_dim=1)[:, 0] * 4000 assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) h = x_perm.type(self.inner_dtype) for module in self.input_blocks: h = module(h, emb) hs.append(h) h = self.middle_block(h, emb) for module in self.output_blocks: cat_in = th.cat([h, hs.pop()], dim=1) h = module(cat_in, emb) h = h.type(x.dtype) out = sandwich(self.out(h).permute(0, 2, 3, 1).contiguous()) if self.skip: out = th.cat([sandwich(x), out], -1) out = self.output_adapter(out) return out def get_feature_vectors(self, x, timesteps, y=None): """ Apply the model and return all of the intermediate tensors. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] Tensor of labels, if class-conditional. :return: a dict with the following keys: - 'down': a list of hidden state tensors from downsampling. - 'middle': the tensor of the output of the lowest-resolution block in the model. - 'up': a list of hidden state tensors from upsampling. """ hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) result = dict(down=[], up=[]) h = x.type(self.inner_dtype) for module in self.input_blocks: h = module(h, emb) hs.append(h) result["down"].append(h.type(x.dtype)) h = self.middle_block(h, emb) result["middle"] = h.type(x.dtype) for module in self.output_blocks: cat_in = th.cat([h, hs.pop()], dim=1) h = module(cat_in, emb) result["up"].append(h.type(x.dtype)) return result ================================================ FILE: networks/unet_vdm.py ================================================ # Source: https://github.com/addtt/variational-diffusion-models # # MIT License # # Copyright (c) 2022 Andrea Dittadi # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # Modifications: # - Added data_adapters to UNetVDM to preprocess the inputs and postprocess the outputs # - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps` # - Added 1/1000 to t before computing timesteps embeddings so t isn't 0 # - Added concatenation of input and output of the network before the final projection import numpy as np import torch from torch import einsum, nn, pi, softmax from utils_model import sandwich @torch.no_grad() def zero_init(module: nn.Module) -> nn.Module: """Sets to zero all the parameters of a module, and returns the module.""" for p in module.parameters(): nn.init.zeros_(p.data) return module class UNetVDM(nn.Module): def __init__( self, data_adapters, embedding_dim: int = 128, n_blocks: int = 32, n_attention_heads: int = 1, dropout_prob: float = 0.1, norm_groups: int = 32, input_channels: int = 3, use_fourier_features: bool = True, attention_everywhere: bool = False, image_size: int = 32, ): super().__init__() self.input_adapter = data_adapters["input_adapter"] self.output_adapter = data_adapters["output_adapter"] attention_params = dict( n_heads=n_attention_heads, n_channels=embedding_dim, norm_groups=norm_groups, ) resnet_params = dict( ch_in=embedding_dim, ch_out=embedding_dim, condition_dim=4 * embedding_dim, dropout_prob=dropout_prob, norm_groups=norm_groups, ) if use_fourier_features: self.fourier_features = FourierFeatures() self.embed_conditioning = nn.Sequential( nn.Linear(embedding_dim, embedding_dim * 4), nn.SiLU(), nn.Linear(embedding_dim * 4, embedding_dim * 4), nn.SiLU(), ) total_input_ch = input_channels if use_fourier_features: total_input_ch *= 1 + self.fourier_features.num_features self.conv_in = nn.Conv2d(total_input_ch, embedding_dim, 3, padding=1) # Down path: n_blocks blocks with a resnet block and maybe attention. self.down_blocks = nn.ModuleList( UpDownBlock( resnet_block=ResnetBlock(**resnet_params), attention_block=AttentionBlock(**attention_params) if attention_everywhere else None, ) for _ in range(n_blocks) ) self.mid_resnet_block_1 = ResnetBlock(**resnet_params) self.mid_attn_block = AttentionBlock(**attention_params) self.mid_resnet_block_2 = ResnetBlock(**resnet_params) # Up path: n_blocks+1 blocks with a resnet block and maybe attention. resnet_params["ch_in"] *= 2 # double input channels due to skip connections self.up_blocks = nn.ModuleList( UpDownBlock( resnet_block=ResnetBlock(**resnet_params), attention_block=AttentionBlock(**attention_params) if attention_everywhere else None, ) for _ in range(n_blocks + 1) ) self.conv_out = nn.Sequential( nn.GroupNorm(num_groups=norm_groups, num_channels=embedding_dim), nn.SiLU(), zero_init(nn.Conv2d(embedding_dim, embedding_dim, 3, padding=1)), ) self.embedding_dim = embedding_dim self.input_channels = input_channels self.image_size = image_size self.use_fourier_features = use_fourier_features def forward( self, data: torch.Tensor, t: torch.Tensor, ) -> torch.Tensor: flat_x = self.input_adapter(data, t) x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.input_channels) x_perm = x.permute(0, 3, 1, 2).contiguous() t = t.float().flatten(start_dim=1)[:, 0] t_embedding = get_timestep_embedding(t + 0.001, self.embedding_dim) # We will condition on time embedding. cond = self.embed_conditioning(t_embedding) h = self.maybe_concat_fourier(x_perm) h = self.conv_in(h) # (B, embedding_dim, H, W) hs = [] for down_block in self.down_blocks: # n_blocks times hs.append(h) h = down_block(h, cond) hs.append(h) h = self.mid_resnet_block_1(h, cond) h = self.mid_attn_block(h) h = self.mid_resnet_block_2(h, cond) for up_block in self.up_blocks: # n_blocks+1 times h = torch.cat([h, hs.pop()], dim=1) h = up_block(h, cond) out = sandwich(self.conv_out(h).permute(0, 2, 3, 1).contiguous()) out = torch.cat([sandwich(x), out], -1) out = self.output_adapter(out) return out def maybe_concat_fourier(self, z): if self.use_fourier_features: return torch.cat([z, self.fourier_features(z)], dim=1) return z class ResnetBlock(nn.Module): def __init__( self, ch_in, ch_out=None, condition_dim=None, dropout_prob=0.0, norm_groups=32, ): super().__init__() ch_out = ch_in if ch_out is None else ch_out self.ch_out = ch_out self.condition_dim = condition_dim self.net1 = nn.Sequential( nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in), nn.SiLU(), nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), ) if condition_dim is not None: self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False)) self.net2 = nn.Sequential( nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out), nn.SiLU(), nn.Dropout(dropout_prob), zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)), ) if ch_in != ch_out: self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1) def forward(self, x, condition): h = self.net1(x) if condition is not None: assert condition.shape == (x.shape[0], self.condition_dim) condition = self.cond_proj(condition) condition = condition[:, :, None, None] h = h + condition h = self.net2(h) if x.shape[1] != self.ch_out: x = self.skip_conv(x) assert x.shape == h.shape return x + h def get_timestep_embedding( timesteps, embedding_dim: int, dtype=torch.float32, max_timescale=10_000, min_timescale=1, ): # Adapted from tensor2tensor and VDM codebase. assert timesteps.ndim == 1 assert embedding_dim % 2 == 0 timesteps *= 1000.0 # In DDPM the time step is in [0, 1000], here [0, 1] num_timescales = embedding_dim // 2 inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n)) -np.log10(min_timescale), -np.log10(max_timescale), num_timescales, device=timesteps.device, ) emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2) return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D) class FourierFeatures(nn.Module): def __init__(self, first=5.0, last=6.0, step=1.0): super().__init__() self.freqs_exponent = torch.arange(first, last + 1e-8, step) @property def num_features(self): return len(self.freqs_exponent) * 2 def forward(self, x): assert len(x.shape) >= 2 # Compute (2pi * 2^n) for n in freqs. freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device) # (F, ) freqs = 2.0**freqs_exponent * 2 * pi # (F, ) freqs = freqs.view(-1, *([1] * (x.dim() - 1))) # (F, 1, 1, ...) # Compute (2pi * 2^n * x) for n in freqs. features = freqs * x.unsqueeze(1) # (B, F, X1, X2, ...) features = features.flatten(1, 2) # (B, F * C, X1, X2, ...) # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W). return torch.cat([features.sin(), features.cos()], dim=1) def attention_inner_heads(qkv, num_heads): """Computes attention with heads inside of qkv in the channel dimension. Args: qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where: H = number of heads, C = number of channels per head. num_heads: number of heads. Returns: Attention output of shape (B, H*C, T). """ bs, width, length = qkv.shape ch = width // (3 * num_heads) # Split into (q, k, v) of shape (B, H*C, T). q, k, v = qkv.chunk(3, dim=1) # Rescale q and k. This makes them contiguous in memory. scale = ch ** (-1 / 4) # scale with 4th root = scaling output by sqrt q = q * scale k = k * scale # Reshape qkv to (B*H, C, T). new_shape = (bs * num_heads, ch, length) q = q.view(*new_shape) k = k.view(*new_shape) v = v.reshape(*new_shape) # Compute attention. weight = einsum("bct,bcs->bts", q, k) # (B*H, T, T) weight = softmax(weight.float(), dim=-1).to(weight.dtype) # (B*H, T, T) out = einsum("bts,bcs->bct", weight, v) # (B*H, C, T) return out.reshape(bs, num_heads * ch, length) # (B, H*C, T) class Attention(nn.Module): """Based on https://github.com/openai/guided-diffusion.""" def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): assert qkv.dim() >= 3, qkv.dim() assert qkv.shape[1] % (3 * self.n_heads) == 0 spatial_dims = qkv.shape[2:] qkv = qkv.view(*qkv.shape[:2], -1) # (B, 3*H*C, T) out = attention_inner_heads(qkv, self.n_heads) # (B, H*C, T) return out.view(*out.shape[:2], *spatial_dims).contiguous() class AttentionBlock(nn.Module): """Self-attention residual block.""" def __init__(self, n_heads, n_channels, norm_groups): super().__init__() assert n_channels % n_heads == 0 self.layers = nn.Sequential( nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels), nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1), # (B, 3 * C, H, W) Attention(n_heads), zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)), ) def forward(self, x): return self.layers(x) + x class UpDownBlock(nn.Module): def __init__(self, resnet_block, attention_block=None): super().__init__() self.resnet_block = resnet_block self.attention_block = attention_block def forward(self, x, cond): x = self.resnet_block(x, cond) if self.attention_block is not None: x = self.attention_block(x) return x ================================================ FILE: probability.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import functools from abc import abstractmethod from torch.distributions.normal import Normal from torch.distributions.categorical import Categorical as torch_Categorical from torch.distributions.bernoulli import Bernoulli as torch_Bernoulli from torch.distributions.mixture_same_family import MixtureSameFamily from torch.distributions.uniform import Uniform from math import log from utils_model import ( safe_exp, safe_log, idx_to_float, float_to_idx, quantize, sandwich, ) class CtsDistribution: @abstractmethod def log_prob(self, x): pass @abstractmethod def sample(self): pass class DiscreteDistribution: @property @abstractmethod def probs(self): pass @functools.cached_property def log_probs(self): return safe_log(self.probs) @functools.cached_property def mean(self): pass @functools.cached_property def mode(self): pass @abstractmethod def log_prob(self, x): pass @abstractmethod def sample(self): pass class DiscretizedDistribution(DiscreteDistribution): def __init__(self, num_bins, device): self.num_bins = num_bins self.bin_width = 2.0 / num_bins self.half_bin_width = self.bin_width / 2.0 self.device = device @functools.cached_property def class_centres(self): return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device) @functools.cached_property def class_boundaries(self): return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device) @functools.cached_property def mean(self): return (self.probs * self.class_centres).sum(-1) @functools.cached_property def mode(self): mode_idx = self.probs.argmax(-1).flatten() return self.class_centres[mode_idx].reshape(self.probs.shape[:-1]) class DiscretizedCtsDistribution(DiscretizedDistribution): def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5): super().__init__(num_bins, device) self.cts_dist = cts_dist self.log_bin_width = log(self.bin_width) self.batch_dims = batch_dims self.clip = clip self.min_prob = min_prob @functools.cached_property def probs(self): bdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims))) bdry_slice = bdry_cdfs[:1] if self.clip: cdf_min = torch.zeros_like(bdry_slice) cdf_max = torch.ones_like(bdry_slice) bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0) return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1) else: cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1) cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice)) bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0) cdf_range = cdf_max - cdf_min cdf_mask = cdf_range < self.min_prob cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range) probs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs) return probs.moveaxis(0, -1) def prob(self, x): class_idx = float_to_idx(x, self.num_bins) centre = idx_to_float(class_idx, self.num_bins) cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width) cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width) if self.clip: cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo) cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi) return cdf_hi - cdf_lo else: cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1) cdf_max = self.cts_dist.cdf(torch.ones_like(centre)) cdf_range = cdf_max - cdf_min cdf_mask = cdf_range < self.min_prob cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range) prob = (cdf_hi - cdf_lo) / cdf_range return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob) def log_prob(self, x): prob = self.prob(x) return torch.where( prob < self.min_prob, self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width, safe_log(prob), ) def sample(self, sample_shape=torch.Size([])): if self.clip: return quantize(self.cts_dist.sample(sample_shape), self.num_bins) else: assert hasattr(self.cts_dist, "icdf") cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1) cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min)) u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape) cts_samp = self.cts_dist.icdf(u) return quantize(cts_samp, self.num_bins) class GMM(MixtureSameFamily): def __init__(self, mix_wt_logits, means, std_devs): mix_wts = torch_Categorical(logits=mix_wt_logits, validate_args=False) components = Normal(means, std_devs, validate_args=False) super().__init__(mix_wts, components, validate_args=False) class DiscretizedGMM(DiscretizedCtsDistribution): def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): assert params.size(-1) % 3 == 0 if min_std_dev < 0: min_std_dev = 1.0 / (num_bins * 5) mix_wt_logits, means, std_devs = params.chunk(3, -1) if log_dev: std_devs = safe_exp(std_devs) std_devs = std_devs.clamp(min=min_std_dev, max=max_std_dev) super().__init__( cts_dist=GMM(mix_wt_logits, means, std_devs), num_bins=num_bins, device=params.device, batch_dims=params.ndim - 1, clip=clip, min_prob=min_prob, ) class DiscretizedNormal(DiscretizedCtsDistribution): def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): assert params.size(-1) == 2 if min_std_dev < 0: min_std_dev = 1.0 / (num_bins * 5) mean, std_dev = params.split(1, -1)[:2] if log_dev: std_dev = safe_exp(std_dev) std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev) super().__init__( cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False), num_bins=num_bins, device=params.device, batch_dims=params.ndim - 1, clip=clip, min_prob=min_prob, ) class Bernoulli(DiscreteDistribution): def __init__(self, logits): self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False) @functools.cached_property def probs(self): p = self.bernoulli.probs.unsqueeze(-1) return torch.cat([1 - p, p], -1) @functools.cached_property def mode(self): return self.bernoulli.mode def log_prob(self, x): return self.bernoulli.log_prob(x.float()) def sample(self, sample_shape=torch.Size([])): return self.bernoulli.sample(sample_shape) class DiscretizedBernoulli(DiscretizedDistribution): def __init__(self, logits): super().__init__(2, logits.device) self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False) @functools.cached_property def probs(self): p = self.bernoulli.probs.unsqueeze(-1) return torch.cat([1 - p, p], -1) @functools.cached_property def mode(self): return idx_to_float(self.bernoulli.mode, 2) def log_prob(self, x): return self.bernoulli.log_prob(float_to_idx(x, 2).float()) def sample(self, sample_shape=torch.Size([])): return idx_to_float(self.bernoulli.sample(sample_shape), 2) class DeltaDistribution(CtsDistribution): def __init__(self, mean, clip_range=1.0): if clip_range > 0: mean = mean.clip(min=-clip_range, max=clip_range) self.mean = mean @functools.cached_property def mode(self): return self.mean @functools.cached_property def mean(self): return self.mean def sample(self, sample_shape=torch.Size([])): return self.mean class Categorical(DiscreteDistribution): def __init__(self, logits): self.categorical = torch_Categorical(logits=logits, validate_args=False) self.n_classes = logits.size(-1) @functools.cached_property def probs(self): return self.categorical.probs @functools.cached_property def mode(self): return self.categorical.mode def log_prob(self, x): return self.categorical.log_prob(x) def sample(self, sample_shape=torch.Size([])): return self.categorical.sample(sample_shape) class DiscretizedCategorical(DiscretizedDistribution): def __init__(self, logits=None, probs=None): assert (logits is not None) or (probs is not None) if logits is not None: super().__init__(logits.size(-1), logits.device) self.categorical = torch_Categorical(logits=logits, validate_args=False) else: super().__init__(probs.size(-1), probs.device) self.categorical = torch_Categorical(probs=probs, validate_args=False) @functools.cached_property def probs(self): return self.categorical.probs @functools.cached_property def mode(self): return idx_to_float(self.categorical.mode, self.num_bins) def log_prob(self, x): return self.categorical.log_prob(float_to_idx(x, self.num_bins)) def sample(self, sample_shape=torch.Size([])): return idx_to_float(self.categorical.sample(sample_shape), self.num_bins) class CtsDistributionFactory: @abstractmethod def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution: """Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory.""" pass class GMMFactory(CtsDistributionFactory): def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True): self.min_std_dev = min_std_dev self.max_std_dev = max_std_dev self.log_dev = log_dev def get_dist(self, params, input_params=None, t=None): mix_wt_logits, means, std_devs = params.chunk(3, -1) if self.log_dev: std_devs = safe_exp(std_devs) std_devs = std_devs.clamp(min=self.min_std_dev, max=self.max_std_dev) return GMM(mix_wt_logits, means, std_devs) class NormalFactory(CtsDistributionFactory): def __init__(self, min_std_dev=1e-3, max_std_dev=10): self.min_std_dev = min_std_dev self.max_std_dev = max_std_dev def get_dist(self, params, input_params=None, t=None): mean, log_std_dev = params.split(1, -1)[:2] std_dev = safe_exp(log_std_dev).clamp(min=self.min_std_dev, max=self.max_std_dev) return Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False) class DeltaFactory(CtsDistributionFactory): def __init__(self, clip_range=1.0): self.clip_range = clip_range def get_dist(self, params, input_params=None, t=None): return DeltaDistribution(params.squeeze(-1), self.clip_range) class DiscreteDistributionFactory: @abstractmethod def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution: """Note: input_params and t are only required by PredDistToDataDistFactory.""" pass class BernoulliFactory(DiscreteDistributionFactory): def get_dist(self, params, input_params=None, t=None): return Bernoulli(logits=params.squeeze(-1)) class CategoricalFactory(DiscreteDistributionFactory): def get_dist(self, params, input_params=None, t=None): return Categorical(logits=params) class DiscretizedBernoulliFactory(DiscreteDistributionFactory): def get_dist(self, params, input_params=None, t=None): return DiscretizedBernoulli(logits=params.squeeze(-1)) class DiscretizedCategoricalFactory(DiscreteDistributionFactory): def get_dist(self, params, input_params=None, t=None): return DiscretizedCategorical(logits=params) class DiscretizedGMMFactory(DiscreteDistributionFactory): def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): self.num_bins = num_bins self.clip = clip self.min_std_dev = min_std_dev self.max_std_dev = max_std_dev self.min_prob = min_prob self.log_dev = log_dev def get_dist(self, params, input_params=None, t=None): return DiscretizedGMM( params, num_bins=self.num_bins, clip=self.clip, min_std_dev=self.min_std_dev, max_std_dev=self.max_std_dev, min_prob=self.min_prob, log_dev=self.log_dev, ) class DiscretizedNormalFactory(DiscreteDistributionFactory): def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): self.num_bins = num_bins self.clip = clip self.min_std_dev = min_std_dev self.max_std_dev = max_std_dev self.min_prob = min_prob self.log_dev = log_dev def get_dist(self, params, input_params=None, t=None): return DiscretizedNormal( params, num_bins=self.num_bins, clip=self.clip, min_std_dev=self.min_std_dev, max_std_dev=self.max_std_dev, min_prob=self.min_prob, log_dev=self.log_dev, ) def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tensor, input_mean: torch.Tensor, t: torch.Tensor, min_variance: float, min_t=1e-6): """Convert output parameters that predict the noise added to data, to parameters that predict the data.""" data_shape = list(noise_pred_params.shape)[:-1] noise_pred_params = sandwich(noise_pred_params) input_mean = input_mean.flatten(start_dim=1) if torch.is_tensor(t): t = t.flatten(start_dim=1) else: t = (input_mean * 0) + t alpha_mask = (t < min_t).unsqueeze(-1) posterior_var = torch.pow(min_variance, t.clamp(min=min_t)) gamma = 1 - posterior_var A = (input_mean / gamma).unsqueeze(-1) B = (posterior_var / gamma).sqrt().unsqueeze(-1) data_pred_params = [] if noise_pred_params.size(-1) == 1: noise_pred_mean = noise_pred_params elif noise_pred_params.size(-1) == 2: noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1) else: assert noise_pred_params.size(-1) % 3 == 0 mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1) data_pred_params.append(mix_wt_logits) data_pred_mean = A - (B * noise_pred_mean) data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean) data_pred_params.append(data_pred_mean) if noise_pred_params.size(-1) >= 2: noise_pred_dev = safe_exp(noise_pred_log_dev) data_pred_dev = B * noise_pred_dev data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev) data_pred_params.append(data_pred_dev) data_pred_params = torch.cat(data_pred_params, -1) data_pred_params = data_pred_params.reshape(data_shape + [-1]) return data_pred_params class PredDistToDataDistFactory(DiscreteDistributionFactory): def __init__(self, data_dist_factory, min_variance, min_t=1e-6): self.data_dist_factory = data_dist_factory self.data_dist_factory.log_dev = False self.min_variance = min_variance self.min_t = min_t def get_dist(self, params, input_params, t): data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t) return self.data_dist_factory.get_dist(data_pred_params) ================================================ FILE: sample.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from omegaconf import OmegaConf, DictConfig from utils_train import seed_everything, make_config, make_bfn torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True def main(cfg: DictConfig) -> torch.Tensor: """ Config entries: seed (int): Optional config_file (str): Name of config file containing model and data config for a saved checkpoint load_model (str): Path to a saved checkpoint to be tested sample_shape (list): Shape of sample batch, e.g.: (3, 256) for sampling 3 sequences of length 256 from the text8 model. (2, 32, 32, 3) for sampling 2 images from the CIFAR10 model. (4, 28, 28, 1) for sampling 4 images from the MNIST model. n_steps (int): Number of sampling steps (positive integer). save_file (str): File path to save the generated sample tensor. Skip saving if None. """ seed_everything(cfg.seed) print(f"Seeded everything with seed {cfg.seed}") # Get model config from the training config file train_cfg = make_config(cfg.config_file) bfn = make_bfn(train_cfg.model) bfn.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu")) if torch.cuda.is_available(): bfn.to("cuda") samples = bfn.sample(cfg.samples_shape, cfg.n_steps) if cfg.save_file is not None: torch.save(samples.to("cpu"), cfg.save_file) return samples if __name__ == "__main__": main(OmegaConf.from_cli()) ================================================ FILE: test.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Tuple import torch from omegaconf import OmegaConf, DictConfig from rich import print from torch import nn from torch.utils.data import DataLoader from data import make_datasets from model import BFN from utils_train import seed_everything, make_config, make_bfn, worker_init_function, make_progress_bar torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True def setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]: test_ds = make_datasets(cfg.data)[-1] test_dl = DataLoader( dataset=test_ds, worker_init_fn=worker_init_function, batch_size=100, shuffle=False, num_workers=8, pin_memory=True, ) model = make_bfn(cfg.model) return model, test_dl @torch.inference_mode() def test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: int) -> tuple[float, float, float, float]: if torch.cuda.is_available(): model.to("cuda") model.eval() losses, recon_losses = [], [] pbar = make_progress_bar(True, "[red]loss: {task.fields[loss]:.4f} repeat: {task.fields[r]}") with pbar: task_id = pbar.add_task("Test", visible=True, total=n_repeats * len(dataloader), loss=math.nan, r=0) for r in range(n_repeats): _losses, _recon_losses = [], [] for eval_batch in dataloader: eval_batch = eval_batch.to("cuda") if torch.cuda.is_available() else eval_batch loss = model(eval_batch, n_steps=n_steps).item() recon_loss = model.compute_reconstruction_loss(eval_batch).item() _losses.append(loss) _recon_losses.append(recon_loss) pbar.update(task_id, advance=1, loss=torch.tensor(_losses).mean() + torch.tensor(_recon_losses).mean(), r=r+1) losses.append(torch.tensor(_losses).mean()) recon_losses.append(torch.tensor(_recon_losses).mean()) losses = torch.stack(losses) loss_mean, loss_err = losses.mean(), losses.std(correction=0).item() / math.sqrt(len(losses)) recon_losses = torch.stack(recon_losses) recon_mean, recon_err = recon_losses.mean(), recon_losses.std(correction=0).item() / math.sqrt(len(recon_losses)) return loss_mean, loss_err, recon_mean, recon_err def main(cfg: DictConfig) -> tuple[float, float, float, float]: """ Config entries: seed (int): Optional config_file (str): Name of config file containing model and data config for a saved checkpoint load_model (str): Path to a saved checkpoint to be tested n_steps (int): Number of Bayesian flow steps. Set to None for continuous time Bayesian flow loss. n_repeats (int): Number of times to iterate through the dataset. """ seed_everything(cfg.seed) print(f"Seeded everything with seed {cfg.seed}") # Get model and data config from the training config file train_cfg = make_config(cfg.config_file) model, dataloader = setup(train_cfg) model.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu")) loss_mean, loss_err, recon_mean, recon_err = test(model, dataloader, cfg.n_steps, cfg.n_repeats) print(f"For {cfg.n_steps} steps with {cfg.n_repeats} repeats:") print(f"Loss is {loss_mean:.6f} +- {loss_err:.6f}") print(f"Reconstruction Loss is {recon_mean:.6f} +- {recon_err:.6f}") print(f"Total loss mean = {loss_mean + recon_mean}") return loss_mean, loss_err, recon_mean, recon_err if __name__ == "__main__": main(OmegaConf.from_cli()) ================================================ FILE: train.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import logging import math from collections import defaultdict from pathlib import Path from typing import Optional, Tuple import torch from accelerate import Accelerator from accelerate.logging import get_logger from omegaconf import OmegaConf from rich.logging import RichHandler from rich.progress import Progress from torch import nn, optim from torch.utils.data import DataLoader from model import BFN from utils_train import ( seed_everything, log_cfg, checkpoint_training_state, init_checkpointing, log, update_ema, ddict, make_infinite, make_progress_bar, make_config, make_dataloaders, make_bfn, ) torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True logging.basicConfig( level=logging.INFO, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True, show_time=False)], ) logger = get_logger(__name__) def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]: """Create the model, dataloader and optimizer""" dataloaders = make_dataloaders(cfg) model = make_bfn(cfg.model) if "weight_decay" in cfg.optimizer.keys() and hasattr(model.net, "get_optim_groups"): params = model.net.get_optim_groups(cfg.optimizer.weight_decay) else: params = model.net.parameters() # Instantiate the optimizer using the hyper-parameters in the config optimizer = optim.AdamW(params=params, **cfg.optimizer) return model, dataloaders, optimizer @torch.no_grad() def validate( cfg, model: BFN, ema_model: nn.Module, val_dataloader: DataLoader, step: int, run: "neptune.Run", pbar: Optional[Progress], best_val_loss: float, checkpoint_root_dir: Optional[Path], accelerator: Accelerator, ) -> float: """Evaluate model on validation data and save checkpoint if loss improves""" dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[accelerator.mixed_precision] model_to_eval = ema_model if ema_model is not None else model model_to_eval.eval() pbar = pbar or Progress() max_steps = cfg.max_val_batches if cfg.max_val_batches > 0 else len(val_dataloader) val_id = pbar.add_task("Validating", visible=True, total=cfg.val_repeats * max_steps, transient=True, loss=math.nan) loss, count = 0.0, 0 for i in range(cfg.val_repeats): for idx, eval_batch in enumerate(val_dataloader): enabled = True if dtype in [torch.float16, torch.bfloat16] else False with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): loss += model_to_eval(eval_batch.to(accelerator.device)).item() count += 1 pbar.update(val_id, advance=1, loss=loss / count) if (idx + 1) >= max_steps: break loss /= count pbar.remove_task(val_id) log(run["metrics"]["val"]["loss"], loss, step) if checkpoint_root_dir is not None and (loss < best_val_loss or math.isinf(best_val_loss)): logger.info(f"loss improved: new value is {loss}") step_checkpoint_path = checkpoint_root_dir / "best" run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch() checkpoint_training_state(step_checkpoint_path, accelerator, ema_model, step, run_id) run["metrics/best/loss/metric"] = loss run["metrics/best/loss/step"] = step model.train() return loss def train( cfg, accelerator: Accelerator, model: BFN, ema_model: Optional[nn.Module], dataloaders: dict, optimizer: optim.Optimizer, run: "neptune.Run", ): is_main = accelerator.is_main_process pbar = make_progress_bar(is_main) run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch() train_id = pbar.add_task(f"Training {run_id}", start=cfg.start_step, total=cfg.n_training_steps, loss=math.nan) checkpoint_root_dir = init_checkpointing(cfg.checkpoint_dir, run_id) if is_main else None best_val_loss = math.inf train_iter = make_infinite(dataloaders["train"]) model.train() with pbar: for step in range(cfg.start_step, cfg.n_training_steps + 1): step_loss = 0.0 for _ in range(cfg.accumulate): with accelerator.accumulate(model): train_batch = next(train_iter) loss = model(train_batch) accelerator.backward(loss) if accelerator.sync_gradients and cfg.grad_clip_norm > 0: accelerator.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm) optimizer.step() optimizer.zero_grad(set_to_none=True) step_loss += loss.item() update_ema(ema_model, model, cfg.ema_decay) if is_main and (step % cfg.checkpoint_interval == 0): checkpoint_training_state(checkpoint_root_dir / "last", accelerator, ema_model, step, run_id) run["checkpoints/last"].track_files(str(checkpoint_root_dir / "last")) log(run["metrics"]["train"]["loss"], step_loss / cfg.accumulate, step, is_main and step % cfg.log_interval == 0) log(run["metrics"]["epoch"], step // len(dataloaders["train"]), step, is_main) if is_main and (step % cfg.val_interval == 0) and "val" in dataloaders: val_loss = validate( cfg=cfg, model=model, ema_model=ema_model, val_dataloader=dataloaders["val"], step=step, run=run, pbar=pbar, best_val_loss=best_val_loss, checkpoint_root_dir=checkpoint_root_dir, accelerator=accelerator, ) best_val_loss = min(val_loss, best_val_loss) pbar.update(train_id, advance=1, loss=loss.item()) def main(cfg): acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate) seed_everything(cfg.training.seed) logger.info(f"Seeded everything with seed {cfg.training.seed}", main_process_only=True) with acc.main_process_first(): model, dataloaders, optimizer = setup(cfg) ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None # EMA on main proc only model, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"]) run = ddict() if acc.is_main_process: ema.to(acc.device) try: if cfg.meta.neptune: import neptune run = neptune.init_run(project=cfg.meta.neptune, mode="debug" if cfg.meta.debug else None) run["accelerate"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes) log_cfg(cfg, run) except ImportError: logger.info("Did not find neptune installed. Logging will be disabled.") train(cfg.training, acc, model, ema, dataloaders, optimizer, run) if __name__ == "__main__": cfg_file = OmegaConf.from_cli()['config_file'] main(make_config(cfg_file)) ================================================ FILE: utils_model.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import numpy as np import torch from torch import Tensor CONST_log_range = 20 CONST_log_min = 1e-10 CONST_summary_rescale = 10 CONST_exp_range = 10 CONST_min_std_dev = math.exp(-CONST_exp_range) def sandwich(x: Tensor): return x.reshape(x.size(0), -1, x.size(-1)) def safe_log(data: Tensor): return data.clamp(min=CONST_log_min).log() def safe_exp(data: Tensor): return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp() def idx_to_float(idx: np.ndarray, num_bins: int): flt_zero_one = (idx + 0.5) / num_bins return (2.0 * flt_zero_one) - 1.0 def float_to_idx(flt: np.ndarray, num_bins: int): flt_zero_one = (flt / 2.0) + 0.5 return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long() def quantize(flt, num_bins: int): return idx_to_float(float_to_idx(flt, num_bins), num_bins) def pe_encode(sequence_length: int, embedding_size: int) -> Tensor: """Positional encoding as described in original attention is all you need paper""" pe = torch.zeros((sequence_length, embedding_size)) pos = torch.arange(sequence_length).unsqueeze(1) pe[:, 0::2] = torch.sin( pos / torch.pow(1000, torch.arange(0, embedding_size, 2, dtype=torch.float32) / embedding_size) ) pe[:, 1::2] = torch.cos( pos / torch.pow(1000, torch.arange(1, embedding_size, 2, dtype=torch.float32) / embedding_size) ) return pe def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> Tensor: pe = torch.zeros(list(x.shape) + [embedding_size], device=x.device) pos = (((x + 1) / 2) * max_freq).unsqueeze(-1) pe[..., 0::2] = torch.sin( pos / torch.pow(10000, torch.arange(0, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size) ) pe[..., 1::2] = torch.cos( pos / torch.pow(10000, torch.arange(1, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size) ) return pe ================================================ FILE: utils_train.py ================================================ # Copyright 2023 NNAISENSE SA # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import math import random import tempfile from collections import defaultdict from pathlib import Path from typing import Optional, Generator, Union try: import neptune from neptune.utils import stringify_unsupported except ImportError: neptune = None def stringify_unsupported(x): return x import numpy as np import torch from accelerate.logging import get_logger from omegaconf import OmegaConf, DictConfig from rich.progress import Progress, SpinnerColumn, MofNCompleteColumn, TimeElapsedColumn, TextColumn from torch.utils.data import DataLoader import model import networks import probability from data import make_datasets from networks import adapters logger = get_logger(__name__) def seed_everything(seed: Optional[int]): assert seed is not None seed += torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def worker_init_function(worker_id: int) -> None: """https://pytorch.org/docs/stable/notes/randomness.html#dataloader""" worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) def init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: str) -> Optional[Path]: if checkpoint_dir is None: return None checkpoint_dir = Path(checkpoint_dir) / run_id checkpoint_dir.mkdir(parents=True, exist_ok=True) last_dir = checkpoint_dir / "last" last_dir.mkdir(parents=True, exist_ok=True) best_dir = checkpoint_dir / "best" best_dir.mkdir(parents=True, exist_ok=True) return checkpoint_dir def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, step: int, run_id: str): if checkpoint_dir is None: return logger.info(f"Checkpointing training state to {checkpoint_dir} at step {step}") accelerator.save_state(checkpoint_dir) with open(checkpoint_dir / "info.json", "w") as f: json.dump({"step": step, "run_id": run_id}, f) if ema_model is not None: ema_checkpoint_path = checkpoint_dir / "ema_model.pt" torch.save(ema_model.state_dict(), ema_checkpoint_path) def log(key_handler, value, step, cond=True): """Log series to neptune only if cond is True. Helps with distributed training and conditional logging.""" if not isinstance(key_handler, defaultdict) and cond and math.isfinite(value): key_handler.log(value, step=step) def log_cfg(cfg, run: "neptune.Run"): with tempfile.TemporaryDirectory() as tmpdir: cfg_temp_filename: Path = Path(tmpdir) / "cfg.yaml" cfg_temp_filename.write_text(OmegaConf.to_yaml(cfg, resolve=True)) run["cfg"].upload(str(cfg_temp_filename), wait=True) run["hyperparameters"] = stringify_unsupported(OmegaConf.to_container(cfg, resolve=True)) @torch.no_grad() def update_ema(ema_model, model, ema_decay): if ema_model is not None and ema_decay > 0: for ema_param, model_param in zip(ema_model.parameters(), model.parameters()): ema_param.sub_((1 - ema_decay) * (ema_param - model_param)) def ddict(): """Infinite default dict to fake neptune run on non-main processes""" return defaultdict(ddict) def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]: while True: for data in dataloader: yield data def make_progress_bar(is_main: bool, text="[red]loss: {task.fields[loss]:.3f}"): return Progress( SpinnerColumn(), MofNCompleteColumn(), *Progress.get_default_columns(), TimeElapsedColumn(), TextColumn(text), disable=not is_main, ) def make_dataloaders(cfg: DictConfig): train_set, val_set, _ = make_datasets(cfg.data) dataloaders = { "train": DataLoader( dataset=train_set, worker_init_fn=worker_init_function, **cfg.train_loader, ), "val": DataLoader( dataset=val_set, worker_init_fn=worker_init_function, **cfg.val_loader, ), } return dataloaders def make_from_cfg(module, cfg, **parameters): return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else None def make_bfn(cfg: DictConfig): data_adapters = { "input_adapter": make_from_cfg(adapters, cfg.input_adapter), "output_adapter": make_from_cfg(adapters, cfg.output_adapter), } net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters) bayesian_flow = make_from_cfg(model, cfg.bayesian_flow) distribution_factory = make_from_cfg(probability, cfg.distribution_factory) loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory) bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss) return bfn default_train_config = { "meta": { "neptune": None, "debug": False, "root_dir": ".", }, "data": { "dataset": "", "data_dir": "./data", }, "train_loader": { "batch_size": 1, "shuffle": True, "num_workers": 0, "pin_memory": True, "drop_last": True, }, "val_loader": { "batch_size": 1, "shuffle": False, "num_workers": 0, "pin_memory": True, "drop_last": False, }, "training": { "accumulate": 1, "checkpoint_dir": "./checkpoints", "checkpoint_interval": None, "ema_decay": -1, "grad_clip_norm": -1, "log_interval": 50, "max_val_batches": -1, "seed": 666, "start_step": 1, "val_repeats": 1, }, } def make_config(cfg_file: str): cli_conf = OmegaConf.load(cfg_file) # Start with default config cfg = OmegaConf.create(default_train_config) # Merge into default config cfg = OmegaConf.merge(cfg, cli_conf) return cfg