Repository: s-sahoo/duo
Branch: main
Commit: 492505208b36
Files: 110
Total size: 270.1 KB
Directory structure:
gitextract_9im7g5qx/
├── .gitignore
├── LICENSE
├── README.md
├── algo.py
├── configs/
│ ├── algo/
│ │ ├── ar.yaml
│ │ ├── d3pm.yaml
│ │ ├── distillation.yaml
│ │ ├── duo.yaml
│ │ ├── duo_base.yaml
│ │ ├── mdlm.yaml
│ │ ├── ot-finetune.yaml
│ │ └── sedd.yaml
│ ├── callbacks/
│ │ ├── checkpoint_every_n_steps.yaml
│ │ ├── checkpoint_monitor.yaml
│ │ ├── grad_record.yaml
│ │ └── learning_rate_monitor.yaml
│ ├── config.yaml
│ ├── data/
│ │ ├── ag_news.yaml
│ │ ├── cifar10.yaml
│ │ ├── fineweb-edu.yaml
│ │ ├── lambada.yaml
│ │ ├── lm1b-gpt2.yaml
│ │ ├── lm1b-streaming.yaml
│ │ ├── lm1b-wrap.yaml
│ │ ├── lm1b.yaml
│ │ ├── openwebtext-split.yaml
│ │ ├── openwebtext-streaming.yaml
│ │ ├── openwebtext.yaml
│ │ ├── ptb.yaml
│ │ ├── scientific_papers_arxiv.yaml
│ │ ├── scientific_papers_pubmed.yaml
│ │ ├── synthetic.yaml
│ │ ├── text8-crop.yaml
│ │ ├── text8.yaml
│ │ ├── wikitext103.yaml
│ │ └── wikitext2.yaml
│ ├── lr_scheduler/
│ │ ├── constant_warmup.yaml
│ │ ├── cosine_decay_warmup.yaml
│ │ └── step_scheduler.yaml
│ ├── model/
│ │ ├── medium.yaml
│ │ ├── small.yaml
│ │ ├── tiny-dimamba.yaml
│ │ ├── tiny.yaml
│ │ └── unet.yaml
│ ├── noise/
│ │ ├── cosine.yaml
│ │ └── log-linear.yaml
│ ├── prior/
│ │ └── none.yaml
│ └── strategy/
│ ├── ddp.yaml
│ └── fsdp.yaml
├── dataloader.py
├── discrete_diffusion_harness.py
├── integral/
│ ├── bert-base-uncased.pkl
│ └── gpt2.pkl
├── main.py
├── metrics.py
├── models/
│ ├── __init__.py
│ ├── dit.py
│ ├── ema.py
│ ├── unet.py
│ └── unit_test_attention.py
├── requirements.txt
├── scripts/
│ ├── distil_owt.sh
│ ├── eval_lm1b_duo.sh
│ ├── eval_owt_ar.sh
│ ├── eval_owt_duo.sh
│ ├── eval_owt_mdlm.sh
│ ├── eval_owt_sedd.sh
│ ├── fid_cifar10_duo_ancestral_cosine.sh
│ ├── fid_cifar10_duo_base_ancestral_cosine.sh
│ ├── fid_cifar10_mdlm_ancestral_cosine.sh
│ ├── gen_ppl_lm1b_ar.sh
│ ├── gen_ppl_lm1b_duo.sh
│ ├── gen_ppl_owt_ar.sh
│ ├── gen_ppl_owt_duo.sh
│ ├── gen_ppl_owt_mdlm.sh
│ ├── gen_ppl_owt_sedd.sh
│ ├── psi_samplers/
│ │ ├── cifar10/
│ │ │ ├── duo_constant_remdm.sh
│ │ │ ├── duo_max_capped_remdm.sh
│ │ │ ├── duo_max_rescale_eta.sh
│ │ │ ├── duo_psi_pc.sh
│ │ │ ├── mdlm_constant_remdm.sh
│ │ │ ├── mdlm_max_capped_remdm.sh
│ │ │ ├── mdlm_max_rescale_eta.sh
│ │ │ └── mdlm_psi_pc.sh
│ │ └── owt/
│ │ ├── duo_loop_remdm.sh
│ │ ├── duo_max_capped_remdm.sh
│ │ ├── duo_max_rescale_eta.sh
│ │ ├── mdlm_loop_remdm.sh
│ │ ├── mdlm_max_capped_remdm.sh
│ │ └── mdlm_max_rescale_eta.sh
│ ├── train_cifar10_duo_base_cosine.sh
│ ├── train_cifar10_duo_cosine.sh
│ ├── train_cifar10_mdlm_cosine.sh
│ ├── train_lm1b_ar.sh
│ ├── train_lm1b_ar_sentencepacking.sh
│ ├── train_lm1b_d3pm.sh
│ ├── train_lm1b_duo.sh
│ ├── train_lm1b_duo_sentencepacking.sh
│ ├── train_lm1b_mdlm.sh
│ ├── train_lm1b_mdlm_sentencepacking.sh
│ ├── train_owt_duo.sh
│ ├── train_owt_duo_finetune.sh
│ ├── train_owt_mdlm.sh
│ ├── train_owt_sedd.sh
│ ├── zero_shot_ar.sh
│ ├── zero_shot_duo.sh
│ ├── zero_shot_mdlm.sh
│ └── zero_shot_sedd.sh
├── trainer_base.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
.hf_cache
test/
outputs/
wandb/
watch_folder/
notes.md
grid_search.sh
*.ipynb
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 Subham Sekhar Sahoo
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# The Diffusion Duality Series
## [Chapter I (ICML 2025)](https://arxiv.org/abs/2506.10892)
By [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Justin Deschenaux](https://jdeschena.com), [Aaron Gokaslan](https://skylion007.github.io),
[Guanghan Wang](https://tech.cornell.edu/people/guanghan-wang/), [Justin Chiu](https://justinchiu.netlify.app), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)
[](https://github.com/s-sahoo/duo/tree/ch-1)
[](https://colab.research.google.com/drive/1Sf7R-dqdR6gq-H8nyZ9E3ZkyvqMTqcwq?usp=sharing)
[](https://youtu.be/FCO-nnqHOqQ?si=4eGnj5zbRgyCYWwI)
[](http://s-sahoo.github.io/duo)
[](https://arxiv.org/abs/2506.10892)
[](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875)
**Unlocks few-step generation in discrete diffusion-LLMs via the underlying Gaussian diffusion.**
## [Chapter II: Ψ-Samplers and Efficient Curriculum (ICLR 2026)](https://arxiv.org/abs/2602.21185)
By [Justin Deschenaux](https://jdeschena.com), [Caglar Gulcehre](https://www.caglar.ai),
[Subham Sekhar Sahoo](https://s-sahoo.github.io)
[](https://colab.research.google.com/drive/1uFSzrfG0KXhGcohRIfWIM2Y7V9Q7cQNA?usp=sharing)
[](http://s-sahoo.github.io/duo-ch2)
[](https://arxiv.org/abs/2602.21185)
**Uniform-state beats Masked diffusion on text and image generation!**
This repository contains the code for the two papers in the Diffusion Duality series. It includes:
- **Duo / $\text{Duo}^\text{++}$** sampling (ancestral, ReMDM, $\Psi$-samplers, greedy-tail) — [Sampling & Eval](#sampling--eval)
- Original and efficient curriculum training strategies — [Training](#training)
- Discrete Consistency Distillation (DCD) — [Distillation](#discrete-consistency-distillation)
- Baselines (AR, MDLM, SEDD, D3PM) — [Baselines](#baselines)
[Getting Started](#getting-started) | [Checkpoints](#checkpoints) | [Citation](#acknowledgements--citation)
# Getting Started
To get started, create a conda environment containing the required dependencies.
```bash
conda create -n duo python=3.12
conda activate duo
conda install nvidia/label/cuda-12.4.0::cuda-toolkit
pip install -r requirements.txt
pip install flash_attn==2.7.4.post1
```
# Checkpoints
* **Duo** (Language Modeling): Trained on OpenWebText for `1M` training steps (distilled / base):
* [Huggingface](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875)🤗.
* [Google Drive folder](https://drive.google.com/drive/folders/1JpqFM8XRvifwIkjWPfMyuDvu41r1yk0t?usp=share_link) as the HF checkpoints can't be finetuned.
* **Duo** (Image Modeling): Trained on CIFAR-10
* [Huggingface (contains the raw checkpoints)](https://huggingface.co/jdeschena/duo2-cifar10)
* **Baselines** (SEDD, MDLM, AR): Trained on OpenWebText
* [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing) — download `ar.ckpt`, `mdlm.ckpt`, `sedd.ckpt`.
# Training
This repo implements the original Duo curriculum, as well as the fast $\text{Duo}^\text{++}$ curriculum. By default, the training scripts use the original curriculum. To enable the efficient curriculum, simply replace `algo.curriculum.mode=simple` by `algo.curriculum.mode=poly9` (see comments in each training script).
To train $\text{Duo}^\text{++}$, use the following scripts:
* LM1B
* w/ sentencepacking (same as in D3PM)
* Training script: [`scripts/train_lm1b_duo_sentencepacking.sh`](./scripts/train_lm1b_duo_sentencepacking.sh)
* [Wandb run](https://api.wandb.ai/links/kuleshov-group/huwt0ek3)
* w/o sentencepacking (same as in MDLM, SEDD)
* Training script: [`scripts/train_lm1b_duo.sh`](./scripts/train_lm1b_duo.sh)
* [Wandb run](https://api.wandb.ai/links/sahoo-diffusion/lkv5z3tm)
* OWT: [`scripts/train_owt_duo.sh`](./scripts/train_owt_duo.sh).
* CIFAR-10:
* Duo: [`scripts/train_cifar10_duo_cosine.sh`](./scripts/train_cifar10_duo_cosine.sh)
* MDLM: [`scripts/train_cifar10_mdlm_cosine.sh`](./scripts/train_cifar10_mdlm_cosine.sh)
* Both scripts default to a cosine noise schedule. To use log-linear instead, set `noise=log-linear`.
**Notes:**
* Run `mkdir watch_folder` to create a directory to store slurm logs,
and then run any script in [`scripts/`](scripts) as a slurm job: `sbatch scripts/ABC_XYZ.sh`
* Control the batch size per GPU using the argument `loader.batch_size`. If `loader.batch_size * num_gpus < loader.global_batch_size`, PyTorch Lightning resorts to gradient accumulation.
# Discrete Consistency Distillation
To distill a model using the Discrete Consistency Distillation (`Alg. 1` in the [Duo paper](https://arxiv.org/abs/2506.10892)), use [`scripts/distil_owt.sh`](scripts/distil_owt.sh).
# Sampling & Eval
## Likelihood
To compute test perplexity on the validation set of OWT use [`scripts/eval_owt_duo.sh`](scripts/eval_owt_duo.sh) and for zero shot perplexities use [`scripts/zero_shot_duo.sh`](scripts/zero_shot_duo.sh).
## Sampling
You can sample with ancestral sampling using the scripts in [`scripts/gen_ppl_*.sh`](scripts/). To sample with the PC samplers such as ReMDM and our $\Psi$-samplers, use the scripts in [`scripts/psi_samplers`](scripts/psi_samplers). This directory contains examples for sampling text and images.
To use the "Greedy-tail sampler" (equivalent to nucleus sampling in AR models; see `Sec. 4.2` in the paper), set `sampling.noise_removal=greedy`. Using the default `sampling.noise_removal=ancestral` will produce more diverse samples (higher entropy) but with worse generative perplexity.
To sample from a HuggingFace checkpoint (text only), run the following command:
```bash
python main.py \
mode=sample_eval \
loader.batch_size=2 \
loader.eval_batch_size=8 \
data=openwebtext-split \
algo=duo_base \
algo.backbone=hf_dit \
eval.checkpoint_path=s-sahoo/duo-distilled \
sampling.steps=8 \
sampling.num_sample_batches=1 \
sampling.noise_removal=greedy \
+wandb.offline=true
```
To use the example scripts with raw checkpoints (see [Checkpoints](#checkpoints)), download them and set the checkpoint path in the script.
# Baselines
Download the baseline checkpoints (see [Checkpoints](#checkpoints)) and specify the paths appropriately in the respective shell scripts:
* [`scripts/eval_owt_*.sh`](scripts/) for computing validation perplexity on OWT.
* [`scripts/gen_ppl_*.sh`](scripts/) for generating text samples and evaluating them.
* [`scripts/zero_shot_*.sh`](scripts/) for computing zero shot perplexities.
* [`scripts/train_*.sh`](scripts/) for training the models.
# Acknowledgements & Citation
This repository was built off of [MDLM's Github repository](https://github.com/kuleshov-group/mdlm). Cite our papers using:
```
@inproceedings{
sahoo2025the,
title={The Diffusion Duality},
author={Subham Sekhar Sahoo and Justin Deschenaux and Aaron Gokaslan and Guanghan Wang and Justin T Chiu and Volodymyr Kuleshov},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=9P9Y8FOSOk}
}
@inproceedings{
deschenaux2026the,
title={The Diffusion Duality, Chapter {II}: \${\textbackslash}Psi\$-Samplers and Efficient Curriculum},
author={Justin Deschenaux and Caglar Gulcehre and Subham Sekhar Sahoo},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=RSIoYWIzaP}
}
```
================================================
FILE: algo.py
================================================
import os
import collections
import copy
import pickle
from typing import Optional
import fsspec
import numpy as np
import torch
import torch.nn.functional as F
import trainer_base
import utils
class AR(trainer_base.TrainerBase):
def __init__(self, config, tokenizer):
vocab_size = tokenizer.vocab_size
if (not hasattr(tokenizer, 'mask_token')
or tokenizer.mask_token is None):
self.mask_index = vocab_size
vocab_size += 1
else:
self.mask_index = tokenizer.mask_token_id
super().__init__(config, tokenizer,
vocab_size=vocab_size)
self.save_hyperparameters()
self._validate_configuration()
def _validate_configuration(self):
super()._validate_configuration()
assert not self.config.algo.time_conditioning
assert self.config.prior.type == 'none'
def _process_model_input(self, x0, valid_tokens):
input_tokens = x0[:, :-1]
output_tokens = x0[:, 1:]
valid_tokens = valid_tokens[:, 1:]
return input_tokens, output_tokens, valid_tokens
def nll(self, input_tokens, output_tokens,
current_accumulation_step):
del current_accumulation_step
output = self.backbone(input_tokens, None)
output[:, :, self.mask_index] = self.neg_infinity
output = output.log_softmax(-1)
return - output.gather(
-1, output_tokens[:, :, None])[:, :, 0]
def generate_samples(self, num_samples, **kwargs):
# precompute token buffer
num_pred_tokens = self.num_tokens - 1
x = torch.zeros(
(num_samples, num_pred_tokens + 1),
dtype=torch.long,
device=self.device)
x[:, 0] = self.tokenizer.bos_token_id
# precompute noise
noise = (torch.distributions.Gumbel(0, 1)
.sample((num_samples, num_pred_tokens, self.vocab_size))
.to(self.device))
if self.config.sampling.use_float64:
noise = noise.to(torch.float64)
for i in range(num_pred_tokens):
output = self.backbone(x[:, :i + 1], None)
output[:, :, self.mask_index] = self.neg_infinity
output = output.log_softmax(-1)
y = (output[:, -1, :] + noise[:, i, :]).argmax(-1)
x[:, i + 1] = y
return x
def _process_sigma(self, sigma):
del sigma
return None
class MDLM(trainer_base.AbsorbingState):
def __init__(self, config, tokenizer):
super().__init__(config, tokenizer)
self._validate_configuration()
def _validate_configuration(self):
assert self.sampler != 'ancestral', \
'sampling.predictor=ancestral is not desirable because ' \
'it is slow. Please set sampling.predictor=ancestral_cache'
def _process_model_output(self, model_output, xt, sigma):
del sigma
model_output[:, :, self.mask_index] += self.neg_infinity
# Normalize the model_output such that x.exp() is
# a probability distribution over vocab_size.
model_output = model_output - torch.logsumexp(
model_output, dim=-1, keepdim=True)
# Apply updates directly in the logits matrix.
# For the logits of the unmasked tokens, set all values
# to -infinity except for the indices corresponding to
# the unmasked tokens.
unmasked_indices = (xt != self.mask_index)
model_output[unmasked_indices] = self.neg_infinity
model_output[unmasked_indices, xt[unmasked_indices]] = 0
return model_output
def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
dalpha_t, low_var=False):
del xt
log_p_theta = torch.gather(
input=log_x_theta,
dim=-1,
index=x0[:, :, None]).squeeze(-1)
return log_p_theta * dalpha_t / (1 - alpha_t)
def _get_score(self, x, sigma):
model_output = self.forward(x, sigma)
# score(x, t) = p_t(y) / p_t(x)
# => log score(x, t) = log p_t(y) - log p_t(x)
# case 1: x = masked
# (i) y = unmasked
# log score(x, t) = log p_\theta(x)|_y + log k
# where k = exp(- sigma) / (1 - exp(- sigma))
# (ii) y = masked
# log score(x, t) = 0
# case 2: x = unmasked
# (i) y != masked, y != x
# log score(x_i, t) = - inf
# (ii) y = x
# log score(x_i, t) = 0
# (iii) y = masked token
# log score(x_i, t) = - log k
# where k = exp(- sigma) / (1 - exp(- sigma))
log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
assert log_k.ndim == 1
masked_score = model_output + log_k[:, None, None]
masked_score[:, :, self.mask_index] = 0
unmasked_score = self.neg_infinity * torch.ones_like(
model_output)
unmasked_score = torch.scatter(
unmasked_score,
-1,
x[..., None],
torch.zeros_like(unmasked_score[..., :1]))
unmasked_score[:, :, self.mask_index] = - (
log_k[:, None] * torch.ones_like(x))
masked_indices = (x == self.mask_index).to(
model_output.dtype)[:, :, None]
model_output = (
masked_score * masked_indices
+ unmasked_score * (1 - masked_indices))
return model_output.exp()
class D3PMAbsorb(trainer_base.AbsorbingState):
def __init__(self, config, tokenizer):
super().__init__(config, tokenizer)
self._validate_configuration()
def _validate_configuration(self):
super()._validate_configuration()
assert self.noise.type == 'log-linear'
assert self.parameterization == 'mean'
def _process_model_output(self, model_output, xt, sigma):
del xt
del sigma
if self.subs_masking:
model_output[:, :, self.mask_index] += self.neg_infinity
return model_output.log_softmax(dim=-1)
def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
dalpha_t, low_var=False):
del dalpha_t
assert not low_var
dt = 1 / self.T
t = 1 - alpha_t # Only valid for log-linear schedule.
t = t.clamp(0., 1.0 - 1e-4)
alpha_t = alpha_t + torch.zeros_like(xt)
alpha_s = t - dt + torch.zeros_like(xt)
assert alpha_s.shape == xt.shape
assert alpha_t.shape == xt.shape
log_x_theta_at_x0 = torch.gather(
log_x_theta, -1, x0[:, :, None]).squeeze(-1)
log_x_theta_at_m = log_x_theta[:, :, self.mask_index]
x_theta_at_m = log_x_theta_at_m.exp()
term_1_coef = dt / t
term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
term_1_log_dr = log_x_theta_at_x0
term_2_coef = 1 - dt / t
term_2_log_nr = term_1_log_nr
term_2_log_dr = torch.log(
alpha_s * x_theta_at_m / (t - dt) + 1)
L_vb_masked = (
term_1_coef * (term_1_log_nr - term_1_log_dr)
+ term_2_coef * (term_2_log_nr - term_2_log_dr))
diffusion_loss = self.T * L_vb_masked * (xt == self.mask_index)
return self._reconstruction_loss(x0) + diffusion_loss
class SEDDAbsorb(trainer_base.AbsorbingState):
def __init__(self, config, tokenizer):
super().__init__(config, tokenizer)
self._validate_configuration()
def _validate_configuration(self):
super()._validate_configuration()
assert self.config.sampling.predictor == 'analytic'
def _get_score(self, x, sigma):
return self.forward(x, sigma).exp()
def _process_model_output(self, model_output, xt, sigma):
esigm1_log = torch.where(
sigma < 0.5,
torch.expm1(sigma),
sigma.exp() - 1).log().to(model_output.dtype)
# logits shape
# (batch_size, context_length, vocab_size)
model_output = (model_output
- esigm1_log[:, None, None]
- np.log(model_output.shape[-1] - 1))
# The below scatter operation sets the log score
# for the input word to 0.
model_output = torch.scatter(
model_output, -1, xt[..., None],
torch.zeros_like(model_output[..., :1]))
return model_output
def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
dalpha_t, low_var=False):
"""Computes the SEDD loss for the Absorbing State Diffusion.
Args:
log_x_theta: float torch.Tensor with shape (batch_size,
context_length, vocab_size),
log score, output of the denoising network.
xt: int torch.Tensor with shape (batch_size,
context_length), input.
x0: int torch.Tensor with shape (batch_size,
context_length), input.
alpha_t: float torch.Tensor with shape (batch_size, 1),
signal level.
alpha_t: float torch.Tensor with shape (batch_size, 1),
signal level.
dalpha_t: float or float torch.Tensor with shape (batch_size, 1),
time derivative of signal level.
low_var: bool, low variance loss during training.
Returns:
loss with shape (batch_size, context_length).
"""
assert not low_var
masked_indices = xt == self.mask_index
sigma = self._sigma_from_alphat(alpha_t)
dsigma = - dalpha_t / alpha_t
expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
q_ratio = 1 / expsig_minus_1[masked_indices]
words_that_were_masked = x0[masked_indices]
neg_term = q_ratio * torch.gather(
log_x_theta[masked_indices],
-1,
words_that_were_masked[..., None]).squeeze(-1)
score = log_x_theta[masked_indices].exp()
if self.mask_index == self.vocab_size - 1:
pos_term = score[:, :-1].sum(dim=-1)
else:
pos_term = score[:, : self.mask_index].sum(
dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
const = q_ratio * (q_ratio.log() - 1)
entropy = torch.zeros(* xt.shape, device=xt.device)
entropy[masked_indices] += pos_term - neg_term + const
return dsigma * entropy
class DUO_BASE(trainer_base.UniformState):
def __init__(self, config, tokenizer):
super().__init__(config, tokenizer)
self._validate_configuration()
def on_save_checkpoint(self, checkpoint):
checkpoint['state_dict'] = collections.OrderedDict(
(k, v) for k, v in checkpoint['state_dict'].items()
if not k.startswith('teacher'))
super().on_save_checkpoint(checkpoint)
def on_load_checkpoint(self, checkpoint):
checkpoint['state_dict'] = collections.OrderedDict(
(k, v) for k, v in checkpoint['state_dict'].items()
if not k.startswith('teacher'))
super().on_load_checkpoint(checkpoint)
def _process_model_output(self, model_output, xt, sigma):
del xt, sigma
return model_output.log_softmax(dim=-1)
def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):
"""Computes the posterior / approximate posterior.
Args:
x: Either clean input `x0` (one-hot),
or model's predicted `x_theta` of shape (B, L, V).
xt: The noisy latent (as indices) of shape (B, L).
alpha_s: Noise level at s of shape (B, [L | 1], 1).
alpha_t: Noise level at t of shape (B, [L | 1], 1).
Returns:
Posterior / approximate posterior of shape (B, L, V).
"""
if self.config.sampling.use_float64:
x0 = x0.to(torch.float64)
if alpha_s.ndim == 2:
alpha_s = alpha_s.unsqueeze(-1)
if alpha_t.ndim == 2:
alpha_t = alpha_t.unsqueeze(-1)
alpha_ts = alpha_t / alpha_s
d_alpha = alpha_s - alpha_t
xt_one_hot = F.one_hot(xt, self.vocab_size).to(
self.dtype).to(self.device)
return (
(alpha_t * self.vocab_size * x0 * xt_one_hot + (
alpha_ts - alpha_t) * xt_one_hot + d_alpha * x0 + (
1 - alpha_ts) * (1 - alpha_s) / self.vocab_size) / (
alpha_t * self.vocab_size * torch.gather(
x0, -1, xt[..., None]) + (1 - alpha_t)))
def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
dalpha_t, low_var=False):
assert alpha_t.ndim == 2
assert x0.ndim == 2
assert xt.ndim == 2
assert not torch.is_tensor(dalpha_t) or dalpha_t.ndim == 2
x_reconst = log_x_theta.exp()
x_bar_theta = self.vocab_size * alpha_t[
:, :, None] * x_reconst + 1 - alpha_t[:, :, None]
coeff = dalpha_t / (self.vocab_size * alpha_t)
x_eq_xt = (x0 == xt).float()
x_neq_xt = 1 - x_eq_xt
xbar_xt = (1 - alpha_t) + self.vocab_size * alpha_t * x_eq_xt
xbar_theta_xt = torch.gather(
x_bar_theta, -1, xt.unsqueeze(-1)).squeeze(-1)
xbar_theta_x = torch.gather(
x_bar_theta, -1, x0.unsqueeze(-1)).squeeze(-1)
term1 = self.vocab_size * (1 / xbar_xt
- 1 / xbar_theta_xt)
const = (1 - alpha_t) / (self.vocab_size * alpha_t
+ 1 - alpha_t)
term2_coefs = x_eq_xt * const + x_neq_xt
term2_offset = ((self.vocab_size - 1) * const * x_eq_xt
- (1 / const) * x_neq_xt) * const.log()
term2_theta = - term2_coefs * (
x_bar_theta.log().sum(-1)
- self.vocab_size * xbar_theta_xt.log())
term2_theta = (
term2_theta
- self.vocab_size * alpha_t / (1 - alpha_t) * (
xbar_theta_x.log() - xbar_theta_xt.log()) * x_neq_xt)
term2 = term2_theta + term2_offset
diffusion_loss = coeff * (term1 - term2)
assert diffusion_loss.ndim == 2
return diffusion_loss
class Integral(torch.autograd.Function):
"""
torch module calculating UDLM's p_t
"""
@staticmethod
def forward(ctx, gamma_t, data):
gamma_max = data['gamma_max']
gamma_min = data['gamma_min']
if (gamma_t.max() > gamma_max) or (
gamma_t.min() < gamma_min):
print('max:{} {}'.format(gamma_t.max(), gamma_max))
print('min:{} {}'.format(gamma_t.min(), gamma_min))
gamma_t = torch.clip(gamma_t, gamma_min, gamma_max)
indices = torch.round(
(data['num_points'] - 1) * (gamma_t - gamma_min) / (
gamma_max - gamma_min)).long()
grad_pt = data['grad_pt']
ctx.grad_pt = grad_pt[indices]
pt = data['pt'][indices]
assert pt.shape == gamma_t.shape
return pt
@staticmethod
def backward(ctx, grad_output):
return ctx.grad_pt * grad_output, None
class DUO(DUO_BASE):
def __init__(self, config, tokenizer):
super().__init__(config, tokenizer)
self.gamma_min = self.config.algo.curriculum.gamma_min
self.gamma_max = self.config.algo.curriculum.gamma_max
self.gumbel_tau_log10_start = \
self.config.algo.curriculum.gumbel_tau_log10_start
self.gumbel_tau_log10_end = \
self.config.algo.curriculum.gumbel_tau_log10_end
self.curriculum_start = self.config.algo.curriculum.start
self.curriculum_end = self.config.algo.curriculum.end
self.loss_type = self.config.algo.loss_type
self._initialize_curriculum_coefficients()
self._validate_configuration()
def _initialize_curriculum_coefficients(self):
if self.config.algo.curriculum.mode in {'simple',
'efficient_cached'}:
self._init_curriculum_cached()
elif self.config.algo.curriculum.mode == 'series':
self._init_curriculum_series()
elif self.config.algo.curriculum.mode in {'sigmoid',
'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7',
'poly9'}:
self._init_curriculum_approx()
else:
raise ValueError(self.config.algo.curriculum.mode)
def _init_curriculum_cached(self):
fpath = self.config.algo.curriculum.integral_cache_path
with fsspec.open(fpath, 'rb') as f:
self.integral_cache = pickle.load(f)
self.integral_cache['pt'] = torch.from_numpy(
self.integral_cache['pt'])
self.integral_cache['grad_pt'] = torch.from_numpy(
self.integral_cache['grad_pt'])
def _init_curriculum_series(self):
m, i = utils.compute_duo_series_coefficients(
self.config.algo.curriculum.n_series_terms,
self.vocab_size)
self.register_buffer('coefficients_m', m,
persistent=False)
self.register_buffer('coefficients_i', i,
persistent=False)
self.register_buffer('power_arange',
torch.arange(self.config.algo.curriculum.n_series_terms,
dtype=torch.float64)[None], persistent=False)
def _init_curriculum_approx(self):
fname = f'{self.config.algo.curriculum.mode}.npy'
fpath = os.path.join(self.config.algo.curriculum.cache_dir,
fname)
if not os.path.exists(fpath):
# Compute the coefficients on the fly
coefficients, _, _, _ = utils.compute_duo_operator_approx(
num_coefficients=self.config.algo.curriculum.n_series_terms,
vocab_size=self.vocab_size,
gamma_min=self.gamma_min,
gamma_max=self.gamma_max,
fct_name=self.config.algo.curriculum.mode)
# Tuples are for torch compile, tuples are immutable
coefficients = tuple(coefficients)
parent_dir = os.path.dirname(fpath)
os.makedirs(parent_dir, exist_ok=True)
np.save(fpath, coefficients)
else:
coefficients = tuple(np.load(fpath).tolist())
mode = self.config.algo.curriculum.mode
if mode == 'sigmoid':
fn = utils.duo_to_alpha_dalpha_sigmoid
elif mode == 'sigmoid-edge-corrected':
fn = utils.duo_t_to_alpha_dalpha_sigm_corrected
elif mode in ('poly3', 'poly5', 'poly7', 'poly9'):
fn = utils.duo_to_alpha_dalpha_poly
else:
raise ValueError(mode)
fn = torch.compile(fn)
self._t_to_alpha_dalpha_compiled = \
lambda t: fn(t, *coefficients)
def to(self, *args, **kwargs):
self = super().to(*args, **kwargs)
self.integral_cache['pt'] = self.integral_cache[
'pt'].to(*args, **kwargs)
self.integral_cache['grad_pt'] = self.integral_cache[
'grad_pt'].to(*args, **kwargs)
return self
def cuda(self, device=None):
self = super().cuda(device=device)
if hasattr(self, 'integral_cache'):
self.integral_cache['pt'] = self.integral_cache[
'pt'].cuda(device=device)
self.integral_cache['grad_pt'] = self.integral_cache[
'grad_pt'].cuda(device=device)
return self
def cpu(self):
self = super().cpu()
if hasattr(self, 'integral_cache'):
self.integral_cache['pt'] = self.integral_cache[
'pt'].cpu()
self.integral_cache['grad_pt'] = self.integral_cache[
'grad_pt'].cpu()
return self
def to(self, *args, **kwargs):
self = super().to(*args, **kwargs)
if hasattr(self, 'integral_cache'):
self.integral_cache['pt'] = self.integral_cache[
'pt'].to(*args, **kwargs)
self.integral_cache['grad_pt'] = self.integral_cache[
'grad_pt'].to(*args, **kwargs)
return self
def _compute_gumbel_tau_inverse(self):
start = self.gumbel_tau_log10_start
end = self.gumbel_tau_log10_end
delta = end - start
if self.global_step < self.curriculum_start:
tau = start
elif self.global_step < self.curriculum_end:
frac = (self.global_step - self.curriculum_start) / (
self.curriculum_end - self.curriculum_start)
tau = start + frac * delta
else:
tau = -10
return 10 ** (-tau)
def training_step(self, batch, batch_idx):
self.log(name='gumbel_tau_log10',
value=1 / self._compute_gumbel_tau_inverse(),
on_step=True,
on_epoch=False,
sync_dist=True)
return super().training_step(batch, batch_idx)
def _gamma_to_alpha_dalpha(self, gamma_t, t):
if self.config.algo.curriculum.mode in ('simple',
'efficient_cached'):
return self._gamma_to_alpha_dalpha_cached(gamma_t)
elif self.config.algo.curriculum.mode == 'series':
return utils.compute_duo_gamma_to_alpha_dalpha_series(
gamma_t, self.coefficients_m, self.coefficients_i,
self.power_arange, self.vocab_size, self.gamma_min,
self.gamma_max)
elif self.config.algo.curriculum.mode in ('sigmoid',
'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7',
'poly9'):
return self._t_to_alpha_dalpha_compiled(t)
else:
raise ValueError(self.config.algo.curriculum.mode)
def _gamma_to_alphat_integral(self, gamma_t):
integral = Integral.apply(gamma_t, self.integral_cache)
return (self.vocab_size * integral - 1) / (
self.vocab_size - 1)
def _gamma_to_alpha_dalpha_cached(self, gamma_t):
gamma_t_prime = self.gamma_max - self.gamma_min
usdm_alpha_t = DUO._gamma_to_alphat_integral(self, gamma_t)
T = 1000
usdm_dalpha_t = gamma_t_prime * T * (
DUO._gamma_to_alphat_integral(self, gamma_t + 1 / T)
- usdm_alpha_t)
return usdm_alpha_t, usdm_dalpha_t
def _prior_loss(self):
alpha_1 = self._gamma_to_alphat_integral(
torch.tensor(self.gamma_max))
loss = ((alpha_1 + (1 - alpha_1) / self.vocab_size) \
* torch.log((self.vocab_size - 1) * alpha_1 + 1) \
+ (1 - 1 / self.vocab_size) * (1 - alpha_1) \
* torch.log(1 - alpha_1))
return loss.item()
def _q_xt_gaussian(self, x, gamma_t):
"""Computes the noisy sample xt."""
assert gamma_t.ndim == 1
assert x.ndim == 3
gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1)
alpha_t = torch.sigmoid(-gamma_t).sqrt()
sigma_t = torch.sigmoid(gamma_t).sqrt()
epsilon = torch.randn(x.shape, dtype=torch.float32,
device=self.device)
return alpha_t * x + sigma_t * epsilon
def nll(self, x0, output_tokens,
current_accumulation_step=None, train_mode=False):
use_true_nll = (self.global_step > self.curriculum_end
or not train_mode)
if use_true_nll:
return super().nll(x0, output_tokens,
current_accumulation_step)
del output_tokens
t = self._sample_t(x0.shape[0], current_accumulation_step)
gamma_t = self.gamma_min + t * (self.gamma_max
- self.gamma_min)
usdm_alpha_t, usdm_dalpha_t = \
self._gamma_to_alpha_dalpha(gamma_t, t)
usdm_alpha_t = usdm_alpha_t.unsqueeze(-1)
assert usdm_alpha_t.ndim == 2
usdm_dalpha_t = usdm_dalpha_t.unsqueeze(-1)
sigma = self._sigma_from_alphat(usdm_alpha_t)
# Default Duo curriculum
if self.config.algo.curriculum.mode == 'simple':
x0_one_hot = F.one_hot(x0, self.vocab_size)
xt = self._q_xt_gaussian(x0_one_hot, gamma_t)
xt = xt * self._compute_gumbel_tau_inverse()
xt_usdm = xt.argmax(-1)
log_x_theta = self.forward(xt, sigma=sigma)
else: # Efficient variant
softmax_approx, topk_indices, xt_usdm = \
utils.sample_tempered_softmax_topk(
extra_index=x0,
alpha=torch.sigmoid(-gamma_t).sqrt(),
sigma=torch.sigmoid(gamma_t).sqrt(),
l=x0.shape[1],
k=self.config.algo.curriculum.top_k,
vocab_size=self.vocab_size,
inverse_temperature=self._compute_gumbel_tau_inverse())
log_x_theta = self.forward(topk_indices, sigma=sigma,
weights=softmax_approx)
return self.nll_per_token(log_x_theta=log_x_theta,
xt=xt_usdm,
x0=x0,
alpha_t=usdm_alpha_t,
dalpha_t=usdm_dalpha_t,
low_var=False)
class Distillation(DUO):
def __init__(self, config, tokenizer):
super().__init__(config, tokenizer)
self.update_teacher_every = config.algo.update_teacher_every
self.save_hyperparameters()
self.teacher = None
self.teacher_ema = config.algo.teacher_ema
self.linear_growth_dt = config.algo.linear_growth_dt
self.linear_growth_min = config.algo.linear_growth_min
self.linear_growth_max = config.algo.linear_growth_max
def _validate_configuration(self):
assert os.path.exists(
self.config.algo.integral_cache_path), (
'The integral cache (Eq. 10 in the paper) for '
f'the {self.config.data.tokenizer_name_or_path} '
' tokenizer doesnt exist at '
f'{self.config.algo.integral_cache_path}. '
'Please generate it by running the utils.py script, '
'and ensure the correct path is specified using the '
'algo.integral_cache_path flag.')
assert self.loss_type in {
'kl-fwd', 'kl-bwd', 'posterior', 'kl-posterior'}
def _maybe_update_teacher_weights(self):
if self.global_step % self.update_teacher_every != 0:
return
if self.teacher_ema:
self.ema.copy_to(self.teacher.parameters())
else:
for better_param, current_param in zip(
self.backbone.parameters(), self.teacher.parameters()):
if current_param.requires_grad:
current_param.data.copy_(better_param.data)
@torch.no_grad()
def _teacher_logits(self, xt, sigma):
if self.teacher is None:
self.teacher = copy.deepcopy(self.backbone)
self._maybe_update_teacher_weights()
sigma = self._process_sigma(sigma)
with torch.amp.autocast('cuda', dtype=torch.float32):
model_output = self.teacher(xt, sigma)
logits = self._process_model_output(
model_output=model_output, xt=xt, sigma=sigma)
return logits.detach()
def _sample_trajectory(self, x0, gamma_t, gamma_s):
"""Computes the noisy sample xt."""
assert gamma_t.ndim == 1
assert gamma_s.ndim == 1
assert x0.ndim == 2
x0 = F.one_hot(x0, self.vocab_size).to(
self.dtype).to(self.device)
gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1)
alpha_t = torch.sigmoid(-gamma_t).sqrt()
sigma_t = torch.sigmoid(gamma_t).sqrt()
gamma_s = gamma_s.unsqueeze(-1).unsqueeze(-1)
alpha_s = torch.sigmoid(-gamma_s).sqrt()
sigma_s = torch.sigmoid(gamma_s).sqrt()
epsilon = torch.randn(x0.shape, dtype=torch.float32,
device=self.device)
xt = alpha_t * x0 + sigma_t * epsilon
xs = alpha_s * x0 + sigma_s * epsilon
return xt, xs
def _compute_dt(self):
if self.linear_growth_dt:
scale = self.global_step / self.trainer.max_steps
return self.linear_growth_min + scale * (
self.linear_growth_max - self.linear_growth_min)
n = self.global_step // self.update_teacher_every
return 2 ** n / self.T
def nll(self, x0, output_tokens,
current_accumulation_step=None, train_mode=None):
del output_tokens, train_mode
t = self._sample_t(x0.shape[0], current_accumulation_step)
dt = self._compute_dt()
t = torch.clip(t + dt, 0, 1)
gamma_t = self.gamma_min + t * (self.gamma_max
- self.gamma_min)
gamma_s = self.gamma_min + (
t - dt) * (self.gamma_max - self.gamma_min)
usdm_alpha_t = self._gamma_to_alphat_integral(gamma_t)
usdm_alpha_t = usdm_alpha_t.unsqueeze(-1)
assert usdm_alpha_t.ndim == 2
usdm_alpha_s = self._gamma_to_alphat_integral(gamma_s)
usdm_alpha_s = usdm_alpha_s.unsqueeze(-1)
assert usdm_alpha_s.ndim == 2
xt, xs = self._sample_trajectory(x0, gamma_t, gamma_s)
xt_discrete = xt.argmax(-1)
xs_discrete = xs.argmax(-1)
log_x_theta_student = self.forward(
xt_discrete, sigma=self._sigma_from_alphat(usdm_alpha_t))
log_x_theta_teacher = self._teacher_logits(
xs_discrete, sigma=self._sigma_from_alphat(usdm_alpha_s))
if self.config.training.loss_precision == 'float64':
log_x_theta_student = log_x_theta_student.to(torch.float64)
log_x_theta_teacher = log_x_theta_teacher.to(torch.float64)
if self.loss_type == 'kl-fwd':
return (log_x_theta_teacher.exp() * (
log_x_theta_teacher - log_x_theta_student)).sum(-1)
elif self.loss_type == 'kl-bwd':
return (log_x_theta_student.exp() * (
log_x_theta_student - log_x_theta_teacher)).sum(-1)
def training_step(self, batch, batch_idx):
self.log(name='dt',
value=self._compute_dt(),
on_step=True,
on_epoch=False,
sync_dist=True)
return super().training_step(batch, batch_idx)
================================================
FILE: configs/algo/ar.yaml
================================================
name: ar
backbone: dit
parameterization: ar
time_conditioning: False
causal_attention: True
# Irrelevant flags
T: 0
subs_masking: False
ignore_bos: False
================================================
FILE: configs/algo/d3pm.yaml
================================================
name: d3pm
backbone: dit # dit / dimamba
parameterization: mean
time_conditioning: True
T: 1000
subs_masking: False # True / False
causal_attention: False
ignore_bos: False
loss_type: elbo # elbo, low_var
================================================
FILE: configs/algo/distillation.yaml
================================================
name: distillation
backbone: dit # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
subs_masking: False
causal_attention: False
gumbel_tau_log10_start: -1
gumbel_tau_log10_end: -1
curriculum_start: -1
curriculum_end: -1
integral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl
loss_type: kl-bwd # kl-fwd, kl-bwd, posterior
update_teacher_every: 10_000
ignore_bos: False
T: 64
gamma_min: -4
gamma_max: -1
teacher_ema: False
posterior_loss_weight: 0.0
linear_growth_dt: False
linear_growth_min: 0.001
linear_growth_max: 0.25
================================================
FILE: configs/algo/duo.yaml
================================================
name: duo
backbone: dit # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
T: 0 # 0 (continuous time) / 1000
subs_masking: False
causal_attention: False
loss_type: elbo
ignore_bos: False
# Curriculum arguments
curriculum:
# Simple is the original duo curriculum, poly9 is the one we use in duo2 for speed
mode: simple # simple / efficient_cached / series / sigmoid / sigmoid-edge-corrected / poly3 / poly5 / poly7 / poly9
# Arguments for the simple & efficient variant
gumbel_tau_log10_start: -1.0
gumbel_tau_log10_end: -2.0
start: 100_000
end: 200_000
gamma_min: -3.5
gamma_max: -1.75
# Argument for the simple and efficient_cached variants
integral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl
# Arguments for the efficient variant only
top_k: -1
cache_dir: './transform_approx_coefficients'
n_series_terms: 150
================================================
FILE: configs/algo/duo_base.yaml
================================================
name: duo_base
backbone: dit # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
T: 0 # 0 (continuous time) / 1000
subs_masking: False
causal_attention: False
ignore_bos: False
loss_type: elbo # elbo, low_var
================================================
FILE: configs/algo/mdlm.yaml
================================================
name: mdlm
backbone: dit # dit / dimamba / hf_dit
parameterization: subs
time_conditioning: False
T: 0 # 0 (continuous time) / 1000
subs_masking: False
causal_attention: False
ignore_bos: False
loss_type: elbo
================================================
FILE: configs/algo/ot-finetune.yaml
================================================
name: ot-finetune
backbone: dit # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
T: 0 # 0 (continuous time) / 1000
subs_masking: False
causal_attention: False
ignore_bos: False
delta_ts: 0.5
================================================
FILE: configs/algo/sedd.yaml
================================================
name: sedd
backbone: dit # dit / dimamba
parameterization: score
time_conditioning: True
T: 0 # 0 (continuous time) / 1000
subs_masking: False
causal_attention: False
ignore_bos: False
loss_type: elbo # elbo, low_var
================================================
FILE: configs/callbacks/checkpoint_every_n_steps.yaml
================================================
checkpoint_every_n_steps:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
dirpath: ${checkpointing.save_dir}/checkpoints
verbose: True
auto_insert_metric_name: False
every_n_train_steps: 500
================================================
FILE: configs/callbacks/checkpoint_monitor.yaml
================================================
checkpoint_monitor:
_target_: lightning.pytorch.callbacks.ModelCheckpoint
monitor: val/nll # name of the logged metric which determines when model is improving
mode: min # can be "max" or "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: False # True = additionally always save model from last epoch
dirpath: ${checkpointing.save_dir}/checkpoints
filename: best
auto_insert_metric_name: False
verbose: True
================================================
FILE: configs/callbacks/grad_record.yaml
================================================
grad_record:
_target_: utils.GradientInspectionCallback
num_grads_log: 4
================================================
FILE: configs/callbacks/learning_rate_monitor.yaml
================================================
learning_rate_monitor:
_target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: step
================================================
FILE: configs/config.yaml
================================================
defaults:
- _self_
- /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
- /data: openwebtext
- /model: small
- /strategy: ddp
- /noise: log-linear
- /lr_scheduler: constant_warmup
- /prior: none
- /algo: duo_base
mode: train # train / ppl_eval / sample_eval
seed: 1
loader:
global_batch_size: 512
eval_global_batch_size: ${.global_batch_size}
# Note: batch_size and eval_batch_size are **per machine**
batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
pin_memory: True
sampling:
predictor: ancestral # ancestral_cache (only for MDLM), ancestral, analytic, psi (for psi-samplers)
steps: 1000
noise_removal: ancestral # 'ancestral', 'greedy', 'none'
use_float64: True
p_nucleus: 1.0
num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
num_sample_log: 2
semi_ar: False
stride_length: 1
num_strides: 1
guid_weight: null # null -> no classifier-free guidance
psi:
time_profile: linear # linear / linear-constant-linear-alpha / linear-constant-linear-alpha-inv
# Note: kappa = 1 -> pure posterior, kappa = 0 -> pure PC
# modes: pure-posterior / pure-pc / constant-eta / constant-remdm-eta / max-capped-eta / max-rescale-eta
high_mode: pure-posterior
middle_mode: pure-posterior
low_mode: pure-posterior
# Fraction of [0, 1] spent in the high / middle modes (the rest being in low mode)
# Example: high_frac=0.2, middle_frac=0.6. Then,
# - For t in [1.0, 0.8] -> use high_mode
# - For t in [0.8, 0.2] -> use middle_mode
# - For t in [0.2, 0.0] -> use low_mode
high_frac: 0.2
middle_frac: 0.6
training:
ema: 0.9999
antithetic_sampling: True
importance_sampling: False
sampling_eps: 1e-3
change_of_variables: False
loss_precision: 'bf16' # bf16, float32, float64
finetune_path: ''
class_dropout_p: 0.1 # only used with class-conditional datasets (eg cifar10)
eval:
checkpoint_path: '' # Used to evaluate a checkpoint after training.
disable_ema: False
compute_generative_perplexity: False
perplexity_batch_size: 8
compute_perplexity_on_sanity: False
gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf
generate_samples: True
generated_samples_path: ${cwd:}/samples.json
optim:
weight_decay: 0
lr: 3e-4
beta1: 0.9
beta2: 0.999
eps: 1e-8
trainer:
_target_: lightning.Trainer
accelerator: cuda
num_nodes: 1
devices: ${device_count:}
accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
gradient_clip_val: 1.0
precision: 'bf16'
num_sanity_val_steps: 2
max_steps: 1_000_000
log_every_n_steps: 100
limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run
limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run
val_check_interval: 5000
check_val_every_n_epoch: null # Disable end-of-epoch validation. Instead, validate every trainer.val_check_interval steps.
wandb:
project: duo
group: null
job_type: null
name: null
id: ${.name}_${seed}
tags:
- ${noise.type}
- ${data.train}
- ${data.valid}
- ${algo.name}
hydra:
run:
dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
job:
chdir: true
checkpointing:
# Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
save_dir: ${cwd:}
# Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
resume_from_ckpt: true
resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt
================================================
FILE: configs/data/ag_news.yaml
================================================
train: ag_news
valid: ag_news
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/cifar10.yaml
================================================
train: cifar10
valid: cifar10
modality: image
tokenizer_name_or_path: cifar10
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
num_classes: 10 # 10 / null
streaming: False
size: 1024
length: 3072
wrap: False
insert_train_eos: False
insert_valid_eos: False
insert_train_spacial: False
================================================
FILE: configs/data/fineweb-edu.yaml
================================================
train: HuggingFaceFW/fineweb-edu
valid: openwebtext-valid #wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: True
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/lambada.yaml
================================================
train: lambada
valid: lambada
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/lm1b-gpt2.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/lm1b-streaming.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: bert-base-uncased
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: False
streaming: True
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/lm1b-wrap.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: bert-base-uncased
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/lm1b.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: bert-base-uncased
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: False
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/openwebtext-split.yaml
================================================
train: openwebtext-train
valid: openwebtext-valid
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/openwebtext-streaming.yaml
================================================
train: openwebtext
valid: wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /tmp/data
wrap: True
streaming: True
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/openwebtext.yaml
================================================
train: openwebtext
valid: wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/ptb.yaml
================================================
train: ptb
valid: ptb
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/scientific_papers_arxiv.yaml
================================================
train: scientific_papers_arxiv
valid: scientific_papers_arxiv
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/scientific_papers_pubmed.yaml
================================================
train: scientific_papers_pubmed
valid: scientific_papers_pubmed
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/synthetic.yaml
================================================
train: synthetic
valid: synthetic
modality: text
tokenizer_name_or_path: synthetic
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: True
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/text8-crop.yaml
================================================
# TODO: When using this dataset, set model.length = 256 to match D3PM setup
train: text8-crop
valid: text8
modality: text
tokenizer_name_or_path: text8
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/text8.yaml
================================================
# TODO: When using this dataset, set model.length = 256 to match D3PM setup
train: text8
valid: text8
modality: text
tokenizer_name_or_path: text8
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/wikitext103.yaml
================================================
train: wikitext103
valid: wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/data/wikitext2.yaml
================================================
train: wikitext2
valid: wikitext2
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True
================================================
FILE: configs/lr_scheduler/constant_warmup.yaml
================================================
_target_: transformers.get_constant_schedule_with_warmup
num_warmup_steps: 2500
================================================
FILE: configs/lr_scheduler/cosine_decay_warmup.yaml
================================================
_target_: utils.CosineDecayWarmupLRScheduler
t_in_epochs: False
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
warmup_prefix: True
warmup_lr_init: 1e-6
warmup_t: ${eval:0.1*${trainer.max_steps}}
lr_min: 1e-6
================================================
FILE: configs/lr_scheduler/step_scheduler.yaml
================================================
_target_: torch.optim.lr_scheduler.LambdaLR
lr_lambda:
_target_: utils.LRHalveScheduler
warmup_steps: 500
n_halve_steps: 10_000
================================================
FILE: configs/model/medium.yaml
================================================
name: medium
type: ddit
hidden_size: 1024
cond_dim: 128
length: 1024
n_blocks: 24
n_heads: 16
scale_by_sigma: True
dropout: 0.1
tie_word_embeddings: False
vocab_lookup: True
================================================
FILE: configs/model/small.yaml
================================================
name: small
type: ddit
hidden_size: 768
cond_dim: 128
length: 1024
n_blocks: 12
n_heads: 12
scale_by_sigma: True
dropout: 0.1
tie_word_embeddings: False
vocab_lookup: True
================================================
FILE: configs/model/tiny-dimamba.yaml
================================================
name: tiny
type: dimamba
hidden_size: 512
cond_dim: 128
length: 1024
n_blocks: 14
n_heads: 8
scale_by_sigma: True
dropout: 0.1
temb_strategy: adaln
tie_word_embeddings: False
================================================
FILE: configs/model/tiny.yaml
================================================
name: tiny
type: ddit
hidden_size: 256
cond_dim: 128
length: 1024
n_blocks: 8
n_heads: 8
scale_by_sigma: True
dropout: 0.1
tie_word_embeddings: False
vocab_lookup: True
================================================
FILE: configs/model/unet.yaml
================================================
name: unet
type: unet
ch: 128
num_res_blocks: 2
num_scales: 4
ch_mult: [1, 2, 2, 2]
input_channels: 3
output_channels: -1 # determined by vocab_size
scale_count_to_put_attn: 1 # at 16 res
data_min_max: [0, 255] # No need currently
dropout: 0.1
skip_rescale: True
time_conditioning: True # Whether to add in time embeddings
time_scale_factor: 1000
time_embed_dim: ${.ch}
fix_logistic: False
size: ${data.size}
cond_dim: ${.ch}
length: ${data.length}
================================================
FILE: configs/noise/cosine.yaml
================================================
type: cosine
eps: 1e-3
================================================
FILE: configs/noise/log-linear.yaml
================================================
type: log-linear
eps: 1e-3
================================================
FILE: configs/prior/none.yaml
================================================
type: none
latent_width: 0
latent_height: 0
================================================
FILE: configs/strategy/ddp.yaml
================================================
_target_: lightning.pytorch.strategies.DDPStrategy
find_unused_parameters: false
================================================
FILE: configs/strategy/fsdp.yaml
================================================
# TODO(yair): Currenly not compatible with grad clipping
_target_: lightning.pytorch.strategies.FSDPStrategy
sharding_strategy: SHARD_GRAD_OP
================================================
FILE: dataloader.py
================================================
import functools
import itertools
import json
import math
import os
import re
import shutil
import typing
import urllib
import zipfile
from typing import Optional
import datasets
import einops
import fsspec
import numpy as np
import requests
import tokenizers
import torch
import torchvision
from torchvision import transforms as th_transforms
import transformers
import utils
LOGGER = utils.get_logger(__name__)
class RawPixelsVisionTokenizer:
def __init__(self, vocab_size, image_size,
add_mask_token=True, add_special_tokens=True):
self.pad_token_id = None
self.pad_token = None
if add_mask_token:
self.mask_token = vocab_size
self.mask_token_id = vocab_size
self.vocab_size = vocab_size + 1 # mask token
else:
self.vocab_size = vocab_size
if add_special_tokens:
self.bos_token_id = vocab_size
self.bos_token = vocab_size
self.eos_token_id = vocab_size + 1
self.eos_token = vocab_size + 1
# mask token, bos_token, eos_token
self.vocab_size = self.vocab_size + 2
else:
self.vocab_size = self.vocab_size
self.image_size = image_size
def __call__(self, x):
return x
def batch_decode(self, x):
x = einops.rearrange(x, 'b (c h w) -> b c h w', c=3,
h=self.image_size)
x = x.to(dtype=torch.uint8)
return x
def decode(self, x):
x = einops.rearrange(x, '(c h w) -> h w c', c=3,
h=self.image_size)
x = x.to(dtype=torch.uint8)
return x
def __len__(self):
return self.vocab_size
class DiscreteCIFAR10(torch.utils.data.Dataset):
def __init__(self, cache_dir, train):
self._dataset = torchvision.datasets.CIFAR10(
root=cache_dir, train=train, download=True)
transforms = []
if train:
transforms += [th_transforms.RandomHorizontalFlip()]
transforms += [th_transforms.Lambda(
lambda x: torch.from_numpy(np.array(x))),
th_transforms.Lambda(
lambda x: einops.rearrange(x, "h w c -> (c h w)")),]
self.transform = th_transforms.Compose(transforms)
def __len__(self):
return len(self._dataset)
def __getitem__(self, index):
img, labels = self._dataset[index]
img = self.transform(img)
attention_mask = torch.ones_like(img)
return {'input_ids': img.to(torch.long), 'labels': labels,
'attention_mask': attention_mask}
def wt_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
def ptb_detokenizer(x):
x = x.replace(" 's", "'s")
x = x.replace("s ' ", "s' ")
x = x.replace(" n't", "n't")
x = x.replace(" \n ", "\n")
x = x.replace("\\/", "/")
for _ in range(10):
x = x.replace(" N ", " 1 ")
x = x.replace("$ 1", "$1")
x = x.replace("# 1", "#1")
x = x.replace("", "?")
return x
def lm1b_detokenizer(x):
x = x.replace('http : / / ', 'http://')
x = x.replace('https : / / ', 'https://')
x = re.sub(r' \'(\w+)', r"'\1", x)
x = re.sub(r' (\w+) \. ', r' \1. ', x)
x = re.sub(r' (\w+) \.$', r' \1.', x)
x = x.replace(' ? ', '? ')
x = re.sub(r' \?$', '?', x)
x = x.replace(' ! ', '! ')
x = re.sub(r' \!$', '!', x)
x = x.replace(' , ', ', ')
x = x.replace(' : ', ': ')
x = x.replace(' ; ', '; ')
x = x.replace(' / ', '/')
x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
x = x.replace('$ ', '$')
x = x.replace('£ ', '£')
return x
def lambada_detokenizer(text):
text = text.replace("“", '"')
text = text.replace("”", '"')
return '\n'+text.strip()
def scientific_papers_detokenizer(x):
x = wt_detokenizer(x)
x = lm1b_detokenizer(x)
return x
class SyntheticTokenizer(
transformers.PreTrainedTokenizer):
def __init__(
self,
vocab_size,
bos_token="[BOS]",
eos_token="[EOS]",
sep_token=None,
cls_token=None,
pad_token=None,
mask_token=None,
unk_token=None,
**kwargs):
self.tokens = []
for i in range (vocab_size - 2):
# appending space for readability
self.tokens.append(str(i) + " ")
self._vocab_str_to_int = {
'[BOS]': vocab_size - 2,
'[EOS]': vocab_size - 1,
** {ch: i for i, ch in enumerate(self.tokens)}}
self._vocab_int_to_str = {
v: k for k, v in self._vocab_str_to_int.items()}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
**kwargs)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
return list(text.lower())
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(
token, self._vocab_str_to_int['[UNK]'])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return ''.join(tokens)
def get_vocab(self) -> typing.Dict[str, int]:
return self._vocab_str_to_int
def _generate_synthetic_data(dataset_size,
seq_len, vocab_size):
dataset = np.zeros((dataset_size, seq_len), dtype=int)
# tokens representing sequence boundary
dataset[:, 0] = vocab_size - 2 # bos
dataset[:, -1] = vocab_size - 1 # eos
for i in range(dataset_size):
# sample from 0, 1, ..., vocab_size - 3
temp = np.random.randint(vocab_size - 2)
for j in reversed(range(1, seq_len - 1)):
dataset[i, j] = temp
if temp != 0:
temp = temp // 4
else:
temp = np.random.randint(vocab_size - 2)
return dataset
def generate_synthetic_dataset(train_dataset_size,
validation_dataset_size,
seq_len, vocab_size):
np.random.seed(42)
train_data = torch.from_numpy(
_generate_synthetic_data(train_dataset_size,
seq_len, vocab_size))
train_dataset = datasets.Dataset.from_dict({
'input_ids': train_data,
'attention_mask': torch.ones_like(train_data),
})
train_dataset.set_format(type='torch')
np.random.seed(41)
validation_data = torch.from_numpy(
_generate_synthetic_data(validation_dataset_size,
seq_len, vocab_size))
validation_dataset = datasets.Dataset.from_dict({
'input_ids': validation_data,
'attention_mask': torch.ones_like(validation_data),
})
validation_dataset.set_format(type='torch')
return {
'train': train_dataset,
'validation': validation_dataset,
}
class Text8Tokenizer(transformers.PreTrainedTokenizer):
def __init__(
self,
bos_token='[BOS]',
eos_token='[EOS]',
sep_token='[SEP]',
cls_token='[CLS]',
pad_token='[PAD]',
mask_token='[MASK]',
unk_token='[UNK]',
**kwargs):
self.characters = list('abcdefghijklmnopqrstuvwxyz ')
self._vocab_str_to_int = {
'[CLS]': 0,
'[SEP]': 1,
'[BOS]': 2,
'[EOS]': 3,
'[MASK]': 4,
'[PAD]': 5,
'[RESERVED]': 6,
'[UNK]': 7,
** {ch: i + 8 for i, ch in enumerate(self.characters)}}
self._vocab_int_to_str = {
v: k for k, v in self._vocab_str_to_int.items()}
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
**kwargs)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
return list(text.lower())
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(
token, self._vocab_str_to_int['[UNK]'])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return ''.join(tokens)
def get_vocab(self) -> typing.Dict[str, int]:
return self._vocab_str_to_int
def get_lambada_test_dataset():
url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"
def read_jsonl_to_list(url):
response = requests.get(url, stream=True)
data_list = []
# Process each line in the response content
for line in response.iter_lines(decode_unicode=True):
if line:
data = json.loads(line)
data_list.append(data)
return data_list
lambada_data = read_jsonl_to_list(url)
dataset = datasets.Dataset.from_list(lambada_data)
return dataset
def get_text8_dataset(cache_dir, max_seq_length=256,
drop_last=True, crop_train=False):
"""Adapted from:
https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344
Args:
cache_dir: str, path to cache directory.
max_seq_length: int, maximum length of sequences.
(default: 256, as in D3PM codebase.)
drop_last: bool, whether to drop the last incomplete
batch. (default: True, as in D3PM codebase.)
crop_train: bool, whether to subsample contiguous
subsequences from training example. serves to
make sure transformer models with absolute position
embeddings do not have incorrect position-wise
marginals. (default: False, but necessary to match D3PM AR)
Returns:
dataset: dataset.DatasetDict, with keys 'train',
'valid', 'test'.
"""
url = 'http://mattmahoney.net/dc/text8.zip'
if not crop_train:
cache_dir = f'{cache_dir}/text8'
else:
cache_dir = f'{cache_dir}/text8-crop-train'
split_names = ['train', 'validation', 'test']
if not all([
utils.fsspec_exists(os.path.join(cache_dir, split))
for split in split_names
]):
# Check if raw data exists
raw_cache_dir = os.path.join(cache_dir, 'raw_data')
if not all([
utils.fsspec_exists(
os.path.join(raw_cache_dir, f'text8.{split}.txt'))
for split in split_names
]):
if not utils.fsspec_exists(
os.path.join(raw_cache_dir, 'text8.zip')):
utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)
LOGGER.info('Downloading text8 from URL {}.'.format(url))
with (urllib.request.urlopen(url) as in_stream,
open(os.path.join(raw_cache_dir, 'text8.zip'),
'wb') as out_file):
shutil.copyfileobj(in_stream, out_file)
with fsspec.open(
os.path.join(raw_cache_dir, 'text8.zip'),
'rb') as f:
rawdata = zipfile.ZipFile(f).read(
'text8').decode('utf-8')
# Splits taken from D3PM codebase
splits = {
'train': rawdata[:90000000],
'validation': rawdata[90000000: 95000000],
'test': rawdata[95000000:],
}
for split, data in splits.items():
_path = os.path.join(raw_cache_dir,
f'text8.{split}.txt')
with fsspec.open(_path, 'w') as f:
f.write(data)
else:
splits = {}
for split in split_names:
_path = os.path.join(raw_cache_dir,
f'text8.{split}.txt')
with fsspec.open(_path, 'r') as f:
splits[split] = f.read()
# Chunk and save as datasets.DatasetDict
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
dataset_dict = {}
for k, v in splits.items():
if k == 'train' and crop_train == True:
chunk_size = 2 * max_seq_length
else:
chunk_size = max_seq_length
text = list(chunks(v, chunk_size))
if drop_last and len(text[-1]) < chunk_size:
text = text[:-1]
dataset_dict[k] = datasets.Dataset.from_dict({'text': text})
dataset = datasets.DatasetDict(dataset_dict)
dataset.save_to_disk(cache_dir)
else:
dataset = datasets.load_from_disk(cache_dir)
return dataset
def _group_texts(examples, block_size, bos, eos):
# Concatenate all texts.
concatenated_examples = list(itertools.chain(* examples['input_ids']))
total_length = len(concatenated_examples)
# TODO(yair): look into not dropping the remainder but rather padding it.
# We drop the small remainder, and if the total_length < block_size - 2
# we exclude this batch and return an empty dict.
# We could add padding if the model supported it instead of
# this drop, you can customize this part to your needs.
new_block_size = block_size - 2 # [BOS] and [EOS] to be added
total_length = (total_length // new_block_size) * new_block_size
# Split by chunks of max_len.
result = {}
_values = []
_attn_masks = []
for i in range(0, total_length, new_block_size):
_values.append(
[bos]
+ concatenated_examples[i : i + new_block_size]
+ [eos])
_attn_masks.append(torch.ones(block_size))
result['input_ids'] = _values
result['attention_mask'] = _attn_masks
return result
def get_dataset(dataset_name,
tokenizer,
wrap,
mode,
cache_dir,
insert_eos=True,
block_size=1024,
num_proc=len(os.sched_getaffinity(0)),
streaming=False,
revision : Optional[str]=None):
if dataset_name == 'cifar10':
assert mode in ('train', 'validation')
return DiscreteCIFAR10(cache_dir=cache_dir,
train=mode=='train')
eos_tag = ''
if not insert_eos:
eos_tag = '_eosFalse'
if wrap:
filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped{eos_tag}.dat'
else:
filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped{eos_tag}.dat'
_path = os.path.join(cache_dir, filename)
if utils.fsspec_exists(_path):
LOGGER.info(f'Loading data from: {_path}')
return datasets.load_from_disk(_path).with_format('torch')
LOGGER.info(f'Generating new data at: {_path}')
LOGGER.info(f'{streaming=}')
crop_train = dataset_name == 'text8-crop'
if mode == 'train' and crop_train:
# double block size for sub-sampling
block_size *= 2
if dataset_name == 'wikitext103':
dataset = datasets.load_dataset(
'wikitext',
name='wikitext-103-raw-v1',
cache_dir=cache_dir,
revision=revision)
elif dataset_name == 'wikitext2':
dataset = datasets.load_dataset(
'wikitext',
name='wikitext-2-raw-v1',
cache_dir=cache_dir,
revision=revision)
elif dataset_name == 'ptb':
dataset = datasets.load_dataset(
'ptb_text_only',
cache_dir=cache_dir,
revision=revision)
elif dataset_name == 'lambada':
dataset = get_lambada_test_dataset()
elif dataset_name == 'text8':
assert wrap
assert revision is None
dataset = get_text8_dataset(
cache_dir, max_seq_length=block_size)
elif dataset_name == 'text8-crop':
assert revision is None
dataset = get_text8_dataset(
cache_dir, max_seq_length=block_size, crop_train=True)
elif dataset_name == 'openwebtext-train':
dataset = datasets.load_dataset(
'openwebtext',
split='train[:-100000]',
cache_dir=cache_dir,
revision=revision,
streaming=False,
num_proc=num_proc)
elif dataset_name == 'openwebtext-valid':
dataset = datasets.load_dataset(
'openwebtext',
split='train[-100000:]',
cache_dir=cache_dir,
revision=revision,
streaming=False,
num_proc=num_proc)
elif dataset_name == 'scientific_papers_arxiv':
dataset = datasets.load_dataset(
'scientific_papers', 'arxiv',
cache_dir=cache_dir,
streaming=streaming,
revision=revision)
elif dataset_name == 'scientific_papers_pubmed':
dataset = datasets.load_dataset(
'scientific_papers', 'pubmed',
cache_dir=cache_dir,
streaming=streaming,
revision=revision)
elif dataset_name == 'ag_news':
dataset = datasets.load_dataset(
'ag_news',
cache_dir=cache_dir,
streaming=streaming,
revision=revision)
elif dataset_name == 'synthetic':
assert streaming
assert wrap # i.e., no pad tokens
dataset = generate_synthetic_dataset(
train_dataset_size=100000,
validation_dataset_size=1024,
seq_len=32,
vocab_size=256,
)
else:
dataset = datasets.load_dataset(
dataset_name,
cache_dir=cache_dir,
streaming=streaming,
revision=revision)
if dataset_name in ['lambada', 'openwebtext-train',
'openwebtext-valid']:
data = dataset
else:
data = dataset[mode]
if dataset_name == 'synthetic':
# already tokenized, no further actions required
return data
if dataset_name.startswith('wikitext'):
detokenizer = wt_detokenizer
elif dataset_name == 'ptb':
detokenizer = ptb_detokenizer
elif dataset_name == 'lm1b':
detokenizer = lm1b_detokenizer
elif dataset_name == 'lambada':
detokenizer = lambada_detokenizer
elif dataset_name.startswith('scientific_papers'):
detokenizer = scientific_papers_detokenizer
else:
detokenizer = None
def _apply_detokenizer(detokenizer):
def detok(text):
for i, t in enumerate(text, 0):
text[i] = detokenizer(t)
return text
return detok
EOS = tokenizer.encode(tokenizer.eos_token)[0]
BOS = tokenizer.encode(tokenizer.bos_token)[0]
def preprocess_and_tokenize(example):
if dataset_name == 'ptb':
text = example['sentence']
elif 'scientific_papers' in dataset_name:
text = example['article']
else:
text = example['text']
if detokenizer is not None:
text = _apply_detokenizer(detokenizer)(text)
tokenizer.padding_side = 'right'
tokenizer.truncation_side = 'right'
if wrap:
tokens = tokenizer(text,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False)
if insert_eos:
tokens = {'input_ids':
[t + [EOS] for t in tokens['input_ids']]}
# Still missing BOS, but will be added in group_texts
else:
tokens = tokenizer(text,
max_length=block_size,
padding='max_length',
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_token_type_ids=True)
return tokens
if streaming:
tokenized_dataset = data.map(
preprocess_and_tokenize,
batched=True)
else:
tokenized_dataset = data.map(
preprocess_and_tokenize,
batched=True,
num_proc=num_proc,
load_from_cache_file=True,
desc='Tokenizing')
if dataset_name == 'ptb':
tokenized_dataset = tokenized_dataset.remove_columns(
'sentence')
elif 'scientific_papers' in dataset_name:
tokenized_dataset = tokenized_dataset.remove_columns([
'article', 'abstract', 'section_names'])
elif dataset_name == 'ag_news':
tokenized_dataset = tokenized_dataset.remove_columns(
['text', 'label'])
else:
tokenized_dataset = tokenized_dataset.remove_columns(
'text')
if not wrap:
if not streaming:
tokenized_dataset.save_to_disk(_path)
return tokenized_dataset.with_format('torch')
group_texts = functools.partial(
_group_texts, block_size=block_size, bos=BOS, eos=EOS)
if streaming:
chunked_dataset = tokenized_dataset.map(
group_texts,
batched=True)
else:
chunked_dataset = tokenized_dataset.map(
group_texts,
batched=True,
num_proc=num_proc,
load_from_cache_file=True,
desc='Grouping')
chunked_dataset.save_to_disk(_path)
chunked_dataset = chunked_dataset.with_format('torch')
return chunked_dataset
def get_tokenizer(config):
if config.data.tokenizer_name_or_path == 'text8':
tokenizer = Text8Tokenizer()
elif config.data.tokenizer_name_or_path == 'bert-base-uncased':
tokenizer = transformers.BertTokenizer.\
from_pretrained('bert-base-uncased')
elif config.data.tokenizer_name_or_path == 'synthetic':
tokenizer = SyntheticTokenizer(vocab_size=256)
elif config.data.tokenizer_name_or_path == 'cifar10':
return RawPixelsVisionTokenizer(
vocab_size=256, image_size=32, add_special_tokens=False,
add_mask_token='mdlm' in config.algo.name)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.data.tokenizer_name_or_path)
if (isinstance(tokenizer, transformers.GPT2TokenizerFast)
or isinstance(tokenizer, transformers.GPT2Tokenizer)):
tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
(tokenizer.bos_token, tokenizer.bos_token_id),
(tokenizer.eos_token, tokenizer.eos_token_id))
# For wrapped batches:
# [BOS] sent1 [EOS] sent2-fragment [EOS]
# [BOS] sent2-fragment [EOS] sent3 [EOS]
if tokenizer.bos_token is None:
if tokenizer.cls_token is None:
raise AttributeError(
'Tokenizer must have a bos_token or '
f'cls_token: {tokenizer}')
tokenizer.bos_token = tokenizer.cls_token
if tokenizer.eos_token is None:
if tokenizer.sep_token is None:
raise AttributeError(
'Tokenizer must have a eos_token '
f'or sep_token: {tokenizer}')
tokenizer.eos_token = tokenizer.sep_token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
return tokenizer
def get_dataloaders(config, tokenizer, skip_train=False,
skip_valid=False, valid_seed=None):
num_gpus = torch.cuda.device_count()
assert (config.loader.global_batch_size
== (config.loader.batch_size
* config.trainer.num_nodes
* num_gpus
* config.trainer.accumulate_grad_batches))
if config.loader.global_batch_size % (
num_gpus * config.trainer.accumulate_grad_batches) != 0:
raise ValueError(
f'Train Batch Size {config.training.batch_size}'
f'not divisible by {num_gpus} gpus with accumulation '
f'{config.trainer.accumulate_grad_batches}.')
if config.loader.eval_global_batch_size % num_gpus != 0:
raise ValueError(
f'Eval Batch Size for {config.eval.batch_size} '
f'not divisible by {num_gpus}.')
if skip_train:
train_set = None
else:
train_set = get_dataset(
config.data.train,
tokenizer,
mode='train',
wrap=config.data.wrap,
insert_eos=config.data.insert_train_eos,
cache_dir=config.data.cache_dir,
block_size=config.model.length,
streaming=config.data.streaming,
num_proc=config.loader.num_workers,
revision=config.data.get("train_revision", None))
if config.data.valid in ['text8', 'lm1b', 'ag_news']:
validation_split = 'test'
else:
validation_split = 'validation'
if skip_valid:
valid_set = None
else:
valid_set = get_dataset(
config.data.valid,
tokenizer,
wrap=config.data.wrap,
mode=validation_split,
cache_dir=config.data.cache_dir,
insert_eos=config.data.insert_valid_eos,
block_size=config.model.length,
streaming=config.data.streaming,
num_proc=config.loader.num_workers,
revision=config.data.get("valid_revision", None))
if skip_train:
train_loader = None
else:
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=config.loader.batch_size,
num_workers=config.loader.num_workers,
pin_memory=config.loader.pin_memory,
shuffle=not config.data.streaming,
persistent_workers=True)
train_loader.tokenizer = tokenizer
if skip_valid:
valid_loader = None
else:
if valid_seed is None:
shuffle_valid = False
generator = None
else:
shuffle_valid = True
generator = torch.Generator().manual_seed(valid_seed)
valid_loader = torch.utils.data.DataLoader(
valid_set,
batch_size=config.loader.eval_batch_size,
num_workers=config.loader.num_workers,
pin_memory=config.loader.pin_memory,
shuffle=shuffle_valid,
generator=generator)
# Will be used in generative perplexity calculation
valid_loader.tokenizer = tokenizer
return train_loader, valid_loader
# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
def __init__(self, *args, generator=None, **kwargs):
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
# which should be reproducible if pl.seed_everything was called beforehand.
# This means that changing the seed of the experiment will also change the
# sampling order.
if generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator().manual_seed(seed)
kwargs.pop('shuffle', None)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
self.restarting = False
def state_dict(self):
return {'random_state': self.generator.get_state(),
'counter': self.counter}
def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get('random_state'))
self.counter = state_dict['counter']
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
def __iter__(self) -> typing.Iterator[int]:
n = len(self.data_source)
self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
for index in indices:
self.counter += 1
yield index
self.counter = 0
class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0
self.restarting = False
def state_dict(self):
return {'epoch': self.epoch, 'counter': self.counter}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.counter = state_dict['counter']
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(
padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter:]
self.restarting = False
for index in indices:
self.counter += 1
yield index
self.counter = 0
================================================
FILE: discrete_diffusion_harness.py
================================================
import torch
from omegaconf import OmegaConf
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from datasets import Dataset
from tqdm import tqdm
import numpy as np
import algo
import dataloader
"""
# Instructions on running the eval
## What is the script doing?
The script evaluates (or approximates) the log-likelihood of
prefix + suffix. The script will load a checkpoint, and
depending on the config, use the corresponding class
(e.g. mdlm, duo, ar, etc).
- For MCQ, the log-likelihood of all continuations given the
same prefix is evaluated. The most likely continuation is
selected as the "correct" answer according to the model.
- For lambada_openai, the model is correct if the true
continuation is generated as the argmax. For diffusion, we
inject noise in the continuation, and check whether the true
answer computed in a single forward pass is the most likely.
This is naturally favoring AR models, as they run one forward
pass per token, while diffusion use a single pass for all
tokens for simplicity. Therefore, I usually only compare on
MCQ since it is more fair.
To run the script, you need to install the lm-eval-harness package:
```
pip install git+https://github.com/EleutherAI/lm-evaluation-harness
```
## Tasks that we tested with
**MCQ**: arc_easy, arc_challenge, hellaswag, winogrande,
boolq, openbookqa, race, social_iqa, mathqa, piqa
**Likelihood based**: lambada_openai
## Batch size that fits at the small scale
- boolq -> 64
- openbookqa -> 256
- race -> 32
- social_iqa -> 512
- winogrande -> 64
- mathqa -> 64
- lambada_openai -> 64
- arc_easy -> 256
- arc_challenge -> 256
- hellaswag -> 64
- piqa -> 64
## Important flags
--trust_remote_code -> some datasets execute code when loading
from huggingface. Without this flag,
the script might crash.
--batch_size -> max. num elements to use in parallel
to eval the likelihood (for diffusion).
For simplicity, inputs are NOT padded
and batched for AR, though it should
be fairly easy to add.
--tasks -> one or multiple comma-separated tasks
to evaluate on.
--model_args -> string (without spaces) that contains
the arguments to pass to the evaluator
(stuff like checkpoints path, number
of MC samples to evaluate the
likelihood, path to the sentencepiece
tokenizer, etc).
--output_path -> path to a json file where the evaluation
results will be saved, instead of
only being printed in the terminal
(they will always be printed).
--limit -> debugging flag; limit the number of
examples to a fixed amount, instead
of using the whole dataset
## Example commands
### Run a single task with an AR model
python discrete_diffusion_harness.py \
--tasks arc_easy \
--batch_size 256 \
--model dlm \
--model_args checkpoint_path=/home/username/baselines/ar/1M.ckpt \
--output_path ./harness_results/ar/1M/arc_easy.ckpt
### Run a single task with an MDLM model, and 2048 MC samples to approximate the likelihood
python discrete_diffusion_harness.py \
--tasks arc_easy \
--model dlm \
--batch_size 256 \
--model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \
--output_path ./harness_results/mdlm/1M/arc_easy.ckpt
### Debug with 20 examples only
python discrete_diffusion_harness.py \
--tasks arc_easy \
--model dlm \
--batch_size 256 \
--model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \
--limit 20 \
--output_path ./harness_results/mdlm/1M/arc_easy.ckpt
### Run the benchmarks from "The Diffusion Duality (Chapter 2)":
(Arc-e, Arc-c, HSwag, WinoG, PIQA, MathQA, OQA)
-> just change the checkpoint to evaluate mdlm, ar, or duo
for task_config in "arc_easy 256" "arc_challenge 256" "hellaswag 64" "winogrande 64" "piqa 64" "mathqa 64" "openbookqa 256"; do
task=$(echo $task_config | cut -d' ' -f1)
batch_size=$(echo $task_config | cut -d' ' -f2)
python discrete_diffusion_harness.py \
--batch_size $batch_size \
--tasks $task \
--model dlm \
--model_args checkpoint_path=/path/to/checkpoint.ckpt \
--output_path ./harness_results/duo/$task.json
done
"""
def requests_to_dataset(config, requests, tokenizer, num_proc):
def _tokenize(e):
eos_idx = tokenizer.eos_token_id
bos_idx = tokenizer.bos_token_id
prefix_tokens = tokenizer(e['prefix'],
return_attention_mask=False,
add_special_tokens=False
)['input_ids']
target_tokens = tokenizer(e['target'],
return_attention_mask=False,
add_special_tokens=False
)['input_ids']
prefix_tokens = [bos_idx] + prefix_tokens
target_tokens = target_tokens + [eos_idx]
return {
'prefix_text': e['prefix'],
'target_text': e['target'],
'prefix': prefix_tokens,
'target': target_tokens,
}
ds = []
ds = [{'prefix': req.args[0], 'target': req.args[1]}
for req in requests]
ds = Dataset.from_list(ds)
ds = ds.map(_tokenize, num_proc=num_proc)
ds = ds.with_format('torch')
seq_lenths = [len(x['prefix']) + len(x['target'])
for x in ds]
num_larger = len([x for x in seq_lenths
if x > config.model.length])
if num_larger > 0:
print(f'\033[91mThere are some examples that are longer '
f'than the context length, they will be ignored '
f'during evaluation. Number of such sequences: '
f'{num_larger}\033[0m')
return ds
def _eval_suffix_nll_generators(config, module, prefix,
suffix, batch_size, num_samples, loss_avg_mode):
device = module.device
assert num_samples % batch_size == 0
full_sentence = torch.cat([prefix, suffix], dim=-1
).repeat(batch_size, 1).to(module.device)
all_ts = module._sample_t(num_samples, accum_step=None)
for idx in range(0, num_samples, batch_size):
t = all_ts[idx:idx+batch_size].unsqueeze(-1)
dalpha_t, alpha_t = module.noise(t)
alpha_t = alpha_t.to(device)
sigma = module._sigma_from_alphat(alpha_t)
x0 = full_sentence.to(device)
# Inject noise
xt = module.q_xt(full_sentence, alpha_t).to(device)
if loss_avg_mode == 'full':
pass # nothing to do
elif loss_avg_mode == 'suffix':
xt[:, :len(prefix)] = prefix
# recompute alpha_t based on number of masked tokens,
# for conditioning of the backbone
alpha_t = (xt == x0).float().mean(dim=1)[:, None]
t = module.noise.get_t_for_alpha(alpha_t)
# We need to get dalpha_t for the loss:
alpha_t = alpha_t.to(device)
sigma = module._sigma_from_alphat(alpha_t)
else:
raise ValueError(loss_avg_mode)
yield xt, x0, t, sigma, alpha_t, dalpha_t
def eval_suffix_nll(config, module, prefix, suffix, batch_size,
num_samples, loss_avg_mode):
if config.algo.name in ('mdlm', 'duo', 'duo_base',
'distillation', 'ot-finetune'):
return eval_suffix_nll_diffusion(config, module, prefix,
suffix, batch_size, num_samples, loss_avg_mode)
elif config.algo.name == 'ar':
return eval_suffix_nll_ar(config, module, prefix,
suffix, batch_size, num_samples, loss_avg_mode)
else:
raise ValueError(config.algo.name)
def eval_suffix_nll_ar(config, module, prefix, suffix,
batch_size, num_samples, loss_avg_mode):
x_cat = torch.cat([prefix, suffix[:-1]], dim=-1)
x_cat = x_cat.reshape(1, -1).to(module.device)
with torch.amp.autocast('cuda', dtype=torch.float32):
out = module.backbone(x_cat, sigma=None)
out[:, :, module.mask_index] = module.neg_infinity
suffix_out = out[:, len(prefix) - 1:, :]
suffix_logits = torch.log_softmax(suffix_out, dim=-1)
index = suffix[None, :, None].to(module.device)
nll = torch.gather(-suffix_logits, dim=-1,
index=index).mean()
return float(nll.cpu())
def eval_suffix_nll_diffusion(config, module, prefix, suffix,
batch_size, num_samples, loss_avg_mode):
all_losses = []
generator = _eval_suffix_nll_generators(config, module,
prefix, suffix, batch_size, num_samples, loss_avg_mode)
for xt, x0, t, sigma, alpha_t, dalpha_t in generator:
log_x_theta = module(xt, sigma, labels=None)
token_nll = module.nll_per_token(log_x_theta, xt, x0,
alpha_t, dalpha_t)
if loss_avg_mode == 'full':
loss = float(token_nll.mean())
elif loss_avg_mode == 'suffix':
loss = float(token_nll[:, len(prefix):].mean())
all_losses.append(loss)
return float(np.mean(all_losses))
@register_model("dlm")
class DiscreteDiffusionHarness(LM):
def __init__(self, pretrained="NONE", max_length=1024,
num_mc_samples=1024, batch_size=64, device="cuda",
checkpoint_path=None, num_proc=8, loss_avg_mode='full',
*args, **kwargs):
super().__init__()
# Whether to use the full sequence, or suffix only to
# approximate the NLL. Full should be the correct way.
assert loss_avg_mode in ('full', 'suffix')
ckpt = torch.load(checkpoint_path, map_location='cpu',
weights_only=False)
config = ckpt['hyper_parameters']['config']
# Backfill missing keys into legacy checkpoints
if not hasattr(config.training, 'class_dropout_p'):
OmegaConf.set_struct(config, False)
config.training.class_dropout_p = 0.0
OmegaConf.set_struct(config, True)
self.tokenizer = dataloader.get_tokenizer(config)
if config.algo.name == 'mdlm':
self.model = algo.MDLM(config, self.tokenizer)
elif config.algo.name in ('duo', 'duo_base',
'distillation', 'ot-finetune'):
self.model = algo.DUO_BASE(config, self.tokenizer)
elif config.algo.name == 'ar':
self.model = algo.AR(config, self.tokenizer)
else:
raise ValueError(f'Implement for {config.algo.name}')
self.config = config
self.num_proc = num_proc
self.num_mc_samples = num_mc_samples
self.batch_size = int(batch_size)
self.device = device
self.loss_avg_mode = loss_avg_mode
self.model.load_state_dict(ckpt['state_dict'])
self.model.to(device)
self.model.eval()
def suffix_greedy_prediction(self, prefix, target):
if self.config.algo.name == 'mdlm':
return self._suffix_greedy_prediction_mdlm(prefix,
target)
elif self.config.algo.name in ('duo', 'duo_base',
'distillation', 'ot-finetune'):
return self._suffix_greedy_prediction_duo_base(prefix,
target)
elif self.config.algo.name == 'ar':
return self._suffix_greedy_prediction_ar(prefix, target)
else:
raise ValueError(self.config.algo.name)
def _suffix_greedy_prediction_ar(self, prefix, target):
x_cat = torch.cat([prefix, target[:-1]],
dim=-1).reshape(1, -1).to(self.device)
# Follows generate_samples in AR (algo.py)
out = self.model.backbone(x_cat, sigma=None)
out[:, :, self.model.mask_index] = self.model.neg_infinity
out = out.log_softmax(-1)
preds_suffix = out[:, len(prefix) - 1:, :]
greedy_preds = preds_suffix.argmax(-1).flatten()
return (greedy_preds.cpu() == target).all().item()
def _suffix_greedy_prediction_mdlm(self, prefix, target):
mask_idx = self.model.mask_index
eos_idx = self.tokenizer.eos_token_id
# Note: because of the preprocessing, we know that the
# last token is an eos token.
noisy_target = [mask_idx] * (len(target) - 1) + [eos_idx]
noisy_target = torch.tensor(noisy_target,
device=self.device)
prefix = prefix.to(self.device)
seq = torch.concatenate([prefix, noisy_target],
dim=-1).reshape(1, -1)
sigma = torch.zeros(size=(seq.shape[0], 1),
device=self.device)
logits = self.model(seq, sigma, labels=None)
assert logits.shape[0] == 1
suffix_logits = logits[0, len(prefix):]
target_preds = suffix_logits.argmax(-1).cpu()
correct = target_preds == target
correct = correct.all()
return bool(correct)
def _suffix_greedy_prediction_duo_base(self, prefix, target):
noisy_suffix = torch.randint(
0, self.model.vocab_size, size=target.shape,
dtype=target.dtype, device=self.device)
prefix = prefix.to(self.device)
# shape: (1, len)
noisy_seq = torch.concatenate([prefix, noisy_suffix])[None, :]
# Set the EOS token at the end
noisy_seq[0, -1] = target[-1]
clean_seq = torch.concatenate([prefix, target.to(self.device)])[None, :]
alpha_t = (noisy_seq == clean_seq).float().mean(dim=1)[:, None]
sigma = self.model._sigma_from_alphat(alpha_t)
t = self.model.noise.get_t_for_alpha(alpha_t)
logits = self.model(noisy_seq, sigma, labels=None)
assert logits.shape[0] == 1
suffix_logits = logits[0, len(prefix):]
target_preds = suffix_logits.argmax(-1).cpu()
correct = target_preds == target
correct = correct.all()
return bool(correct)
@torch.no_grad()
def loglikelihood(self, requests: list[Instance]) \
-> list[tuple[float, bool]]:
dataset = requests_to_dataset(self.config, requests,
self.tokenizer, self.num_proc)
out = []
for elem in tqdm(dataset, 'Computing likelihood...'):
prefix = elem['prefix']
target = elem['target']
if len(prefix) + len(target) > self.model.config.model.length:
# If the request is too long, skip it.
ll = 0.0
is_target_greedy_dec = False
out.append((ll, is_target_greedy_dec))
print("SKIPPING")
continue
ll = -eval_suffix_nll(self.config, self.model, prefix,
target, self.batch_size, self.num_mc_samples,
self.loss_avg_mode)
is_target_greedy_dec = self.suffix_greedy_prediction(
prefix, target)
out.append((ll, bool(is_target_greedy_dec)))
return out
def loglikelihood_rolling(
self, requests: list[Instance]
) -> list[tuple[float, bool]]:
raise NotImplementedError
def generate_until(self, context, max_length, stop,
**generation_kwargs):
raise NotImplementedError
if __name__ == "__main__":
cli_evaluate()
================================================
FILE: main.py
================================================
import json
import os
import fsspec
import hydra
import lightning as L
from lightning.fabric import Fabric
import omegaconf
import rich.syntax
import rich.tree
import torch
from torch.utils.data.distributed import DistributedSampler
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from tqdm import tqdm, trange
import algo
import dataloader
import utils
omegaconf.OmegaConf.register_new_resolver(
'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
'div_up', lambda x, y: (x + y - 1) // y)
def _load_from_checkpoint(diffusion_model, config, tokenizer):
if 'hf' in config.algo.backbone:
return diffusion_model(
config, tokenizer=tokenizer).to('cuda')
return diffusion_model.load_from_checkpoint(
config.eval.checkpoint_path,
tokenizer=tokenizer,
config=config)
@L.pytorch.utilities.rank_zero_only
def _print_config(
config: omegaconf.DictConfig,
resolve: bool = True,
save_cfg: bool = True) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool): Whether to resolve reference fields of DictConfig.
save_cfg (bool): Whether to save the configuration tree to a file.
"""
style = 'dim'
tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, omegaconf.DictConfig):
branch_content = omegaconf.OmegaConf.to_yaml(
config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
rich.print(tree)
if save_cfg:
with fsspec.open(
'{}/config_tree.txt'.format(
config.checkpointing.save_dir), 'w') as fp:
rich.print(tree, file=fp)
@L.pytorch.utilities.rank_zero_only
def _print_batch(config, train_ds, valid_ds, tokenizer, k=64):
for dl_type, dl in [
('train', train_ds), ('valid', valid_ds)]:
print(f'Printing {dl_type} dataloader batch.')
batch = next(iter(dl))
print('Batch input_ids.shape', batch['input_ids'].shape)
if config.data.modality == 'text':
first = batch['input_ids'][0, :k]
last = batch['input_ids'][0, -k:]
print(f'First {k} tokens:', tokenizer.decode(first))
print('ids:', first)
print(f'Last {k} tokens:', tokenizer.decode(last))
print('ids:', last)
def _generate_samples(diffusion_model, config, logger,
tokenizer):
logger.info('Starting Sample Eval.')
model = _load_from_checkpoint(
diffusion_model=diffusion_model,
config=config,
tokenizer=tokenizer)
model.metrics.gen_ppl.reset()
model.metrics.sample_entropy.reset()
if config.eval.disable_ema:
logger.info('Disabling EMA.')
model.ema = None
stride_length = config.sampling.stride_length
num_strides = config.sampling.num_strides
all_samples = []
for _ in trange(config.sampling.num_sample_batches):
if config.sampling.semi_ar:
_, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
stride_length=stride_length,
num_strides=num_strides,
dt=1 / config.sampling.steps)
text_samples = intermediate_samples[-1]
# Note: Samples generated using semi-ar method
# need to to be processed before computing generative perplexity
# since these samples contain numerous <|endoftext|> tokens
# and diffusion.compute_generative_perplexity() discards
# any text after the first EOS token.
else:
samples = model.restore_model_and_sample(
num_steps=config.sampling.steps)
model.metrics.record_entropy(samples)
text_samples = model.tokenizer.batch_decode(samples)
model.metrics.record_generative_perplexity(
text_samples, config.model.length, model.device)
all_samples.extend(list(text_samples))
generative_ppl = 0.
entropy = 0.
if not config.sampling.semi_ar:
generative_ppl = model.metrics.gen_ppl.compute().item()
entropy = model.metrics.sample_entropy.compute().item()
logger.info(f'Generative perplexity: {generative_ppl}')
logger.info(f'Sample entropy: {entropy}')
samples_path = config.eval.generated_samples_path
with fsspec.open(samples_path, 'w') as f:
json.dump({'generative_ppl': generative_ppl,
'entropy': entropy,
'generated_seqs': all_samples}, f, indent=4)
logger.info(f'Samples saved at: {samples_path}',)
def _eval_ppl(diffusion_model, config, logger, tokenizer):
logger.info('Starting Perplexity Eval.')
model = _load_from_checkpoint(
diffusion_model=diffusion_model,
config=config,
tokenizer=tokenizer)
if config.eval.disable_ema:
logger.info('Disabling EMA.')
model.ema = None
wandb_logger = None
if config.get('wandb', None) is not None:
wandb_logger = L.pytorch.loggers.WandbLogger(
config=omegaconf.OmegaConf.to_object(config),
** config.wandb)
callbacks = []
if 'callbacks' in config:
for _, callback in config.callbacks.items():
callbacks.append(hydra.utils.instantiate(callback))
trainer = hydra.utils.instantiate(
config.trainer,
default_root_dir=os.getcwd(),
callbacks=callbacks,
strategy=hydra.utils.instantiate(config.strategy),
logger=wandb_logger)
_, valid_ds = dataloader.get_dataloaders(
config, tokenizer, skip_train=True, valid_seed=config.seed)
trainer.validate(model, valid_ds)
def _train(diffusion_model, config, logger, tokenizer):
logger.info('Starting Training.')
wandb_logger = None
if config.get('wandb', None) is not None:
wandb_logger = L.pytorch.loggers.WandbLogger(
config=omegaconf.OmegaConf.to_object(config),
**config.wandb)
if (config.checkpointing.resume_from_ckpt
and config.checkpointing.resume_ckpt_path is not None
and utils.fsspec_exists(
config.checkpointing.resume_ckpt_path)):
ckpt_path = config.checkpointing.resume_ckpt_path
else:
ckpt_path = None
# Lightning callbacks
callbacks = []
if 'callbacks' in config:
for _, callback in config.callbacks.items():
callbacks.append(hydra.utils.instantiate(callback))
train_ds, valid_ds = dataloader.get_dataloaders(
config, tokenizer)
_print_batch(config, train_ds, valid_ds, tokenizer)
if config.training.finetune_path != '':
assert utils.fsspec_exists(config.training.finetune_path)
model = diffusion_model.load_from_checkpoint(
config.training.finetune_path,
tokenizer=tokenizer,
config=config)
else:
model = diffusion_model(config, tokenizer=valid_ds.tokenizer)
trainer = hydra.utils.instantiate(
config.trainer,
default_root_dir=os.getcwd(),
callbacks=callbacks,
strategy=hydra.utils.instantiate(config.strategy),
logger=wandb_logger)
trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)
def _eval_fid(diffusion_model, config, logger, tokenizer):
logger.info('Preparing data and model for FID eval.')
fabric = Fabric(accelerator=config.trainer.accelerator,
devices=config.trainer.devices,
num_nodes=config.trainer.num_nodes)
fabric.launch()
seed = config.seed + fabric.global_rank
L.seed_everything(seed)
print(f'(Rank {fabric.global_rank}): seed: {seed}')
model = _load_from_checkpoint(
diffusion_model=diffusion_model,
config=config,
tokenizer=tokenizer)
model.to(fabric.device)
if config.eval.disable_ema:
logger.info('Disabling EMA.')
model.ema = None
model._eval_mode()
assert config.data.train == 'cifar10', \
'FID eval only implemented for CIFAR-10'
# Like in flow matching papers: FID against train
loader, _ = dataloader.get_dataloaders(config,
tokenizer=tokenizer, skip_valid=True)
sampler = DistributedSampler(
loader.dataset,
num_replicas=fabric.world_size,
rank=fabric.global_rank,
shuffle=False)
loader = torch.utils.data.DataLoader(
loader.dataset,
batch_size=config.loader.eval_batch_size,
sampler=sampler,
num_workers=loader.num_workers if hasattr(loader, 'num_workers') else 0,
pin_memory=getattr(loader, 'pin_memory', False))
# Check each GPU must generate the same number of images
assert len(loader) == len(loader.dataset) // loader.batch_size // fabric.world_size, \
f'{len(loader)=}, {len(loader.dataset)=}, {loader.batch_size=}, {fabric.world_size=}'
fid_calculator = FrechetInceptionDistance(
normalize=False).to(fabric.device)
is_calculator = InceptionScore(
normalize=False).to(fabric.device)
desc = f'(Rank {fabric.global_rank}) Sampling...'
for batch in tqdm(loader, desc=desc):
real_samples = batch['input_ids']
# Generate images with labels matching the true data
labels = batch['labels']
gen_samples = model.generate_samples(
num_samples=real_samples.shape[0],
num_steps=config.sampling.steps,
labels=labels)
# Reshape 1D seq -> 2D image
gen_samples = model.tokenizer.batch_decode(gen_samples)
real_samples = model.tokenizer.batch_decode(
real_samples).to(fabric.device)
fid_calculator.update(gen_samples, real=False)
fid_calculator.update(real_samples, real=True)
is_calculator.update(gen_samples)
fabric.barrier()
logger.info('Done sampling. Computing FID & IS...')
fid = fid_calculator.compute()
incep_score = is_calculator.compute()
if fabric.global_rank == 0:
logger.info(f'FID: {fid}')
logger.info(f'IS: {incep_score}')
fabric.barrier()
@hydra.main(version_base=None, config_path='configs',
config_name='config')
def main(config):
"""Main entry point for training."""
L.seed_everything(config.seed)
_print_config(config, resolve=True, save_cfg=True)
logger = utils.get_logger(__name__)
tokenizer = dataloader.get_tokenizer(config)
if config.algo.name == 'ar':
diffusion_model = algo.AR
elif config.algo.name == 'mdlm':
diffusion_model = algo.MDLM
elif config.algo.name == 'duo_base':
diffusion_model = algo.DUO_BASE
elif config.algo.name == 'd3pm':
diffusion_model = algo.D3PMAbsorb
elif config.algo.name == 'sedd':
diffusion_model = algo.SEDDAbsorb
elif config.algo.name == 'duo':
diffusion_model = algo.DUO
elif config.algo.name == 'distillation':
diffusion_model = algo.Distillation
elif config.algo.name == 'ot-finetune':
diffusion_model = algo.OptimalTransportFinetune
else:
raise ValueError(
f'Invalid algorithm name: {config.algo.name}')
kwargs = {'diffusion_model': diffusion_model,
'config': config,
'tokenizer': tokenizer,
'logger': logger}
if config.mode == 'sample_eval':
_generate_samples(**kwargs)
elif config.mode == 'ppl_eval':
_eval_ppl(**kwargs)
elif config.mode == 'fid_eval':
_eval_fid(**kwargs)
else:
_train(**kwargs)
if __name__ == '__main__':
main()
================================================
FILE: metrics.py
================================================
import math
import os
import typing
import torch
import torch.nn.functional as F
import torchmetrics
import transformers
LOG2 = math.log(2)
class NLL(torchmetrics.aggregation.MeanMetric):
def update(self,
value:typing.Union[float, torch.Tensor],
weight:typing.Union[float, torch.Tensor]=1.0) -> None:
"""Update state with data.
Args:
value: Either a float or tensor containing data.
Additional tensor dimensions will be flattened
weight: Either a float or tensor containing weights
for calculating the average. Shape of weight should
be able to broadcast with the shape of `value`.
Default to `1.0` corresponding to simple harmonic
average.
"""
# broadcast weight to value shape
if not isinstance(value, torch.Tensor):
value = torch.as_tensor(value, dtype=self.dtype,
device=self.device)
if (weight is not None
and not isinstance(weight, torch.Tensor)):
weight = torch.as_tensor(weight,
dtype=self.dtype,
device=self.device)
weight = torch.broadcast_to(weight, value.shape)
value, weight = self._cast_and_nan_check_input(value,
weight)
if value.numel() == 0:
return
self.mean_value += value.sum()
self.weight += weight.sum()
class BPD(NLL):
def compute(self) -> torch.Tensor:
"""Computes the bits per dimension.
Returns:
bpd
"""
return self.mean_value / self.weight / LOG2
class Perplexity(NLL):
def compute(self) -> torch.Tensor:
"""Computes the Perplexity.
Returns:
Perplexity
"""
return torch.exp(self.mean_value / self.weight)
class Metrics:
def __init__(self, gen_ppl_eval_model_name_or_path=None,
eval_ppl_batch_size=None) -> None:
metrics = torchmetrics.MetricCollection({
'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity()})
metrics.set_dtype(torch.float64)
self.train_nlls = metrics.clone(prefix='train/')
self.train_aux = BPD()
self.valid_nlls = metrics.clone(prefix='val/')
self.valid_aux = BPD()
self.gen_ppl = Perplexity()
self.sample_entropy = torchmetrics.aggregation.MeanMetric()
self.eval_ppl_batch_size = eval_ppl_batch_size
self.gen_ppl_eval_model_name_or_path = gen_ppl_eval_model_name_or_path
self.tokenizer = transformers.AutoTokenizer.\
from_pretrained(gen_ppl_eval_model_name_or_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def to(self, *args, **kwargs):
self.gen_ppl = self.gen_ppl.to(*args, **kwargs)
self.sample_entropy = self.sample_entropy.to(*args, **kwargs)
self.train_nlls = self.train_nlls.to(*args, **kwargs)
self.train_aux = self.train_aux.to(*args, **kwargs)
self.valid_nlls = self.valid_nlls.to(*args, **kwargs)
self.valid_aux = self.valid_aux.to(*args, **kwargs)
def reset(self):
self.gen_ppl.reset()
self.sample_entropy.reset()
self.train_nlls.reset()
self.train_aux.reset()
self.valid_nlls.reset()
self.valid_aux.reset()
def update_train(self, nll, aux_loss, num_tokens):
self.train_nlls.update(nll, num_tokens)
self.train_aux.update(aux_loss, num_tokens)
def update_valid(self, nll, aux_loss, num_tokens):
self.valid_nlls.update(nll, num_tokens)
self.valid_aux.update(aux_loss, num_tokens)
@torch.no_grad()
def _eval_retokenize(self, text_samples, max_length,
device):
"""Retokenizes samples for the eval model.
Args:
text_samples: List of sentences generated by the model.
Returns:
samples: Samples re-tokenized for the eval model
attn_mask: Attention mask for the eval model
eval_context_size: Size of the context for the eval model
"""
if 'llama2' in self.gen_ppl_eval_model_name_or_path:
tokenizer_kwargs = {
'text_samples': text_samples,
'return_tensors': 'pt',
'return_token_type_ids': False,
'return_attention_mask': True,
'truncation': True,
'padding': True,
'max_length': max_length,
}
eval_context_size = 4096
else:
tokenizer_kwargs = {
'return_tensors': 'pt',
'return_token_type_ids': False,
'return_attention_mask': True,
'truncation': True,
'padding': True,
'max_length': max_length,
}
eval_context_size = 1024
samples = self.tokenizer(text_samples,
**tokenizer_kwargs)
attn_mask = samples['attention_mask']
samples = samples['input_ids']
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
attn_mask = attn_mask.to(device)
samples = samples.to(device)
return samples, attn_mask, eval_context_size
@torch.no_grad()
def record_entropy(self, tokens):
for sample in tokens:
_, counts = torch.unique(
sample, return_counts=True, sorted=False)
entropy = torch.special.entr(
counts.float() / counts.sum()).sum().item()
self.sample_entropy.update(entropy)
@torch.no_grad()
def record_generative_perplexity(
self,
text_samples: typing.List[str],
max_length: int,
retokenize: bool = True,
device='cuda') -> None:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
eval_model = transformers.AutoModelForCausalLM.from_pretrained(
self.gen_ppl_eval_model_name_or_path).eval()
if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
eval_model = eval_model.to(device)
# Re-tokenize using eval model's tokenizer
if retokenize:
(samples, attn_mask,
eval_context_size) = self._eval_retokenize(
text_samples, max_length=max_length, device=device)
else:
samples = text_samples
attn_mask = torch.ones(samples.shape).to(device)
eval_context_size = samples.shape[-1]
batch_size = min(self.eval_ppl_batch_size,
samples.shape[0])
num_batches = samples.shape[0] // batch_size
for i in range(num_batches):
_samples = torch.split(
samples[i * batch_size: (i + 1) * batch_size],
eval_context_size,
dim=-1)
_attn_mask = torch.split(
attn_mask[i * batch_size: (i + 1) * batch_size],
eval_context_size,
dim=-1)
for (sample_chunk, attn_mask_chunk) in zip(_samples,
_attn_mask):
logits = eval_model(sample_chunk,
attention_mask=attn_mask_chunk)
logits = logits[0].transpose(-1, -2)
nlls = F.cross_entropy(logits[..., :-1],
sample_chunk[..., 1:],
reduction='none')
first_eos = (
sample_chunk
== self.tokenizer.eos_token_id).cumsum(-1) == 1
token_mask = sample_chunk != self.tokenizer.eos_token_id
valid_tokens = first_eos[..., 1:] + token_mask[..., 1:]
self.gen_ppl.update(nlls * valid_tokens, valid_tokens)
================================================
FILE: models/__init__.py
================================================
from . import dit
from . import ema
from . import unet
================================================
FILE: models/dit.py
================================================
import math
import typing
import einops
import flash_attn
import flash_attn.layers.rotary
import huggingface_hub
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
# Flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
def bias_dropout_add_scale(
x: torch.Tensor,
bias: typing.Optional[torch.Tensor],
scale: torch.Tensor,
residual: typing.Optional[torch.Tensor],
prob: float,
training: bool) -> torch.Tensor:
if bias is not None:
out = scale * F.dropout(x + bias, p=prob, training=training)
else:
out = scale * F.dropout(x, p=prob, training=training)
if residual is not None:
out = residual + out
return out
def get_bias_dropout_add_scale(training):
def _bias_dropout_add(x, bias, scale, residual, prob):
return bias_dropout_add_scale(
x, bias, scale, residual, prob, training)
return _bias_dropout_add
# function overload
def modulate(x: torch.Tensor,
shift: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
return x * (1 + scale) + shift
@torch.jit.script
def bias_dropout_add_scale_fused_train(
x: torch.Tensor,
bias: typing.Optional[torch.Tensor],
scale: torch.Tensor,
residual: typing.Optional[torch.Tensor],
prob: float) -> torch.Tensor:
return bias_dropout_add_scale(
x, bias, scale, residual, prob, True)
@torch.jit.script
def bias_dropout_add_scale_fused_inference(
x: torch.Tensor,
bias: typing.Optional[torch.Tensor],
scale: torch.Tensor,
residual: typing.Optional[torch.Tensor],
prob: float) -> torch.Tensor:
return bias_dropout_add_scale(
x, bias, scale, residual, prob, False)
@torch.jit.script
def modulate_fused(x: torch.Tensor,
shift: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
return modulate(x, shift, scale)
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10_000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
# dims are: batch, seq_len, qkv, head, dim
self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
# This makes the transformation on v an identity.
self.cos_cached[:,:,2,:,:].fill_(1.)
self.sin_cached[:,:,2,:,:].fill_(0.)
return self.cos_cached, self.sin_cached
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
with torch.amp.autocast('cuda', enabled=False):
cos, sin = rotary_cos_sin
cos = cos.to(qkv.dtype)
sin = sin.to(qkv.dtype)
cos = cos[0,:,0,0,:cos.shape[-1]//2]
sin = sin[0,:,0,0,:sin.shape[-1]//2]
q, k, v = qkv.chunk(3, dim=2)
q = flash_attn.layers.rotary.apply_rotary_emb_torch(
q.squeeze(dim=2), cos, sin)
k = flash_attn.layers.rotary.apply_rotary_emb_torch(
k.squeeze(dim=2), cos, sin)
v = v.squeeze(dim=2)
return q, k, v
def apply_rotary_pos_emb(qkv, cos, sin):
cos = cos[0,:,0,0,:cos.shape[-1]//2]
sin = sin[0,:,0,0,:sin.shape[-1]//2]
return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)
def regular_attention_multi_headed(q, k, v):
# Assuming qkv is a tensor with shape [batch, seq_len, 3, num_heads, head_dim]
# where the 3 represents Q, K, V packed in that order
attention_output = F.scaled_dot_product_attention(
query=q.transpose(1, 2),
key=k.transpose(1, 2),
value=v.transpose(1, 2),
attn_mask=None,
dropout_p=0.0,
is_causal=False)
# [batch_size, seq_len, num_heads, head_dim]
attention_output = attention_output.transpose(1, 2)
return einops.rearrange(attention_output, 'b s h d -> b s (h d)')
#################################################################################
# Layers #
#################################################################################
class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones([dim]))
self.dim = dim
def forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
x = F.layer_norm(x.float(), [self.dim])
return x * self.weight[None, None, :]
def residual_linear(x, W, x_skip, residual_scale):
"""x_skip + residual_scale * W @ x"""
dim_out, dim_in = W.shape[0], W.shape[1]
return torch.addmm(
x_skip.view(-1, dim_out),
x.view(-1, dim_in),
W.T,
alpha=residual_scale).view(*x.shape[:-1], dim_out)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True))
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
- math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding,
torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
"""Embeds class labels into vector representations.
Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, cond_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
self.num_classes = num_classes
# TODO think of initializing with 0.02 std deviation like in original DiT paper
def forward(self, labels):
embeddings = self.embedding_table(labels)
return embeddings
#################################################################################
# Core Model #
#################################################################################
class DDiTBlockCausal(nn.Module):
def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.norm1 = LayerNorm(dim)
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.dropout1 = nn.Dropout(dropout)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(mlp_ratio * dim, dim, bias=True))
self.dropout2 = nn.Dropout(dropout)
self.dropout = dropout
def _get_bias_dropout_scale(self):
if self.training:
return bias_dropout_add_scale_fused_train
else:
return bias_dropout_add_scale_fused_inference
def forward(self, x, rotary_cos_sin, **kwargs):
del kwargs
batch_size, seq_len = x.shape[0], x.shape[1]
bias_dropout_scale_fn = self._get_bias_dropout_scale()
# attention operation
x_skip = x
x = self.norm1(x)
qkv = self.attn_qkv(x)
qkv = einops.rearrange(
qkv,
'b s (three h d) -> b s three h d',
three=3,
h=self.n_heads)
with torch.amp.autocast('cuda', enabled=False):
cos, sin = rotary_cos_sin
qkv = apply_rotary_pos_emb(
qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
)
qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...')
cu_seqlens = torch.arange(
0, (batch_size + 1) * seq_len,
step=seq_len, dtype=torch.int32, device=qkv.device)
x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, seq_len, 0.0, causal=True)
x = einops.rearrange(x, '(b s) h d -> b s (h d)',
b=batch_size)
scale = torch.ones(1, device=x.device, dtype=x.dtype)
x = bias_dropout_scale_fn(
self.attn_out(x), None, scale, x_skip, self.dropout)
# mlp operation
x = bias_dropout_scale_fn(
self.mlp(self.norm2(x)), None, scale, x, self.dropout)
return x
class DDiTBlock(nn.Module):
def __init__(self, dim, n_heads, adaLN,
cond_dim=None, mlp_ratio=4,
dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.adaLN = adaLN
self.norm1 = LayerNorm(dim)
self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.dropout1 = nn.Dropout(dropout)
self.norm2 = LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_ratio * dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(mlp_ratio * dim, dim, bias=True))
self.dropout2 = nn.Dropout(dropout)
self.dropout = dropout
if self.adaLN:
self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def _get_bias_dropout_scale(self):
if self.training:
return bias_dropout_add_scale_fused_train
else:
return bias_dropout_add_scale_fused_inference
def forward(self, x, rotary_cos_sin, c=None):
bias_dropout_scale_fn = self._get_bias_dropout_scale()
x_skip = x
x = self.norm1(x)
if self.adaLN:
# self.adaLN_modulation(c): (128, 1536)
# self.adaLN_modulation(c)[:, None]: (128, 1, 1536)
# "" .chunk(6, dim=2) returns 6 tuples of shapes (128, 1, 256)
(shift_msa, scale_msa, gate_msa, shift_mlp,
scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
x = modulate_fused(x, shift_msa, scale_msa)
qkv = einops.rearrange(
self.attn_qkv(x),
'b s (three h d) -> b s three h d',
three=3,
h=self.n_heads)
q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
x = regular_attention_multi_headed(q, k, v)
if self.adaLN:
x = bias_dropout_scale_fn(self.attn_out(x),
None,
gate_msa,
x_skip,
self.dropout)
x = bias_dropout_scale_fn(
self.mlp(modulate_fused(
self.norm2(x), shift_mlp, scale_mlp)),
None, gate_mlp, x, self.dropout)
else:
scale = torch.ones(1, device=x.device, dtype=x.dtype)
x = bias_dropout_scale_fn(
self.attn_out(x), None, scale, x_skip, self.dropout)
x = bias_dropout_scale_fn(
self.mlp(self.norm2(x)), None, scale, x, self.dropout)
return x
class EmbeddingLayer(nn.Module):
def __init__(self, dim, vocab_dim):
super().__init__()
self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))
def forward(self, x, weights=None):
if weights is not None:
bs, seq_len, k = x.shape
flat_x = x.reshape(-1, k)
flat_w = weights.reshape(-1, k).float()
bag = F.embedding_bag(flat_x, self.embedding.float(),
per_sample_weights=flat_w,
mode='sum')
return bag.view(bs, seq_len, -1)
elif x.ndim == 2:
return self.embedding[x]
assert x.ndim == 3
return torch.einsum(
"blv,ve->ble",
torch.nn.functional.softmax(x, dim=-1).float(),
self.embedding.float()).to(x.dtype)
class DDiTFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, cond_dim,
adaLN):
super().__init__()
self.norm_final = LayerNorm(hidden_size)
self.linear = nn.Linear(hidden_size, out_channels)
self.linear.weight.data.zero_()
self.linear.bias.data.zero_()
self.adaLN = adaLN
if self.adaLN:
self.adaLN_modulation = nn.Linear(cond_dim,
2 * hidden_size,
bias=True)
self.adaLN_modulation.weight.data.zero_()
self.adaLN_modulation.bias.data.zero_()
def forward(self, x, c):
x = self.norm_final(x)
if self.adaLN:
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
x = modulate_fused(x, shift, scale)
x = self.linear(x)
return x
class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
def __init__(self, config, vocab_size: int):
super().__init__()
if type(config) == dict:
config = omegaconf.OmegaConf.create(config)
self.causal = config.algo.causal_attention
self.adaLN = not self.causal
self.config = config
self.vocab_size = vocab_size
dim = config.model.hidden_size
cond_dim = config.model.cond_dim
self.vocab_embed = EmbeddingLayer(dim, vocab_size)
if not self.causal:
self.sigma_map = TimestepEmbedder(cond_dim)
self.rotary_emb = Rotary(dim // config.model.n_heads)
blocks = []
for _ in range(config.model.n_blocks):
if self.causal:
block = DDiTBlockCausal(
dim=dim,
n_heads=config.model.n_heads,
dropout=config.model.dropout)
else:
block = DDiTBlock(
dim=dim,
n_heads=config.model.n_heads,
cond_dim=cond_dim,
adaLN=self.adaLN,
dropout=config.model.dropout)
blocks.append(block)
self.blocks = nn.ModuleList(blocks)
self.output_layer = DDiTFinalLayer(
hidden_size=dim,
out_channels=vocab_size,
cond_dim=cond_dim,
adaLN=self.adaLN)
self.scale_by_sigma = config.model.scale_by_sigma
def _get_bias_dropout_scale(self):
if self.training:
return bias_dropout_add_scale_fused_train
else:
return bias_dropout_add_scale_fused_inference
def forward(self, x, sigma, class_cond=None, weights=None):
assert class_cond is None, 'Not implemented for DiT'
x = self.vocab_embed(x, weights)
if self.causal:
t_cond = None
else:
t_cond = F.silu(self.sigma_map(sigma))
rotary_cos_sin = self.rotary_emb(x)
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
for i in range(len(self.blocks)):
x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
x = self.output_layer(x, c=t_cond)
return x
================================================
FILE: models/ema.py
================================================
import torch
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
"""
def __init__(self, parameters, decay, use_num_updates=True):
"""
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the result of
`model.parameters()`.
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
def move_shadow_params_to_device(self, device):
self.shadow_params = [i.to(device) for i in self.shadow_params]
def update(self, parameters):
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object.
"""
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) /
(10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
s_param.sub_(one_minus_decay * (s_param - param))
def copy_to(self, parameters):
"""
Copy current parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def state_dict(self):
return dict(decay=self.decay,
num_updates=self.num_updates,
shadow_params=self.shadow_params)
def load_state_dict(self, state_dict):
self.decay = state_dict['decay']
self.num_updates = state_dict['num_updates']
self.shadow_params = state_dict['shadow_params']
================================================
FILE: models/unet.py
================================================
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import numpy as np
import omegaconf
import transformers
from einops import rearrange
from .dit import LabelEmbedder, EmbeddingLayer
# From https://github.com/yang-song/score_sde_pytorch/ which is from
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
def transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
# Code modified from https://github.com/yang-song/score_sde_pytorch
def variance_scaling(scale, mode, distribution,
in_axis=1, out_axis=0,
dtype=torch.float32,
device='cpu'):
"""Ported from JAX. """
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError(
"invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init
def default_init(scale=1.):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, 'fan_avg', 'uniform')
class NiN(nn.Module):
def __init__(self, in_ch, out_ch, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(default_init(scale=init_scale)((in_ch, out_ch)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(out_ch), requires_grad=True)
def forward(self, x, # ["batch", "in_ch", "H", "W"]
):
x = x.permute(0, 2, 3, 1)
# x (batch, H, W, in_ch)
y = torch.einsum('bhwi,ik->bhwk', x, self.W) + self.b
# y (batch, H, W, out_ch)
return y.permute(0, 3, 1, 2)
class AttnBlock(nn.Module):
"""Channel-wise self-attention block."""
def __init__(self, channels, skip_rescale=True):
super().__init__()
self.skip_rescale = skip_rescale
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels//4, 32),
num_channels=channels, eps=1e-6)
self.NIN_0 = NiN(channels, channels)
self.NIN_1 = NiN(channels, channels)
self.NIN_2 = NiN(channels, channels)
self.NIN_3 = NiN(channels, channels, init_scale=0.)
def forward(self, x, # ["batch", "channels", "H", "W"]
):
B, C, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, H, W, H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, H, W, H, W))
h = torch.einsum('bhwij,bcij->bchw', w, v)
h = self.NIN_3(h)
if self.skip_rescale:
return (x + h) / np.sqrt(2.)
else:
return x + h
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_rescale=True):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.skip_rescale = skip_rescale
self.act = nn.functional.silu
self.groupnorm0 = nn.GroupNorm(
num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6
)
self.conv0 = nn.Conv2d(
in_ch, out_ch, kernel_size=3, padding=1
)
if temb_dim is not None:
self.dense0 = nn.Linear(temb_dim, out_ch)
nn.init.zeros_(self.dense0.bias)
self.groupnorm1 = nn.GroupNorm(
num_groups=min(out_ch // 4, 32),
num_channels=out_ch, eps=1e-6
)
self.dropout0 = nn.Dropout(dropout)
self.conv1 = nn.Conv2d(
out_ch, out_ch, kernel_size=3, padding=1
)
if out_ch != in_ch:
self.nin = NiN(in_ch, out_ch)
def forward(self, x, # ["batch", "in_ch", "H", "W"]
temb=None, # ["batch", "temb_dim"]
):
assert x.shape[1] == self.in_ch
h = self.groupnorm0(x)
h = self.act(h)
h = self.conv0(h)
if temb is not None:
h += self.dense0(self.act(temb))[:, :, None, None]
h = self.groupnorm1(h)
h = self.act(h)
h = self.dropout0(h)
h = self.conv1(h)
if h.shape[1] != self.in_ch:
x = self.nin(x)
assert x.shape == h.shape
if self.skip_rescale:
return (x + h) / np.sqrt(2.)
else:
return x + h
class Downsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3,
stride=2, padding=0)
def forward(self, x, # ["batch", "ch", "inH", "inW"]
):
B, C, H, W = x.shape
x = nn.functional.pad(x, (0, 1, 0, 1))
x= self.conv(x)
assert x.shape == (B, C, H // 2, W // 2)
return x
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x, # ["batch", "ch", "inH", "inW"]
):
B, C, H, W = x.shape
h = F.interpolate(x, (H*2, W*2), mode='nearest')
h = self.conv(h)
assert h.shape == (B, C, H*2, W*2)
return h
class UNet(nn.Module):
def __init__(self, config, vocab_size=None):
super().__init__()
if type(config) == dict:
config = omegaconf.OmegaConf.create(config)
assert config.model.name == 'unet'
self.ch = config.model.ch
self.num_res_blocks = config.model.num_res_blocks
self.num_scales = config.model.num_scales
self.ch_mult = config.model.ch_mult
assert self.num_scales == len(self.ch_mult)
self.input_channels = config.model.input_channels
self.output_channels = 2 * config.model.input_channels
self.scale_count_to_put_attn = config.model.scale_count_to_put_attn
self.data_min_max = [0, vocab_size] # config.model.data_min_max # tuple of min and max value of input so it can be rescaled to [-1, 1]
self.dropout = config.model.dropout
self.skip_rescale = config.model.skip_rescale
self.time_conditioning = config.model.time_conditioning # Whether to add in time embeddings
self.time_scale_factor = config.model.time_scale_factor # scale to make the range of times be 0 to 1000
self.time_embed_dim = config.model.time_embed_dim
self.vocab_size = vocab_size
self.size = config.model.size
self.length = config.model.length
# truncated logistic
self.fix_logistic = config.model.fix_logistic
self.act = nn.functional.silu
if self.time_conditioning:
self.temb_modules = []
self.temb_modules.append(nn.Linear(self.time_embed_dim, self.time_embed_dim*4))
nn.init.zeros_(self.temb_modules[-1].bias)
self.temb_modules.append(nn.Linear(self.time_embed_dim*4, self.time_embed_dim*4))
nn.init.zeros_(self.temb_modules[-1].bias)
self.temb_modules = nn.ModuleList(self.temb_modules)
self.expanded_time_dim = 4 * self.time_embed_dim if self.time_conditioning else None
self.input_conv = nn.Conv2d(
in_channels=self.input_channels, out_channels=self.ch,
kernel_size=3, padding=1
)
h_cs = [self.ch]
in_ch = self.ch
# Downsampling
self.downsampling_modules = []
for scale_count in range(self.num_scales):
for res_count in range(self.num_res_blocks):
out_ch = self.ch * self.ch_mult[scale_count]
self.downsampling_modules.append(
ResBlock(in_ch, out_ch, temb_dim=self.expanded_time_dim,
dropout=self.dropout, skip_rescale=self.skip_rescale)
)
in_ch = out_ch
h_cs.append(in_ch)
if scale_count == self.scale_count_to_put_attn:
self.downsampling_modules.append(
AttnBlock(in_ch, skip_rescale=self.skip_rescale)
)
if scale_count != self.num_scales - 1:
self.downsampling_modules.append(Downsample(in_ch))
h_cs.append(in_ch)
self.downsampling_modules = nn.ModuleList(self.downsampling_modules)
# Middle
self.middle_modules = []
self.middle_modules.append(
ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim,
dropout=self.dropout, skip_rescale=self.skip_rescale)
)
self.middle_modules.append(
AttnBlock(in_ch, skip_rescale=self.skip_rescale)
)
self.middle_modules.append(
ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim,
dropout=self.dropout, skip_rescale=self.skip_rescale)
)
self.middle_modules = nn.ModuleList(self.middle_modules)
# Upsampling
self.upsampling_modules = []
for scale_count in reversed(range(self.num_scales)):
for res_count in range(self.num_res_blocks+1):
out_ch = self.ch * self.ch_mult[scale_count]
self.upsampling_modules.append(
ResBlock(in_ch + h_cs.pop(),
out_ch,
temb_dim=self.expanded_time_dim,
dropout=self.dropout,
skip_rescale=self.skip_rescale
)
)
in_ch = out_ch
if scale_count == self.scale_count_to_put_attn:
self.upsampling_modules.append(
AttnBlock(in_ch, skip_rescale=self.skip_rescale)
)
if scale_count != 0:
self.upsampling_modules.append(Upsample(in_ch))
self.upsampling_modules = nn.ModuleList(self.upsampling_modules)
assert len(h_cs) == 0
# output
self.output_modules = []
self.output_modules.append(
nn.GroupNorm(min(in_ch//4, 32), in_ch, eps=1e-6)
)
self.output_modules.append(
nn.Conv2d(in_ch, self.output_channels, kernel_size=3, padding=1)
)
self.output_modules = nn.ModuleList(self.output_modules)
self.cond_map = LabelEmbedder(
config.data.num_classes,
self.time_embed_dim*4)
def _center_data(self, x):
out = (x - self.data_min_max[0]) / (self.data_min_max[1] - self.data_min_max[0]) # [0, 1]
return 2 * out - 1 # to put it in [-1, 1]
def _time_embedding(self, timesteps):
if self.time_conditioning:
temb = transformer_timestep_embedding(
timesteps * self.time_scale_factor, self.time_embed_dim
)
temb = self.temb_modules[0](temb)
temb = self.temb_modules[1](self.act(temb))
else:
temb = None
return temb
def _do_input_conv(self, h):
h = self.input_conv(h)
hs = [h]
return h, hs
def _do_downsampling(self, h, hs, temb):
m_idx = 0
for scale_count in range(self.num_scales):
for res_count in range(self.num_res_blocks):
h = self.downsampling_modules[m_idx](h, temb)
m_idx += 1
if scale_count == self.scale_count_to_put_attn:
h = self.downsampling_modules[m_idx](h)
m_idx += 1
hs.append(h)
if scale_count != self.num_scales - 1:
h = self.downsampling_modules[m_idx](h)
hs.append(h)
m_idx += 1
assert m_idx == len(self.downsampling_modules)
return h, hs
def _do_middle(self, h, temb):
m_idx = 0
h = self.middle_modules[m_idx](h, temb)
m_idx += 1
h = self.middle_modules[m_idx](h)
m_idx += 1
h = self.middle_modules[m_idx](h, temb)
m_idx += 1
assert m_idx == len(self.middle_modules)
return h
def _do_upsampling(self, h, hs, temb):
m_idx = 0
for scale_count in reversed(range(self.num_scales)):
for res_count in range(self.num_res_blocks+1):
h = self.upsampling_modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
m_idx += 1
if scale_count == self.scale_count_to_put_attn:
h = self.upsampling_modules[m_idx](h)
m_idx += 1
if scale_count != 0:
h = self.upsampling_modules[m_idx](h)
m_idx += 1
assert len(hs) == 0
assert m_idx == len(self.upsampling_modules)
return h
def _do_output(self, h):
h = self.output_modules[0](h)
h = self.act(h)
h = self.output_modules[1](h)
return h
def _logistic_output_res(self,
h, # ["B", "twoC", "H", "W"]
centered_x_in, # ["B", "C", "H", "W"]
):
B, twoC, H, W = h.shape
C = twoC//2
h[:, 0:C, :, :] = torch.tanh(centered_x_in + h[:, 0:C, :, :])
return h
def _log_minus_exp(self, a, b, eps=1e-6):
"""
Compute log (exp(a) - exp(b)) for (b b c h w", h=img_size,
w=img_size, c=3)
h = self._center_data(h)
centered_x_in = h
temb = self._time_embedding(sigma)
if class_cond is not None:
if self.cond_map is None:
raise ValueError("Conditioning variable provided, "
"but Model was not initialized "
"with condition embedding layer.")
else:
assert class_cond.shape == (x.shape[0],)
temb = temb + self.cond_map(class_cond)
h, hs = self._do_input_conv(h)
h, hs = self._do_downsampling(h, hs, temb)
h = self._do_middle(h, temb)
h = self._do_upsampling(h, hs, temb)
h = self._do_output(h)
# h (B, 2*C, H, W)
h = self._logistic_output_res(h, centered_x_in)
h = self._truncated_logistic_output(h) # (B, D, S)
return h
================================================
FILE: models/unit_test_attention.py
================================================
import unittest
import torch
# from flash_attn import flash_attention
import torch.nn.functional as F
def attention_inner_heads_flash(qkv, num_heads):
"""Computes attention with heads inside of qkv in the channel dimension using FlashAttention.
Args:
qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:
H = number of heads,
C = number of channels per head.
num_heads: number of heads.
Returns:
Attention output of shape (B, H*C, T).
"""
# bs, width, length = qkv.shape
# ch = width // (3 * num_heads)
# # Split into (q, k, v) of shape (B, H*C, T).
# q, k, v = qkv.chunk(3, dim=1)
# # Rescale q and k. This makes them contiguous in memory.
# scale = ch ** (-1 / 4) # scale with 4th root = scaling output by sqrt
# q = q * scale
# k = k * scale
# # Reshape q, k, v to (B, H, T, C) for FlashAttention
# q = q.view(bs, num_heads, ch, length).permute(0, 1, 3, 2) # (B, H, T, C)
# k = k.view(bs, num_heads, ch, length).permute(0, 1, 3, 2) # (B, H, T, C)
# v = v.view(bs, num_heads, ch, length).permute(0, 1, 3, 2) # (B, H, T, C)
# # Compute attention using FlashAttention
# out = flash_attention(q, k, v) # (B, H, T, C)
# # Reshape back to (B, H*C, T)
# out = out.permute(0, 1, 3, 2).reshape(bs, num_heads * ch, length)
# return out
bs, width, length = qkv.shape
ch = width // (3 * num_heads)
# Split into (q, k, v) and reshape directly to (B, H, T, C)
q, k, v = qkv.chunk(3, dim=1)
q = q.view(bs, num_heads, ch, length).transpose(2, 3) # (B, H, T, C)
k = k.view(bs, num_heads, ch, length).transpose(2, 3) # (B, H, T, C)
v = v.view(bs, num_heads, ch, length).transpose(2, 3) # (B, H, T, C)
# Compute scaled dot-product attention
out = F.scaled_dot_product_attention(q, k, v) # (B, H, T, C)
# Reshape back to (B, H*C, T) in one step
out = out.transpose(2, 3).reshape(bs, num_heads * ch, length)
return out
class TestAttentionInnerHeadsFlash(unittest.TestCase):
def setUp(self):
# Set up common test variables
self.batch_size = 2
self.num_heads = 4
self.seq_len = 8
self.channels_per_head = 16
self.qkv_dim = 3 * self.num_heads * self.channels_per_head
# Create a random qkv tensor
self.qkv = torch.randn(self.batch_size, self.qkv_dim, self.seq_len, dtype=torch.float32)
def attention_inner_heads_old(self, qkv, num_heads):
"""Original implementation of attention for reference."""
bs, width, length = qkv.shape
ch = width // (3 * num_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = ch ** (-1 / 4)
q = q * scale
k = k * scale
q = q.view(bs * num_heads, ch, length)
k = k.view(bs * num_heads, ch, length)
v = v.reshape(bs * num_heads, ch, length)
weight = torch.einsum("bct,bcs->bts", q, k)
weight = torch.softmax(weight.float(), dim=-1).to(weight.dtype)
out = torch.einsum("bts,bcs->bct", weight, v)
return out.reshape(bs, num_heads * ch, length)
def test_attention_inner_heads_flash(self):
# Compute the output using the old and flash attention implementations
output_old = self.attention_inner_heads_old(self.qkv, self.num_heads)
output_flash = attention_inner_heads_flash(self.qkv, self.num_heads)
# Verify the shapes are the same
self.assertEqual(output_old.shape, output_flash.shape)
# Verify that the outputs are close (numerically similar)
self.assertTrue(torch.allclose(output_old, output_flash, atol=1e-5))
if __name__ == "__main__":
unittest.main()
================================================
FILE: requirements.txt
================================================
# conda install nvidia/label/cuda-12.4.0::cuda-toolkit
datasets==2.15.0
einops==0.7.0
fsspec
git-lfs==1.6
h5py==3.10.0
hydra-core==1.3.2
ipdb==0.13.13
lightning==2.2.1
notebook==7.1.1
nvitop==1.3.2
omegaconf==2.3.0
packaging==23.2
pandas==2.2.1
rich==13.7.1
seaborn==0.13.2
scikit-learn==1.4.0
transformers==4.38.2
triton==2.2.0
torch==2.3.1
torchaudio==2.3.1
torchmetrics==1.6.1
torchvision==0.18.1
wandb
timm
ocifs
hf_transfer
huggingface-hub
# Install flash attention only after installing the above modules via pip install -r requirements.txt
# flash_attn==2.7.4.post1
================================================
FILE: scripts/distil_owt.sh
================================================
#!/bin/bash
#SBATCH -J posterior # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=64000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
export HYDRA_FULL_ERROR=1
finetune_path=/path/to/duo.ckpt
srun python -u -m main \
mode=train \
loader.batch_size=2 \
loader.eval_batch_size=2 \
data=openwebtext-split \
model=small \
algo=distillation \
training.finetune_path=$finetune_path \
sampling.num_sample_batches=10 \
sampling.steps=32 \
eval.compute_generative_perplexity=True \
algo.T=512 \
lr_scheduler.num_warmup_steps=500 \
trainer.val_check_interval=1000 \
trainer.max_steps=50000 \
loader.global_batch_size=128 \
training.ema=0.999 \
algo.update_teacher_every=10000 \
optim.lr=6e-5 \
trainer.limit_val_batches=8 \
algo.teacher_ema=False \
algo.linear_growth_dt=false \
+wandb.offline=true
================================================
FILE: scripts/eval_lm1b_duo.sh
================================================
#!/bin/bash
#SBATCH -J eval_mdlm # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=100000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-W2ZcFy-small-conjugate-lm1b-wrap-1M-gmin-350-gmax-175-3/checkpoints/7-100000.ckpt
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=64 \
loader.eval_batch_size=64 \
data=lm1b-wrap \
model=small \
model.length=128 \
algo=duo_base \
eval.checkpoint_path=$checkpoint_path \
sampling.num_sample_batches=0 \
+wandb.offline=true
================================================
FILE: scripts/eval_owt_ar.sh
================================================
#!/bin/bash
#SBATCH -J eval_ar # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=100000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
algo=ar \
model.length=1024 \
eval.checkpoint_path=$checkpoint_path \
+wandb.offline=true
================================================
FILE: scripts/eval_owt_duo.sh
================================================
#!/bin/bash
#SBATCH -J owt_duo_anneal # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=100000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-VlCQLK-small-conjugate-OWT-anneal/checkpoints/last.ckpt
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=8 \
loader.eval_batch_size=8 \
data=openwebtext-split \
model=small \
algo=duo_base \
eval.checkpoint_path=$checkpoint_path \
sampling.num_sample_batches=0 \
+wandb.offline=true
================================================
FILE: scripts/eval_owt_mdlm.sh
================================================
#!/bin/bash
#SBATCH -J eval_mdlm # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=100000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small \
algo=mdlm \
eval.checkpoint_path=$checkpoint_path \
sampling.num_sample_batches=0 \
+wandb.offline=true
================================================
FILE: scripts/eval_owt_sedd.sh
================================================
#!/bin/bash
#SBATCH -J eval_sedd # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=100000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small \
algo=sedd \
eval.checkpoint_path=$checkpoint_path \
sampling.num_sample_batches=0 \
+wandb.offline=true
================================================
FILE: scripts/fid_cifar10_duo_ancestral_cosine.sh
================================================
export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1
python -u -m main \
mode=fid_eval \
sampling.steps=64 \
sampling.guid_weight=1.0 \
data=cifar10 \
data.cache_dir= \
model=unet \
noise=cosine \
algo=duo_base \
algo.backbone=unet \
trainer.num_nodes=1 \
loader.eval_batch_size=500 \
eval.checkpoint_path=
================================================
FILE: scripts/fid_cifar10_duo_base_ancestral_cosine.sh
================================================
export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1
python -u -m main \
mode=fid_eval \
sampling.steps=64 \
sampling.guid_weight=1.0 \
data=cifar10 \
data.cache_dir= \
model=unet \
noise=cosine \
algo=duo_base \
algo.backbone=unet \
trainer.num_nodes=1 \
loader.eval_batch_size=500 \
eval.checkpoint_path=
================================================
FILE: scripts/fid_cifar10_mdlm_ancestral_cosine.sh
================================================
python -u -m main \
mode=fid_eval \
sampling.steps=64 \
sampling.guid_weight=1.0 \
sampling.predictor=ancestral_cache \
data=cifar10 \
data.cache_dir= \
model=unet \
noise=cosine \
algo=mdlm \
algo.backbone=unet \
loader.eval_batch_size=500 \
eval.checkpoint_path=
================================================
FILE: scripts/gen_ppl_lm1b_ar.sh
================================================
#!/bin/bash
#SBATCH -J sample_ar # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/lm1b-ar/last_copy.ckpt
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=sample_eval \
loader.batch_size=2 \
loader.eval_batch_size=64 \
data=lm1b-wrap \
algo=ar \
model=small \
model.length=128 \
eval.checkpoint_path=$checkpoint_path \
sampling.num_sample_batches=15 \
+wandb.offline=true
================================================
FILE: scripts/gen_ppl_lm1b_duo.sh
================================================
#!/bin/bash
#SBATCH -J sample_ar # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil-kl-bwd-32/checkpoints/0-1000.ckpt
steps=32
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=sample_eval \
loader.batch_size=2 \
loader.eval_batch_size=64 \
data=lm1b-wrap \
algo=duo_base \
model=small \
model.length=128 \
eval.checkpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil7-kl-bwd/checkpoints/last.ckpt \
sampling.num_sample_batches=15 \
sampling.steps=$steps \
+wandb.offline=true \
sampling.noise_removal=greedy
================================================
FILE: scripts/gen_ppl_owt_ar.sh
================================================
#!/bin/bash
#SBATCH -J sample_ar # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=16000 # server memory requested (per node)
#SBATCH -t 24:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov,gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints
seed=1
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=sample_eval \
seed=$seed \
loader.batch_size=2 \
loader.eval_batch_size=8 \
data=openwebtext-split \
algo=ar \
model=small \
eval.checkpoint_path=$checkpoint_path/last.ckpt \
sampling.num_sample_batches=100 \
+wandb.offline=true \
eval.generated_samples_path=$checkpoint_path/$seed-ckpt-last.json
================================================
FILE: scripts/gen_ppl_owt_duo.sh
================================================
#!/bin/bash
#SBATCH -J an_owt_duo # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=16000 # server memory requested (per node)
#SBATCH -t 24:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov,gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
export HYDRA_FULL_ERROR=1
while [[ "$#" -gt 0 ]]; do
case $1 in
--steps) steps="$2"; shift ;;
--seed) seed="$2"; shift ;;
--ckpt) ckpt="$2"; shift ;;
*) echo "Unknown parameter: $1"; exit 1 ;;
esac
shift
done
checkpoint_path=/share/kuleshov/ssahoo/flow-ode/distil-distil-vjrpZb-distillation-OWT/checkpoints
ckpt=0-50000
steps=${steps:-32}
seed=${seed:-1}
echo " Steps: $steps"
echo " Seed: $seed"
echo " ckpt: $ckpt"
srun python -u -m main \
mode=sample_eval \
seed=$seed \
loader.batch_size=2 \
loader.eval_batch_size=8 \
data=openwebtext-split \
algo=duo_base \
model=small \
eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \
sampling.num_sample_batches=100 \
sampling.steps=$steps \
+wandb.offline=true \
eval.generated_samples_path=$checkpoint_path/samples_ancestral/$seed-$steps-ckpt-$ckpt.json
================================================
FILE: scripts/gen_ppl_owt_mdlm.sh
================================================
#!/bin/bash
#SBATCH -J sample_mdlm # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=64000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt
export HYDRA_FULL_ERROR=1
srun python -u -m main \
mode=sample_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=openwebtext-split \
model=small \
algo=mdlm \
eval.checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt \
sampling.num_sample_batches=4 \
+wandb.offline=true
================================================
FILE: scripts/gen_ppl_owt_sedd.sh
================================================
#!/bin/bash
#SBATCH -J sedd_samples # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=16000 # server memory requested (per node)
#SBATCH -t 24:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov,gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
while [[ "$#" -gt 0 ]]; do
case $1 in
--steps) steps="$2"; shift ;;
--seed) seed="$2"; shift ;;
*) echo "Unknown parameter: $1"; exit 1 ;;
esac
shift
done
steps=${steps:-32}
seed=${seed:-1}
echo " Steps: $steps"
echo " Seed: $seed"
export HYDRA_FULL_ERROR=1
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints
ckpt=last
srun python -u -m main \
mode=sample_eval \
seed=$seed \
loader.batch_size=2 \
loader.eval_batch_size=8 \
data=openwebtext-split \
algo=sedd \
model=small \
eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \
sampling.num_sample_batches=0 \
sampling.num_sample_batches=100 \
sampling.steps=$steps \
sampling.predictor=analytic \
eval.generated_samples_path=$checkpoint_path/$seed-$steps-ckpt-$ckpt.json \
+wandb.offline=true
================================================
FILE: scripts/psi_samplers/cifar10/duo_constant_remdm.sh
================================================
# DUO psi-sampler with constant-remdm-eta mode (ReMDM loop)
NUM_STEPS=256
ETA=0.01
NOISE=cosine
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=duo_base \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear-constant-linear-0.9-inv \
sampling.psi.high_mode=pure-posterior \
sampling.psi.middle_mode=constant-remdm-$ETA \
sampling.psi.low_mode=pure-posterior \
sampling.psi.high_frac=0.45 \
sampling.psi.middle_frac=0.5
================================================
FILE: scripts/psi_samplers/cifar10/duo_max_capped_remdm.sh
================================================
# DUO psi-sampler with max-capped-eta mode (ReMDM cap)
NUM_STEPS=256
ETA=0.005
NOISE=cosine
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=duo_base \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-capped-$ETA \
sampling.psi.middle_mode=max-capped-$ETA \
sampling.psi.low_mode=max-capped-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/psi_samplers/cifar10/duo_max_rescale_eta.sh
================================================
# DUO psi-sampler with max-rescale-eta mode — CIFAR-10 FID eval
NUM_STEPS=256
ETA=0.01
NOISE=cosine
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=duo_base \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-rescale-$ETA \
sampling.psi.middle_mode=max-rescale-$ETA \
sampling.psi.low_mode=max-rescale-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/psi_samplers/cifar10/duo_psi_pc.sh
================================================
# DUO psi-sampler with constant kappa in pc phase
# Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC
# t in [1.0, 0.5] -> pure-posterior
# t in [0.5, 0.1] -> constant-kappa
# t in [0.1, 0.0] -> pure-posterior
NUM_STEPS=256
KAPPA=0.95
NOISE=cosine
HIGH_FRAC=0.5
MIDDLE_FRAC=0.4
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=duo_base \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=pure-posterior \
sampling.psi.middle_mode=constant-$KAPPA \
sampling.psi.low_mode=pure-posterior \
sampling.psi.high_frac=$HIGH_FRAC \
sampling.psi.middle_frac=$MIDDLE_FRAC
================================================
FILE: scripts/psi_samplers/cifar10/mdlm_constant_remdm.sh
================================================
# MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop) — CIFAR-10 FID eval
NUM_STEPS=256
ETA=0.01
NOISE=cosine
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=mdlm \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear-constant-linear-0.9-inv \
sampling.psi.high_mode=pure-posterior \
sampling.psi.middle_mode=constant-remdm-$ETA \
sampling.psi.low_mode=pure-posterior \
sampling.psi.high_frac=0.45 \
sampling.psi.middle_frac=0.5
================================================
FILE: scripts/psi_samplers/cifar10/mdlm_max_capped_remdm.sh
================================================
# MDLM psi-sampler with max-capped-eta mode (ReMDM cap)
NUM_STEPS=256
ETA=0.005
NOISE=cosine
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=mdlm \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-capped-$ETA \
sampling.psi.middle_mode=max-capped-$ETA \
sampling.psi.low_mode=max-capped-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/psi_samplers/cifar10/mdlm_max_rescale_eta.sh
================================================
# MDLM psi-sampler with max-rescale-eta mode
NUM_STEPS=256
ETA=0.01
NOISE=cosine
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=mdlm \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-rescale-$ETA \
sampling.psi.middle_mode=max-rescale-$ETA \
sampling.psi.low_mode=max-rescale-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/psi_samplers/cifar10/mdlm_psi_pc.sh
================================================
# MDLM psi-sampler with constant kappa during pc phase
# Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC
# t in [1.0, 0.1] -> constant kappa
# t in [0.1, 0.0] -> pure posterior
NUM_STEPS=256
KAPPA=0.99
NOISE=cosine
HIGH_FRAC=0.0
MIDDLE_FRAC=0.9
CHECKPOINT_PATH=
DATA_CACHE_DIR=
EVAL_BATCH_SIZE=500
python -u -m main \
mode=fid_eval \
data=cifar10 \
data.cache_dir=$DATA_CACHE_DIR \
model=unet \
algo=mdlm \
algo.backbone=unet \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.guid_weight=1.0 \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=pure-posterior \
sampling.psi.middle_mode=constant-$KAPPA \
sampling.psi.low_mode=pure-posterior \
sampling.psi.high_frac=$HIGH_FRAC \
sampling.psi.middle_frac=$MIDDLE_FRAC
================================================
FILE: scripts/psi_samplers/owt/duo_loop_remdm.sh
================================================
# DUO psi-sampler with constant-remdm-eta mode (ReMDM loop)
NUM_STEPS=256
ETA=0.01
NUCLEUS_P=0.95
NOISE=log-linear
CHECKPOINT_PATH=???
DATA_CACHE_DIR=???
EVAL_BATCH_SIZE=16
NUM_SAMPLE_BATCHES=32
python -u -m main \
mode=sample_eval \
data=openwebtext-split \
data.cache_dir=$DATA_CACHE_DIR \
model=small \
algo=duo_base \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.p_nucleus=$NUCLEUS_P \
sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear-constant-linear-0.9-inv \
sampling.psi.high_mode=pure-posterior \
sampling.psi.middle_mode=constant-remdm-$ETA \
sampling.psi.low_mode=pure-posterior \
sampling.psi.high_frac=0.45 \
sampling.psi.middle_frac=0.5
================================================
FILE: scripts/psi_samplers/owt/duo_max_capped_remdm.sh
================================================
# DUO psi-sampler with max-capped-eta mode (ReMDM cap)
NUM_STEPS=256
ETA=0.01
NUCLEUS_P=0.9
NOISE=log-linear
CHECKPOINT_PATH=???
DATA_CACHE_DIR=???
EVAL_BATCH_SIZE=16
NUM_SAMPLE_BATCHES=32
python -u -m main \
mode=sample_eval \
data=openwebtext-split \
data.cache_dir=$DATA_CACHE_DIR \
model=small \
algo=duo_base \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.p_nucleus=$NUCLEUS_P \
sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-capped-$ETA \
sampling.psi.middle_mode=max-capped-$ETA \
sampling.psi.low_mode=max-capped-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/psi_samplers/owt/duo_max_rescale_eta.sh
================================================
# DUO psi-sampler with max-rescale-eta mode (ReMDM rescale)
NUM_STEPS=256
ETA=0.05
NUCLEUS_P=0.9
NOISE=log-linear
CHECKPOINT_PATH=???
DATA_CACHE_DIR=???
EVAL_BATCH_SIZE=16
NUM_SAMPLE_BATCHES=32
python -u -m main \
mode=sample_eval \
data=openwebtext-split \
data.cache_dir=$DATA_CACHE_DIR \
model=small \
algo=duo_base \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.p_nucleus=$NUCLEUS_P \
sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-rescale-$ETA \
sampling.psi.middle_mode=max-rescale-$ETA \
sampling.psi.low_mode=max-rescale-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/psi_samplers/owt/mdlm_loop_remdm.sh
================================================
# MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop)
NUM_STEPS=256
ETA=0.01
NUCLEUS_P=0.95
NOISE=log-linear
CHECKPOINT_PATH=???
DATA_CACHE_DIR=???
EVAL_BATCH_SIZE=16
NUM_SAMPLE_BATCHES=32
python -u -m main \
mode=sample_eval \
data=openwebtext-split \
data.cache_dir=$DATA_CACHE_DIR \
model=small \
algo=mdlm \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.p_nucleus=$NUCLEUS_P \
sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear-constant-linear-0.9-inv \
sampling.psi.high_mode=pure-posterior \
sampling.psi.middle_mode=constant-remdm-$ETA \
sampling.psi.low_mode=pure-posterior \
sampling.psi.high_frac=0.45 \
sampling.psi.middle_frac=0.5
================================================
FILE: scripts/psi_samplers/owt/mdlm_max_capped_remdm.sh
================================================
# MDLM psi-sampler with max-capped-eta mode (ReMDM cap)
NUM_STEPS=256
ETA=0.01
NUCLEUS_P=0.9
NOISE=log-linear
CHECKPOINT_PATH=???
DATA_CACHE_DIR=???
EVAL_BATCH_SIZE=16
NUM_SAMPLE_BATCHES=32
python -u -m main \
mode=sample_eval \
data=openwebtext-split \
data.cache_dir=$DATA_CACHE_DIR \
model=small \
algo=mdlm \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.p_nucleus=$NUCLEUS_P \
sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-capped-$ETA \
sampling.psi.middle_mode=max-capped-$ETA \
sampling.psi.low_mode=max-capped-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/psi_samplers/owt/mdlm_max_rescale_eta.sh
================================================
# MDLM psi-sampler with max-rescale-eta mode (ReMDM rescale)
NUM_STEPS=256
ETA=0.05
NUCLEUS_P=0.9
NOISE=log-linear
CHECKPOINT_PATH=???
DATA_CACHE_DIR=???
EVAL_BATCH_SIZE=16
NUM_SAMPLE_BATCHES=32
python -u -m main \
mode=sample_eval \
data=openwebtext-split \
data.cache_dir=$DATA_CACHE_DIR \
model=small \
algo=mdlm \
noise=$NOISE \
sampling.predictor=psi \
sampling.steps=$NUM_STEPS \
sampling.p_nucleus=$NUCLEUS_P \
sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \
eval.checkpoint_path=$CHECKPOINT_PATH \
loader.eval_batch_size=$EVAL_BATCH_SIZE \
sampling.psi.time_profile=linear \
sampling.psi.high_mode=max-rescale-$ETA \
sampling.psi.middle_mode=max-rescale-$ETA \
sampling.psi.low_mode=max-rescale-$ETA \
sampling.psi.high_frac=0.0 \
sampling.psi.middle_frac=0.0
================================================
FILE: scripts/train_cifar10_duo_base_cosine.sh
================================================
python -u -m main \
data=cifar10 \
data.cache_dir= \
model=unet \
algo=duo_base \
algo.backbone=unet \
noise=cosine \
loader.global_batch_size=128 \
loader.batch_size=32 \
loader.eval_batch_size=32 \
loader.num_workers=8 \
trainer.val_check_interval=2500 \
trainer.max_steps=1_500_000 \
lr_scheduler.num_warmup_steps=5000 \
eval.generate_samples=False \
optim.lr=2e-4 \
callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \
wandb.name=duo_base_1_5M_d3pm_like_cosine \
hydra.run.dir=./outputs/cifar10/duo_base_1_5M_d3pm_like_cosine \
================================================
FILE: scripts/train_cifar10_duo_cosine.sh
================================================
python -u -m main \
data=cifar10 \
data.cache_dir= \
model=unet \
algo=duo_base \
algo.backbone=unet \
noise=cosine \
loader.global_batch_size=128 \
loader.batch_size=32 \
loader.eval_batch_size=32 \
loader.num_workers=8 \
trainer.val_check_interval=2500 \
trainer.max_steps=1_500_000 \
lr_scheduler.num_warmup_steps=5000 \
eval.generate_samples=False \
optim.lr=2e-4 \
callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \
wandb.name=duo_base_1_5M_d3pm_like_cosine \
hydra.run.dir=./outputs/cifar10/duo_base_1_5M_d3pm_like_cosine \
================================================
FILE: scripts/train_cifar10_mdlm_cosine.sh
================================================
python -u -m main \
data=cifar10 \
data.cache_dir= \
model=unet \
algo=mdlm \
algo.backbone=unet \
noise=cosine \
loader.global_batch_size=128 \
loader.batch_size=32 \
loader.eval_batch_size=32 \
loader.num_workers=8 \
trainer.val_check_interval=2500 \
trainer.max_steps=1_500_000 \
lr_scheduler.num_warmup_steps=5000 \
eval.generate_samples=False \
optim.lr=2e-4 \
callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \
wandb.name=mdlm_1_5M_d3pm_like_cosine \
hydra.run.dir=./outputs/cifar10/mdlm_1_5M_d3pm_like_cosine \
================================================
FILE: scripts/train_lm1b_ar.sh
================================================
#!/bin/bash
#SBATCH -J train_ar_lm1b # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
srun python -u -m main \
loader.batch_size=64 \
loader.eval_batch_size=64 \
algo=ar \
data=lm1b \
wandb.name=ar-lm1b-small \
model=small \
model.length=128
================================================
FILE: scripts/train_lm1b_ar_sentencepacking.sh
================================================
#!/bin/bash
#SBATCH -J train_ar_lm1b # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=2
#SBATCH --gres=gpu:2 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
srun python -u -m main \
loader.batch_size=256 \
loader.eval_batch_size=256 \
algo=ar \
data=lm1b-wrap \
wandb.name=ar-lm1b-wrap-small \
model=small \
model.length=128
================================================
FILE: scripts/train_lm1b_d3pm.sh
================================================
#!/bin/bash
#SBATCH -J train_d3pm # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
srun python -u -m main \
loader.batch_size=64 \
loader.eval_batch_size=64 \
model=small \
data=lm1b \
wandb.name=d3pm-lm1b \
algo=d3pm \
model.length=128 \
eval.compute_generative_perplexity=False
================================================
FILE: scripts/train_lm1b_duo.sh
================================================
#!/bin/bash
#SBATCH -J duo-lm1b # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=64000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
# Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed
srun python -u -m main \
loader.batch_size=64 \
loader.eval_batch_size=64 \
data=lm1b \
wandb.name=duo-lm1b \
model=small \
algo=duo \
model.length=128 \
algo.curriculum.mode=simple \
algo.curriculum.gumbel_tau_log10_start=-3.0 \
algo.curriculum.gumbel_tau_log10_end=-3.0 \
algo.curriculum.gamma_min=-3.5 \
algo.curriculum.gamma_max=-1.75 \
algo.curriculum.start=0 \
algo.curriculum.end=500000
================================================
FILE: scripts/train_lm1b_duo_sentencepacking.sh
================================================
#!/bin/bash
#SBATCH -J duo-lm1b # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=64000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
# Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed
srun python -u -m main \
loader.batch_size=64 \
loader.eval_batch_size=64 \
data=lm1b-wrap \
wandb.name=duo-lm1b \
model=small \
algo=duo \
model.length=128 \
algo.curriculum.mode=simple \
algo.curriculum.gumbel_tau_log10_start=-3.0 \
algo.curriculum.gumbel_tau_log10_end=-3.0 \
algo.curriculum.gamma_min=-3.5 \
algo.curriculum.gamma_max=-1.75 \
algo.curriculum.start=0 \
algo.curriculum.end=500000
================================================
FILE: scripts/train_lm1b_mdlm.sh
================================================
#!/bin/bash
#SBATCH -J lm1b_mdlm # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=2
#SBATCH --gres=gpu:2 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
srun python -u -m main \
loader.batch_size=64 \
loader.eval_batch_size=64 \
data=lm1b \
wandb.name=mdlm-lm1b \
model=small \
algo=mdlm \
model.length=128
================================================
FILE: scripts/train_lm1b_mdlm_sentencepacking.sh
================================================
#!/bin/bash
#SBATCH -J lm1b_mdlm # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=2
#SBATCH --gres=gpu:2 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
srun python -u -m main \
loader.batch_size=64 \
loader.eval_batch_size=64 \
data=lm1b-wrap \
wandb.name=mdlm-lm1b-wrap-small \
model=small \
algo=mdlm \
model.length=128
================================================
FILE: scripts/train_owt_duo.sh
================================================
#!/bin/bash
#SBATCH -J duo-lm1b # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=64000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
# Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed
srun python -u -m main \
loader.batch_size=32 \
loader.eval_batch_size=32 \
data=openwebtext-split \
wandb.name=duo-owt \
model=small \
algo=duo \
model.length=1024 \
algo.curriculum.mode=simple \
algo.curriculum.gumbel_tau_log10_start=-3.0 \
algo.curriculum.gumbel_tau_log10_end=-3.0 \
algo.curriculum.gamma_min=-3.55 \
algo.curriculum.gamma_max=-1.85 \
algo.curriculum.start=0 \
algo.curriculum.end=500000
================================================
FILE: scripts/train_owt_duo_finetune.sh
================================================
#!/bin/bash
#SBATCH -J duo-base # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=64000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
finetune_path=/path/to/intermediate_duo_500k.ckpt
# Assuming the finetune_path corresponds to the DUO model
# trained for 500K steps with curriculum learning, we train the
# model for 500K more steps.
srun python -u -m main \
loader.batch_size=64 \
loader.eval_batch_size=64 \
data=openwebtext-split \
wandb.name=duo-owt-finetune \
model=small \
algo=duo_base \
model.length=1024 \
wandb.name=duo-base \
training.finetune_path=$finetune_path \
sampling.num_sample_batches=0 \
trainer.max_steps=500000
================================================
FILE: scripts/train_owt_mdlm.sh
================================================
#!/bin/bash
#SBATCH -J train_mdlm # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
srun python -u -m main \
loader.batch_size=2 \
loader.eval_batch_size=2 \
model=small \
data=openwebtext-split \
wandb.name=mdlm-owt \
algo=mdlm \
model.length=1024 \
eval.compute_generative_perplexity=False \
+wandb.offline=True
================================================
FILE: scripts/train_owt_sedd.sh
================================================
#!/bin/bash
#SBATCH -J train_sedd # Job name
#SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --constraint="gpu-mid|gpu-high"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon pre-emption
# To enable preemption re-loading, set `hydra.run.dir` or
# `checkpointing.save_dir` explicitly.
srun python -u -m main \
loader.batch_size=16 \
loader.eval_batch_size=16 \
model=small \
data=openwebtext-split \
wandb.name=sedd-owt \
algo=sedd \
model.length=1024 \
eval.compute_generative_perplexity=True \
sampling.predictor=analytic
================================================
FILE: scripts/zero_shot_ar.sh
================================================
#!/bin/bash
#SBATCH -J zeroshot_ar # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=gpu # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt
export HYDRA_FULL_ERROR=1
datasets=("ag_news"
"scientific_papers_pubmed"
"scientific_papers_arxiv"
"lambada"
"wikitext2"
"wikitext103"
"ptb"
"lm1b-gpt2")
for data in "${datasets[@]}"; do
echo "$data"
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data="$data" \
data.insert_valid_eos=False \
algo=ar \
model.length=1024 \
eval.checkpoint_path=$checkpoint_path \
+wandb.offline=true
done
================================================
FILE: scripts/zero_shot_duo.sh
================================================
#!/bin/bash
#SBATCH -J zeroshot_duo # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-VlCQLK-small-conjugate-OWT-anneal/checkpoints/last.ckpt
export HYDRA_FULL_ERROR=1
datasets=("ag_news"
"scientific_papers_pubmed"
"scientific_papers_arxiv"
"lambada"
"wikitext2"
"wikitext103"
"ptb"
"lm1b-gpt2")
for data in "${datasets[@]}"; do
echo "$data"
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
loader.eval_global_batch_size=128 \
data="$data" \
data.insert_valid_eos=False \
model=small \
algo=duo_base \
model.length=1024 \
eval.checkpoint_path=$checkpoint_path \
sampling.num_sample_batches=0 \
+wandb.offline=true
done
================================================
FILE: scripts/zero_shot_mdlm.sh
================================================
#!/bin/bash
#SBATCH -J zeroshot_mdlm_noeos # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diff-clean-s-owt-no-t-mQ4fQG-param-subs_data-openwebtext-split/checkpoints/61-1000000.ckpt
export HYDRA_FULL_ERROR=1
datasets=("ag_news"
"scientific_papers_pubmed"
"scientific_papers_arxiv"
"lambada"
"wikitext2"
"wikitext103"
"ptb"
"lm1b-gpt2")
for data in "${datasets[@]}"; do
echo "$data"
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
loader.eval_global_batch_size=128 \
data="$data" \
data.insert_valid_eos=False \
model=small \
algo=mdlm \
model.length=1024 \
eval.checkpoint_path=$checkpoint_path \
+wandb.offline=true
done
================================================
FILE: scripts/zero_shot_sedd.sh
================================================
#!/bin/bash
#SBATCH -J zeroshot_sedd # Job name
#SBATCH -o watch_folder/%x_%j.out # log file (out & err)
#SBATCH -N 1 # Total number of nodes requested
#SBATCH --get-user-env # retrieve the users login environment
#SBATCH --mem=32000 # server memory requested (per node)
#SBATCH -t 960:00:00 # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1 # Type/number of GPUs needed
#SBATCH --open-mode=append # Do not overwrite logs
#SBATCH --requeue # Requeue upon preemption
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt
export HYDRA_FULL_ERROR=1
datasets=("ag_news"
"scientific_papers_pubmed"
"scientific_papers_arxiv"
"lambada"
"wikitext2"
"wikitext103"
"ptb"
"lm1b-gpt2")
for data in "${datasets[@]}"; do
echo "$data"
srun python -u -m main \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
loader.eval_global_batch_size=128 \
data="$data" \
data.insert_valid_eos=False \
model=small \
algo=sedd \
model.length=1024 \
eval.checkpoint_path=$checkpoint_path \
+wandb.offline=true
done
================================================
FILE: trainer_base.py
================================================
import itertools
from dataclasses import dataclass
import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
import transformers
import dataloader
import metrics
import models
import utils
@dataclass
class Loss:
loss: torch.FloatTensor
nlls: torch.FloatTensor
prior_loss: torch.FloatTensor
num_tokens: torch.FloatTensor
class LogLinear(torch.nn.Module):
def __init__(self, eps):
super().__init__()
self.eps = eps # 1e-3 by default, to be consistent with SEDD: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion/blob/0605786da5ccb5747545e26d66fdf477187598b6/noise_lib.py#L56
def forward(self, t):
t = (1 - self.eps) * t
alpha_t = 1 - t
dalpha_t = - torch.ones_like(alpha_t) * (1 - self.eps)
return dalpha_t, alpha_t
def get_t_for_alpha(self, alpha_t):
return 1 - alpha_t
class Cosine(torch.nn.Module):
def __init__(self, eps):
super().__init__()
self.eps = eps
self.half_pi = torch.pi / 2
def forward(self, t):
t = (1 - self.eps) * t
alpha_t = 1 - torch.cos(self.half_pi * (1 - t))
dalpha_t = - torch.sin(self.half_pi * (1 - t)) * self.half_pi
return dalpha_t, alpha_t
def get_t_for_alpha(self, alpha_t):
is_tensor = torch.is_tensor(alpha_t)
if not is_tensor:
alpha_t = torch.tensor([alpha_t])
t = 1 - 2 / torch.pi * torch.acos(1 - alpha_t)
if not is_tensor:
t = t.cpu().item()
return t
def sample_categorical(categorical_probs):
gumbel_norm = (
1e-10
- (torch.rand_like(categorical_probs) + 1e-10).log())
return (categorical_probs / gumbel_norm).argmax(dim=-1)
def _unsqueeze(x, reference):
return x.view(
* x.shape,
* ((1,) * (len(reference.shape) - len(x.shape))))
class TrainerBase(L.LightningModule):
def __init__(
self,
config,
tokenizer: transformers.PreTrainedTokenizer,
vocab_size=None):
super().__init__()
self.save_hyperparameters()
self.config = config
if hasattr(self.config.algo, 'ignore_bos'):
self.ignore_bos = config.algo.ignore_bos
else:
self.ignore_bos = False
if hasattr(self.config.algo, 'loss_type'):
self.loss_type = config.algo.loss_type
self.tokenizer = tokenizer
if vocab_size is None:
self.vocab_size = len(self.tokenizer)
else:
self.vocab_size = vocab_size
self.sampler = self.config.sampling.predictor
self.antithetic_sampling = self.config.training.antithetic_sampling
self.parameterization = self.config.algo.parameterization
if self.config.algo.backbone == 'dit':
self.backbone = models.dit.DIT(
self.config, vocab_size=self.vocab_size)
elif self.config.algo.backbone == 'unet':
self.backbone = models.unet.UNet(self.config, self.vocab_size)
elif self.config.algo.backbone == 'dimamba':
self.backbone = models.dimamba.DiMamba(
self.config,
vocab_size=self.vocab_size,
pad_token_id=self.tokenizer.pad_token_id)
elif self.config.algo.backbone == 'hf_dit':
self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
config.eval.checkpoint_path, trust_remote_code=True)
self.T = self.config.algo.T
self.num_tokens = self.config.model.length
self.p_nucleus = self.config.sampling.p_nucleus
# Noise Schedule
if config.noise.type == 'log-linear':
self.noise = LogLinear(config.noise.eps)
elif config.noise.type == 'cosine':
self.noise = Cosine(config.noise.eps)
else:
raise ValueError(config.noise.type)
# Class-conditional training arguments
self.num_classes = config.data.get('num_classes', None)
self.class_conditional = self.num_classes is not None
self.class_cond_dropout = config.training.class_dropout_p
self.metrics = metrics.Metrics(
gen_ppl_eval_model_name_or_path=\
self.config.eval.gen_ppl_eval_model_name_or_path,
eval_ppl_batch_size=\
self.config.eval.perplexity_batch_size)
if self.config.training.ema > 0:
self.ema = models.ema.ExponentialMovingAverage(
self._get_parameters(),
decay=self.config.training.ema)
else:
self.ema = None
self.lr = self.config.optim.lr
self.sampling_eps = self.config.training.sampling_eps
self.time_conditioning = self.config.algo.time_conditioning
self.neg_infinity = -1000000.0
self.fast_forward_epochs = None
self.fast_forward_batches = None
def _validate_configuration(self):
assert self.config.algo.backbone in {'dit', 'hf_dit',
'unet'}
if self.config.algo.parameterization == 'ar':
assert not self.config.algo.time_conditioning
assert self.config.prior.type == 'none'
if self.time_conditioning:
assert self.config.sampling.predictor != 'ancestral_cache'
if self.parameterization in {'score', 'mean'}:
assert self.time_conditioning
if self.T > 0:
assert self.parameterization != 'score'
def to(self, *args, **kwargs):
self = super().to(*args, **kwargs)
self.metrics.to(*args, **kwargs)
return self
def q_xt(self, x, alpha_t):
raise NotImplementedError
def _get_parameters(self):
return itertools.chain(self.backbone.parameters(),
self.noise.parameters())
def _eval_mode(self):
if self.ema:
self.ema.store(self._get_parameters())
self.ema.copy_to(self._get_parameters())
self.backbone.eval()
self.noise.eval()
def _train_mode(self):
if self.ema:
self.ema.restore(self._get_parameters())
self.backbone.train()
self.noise.train()
def on_load_checkpoint(self, checkpoint):
if self.ema:
self.ema.load_state_dict(checkpoint['ema'])
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
self.fast_forward_epochs = checkpoint['loops'][
'fit_loop']['epoch_progress']['current']['completed']
self.fast_forward_batches = checkpoint['loops'][
'fit_loop']['epoch_loop.batch_progress'][
'current']['completed']
def on_save_checkpoint(self, checkpoint):
if self.ema:
checkpoint['ema'] = self.ema.state_dict()
# Copied from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
# ['epoch_loop.batch_progress']['total']['completed']
# is 1 iteration behind, so we're using the optimizer's progress.
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['total'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total'][
'completed'] * self.trainer.accumulate_grad_batches
checkpoint['loops']['fit_loop'][
'epoch_loop.batch_progress']['current'][
'completed'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['current'][
'completed'] * self.trainer.accumulate_grad_batches
# _batches_that_stepped tracks the number of global steps,
# not the number of local steps, so we don't multiply with
# self.trainer.accumulate_grad_batches here.
checkpoint['loops']['fit_loop'][
'epoch_loop.state_dict'][
'_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
'epoch_loop.automatic_optimization.optim_progress'][
'optimizer']['step']['total']['completed']
if 'sampler' not in checkpoint.keys():
checkpoint['sampler'] = {}
if hasattr(self.trainer.train_dataloader.sampler,
'state_dict'):
sampler_state_dict = self.trainer.\
train_dataloader.sampler.state_dict()
checkpoint['sampler'][
'random_state'] = sampler_state_dict.get(
'random_state', None)
else:
checkpoint['sampler']['random_state'] = None
def on_train_start(self):
if self.ema:
self.ema.move_shadow_params_to_device(self.device)
# Adapted from:
# https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
distributed = (
self.trainer._accelerator_connector.use_distributed_sampler
and self.trainer._accelerator_connector.is_distributed)
if distributed:
sampler_cls = dataloader.FaultTolerantDistributedSampler
else:
sampler_cls = dataloader.RandomFaultTolerantSampler
updated_dls = []
for dl in self.trainer.fit_loop._combined_loader.flattened:
if hasattr(dl.sampler, 'shuffle'):
dl_sampler = sampler_cls(
dl.dataset, shuffle=dl.sampler.shuffle)
else:
dl_sampler = sampler_cls(dl.dataset)
if (distributed
and self.fast_forward_epochs is not None
and self.fast_forward_batches is not None):
dl_sampler.load_state_dict({
'epoch': self.fast_forward_epochs,
'counter': (self.fast_forward_batches
* self.config.loader.batch_size)})
updated_dls.append(
torch.utils.data.DataLoader(
dl.dataset,
batch_size=self.config.loader.batch_size,
num_workers=self.config.loader.num_workers,
pin_memory=self.config.loader.pin_memory,
sampler=dl_sampler,
shuffle=False,
persistent_workers=True))
self.trainer.fit_loop._combined_loader.flattened = updated_dls
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
if self.ema:
self.ema.update(self._get_parameters())
def _process_sigma(self, sigma):
raise NotImplementedError
def _process_model_output(self, model_output, xt, sigma):
raise NotImplementedError
def forward(self, xt, sigma, labels=None, weights=None,
nn_input_idxs=None):
if nn_input_idxs is None:
nn_input_idxs = xt
sigma = self._process_sigma(sigma)
with torch.amp.autocast('cuda', dtype=torch.float32):
model_output = self.backbone(
x=nn_input_idxs, sigma=sigma, class_cond=labels,
weights=weights)
return self._process_model_output(
model_output=model_output, xt=xt, sigma=sigma)
def on_train_epoch_start(self):
self.metrics.reset()
assert self.metrics.train_nlls.nll.mean_value == 0
assert self.metrics.train_nlls.nll.weight == 0
def training_step(self, batch, batch_idx):
current_accumulation_step = (
batch_idx % self.trainer.accumulate_grad_batches)
losses = self._loss(batch['input_ids'],
batch.get('labels', None),
batch['attention_mask'],
current_accumulation_step,
train_mode=True)
self.metrics.update_train(losses.nlls, losses.prior_loss,
losses.num_tokens)
self.log(name='trainer/loss',
value=losses.loss.item(),
on_step=True,
on_epoch=False,
sync_dist=True)
return losses.loss
def on_train_epoch_end(self):
for k, v in self.metrics.valid_nlls.items():
self.log(name=k, value=v.compute(), on_step=False,
on_epoch=True, sync_dist=True)
def on_validation_epoch_start(self):
self.metrics.reset()
self._eval_mode()
assert self.metrics.valid_nlls.nll.mean_value == 0
assert self.metrics.valid_nlls.nll.weight == 0
def validation_step(self, batch, batch_idx):
del batch_idx
losses = self._loss(batch['input_ids'],
batch.get('labels', None),
batch['attention_mask'])
self.metrics.update_valid(losses.nlls, losses.prior_loss,
losses.num_tokens)
return losses.loss
def on_validation_epoch_end(self):
for k, v in self.metrics.valid_nlls.items():
self.log(name=k, value=v.compute(), on_step=False,
on_epoch=True, sync_dist=True)
if ((self.config.eval.compute_perplexity_on_sanity
or not self.trainer.sanity_checking)
and self.config.eval.generate_samples):
samples, text_samples = None, None
for _ in range(
self.config.sampling.num_sample_batches):
samples = self.generate_samples(
num_samples=self.config.loader.eval_batch_size)
self.metrics.record_entropy(samples)
# Decode the samples to be re-tokenized by eval model
text_samples = self.tokenizer.batch_decode(samples)
if self.config.eval.compute_generative_perplexity:
self.metrics.record_generative_perplexity(
text_samples, self.num_tokens, self.device)
if text_samples is not None:
if self.trainer.global_rank == 0 and hasattr(
self.trainer.logger, 'log_table'):
# Log the last generated samples
text_samples = text_samples[
: self.config.sampling.num_sample_log]
self.trainer.logger.log_table(
key=f'samples@global_step{self.global_step}',
columns=['Generated Samples'],
data=[[s] for s in text_samples])
if self.config.eval.compute_generative_perplexity:
self.log('val/gen_ppl',
self.metrics.gen_ppl.compute(),
on_epoch=True,
on_step=False,
sync_dist=True)
self.log('val/sample_entropy',
self.metrics.sample_entropy.compute(),
on_epoch=True,
on_step=False,
sync_dist=True)
self._train_mode()
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self._get_parameters(),
lr=self.config.optim.lr,
betas=(self.config.optim.beta1,
self.config.optim.beta2),
eps=self.config.optim.eps,
weight_decay=self.config.optim.weight_decay)
scheduler = hydra.utils.instantiate(
self.config.lr_scheduler, optimizer=optimizer)
scheduler_dict = {'scheduler': scheduler,
'interval': 'step',
'monitor': 'val/loss',
'name': 'trainer/lr'}
return [optimizer], [scheduler_dict]
def generate_samples(self, num_samples, num_steps, eps):
raise NotImplementedError
def restore_model_and_sample(self, num_steps, eps=1e-5):
"""Generate samples from the model."""
# Lightning auto-casting is not working in this method for some reason
self._eval_mode()
samples = self.generate_samples(
num_samples=self.config.loader.eval_batch_size,
num_steps=num_steps,
eps=eps)
self._train_mode()
return samples
def _process_model_input(self, x0, valid_tokens):
raise NotImplementedError
def nll(self, input_tokens, labels, output_tokens,
current_accumulation_step=None, train_mode=False):
raise NotImplementedError
def _loss(self, x0, labels, valid_tokens,
current_accumulation_step=None,
train_mode=False):
(input_tokens, output_tokens,
valid_tokens) = self._process_model_input(
x0, valid_tokens)
loss = self.nll(input_tokens, labels, output_tokens,
current_accumulation_step, train_mode)
assert loss.ndim == 2
if self.ignore_bos:
loss[:, 1:] = loss[:, 1:]
valid_tokens[:, 1:] = valid_tokens[:, 1:]
nlls = (loss * valid_tokens).sum()
num_tokens = valid_tokens.sum()
token_nll = nlls / num_tokens
return Loss(loss=token_nll,
nlls=nlls,
prior_loss=0.0,
num_tokens=num_tokens)
class Diffusion(TrainerBase):
def _validate_configuration(self):
super()._validate_configuration()
assert self.config.sampling.noise_removal in {
'none', 'ancestral', 'greedy'}
assert self.loss_type in {'elbo', 'low_var'}
if self.config.sampling.noise_removal == 'greedy':
assert self.sampler != 'analytic'
assert self.parameterization in {'mean', 'subs'}
def _process_model_input(self, x0, valid_tokens):
return x0, None, valid_tokens
def _process_sigma(self, sigma):
assert sigma.ndim == 2
sigma = sigma.mean(-1).squeeze()
if sigma.ndim == 0:
sigma = sigma.unsqueeze(0)
if not self.time_conditioning:
sigma = torch.zeros_like(sigma)
assert sigma.ndim == 1, sigma.shape
return sigma
def _sample_t(self, n, accum_step):
if accum_step is not None:
# During training
batch_dim = n
n = self.config.loader.global_batch_size
_eps_t = torch.rand(n, device=self.device)
if self.antithetic_sampling:
offset = torch.arange(n, device=self.device) / n
_eps_t = (_eps_t / n + offset) % 1
t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
if accum_step is not None:
t = t.chunk(self.trainer.num_nodes)[self.trainer.node_rank]
t = t.chunk(self.trainer.num_devices)[self.trainer.local_rank]
t = t.chunk(self.trainer.accumulate_grad_batches)[
accum_step]
# corner case for the last datapoint
t = t[:batch_dim]
return t
def _sigma_from_alphat(self, alpha_t):
return -torch.log(alpha_t)
def _reconstruction_loss(self, x0):
t0 = torch.zeros(1, x0.shape[0], dtype=self.dtype,
device=self.device)
sigma_t0 = self._sigma_from_alphat(self.noise(t0)[1])
model_output_t0 = self.forward(x0, sigma_t0)
return - torch.gather(input=model_output_t0,
dim=-1,
index=x0[:, :, None]).squeeze(-1)
def nll_per_token(self, model_output, xt, x0, alpha_t,
dalpha_t, low_var):
raise NotImplementedError
def nll(self, x0, labels, output_tokens,
current_accumulation_step=None, train_mode=False):
del output_tokens
t = self._sample_t(x0.shape[0],
current_accumulation_step)
assert t.shape[0] == x0.shape[0]
if self.T > 0:
t = (t * self.T).to(torch.int)
t = t / self.T
# t \in {1/T, 2/T, ..., 1}
t += (1 / self.T)
dalpha_t, alpha_t = self.noise(t)
alpha_t = alpha_t.unsqueeze(-1)
dalpha_t = dalpha_t.unsqueeze(-1)
assert alpha_t.ndim == 2
assert dalpha_t.ndim == 2
sigma = self._sigma_from_alphat(alpha_t)
xt = self.q_xt(x0, alpha_t)
# Handle class-conditional training, with class dropout
if self.class_conditional:
assert labels is not None
rand = torch.rand(size=labels.shape, dtype=torch.float32,
device=self.device)
# num_classes represent the absence of class-conditioning
labels = torch.where(rand < self.class_cond_dropout,
self.num_classes, labels)
else:
assert labels is None
log_x_theta = self.forward(xt, sigma=sigma, labels=labels)
utils.print_nans(log_x_theta, 'model_output')
return self.nll_per_token(
log_x_theta=log_x_theta,
xt=xt,
x0=x0,
alpha_t=alpha_t,
dalpha_t=dalpha_t,
low_var=train_mode and self.loss_type == 'low_var')
def _get_score(self, **kwargs):
del kwargs
raise NotImplementedError
def _denoiser_update(self, x, t):
raise NotImplementedError
def _analytic_update(self, x, t, dt):
raise NotImplementedError
def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):
"""From the clean x0, or denoiser predictions, implement
the mathematical expression for q(z_s | z_t, x)"""
raise NotImplementedError
def _forward_process(self, q_x0, alpha_s):
"""Apply forward noising: q(x_s | x_0), where x_0 is a
K-dimensional vector"""
raise NotImplementedError
def _get_ancestral_posterior(self, xt, sigma, labels,
alpha_s, alpha_t, p_x0):
"""Top-level call from generate_samples. Compute the
standard posterior sampling distribution from the
noisy sequence xt"""
gamma = self.config.sampling.guid_weight
# If class cond but gamma is None or 0, same as sampling
# from the class unconditional case.
if not self.class_conditional or gamma in (None, 0.0):
if labels is not None:
labels = torch.full_like(labels, self.num_classes)
return self._get_posterior_from_xt(xt, sigma, labels,
alpha_s, alpha_t, p_x0)
elif gamma == 1.0:
# Simply sample from the cond. posterior only.
assert labels is not None
return self._get_posterior_from_xt(xt, sigma, labels,
alpha_s, alpha_t, p_x0)
else:
# Case gamma not in {0, 1}, and class conditional
# -> mix conditional and unconditional predictions.
return self._get_guided_posterior_from_xt(self, xt,
sigma, labels, gamma, alpha_s, alpha_t, p_x0)
def _get_posterior_from_xt(self, xt, sigma, labels, alpha_s,
alpha_t, p_x0=None):
"""From xt, compute a single denoiser predictions, cast
to correct dtype, and call _posterior_from_x0."""
if p_x0 is None:
log_x0_pred = self.forward(xt, sigma, labels)
if self.config.sampling.use_float64:
log_x0_pred = log_x0_pred.to(torch.float64)
if self.p_nucleus < 1:
log_x0_pred = utils.top_k_top_p_filtering(
log_x0_pred, top_p=self.p_nucleus)
p_x0 = log_x0_pred.exp()
return p_x0, self._posterior_from_x0(x0=p_x0,xt=xt,
alpha_s=alpha_s, alpha_t=alpha_t)
def _get_guided_posterior_from_xt(self, xt, sigma, labels,
gamma, alpha_s, alpha_t, p_x0=None):
"""From xt, combine the class cond / uncond predictions
of the denoiser. Call _get_posterior_from_xt."""
# unpack the cache
if p_x0 is None:
p_x0_cond = p_x0_uncond = None
else:
p_x0_cond, p_x0_uncond = p_x0
p_x0_cond, cond_posterior = self._get_posterior_from_xt(
xt, sigma, labels, alpha_s, alpha_t, p_x0_cond)
log_cond_posterior = cond_posterior.log()
# NOTE: conditioning on self.num_classes represents the
# class-unconditional predictions.
p_x0_uncond, uncond_posterior = self._get_posterior_from_xt(
xt, sigma, torch.full_like(labels, self.num_classes),
alpha_s, alpha_t, p_x0_uncond)
log_uncond_posterior = uncond_posterior.log()
un_normalized_posterior = gamma * log_cond_posterior \
+ (1 - gamma) * log_uncond_posterior
# Handle cases where the posterior is zero (eg after
# nucleus sampling)
is_inf_mask = torch.logical_or(
log_cond_posterior.isinf(), log_uncond_posterior.isinf())
un_normalized_posterior[is_inf_mask] = self.neg_infinity
return ((p_x0_cond, p_x0_uncond),
un_normalized_posterior.softmax(-1))
def _ancestral_update(self, x, t, labels, dt, p_x0=None,
noise_removal_step=False):
_, alpha_t = self.noise(t)
if noise_removal_step:
alpha_s = torch.ones_like(alpha_t)
else:
_, alpha_s = self.noise(t - dt)
assert alpha_t.ndim == 2
sigma = self._sigma_from_alphat(alpha_t)
assert alpha_t.ndim == 2, f'{alpha_t.ndim=}'
p_x0, q_xs = self._get_ancestral_posterior(x, sigma,
labels, alpha_s, alpha_t, p_x0)
return p_x0, sample_categorical(q_xs)
def _psi_update(self, x, t, labels, dt, kappa, p_x0=None,
noise_removal_step=False):
_, alpha_t = self.noise(t)
if noise_removal_step:
alpha_s = torch.ones_like(alpha_t)
else:
_, alpha_s = self.noise(t - dt)
alpha_0 = torch.ones_like(alpha_t)
sigma = self._sigma_from_alphat(alpha_t)
# Standard posterior q(x_s | x_t)
p_x0, q_xs = self._get_ancestral_posterior(
x, sigma, labels, alpha_s, alpha_t, p_x0)
# Posterior targeting t=0, reuse predictions p_x0
_, q_x0 = self._get_ancestral_posterior(
x, sigma, labels, alpha_0, alpha_t, p_x0)
# PC: forward-noise q_x0 back to time s
pc_q_xs = self._forward_process(q_x0, alpha_s)
q_sample = kappa * q_xs + (1 - kappa) * pc_q_xs
return p_x0, sample_categorical(q_sample)
def _get_sampling_time_profile(self, eps, num_steps):
profile = self.config.sampling.psi.time_profile
num_steps += 1
if profile == 'linear' \
or self.config.sampling.predictor != 'psi':
# Default: linearly decrease
return torch.linspace(1, eps, num_steps)
if not profile.startswith('linear-constant-linear'):
raise ValueError(profile)
c = float(profile.split('-')[3])
if 'inv' in profile:
c = self.noise.get_t_for_alpha(c)
psi_cfg = self.config.sampling.psi
n_hi = round(psi_cfg.high_frac * num_steps)
n_mid = round(psi_cfg.middle_frac * num_steps)
return torch.cat([
torch.linspace(1, c, n_hi),
torch.full((n_mid,), c),
torch.linspace(c, eps, num_steps - n_hi - n_mid)])
def _mode_to_psi_kappas(self, mode, timesteps):
n = len(timesteps)
if mode == 'pure-posterior':
return torch.ones(n)
if mode == 'pure-pc':
return torch.zeros(n)
eta = float(mode.split('-')[-1])
if (mode.startswith('constant-')
and not mode.startswith('constant-remdm')):
return torch.full((n,), eta)
# Noise-schedule-dependent modes (ReMDM-like)
_, all_alphas = self.noise(timesteps)
alpha_t, alpha_s = all_alphas[:-1], all_alphas[1:]
eta_t = torch.tensor([eta])
if mode.startswith('max-capped-'):
sigma = torch.minimum(
eta_t.expand_as(alpha_t), (1 - alpha_s) / alpha_t)
sigma = torch.where(alpha_t == 0, eta, sigma)
elif mode.startswith('max-rescale-'):
sigma_max = torch.minimum(
eta_t.expand_as(alpha_t), (1 - alpha_s) / alpha_t)
sigma = torch.where(alpha_t > 0, sigma_max, 1) * eta
elif mode.startswith('constant-remdm'):
sigma = eta_t
else:
raise ValueError(mode)
kappas = torch.clip(1 - sigma / (1 - alpha_s), 0, 1)
if len(kappas) > 0:
kappas = torch.cat([kappas, torch.ones(1)])
return kappas
def _get_kappas(self, timesteps):
cfg = self.config.sampling.psi
n = len(timesteps)
n_hi = round(cfg.high_frac * n)
n_mid = round(cfg.middle_frac * n)
kappas = torch.cat([
self._mode_to_psi_kappas(cfg.high_mode, timesteps[:n_hi]),
self._mode_to_psi_kappas(cfg.middle_mode,
timesteps[n_hi:n_hi + n_mid]),
self._mode_to_psi_kappas(cfg.low_mode,
timesteps[n_hi + n_mid:])])
assert (kappas >= 0).all() and (kappas <= 1).all()
assert len(kappas) == n
return kappas
@torch.no_grad()
def generate_samples(self, num_samples, labels=None,
num_steps=None, eps=1e-5):
"""Generate samples from the model."""
# Lightning auto-casting is not working in this method for some reason
if num_steps is None:
num_steps = self.config.sampling.steps
x = self.prior_sample(num_samples, self.num_tokens)
use_psi_sampler = self.config.sampling.predictor == 'psi'
timesteps = self._get_sampling_time_profile(eps,
num_steps)
if use_psi_sampler:
kappas = self._get_kappas(timesteps).to(self.device)
p_x0_cache = None
if labels is not None:
labels = labels.to(self.device)
for i in range(num_steps):
t = timesteps[i] * torch.ones(
x.shape[0], 1, device=self.device)
dt = timesteps[i] - timesteps[i+1]
if self.sampler == 'ancestral':
_, x = self._ancestral_update(
x=x, t=t, labels=labels, dt=dt, p_x0=None)
elif self.sampler == 'ancestral_cache':
p_x0_cache, x_next = self._ancestral_update(
x=x, t=t, labels=labels, dt=dt, p_x0=p_x0_cache)
if (not torch.allclose(x_next, x)
or self.time_conditioning):
# Disable caching
p_x0_cache = None
x = x_next
elif self.sampler == 'psi':
_, x = self._psi_update(x=x, t=t, kappa=kappas[i],
labels=labels, dt=dt,
p_x0=None)
elif self.sampler == 'analytic':
assert labels is None, 'class-conditional sampling ' \
'is not implemented with the analytic sampler'
x = self._analytic_update(x=x,t=t, dt=dt)
else:
raise ValueError(self.sampler)
t0 = timesteps[-1] * torch.ones(x.shape[0], 1,
device=self.device)
if self.config.sampling.noise_removal == 'ancestral':
if self.sampler == 'analytic':
x = self._denoiser_update(x=x, t=t0)
else:
_, x = self._ancestral_update(x=x, t=t0, labels=labels,
dt=None, p_x0=p_x0_cache, noise_removal_step=True)
elif self.config.sampling.noise_removal == 'greedy':
sigma = self._sigma_from_alphat(self.noise(t0)[1])
x = self.forward(xt=x, sigma=sigma).argmax(dim=-1)
return x
@torch.no_grad
def _semi_ar_sampler(
self, n_samples, stride_length, num_strides, dt=0.001):
# TODO(subham): Test this method after refactoring.
ones = torch.ones(n_samples, dtype=self.dtype,
device=self.device)
num_steps = int(1 / dt)
sampling_steps = 0
intermediate_tokens = []
target = None
for _ in range(num_strides + 1):
p_x0_cache = None
x = self.prior_sample(n_samples, self.num_tokens)
if target is not None:
x[:, : -stride_length] = target
for i in range(num_steps + 1):
p_x0_cache, x_next = self._ancestral_update(
x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)
if (not torch.allclose(x_next, x)
or self.time_conditioning):
p_x0_cache = None
sampling_steps += 1
x = x_next
x = self.forward(x, 0 * ones).argmax(dim=-1)
intermediate_tokens.append(
x[:, :stride_length].cpu().numpy())
target = x[:, stride_length:]
intermediate_tokens.append(target.cpu().numpy())
intermediate_text_samples = []
sequence_lengths = ((
np.concatenate(intermediate_tokens, axis=1)[:, 1:]
== self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)
for i in range(2, len(intermediate_tokens) + 1):
intermediate_text_samples.append(
self.tokenizer.batch_decode(
np.concatenate(intermediate_tokens[:i], axis=1)))
return (sampling_steps, intermediate_text_samples,
sequence_lengths)
def restore_model_and_semi_ar_sample(
self, stride_length, num_strides, dt=0.001):
"""Generate samples from the model."""
# Lightning auto-casting is not working in this method for some reason
# TODO(subham): Test this method after refactoring.
self._eval_mode()
(sampling_steps, samples,
sequence_lengths) = self._semi_ar_sampler(
n_samples=self.config.loader.eval_batch_size,
stride_length=stride_length,
num_strides=num_strides,
dt=dt)
self._train_mode()
return sampling_steps, samples, sequence_lengths
class AbsorbingState(Diffusion):
def __init__(self, config, tokenizer):
# NOTE: Ideally, we should do
# vocab_size = len(tokenizer), so that we account
# for the special tokens added in dataloader.py.
# But we use tokenizer.vocab_size so as to to be
# consistent with the prior checkpoints.
vocab_size = tokenizer.vocab_size
if (not hasattr(tokenizer, 'mask_token')
or tokenizer.mask_token is None):
self.mask_index = vocab_size
vocab_size += 1
else:
self.mask_index = tokenizer.mask_token_id
self.subs_masking = config.algo.subs_masking
super().__init__(config, tokenizer,
vocab_size=vocab_size)
self.save_hyperparameters()
def _validate_configuration(self):
super()._validate_configuration()
if self.parameterization in {'score', 'mean'}:
assert self.time_conditioning
assert not (self.parameterization == 'mean'
and self.T == 0)
if self.T > 0:
assert self.parameterization in {'mean', 'subs'}
if self.subs_masking:
assert self.parameterization == 'mean'
def q_xt(self, x, alpha_t):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
alpha_t: float torch.Tensor with shape (batch_size, 1).
"""
move_indices = torch.rand(
* x.shape, device=x.device) < 1 - alpha_t
xt = torch.where(move_indices, self.mask_index, x)
if self.ignore_bos:
xt[:, 0] = x[:, 0]
return xt
def prior_sample(self, *batch_dims):
return self.mask_index * torch.ones(
* batch_dims, dtype=torch.int64, device=self.device)
def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):
"""From the clean x0, or denoiser predictions, implement
the mathematical expression for q(z_s | z_t, x)"""
assert x0.dtype == torch.float64, 'Requires float64 prec.'
# should be one-hot on clean tokens
orig_mask = xt[:, :, None] != self.mask_index
orig_mask = orig_mask.expand(-1, -1, x0.shape[-1])
orig_output_on_clean = x0[orig_mask]
q_xs = ((alpha_s - alpha_t) / (1 - alpha_t))[..., None] * x0
q_xs[..., self.mask_index] = (1 - alpha_s) / (1 - alpha_t)
q_xs[orig_mask] = orig_output_on_clean
return q_xs
def _forward_process(self, x0, alpha_s):
out = alpha_s[..., None] * x0
out[..., self.mask_index] = 1 - alpha_s
return out
def _staggered_score(self, score, dsigma):
score = score.clone()
extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
score *= dsigma.exp()[:, None]
score[..., self.mask_index] += extra_const
return score
def _analytic_update(self, x, t, dt):
sigma_t = self._sigma_from_alphat(self.noise(t)[1])
sigma_s = self._sigma_from_alphat(self.noise(t - dt)[1])
dsigma = sigma_t - sigma_s
score = self._get_score(x, sigma_t)
if self.config.sampling.use_float64:
score = score.to(torch.float64)
stag_score = self._staggered_score(score, dsigma)
probs = stag_score * self._transp_transition(x, dsigma)
return sample_categorical(probs)
def _denoiser_update(self, x, t):
sigma = self._sigma_from_alphat(self.noise(t)[1])
score = self._get_score(x, sigma)
if self.config.sampling.use_float64:
score = score.to(torch.float64)
stag_score = self._staggered_score(score, sigma)
probs = stag_score * self._transp_transition(x, sigma)
probs[..., self.mask_index] = 0
samples = sample_categorical(probs)
return samples
def _transp_transition(self, i, sigma):
sigma = _unsqueeze(sigma, reference=i[..., None])
edge = torch.exp(-sigma) * F.one_hot(
i, num_classes=self.vocab_size)
edge += torch.where(i == self.mask_index,
1 - torch.exp(-sigma).squeeze(-1),
0)[..., None]
return edge
class UniformState(Diffusion):
def _validate_configuration(self):
super()._validate_configuration()
assert self.time_conditioning
assert self.parameterization == 'mean'
if self.config.algo.name != 'distillation':
assert self.T == 0
def _forward_process(self, x0, alpha_s):
return (alpha_s[..., None] * x0
+ (1 - alpha_s[..., None]) / self.vocab_size)
def q_xt(self, x, alpha_t):
"""Computes the noisy sample xt.
Args:
x: int torch.Tensor with shape (batch_size,
diffusion_model_input_length), input.
move_chance: float torch.Tensor with shape
(batch_size, 1).
"""
move_indices = torch.rand(
*x.shape, device=x.device) < 1 - alpha_t
uniform_tensor = torch.randint(
0, self.vocab_size, x.shape, device=x.device)
xt = torch.where(move_indices, uniform_tensor, x)
if self.ignore_bos:
xt[:, 0] = x[:, 0]
return xt
def prior_sample(self, *batch_dims):
return torch.randint(
0, self.vocab_size, batch_dims, dtype=torch.int64,
device=self.device)
================================================
FILE: utils.py
================================================
"""Console logger utilities.
Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py
Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
"""
import argparse
import logging
import os
import pickle
import time
from typing import Optional
import fsspec
import lightning
import numpy as np
import torch
from scipy.integrate import quad
from scipy.optimize import curve_fit
from scipy.stats import norm
from timm.scheduler import CosineLRScheduler
def count_parameters(model):
return sum(p.numel()
for p in model.parameters()
if p.requires_grad)
def fsspec_exists(filename):
"""Check if a file exists using fsspec."""
fs, _ = fsspec.core.url_to_fs(filename)
return fs.exists(filename)
def fsspec_listdir(dirname):
"""Listdir in manner compatible with fsspec."""
fs, _ = fsspec.core.url_to_fs(dirname)
return fs.ls(dirname)
def fsspec_mkdirs(dirname, exist_ok=True):
"""Mkdirs in manner compatible with fsspec."""
fs, _ = fsspec.core.url_to_fs(dirname)
fs.makedirs(dirname, exist_ok=exist_ok)
def print_nans(tensor, name):
if torch.isnan(tensor).any():
print(name, tensor)
class LRHalveScheduler:
def __init__(self, warmup_steps, n_halve_steps):
self.warmup_steps = warmup_steps
self.n_halve_steps = n_halve_steps
def __call__(self, current_step):
if current_step < self.warmup_steps:
return current_step / self.warmup_steps
return 0.5 ** ((current_step - self.warmup_steps)
// self.n_halve_steps)
class CosineDecayWarmupLRScheduler(
CosineLRScheduler,
torch.optim.lr_scheduler._LRScheduler):
"""Wrap timm.scheduler.CosineLRScheduler
Enables calling scheduler.step() without passing in epoch.
Supports resuming as well.
Adapted from:
https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_epoch = -1
self.step(epoch=0)
def step(self, epoch=None):
if epoch is None:
self._last_epoch += 1
else:
self._last_epoch = epoch
# We call either step or step_update, depending on
# whether we're using the scheduler every epoch or every
# step.
# Otherwise, lightning will always call step (i.e.,
# meant for each epoch), and if we set scheduler
# interval to "step", then the learning rate update will
# be wrong.
if self.t_in_epochs:
super().step(epoch=self._last_epoch)
else:
super().step_update(num_updates=self._last_epoch)
class LoggingContext:
"""Context manager for selective logging."""
def __init__(self, logger, level=None, handler=None, close=True):
self.logger = logger
self.level = level
self.handler = handler
self.close = close
def __enter__(self):
if self.level is not None:
self.old_level = self.logger.level
self.logger.setLevel(self.level)
if self.handler:
self.logger.addHandler(self.handler)
def __exit__(self, et, ev, tb):
if self.level is not None:
self.logger.setLevel(self.old_level)
if self.handler:
self.logger.removeHandler(self.handler)
if self.handler and self.close:
self.handler.close()
class GradientInspectionCallback(lightning.Callback):
def __init__(self, num_grads_log):
self.num_grads_log = 10
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
gradients = []
for name, param in pl_module.backbone.blocks.named_parameters():
gradients.append(param.grad.view(-1))
if gradients:
grads = torch.cat((gradients))
if not hasattr(pl_module, 'grad_accum_buffer'):
pl_module.grad_step = torch.tensor(
0, device=pl_module.device)
pl_module.grad_accum_buffer = torch.zeros(
self.num_grads_log,
grads.shape[0],
device=pl_module.device)
pl_module.grad_accum_buffer[pl_module.grad_step] = grads
pl_module.grad_step += 1
if (hasattr(pl_module, 'grad_accum_buffer')
and pl_module.grad_step == self.num_grads_log):
grads = pl_module.grad_accum_buffer
grad_var = grads.std(0).mean()
pl_module.log(name='trainer/grad_var',
value=grad_var.item(),
on_step=True,
on_epoch=False,
sync_dist=True)
# import ipdb; ipdb.set_trace()
# should save the grads tensor as a numpy array
# and visualize mean, median, top-k
pl_module.grad_accum_buffer.zero_()
pl_module.grad_step = 0
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in ('debug', 'info', 'warning', 'error',
'exception', 'fatal', 'critical'):
setattr(logger,
level,
lightning.pytorch.utilities.rank_zero_only(
getattr(logger, level)))
return logger
# Copied from https://github.com/jdeschena/sdtt/blob/bbc54d5b3c5fcffd79602cff17ed34dde1f3eff6/src/sdtt/core/sampling/utils.py#L10
def top_k_top_p_filtering(
logits,
top_k=0,
top_p=0.0,
filter_value=-float("Inf"),
dim=-1):
"""Filter a distribution of logits using top-k/top-p (nucleus) filtering.
Adapted from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
Args:
logits (Tensor): Tensor of logits
top_k (int, optional): Number of top values to keep.
Deactivated if k is 0. Defaults to 0.
top_p (float, optional): Cumulative mass to retain.
Deactivated if p = 0. Defaults to 0.0.
filter_value (float, optional): Fill value to replace
the entries removed by top-k/top-p filtering.
Defaults to -float('Inf').
dim (int, optional): Dimension of the filtering. Defaults to -1.
Returns:
logits: Tensor whose axis `dim` was filtered.
"""
if dim != -1:
logits = torch.transpose(logits, dim, -1)
assert top_k < logits.size(dim)
if top_k > 0:
# Remove all tokens with a probability less than
# the last token of the top-k
values, _ = torch.topk(logits, k=top_k, dim=-1)
to_remove_mask = (
logits < torch.min(values, dim=-1, keepdim=True)[0]
) # min returns a tuple (values, indices)
logits[to_remove_mask] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(
logits, descending=True, dim=-1)
cum_probs = torch.cumsum(
torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
# Ensures at least one token is kept
sorted_indices_to_remove[..., 1:] = \
sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
mask_to_remove = torch.empty_like(
sorted_indices_to_remove)
mask_to_remove.scatter_(dim=-1,
index=sorted_indices,
src=sorted_indices_to_remove)
logits[mask_to_remove] = filter_value
if dim != -1:
logits = torch.transpose(logits, dim, -1)
# Re-normalize as in ReMDM. Alternatively,
# could apply a log-softmax. This assumes that the input
# tensor `logits` has been pre-processed with log_softmax.
probs = logits.exp()
Z = probs.sum(-1, keepdim=True)
logits = (probs / Z).log()
return logits
def _discrete_prob_map(gamma_t, N=10):
snr_sqrt = np.exp(-gamma_t / 2)
def value(x):
cdf = norm.cdf(x, scale=1) ** (N - 1)
pdf = norm.pdf(x, loc=snr_sqrt, scale=1)
return pdf * cdf
return value
def _discrete_prob_grad(gamma_t, N=10):
snr_sqrt = np.exp(-gamma_t / 2)
def value(x):
coef = -0.5 * snr_sqrt * (x - snr_sqrt)
cdf = norm.cdf(x, scale=1) ** (N - 1)
pdf = norm.pdf(x, loc=snr_sqrt, scale=1)
return coef * pdf * cdf
return value
def _cache_prob_usdm_in_partition(
vocab_size=30522, partition_index=0, num_partitions=1,
log10_num_points=5):
print(f'Caching partition:{partition_index} / {num_partitions}')
path = 'integral'
gamma_min = -5
gamma_max = -1
num_points = 10 ** log10_num_points
p_cache = []
grad_p_cache = []
start_time = time.time()
gammas = np.linspace(gamma_min, gamma_max, num_points)
n = num_points // num_partitions
for gamma in gammas[partition_index * n:
(partition_index + 1) * n]:
pt, _ = quad(_discrete_prob_map(gamma, vocab_size),
-np.inf, np.inf)
p_cache.append(pt)
grad_pt, _ = quad(_discrete_prob_grad(gamma, vocab_size),
-np.inf, np.inf)
grad_p_cache.append(grad_pt)
if len(p_cache) % 100 == 0:
print('{}% completed. Time elapsed:{:.2f} mins'.format(
int(100 * len(p_cache) / num_points),
(time.time() - start_time) / 60))
filename = os.path.join(
path, '{}_{}_{}-{}.pkl'.format(
vocab_size, log10_num_points, partition_index,
num_partitions))
with open(filename, 'wb') as f:
pickle.dump({
'vocab_size': vocab_size,
'gamma_min': gamma_min,
'gamma_max': gamma_max,
'num_points': num_points,
'pt': np.asarray(p_cache),
'grad_pt': np.asarray(grad_p_cache)}, f)
def test_cache_prob_usdm_in_partition(
partition_index=0, num_partitions=1, vocab_size=30522,
log10_num_points=5):
path = 'integral/{}_{}_{}-{}.pkl'.format(
vocab_size, log10_num_points, partition_index,
num_partitions)
with open(path, 'rb') as f:
data = pickle.load(f)
num_points = data['num_points']
def _get_index(x):
return round((num_points - 1) * (x - data['gamma_min']) / (
data['gamma_max'] - data['gamma_min']))
pt_errors = []
grad_pt_errors = []
gammas = np.linspace(data['gamma_min'],
data['gamma_max'],
num_points)
n = num_points // num_partitions
for gamma in gammas[partition_index * n:
(partition_index + 1) * n]:
pt, _ = quad(
_discrete_prob_map(gamma, data['vocab_size']),
-np.inf, np.inf)
grad_pt, _ = quad(
_discrete_prob_grad(gamma, data['vocab_size']),
-np.inf, np.inf)
idx = _get_index(gamma)
print(idx)
pt_errors.append((pt - data['pt'][idx]) ** 2)
grad_pt_errors.append((grad_pt - data['grad_pt'][idx]) ** 2)
print('Integral MSE:{} Integral Squared:{:.4f}'.format(
np.mean(pt_errors), np.mean(data['pt'] ** 2)))
print('Integral Grad MSE:{} Integral Grad Squared:{:.4f}'.format(
np.mean(grad_pt_errors), np.mean(data['grad_pt'] ** 2)))
def compute_duo_series_coefficients(num_coefficients,
vocab_size):
def integrand_m(z, n, K):
z = np.float64(z)
return z**n * norm.pdf(z) * norm.cdf(z)**(K-1)
def integrand_i(z, n, K):
z = np.float64(z)
return z ** (n+1) * norm.pdf(z) * norm.cdf(z) ** (K-1)
arange = np.cumprod(np.arange(1, num_coefficients
).astype(np.float64))
factorials = np.concatenate([[1.0], arange],
dtype=np.float64)
lo = np.array([-100], dtype=np.float64)
hi = np.array([100], dtype=np.float64)
coefficients_m = []
coefficients_i = []
for n in range(num_coefficients):
f = lambda z: integrand_m(np.float64(z), np.float64(n),
vocab_size)
g = lambda z: integrand_i(np.float64(z), np.float64(n),
vocab_size)
m, _ = quad(f, lo, hi)
i, _ = quad(g, lo, hi)
coefficients_m.append(m / factorials[n])
coefficients_i.append(i / factorials[n])
return (torch.tensor(coefficients_m)[None],
torch.tensor(coefficients_i)[None])
def compute_duo_gamma_to_alpha_dalpha_series(
gamma_t, coefficients_m, coefficients_i, power_arange,
vocab_size, gamma_min, gamma_max):
gamma_t = gamma_t.to(torch.float64)[:, None]
sigmoid_neg_gamma = torch.sigmoid(-gamma_t)
sigmoid_gamma = torch.sigmoid(gamma_t)
alpha_t_squared = sigmoid_neg_gamma
alpha_t = alpha_t_squared.sqrt()
one_minus_alpha_t_squared = 1 - alpha_t_squared
mu_t = alpha_t / one_minus_alpha_t_squared.sqrt()
arange = power_arange.to(device=gamma_t.device)
mu_t_pow = mu_t ** arange
exp_term = (-mu_t**2/2).exp().squeeze(-1)
vocab_scale = vocab_size / (vocab_size - 1)
# Compute alpha
sum_term_alpha = (mu_t_pow * coefficients_m).sum(-1)
alpha_usdm = (sum_term_alpha * exp_term - 1 \
/ vocab_size) * vocab_scale
# Compute alpha'
sum_term_dalpha = (
mu_t_pow * (coefficients_i - mu_t * coefficients_m)).sum(-1)
dalpha_usdm = exp_term * sum_term_dalpha * vocab_scale
final_scale = - (sigmoid_gamma.squeeze(-1)
* sigmoid_neg_gamma.squeeze(-1) ** 0.5 *
0.5 * (gamma_max - gamma_min))
dalpha_usdm = dalpha_usdm \
/ one_minus_alpha_t_squared.squeeze(-1) ** 1.5 \
* final_scale
return alpha_usdm.squeeze(-1), dalpha_usdm.squeeze(-1)
def duo_t_to_alpha_dalpha_sigm_corrected(
t, a: float, b: float, c: float, d: float, e: float,
f: float, alpha: float):
# Shared quantities
sigm_bc = (torch.tanh(b * t + c) + 1) / 2
sigm_ef = (torch.tanh(e * t + f) + 1) / 2
# Compute alpha_t
base = a * sigm_bc + d
edge_gate = 1 - 4 * sigm_ef * (1 - sigm_ef)
edge_correction = alpha * (t - 0.5) * edge_gate
alpha_t = base + edge_correction
# Compute d_alpha_t
dbase = a * b * sigm_bc * (1 - sigm_bc)
dgate = -4 * e * sigm_ef * (1 - sigm_ef) * (1 - 2 * sigm_ef)
dcorrection = alpha * edge_gate + alpha * (t - 0.5) * dgate
dalpha_t = dbase + dcorrection
return alpha_t, dalpha_t
def duo_to_alpha_dalpha_sigmoid(t: torch.Tensor, a: float,
b: float, c: float, d: float):
sigm_bc = (torch.tanh(b * t + c) + 1) / 2
alpha_t = a * sigm_bc + d
dalpha_t = a * b * sigm_bc * (1 - sigm_bc)
return alpha_t, dalpha_t
def duo_to_alpha_dalpha_poly(t: torch.Tensor,
*coefficients: float):
alpha_t = coefficients[0] # a0 term
for i, a in enumerate(coefficients[1:], 1):
alpha_t = alpha_t + a * t**i
dalpha_t = coefficients[1] # a1 term
for i, a in enumerate(coefficients[2:], 2):
dalpha_t = dalpha_t + i * a * t**(i-1)
return alpha_t, dalpha_t
def compute_duo_operator_approx(num_coefficients, vocab_size,
gamma_min, gamma_max,
fct_name='sigmoid'):
series_m, series_i = compute_duo_series_coefficients(
num_coefficients, vocab_size)
ts = torch.linspace(0, 1, steps=100_000)
gammas = gamma_min + ts * (gamma_max - gamma_min)
power_arange = torch.arange(num_coefficients,
dtype=torch.float64)[None]
alpha_approx = compute_duo_gamma_to_alpha_dalpha_series(
gammas, series_m, series_i, power_arange, vocab_size,
gamma_min, gamma_max)[0].float()
t_np = ts.numpy()
y_np = alpha_approx.numpy()
def sigmoid(x):
return (np.tanh(x) + 1) / 2
if fct_name == 'sigmoid':
def func(t, a, b, c, d):
return a * sigmoid(b * t + c) + d
p0 = [0.5, 2.0, -1.0, 0.5]
elif fct_name == 'sigmoid-edge-corrected':
def func(t, a, b, c, d, e, f, alpha):
base = a * sigmoid(b * t + c) + d
edge_gate = 1 - 4 * sigmoid(e * t + f) \
* sigmoid(-e * t - f)
edge_correction = alpha * (t - 0.5) * edge_gate
return base + edge_correction
p0 = [0.5, 2.0, -1.0, 0.1, 3.0, 0.0, 0.1]
elif fct_name == 'poly3':
def func(t, a0, a1, a2, a3):
return a0 + a1*t + a2*t**2 + a3*t**3
p0 = [0.1] * 4
elif fct_name == 'poly5':
def func(t, a0, a1, a2, a3, a4, a5):
return a0 + a1*t + a2*t**2 + a3*t**3 + a4*t**4 + a5*t**5
p0 = [0.1] * 6
elif fct_name == 'poly7':
def func(t, a0, a1, a2, a3, a4, a5, a6, a7):
return (a0 + a1*t + a2*t**2 + a3*t**3 +
a4*t**4 + a5*t**5 + a6*t**6 + a7*t**7)
p0 = [0.1] * 8
elif fct_name == 'poly9':
def func(t, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
return (a0 + a1*t + a2*t**2 + a3*t**3 + a4*t**4 +
a5*t**5 + a6*t**6 + a7*t**7 + a8*t**8 + a9*t**9)
p0 = [0.1] * 10
else:
raise ValueError(fct_name)
popt, _ = curve_fit(func, t_np, y_np, p0=p0, maxfev=10000)
preds = func(t_np, *popt)
return list(popt), y_np, preds, t_np
def _sample_k_int(bs: int, l: int, k: int, max_value: int,
device: torch.device):
# Robert Floyd's algorithm:
# https://www.nowherenearithaca.com/2013/05/robert-floyds-tiny-and-beautiful.html
out = torch.empty(size=(bs, l, k), dtype=torch.int64,
device=device)
for t, i in enumerate(range(max_value - k, max_value)):
j = torch.randint(0, i + 1, size=(bs, l), device=device)
if t > 0:
# Does j already appear in previously chosen slots?
dup = (out[..., :t] == j[..., None]).any(dim=-1)
# write j where it is new, otherwise write i
out[..., t] = torch.where(dup, i, j)
else:
out[..., 0] = j
return out
def _sample_topk_gaussian(N: int,
sigma: Optional[torch.Tensor]=None, l: int=0, k: int=0,
batch: int=None, device: str=None,
dtype: torch.dtype=torch.float64):
"""
Sample from the order statistic of N Gaussians with zero
mean (top k). Operate in logspace for stability.
"""
if sigma is None:
assert batch is not None
assert device is not None
assert dtype is not None
else:
batch = sigma.shape[0]
device = sigma.device
dtype = sigma.dtype
log_u = torch.log(torch.rand(batch, l, k, device=device,
dtype=dtype))
divisors = torch.arange(N, N - k, -1, device=device,
dtype=dtype) # (k,)
log_rj = log_u / divisors # (batch, l, k)
log_prod = torch.cumsum(log_rj, dim=-1) # (batch, l, k)
uniforms = torch.exp(log_prod) # (batch, l, k)
# convert to Gaussian and rescale
topk = torch.special.ndtri(uniforms)
if sigma is not None:
topk = topk * sigma[:, None, None]
return topk
def _sample_topk_and_extra(N: int, alpha: torch.Tensor,
sigma: torch.Tensor, l: int, k: int):
"""
Sample the top k order statistics between N - 1 zero mean
Gaussians, and a single Gaussian with mean alpha.
"""
top_k_others = _sample_topk_gaussian(N - 1, sigma, l, k)
extra = alpha[:, None] + torch.randn(
size=(alpha.shape[0], l), device=alpha.device
) * sigma[:, None] # (bs, l)
min_values = top_k_others[:, :, -1]
is_extra_in_topk = (extra > min_values) # bs x l
top_k_others[:, :, -1][is_extra_in_topk] = extra[is_extra_in_topk]
return extra, top_k_others, is_extra_in_topk
def _log_mean_exp_trunc_normal(c: torch.Tensor,
sigma: torch.Tensor):
"""
Compute log(E[exp(X) | X < c] for X ~ N(0, sigma^2).
Closed-form expression:
mu = exp(sigma**2 / 2)
* Phi((c - sigma**2) / sigma)
/ Phi(c / sigma)
where Phi is the standard normal CDF. Operate in log space
for stability.
"""
log_num = torch.special.log_ndtr((c - sigma**2) / sigma)
log_den = torch.special.log_ndtr(c / sigma)
return sigma**2 / 2.0 + log_num - log_den
def sample_tempered_softmax_topk(
extra_index: torch.Tensor,
alpha: torch.Tensor,
sigma: torch.Tensor,
l: int,
k: int,
vocab_size: int,
# 1 / T. If low temperature, inverse will be large
inverse_temperature: float = 1.0):
assert alpha.ndim == 1
assert sigma.ndim == 1
# float64 needed for numerical precision
alpha = alpha.to(torch.float64)
sigma = sigma.to(torch.float64)
# Sample the top k between (vocab_size - 1) zero-mean
# Gaussians, and a single Gaussian with mean alpha.
extra, top_k, is_extra_in_topk = _sample_topk_and_extra(
vocab_size, alpha, sigma, l, k)
min_rv = torch.where(is_extra_in_topk, top_k[:, :, -2],
top_k[:, :, -1]) # (bs, l)
scaled_sigma = sigma[:, None] * inverse_temperature # (bs, 1)
scaled_c = min_rv * inverse_temperature # (bs, l)
log_mu = _log_mean_exp_trunc_normal(scaled_c, scaled_sigma)
log_topk = top_k * inverse_temperature
# Approximate contribution of the unknown variables
count = torch.where(is_extra_in_topk, vocab_size - k,
vocab_size - k - 1).to(log_mu.dtype)
log_tail = torch.log(count) + log_mu
# Contribution of the extra variable, when NOT in the topk
log_extra = extra * inverse_temperature
extra_not_in_topk = ~is_extra_in_topk
log_extra_masked = torch.full_like(log_tail, float('-inf'))
log_extra_masked[extra_not_in_topk] = \
log_extra[extra_not_in_topk]
log_contribs = torch.cat([
log_topk,
log_tail[..., None],
log_extra_masked[..., None]],
dim=-1) # (bs, l, k+2)
log_denom = torch.logsumexp(log_contribs, dim=-1,
keepdim=True) # (bs, l, 1)
softmax_approx = torch.exp(log_topk - log_denom) # (bs, l, k)
# If sum over k categories is zero, just set to one-hot on
# the largest. Fix div by zero
normalizer = softmax_approx.sum(dim=-1, keepdim=True)
zero_sum = normalizer == 0.0
softmax_approx = torch.where(zero_sum, 0.0, softmax_approx)
softmax_approx[..., 0][zero_sum[..., 0]] = 1.0
indices = _sample_k_int(alpha.shape[0], l, k,
# Note the -1:
max_value=vocab_size - 1,
device=alpha.device)
# Ensure x0 (true token) is not generated
indices[indices >= extra_index[..., None]] += 1
indices[..., -1][is_extra_in_topk] = \
extra_index[is_extra_in_topk]
xt_usdm = torch.where(is_extra_in_topk, extra_index,
indices[..., 0])
return softmax_approx, indices, xt_usdm
if __name__ == "__main__":
# Usage: python utils.py --vocab_size=N
parser = argparse.ArgumentParser(
description='Caches the integral appearing in the '
'Diffusion Transformation operator.')
parser.add_argument(
'--vocab_size',
type=int,
default=50257, # For the gpt2 tokenizer
help='Vocabulary size (default: 50257)')
parser.add_argument(
'--partition_index',
type=int,
default=0,
help='Helps parallelize caching')
parser.add_argument(
'--num_partitions',
type=int,
default=1,
help='Helps parallelize caching')
parser.add_argument(
'--log10_num_points',
type=int,
default=5,
help=('The integral is function that needs to be '
'evaluated for inputs with a range [-5, 1]. '
'This argument represents the logarithm base 10 '
'of number of bins of discretization.'))
args = parser.parse_args()
# Computing the integral over [-5, 1] can be slow,
# so one might prefer splitting it into `num_partitions`
# bins and compute each separately and merge them later.
_cache_prob_usdm_in_partition(
partition_index=args.partition_index,
num_partitions=args.num_partitions,
vocab_size=args.vocab_size,
log10_num_points=args.log10_num_points)
test_cache_prob_usdm_in_partition(
partition_index=args.partition_index,
num_partitions=args.num_partitions,
vocab_size=args.vocab_size,
log10_num_points=args.log10_num_points)