Full Code of s-sahoo/duo for AI

main 492505208b36 cached
110 files
270.1 KB
76.6k tokens
348 symbols
1 requests
Download .txt
Showing preview only (296K chars total). Download the full file or copy to clipboard to get everything.
Repository: s-sahoo/duo
Branch: main
Commit: 492505208b36
Files: 110
Total size: 270.1 KB

Directory structure:
gitextract_9im7g5qx/

├── .gitignore
├── LICENSE
├── README.md
├── algo.py
├── configs/
│   ├── algo/
│   │   ├── ar.yaml
│   │   ├── d3pm.yaml
│   │   ├── distillation.yaml
│   │   ├── duo.yaml
│   │   ├── duo_base.yaml
│   │   ├── mdlm.yaml
│   │   ├── ot-finetune.yaml
│   │   └── sedd.yaml
│   ├── callbacks/
│   │   ├── checkpoint_every_n_steps.yaml
│   │   ├── checkpoint_monitor.yaml
│   │   ├── grad_record.yaml
│   │   └── learning_rate_monitor.yaml
│   ├── config.yaml
│   ├── data/
│   │   ├── ag_news.yaml
│   │   ├── cifar10.yaml
│   │   ├── fineweb-edu.yaml
│   │   ├── lambada.yaml
│   │   ├── lm1b-gpt2.yaml
│   │   ├── lm1b-streaming.yaml
│   │   ├── lm1b-wrap.yaml
│   │   ├── lm1b.yaml
│   │   ├── openwebtext-split.yaml
│   │   ├── openwebtext-streaming.yaml
│   │   ├── openwebtext.yaml
│   │   ├── ptb.yaml
│   │   ├── scientific_papers_arxiv.yaml
│   │   ├── scientific_papers_pubmed.yaml
│   │   ├── synthetic.yaml
│   │   ├── text8-crop.yaml
│   │   ├── text8.yaml
│   │   ├── wikitext103.yaml
│   │   └── wikitext2.yaml
│   ├── lr_scheduler/
│   │   ├── constant_warmup.yaml
│   │   ├── cosine_decay_warmup.yaml
│   │   └── step_scheduler.yaml
│   ├── model/
│   │   ├── medium.yaml
│   │   ├── small.yaml
│   │   ├── tiny-dimamba.yaml
│   │   ├── tiny.yaml
│   │   └── unet.yaml
│   ├── noise/
│   │   ├── cosine.yaml
│   │   └── log-linear.yaml
│   ├── prior/
│   │   └── none.yaml
│   └── strategy/
│       ├── ddp.yaml
│       └── fsdp.yaml
├── dataloader.py
├── discrete_diffusion_harness.py
├── integral/
│   ├── bert-base-uncased.pkl
│   └── gpt2.pkl
├── main.py
├── metrics.py
├── models/
│   ├── __init__.py
│   ├── dit.py
│   ├── ema.py
│   ├── unet.py
│   └── unit_test_attention.py
├── requirements.txt
├── scripts/
│   ├── distil_owt.sh
│   ├── eval_lm1b_duo.sh
│   ├── eval_owt_ar.sh
│   ├── eval_owt_duo.sh
│   ├── eval_owt_mdlm.sh
│   ├── eval_owt_sedd.sh
│   ├── fid_cifar10_duo_ancestral_cosine.sh
│   ├── fid_cifar10_duo_base_ancestral_cosine.sh
│   ├── fid_cifar10_mdlm_ancestral_cosine.sh
│   ├── gen_ppl_lm1b_ar.sh
│   ├── gen_ppl_lm1b_duo.sh
│   ├── gen_ppl_owt_ar.sh
│   ├── gen_ppl_owt_duo.sh
│   ├── gen_ppl_owt_mdlm.sh
│   ├── gen_ppl_owt_sedd.sh
│   ├── psi_samplers/
│   │   ├── cifar10/
│   │   │   ├── duo_constant_remdm.sh
│   │   │   ├── duo_max_capped_remdm.sh
│   │   │   ├── duo_max_rescale_eta.sh
│   │   │   ├── duo_psi_pc.sh
│   │   │   ├── mdlm_constant_remdm.sh
│   │   │   ├── mdlm_max_capped_remdm.sh
│   │   │   ├── mdlm_max_rescale_eta.sh
│   │   │   └── mdlm_psi_pc.sh
│   │   └── owt/
│   │       ├── duo_loop_remdm.sh
│   │       ├── duo_max_capped_remdm.sh
│   │       ├── duo_max_rescale_eta.sh
│   │       ├── mdlm_loop_remdm.sh
│   │       ├── mdlm_max_capped_remdm.sh
│   │       └── mdlm_max_rescale_eta.sh
│   ├── train_cifar10_duo_base_cosine.sh
│   ├── train_cifar10_duo_cosine.sh
│   ├── train_cifar10_mdlm_cosine.sh
│   ├── train_lm1b_ar.sh
│   ├── train_lm1b_ar_sentencepacking.sh
│   ├── train_lm1b_d3pm.sh
│   ├── train_lm1b_duo.sh
│   ├── train_lm1b_duo_sentencepacking.sh
│   ├── train_lm1b_mdlm.sh
│   ├── train_lm1b_mdlm_sentencepacking.sh
│   ├── train_owt_duo.sh
│   ├── train_owt_duo_finetune.sh
│   ├── train_owt_mdlm.sh
│   ├── train_owt_sedd.sh
│   ├── zero_shot_ar.sh
│   ├── zero_shot_duo.sh
│   ├── zero_shot_mdlm.sh
│   └── zero_shot_sedd.sh
├── trainer_base.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.DS_Store

.hf_cache
test/
outputs/
wandb/
watch_folder/
notes.md
grid_search.sh
*.ipynb

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright 2025 Subham Sekhar Sahoo

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">

# The Diffusion Duality Series

</div>

## [Chapter I (ICML 2025)](https://arxiv.org/abs/2506.10892)

By [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Justin Deschenaux](https://jdeschena.com), [Aaron Gokaslan](https://skylion007.github.io),
[Guanghan Wang](https://tech.cornell.edu/people/guanghan-wang/), [Justin Chiu](https://justinchiu.netlify.app), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)

[![GitHub](https://img.shields.io/badge/GitHub-Repo-181717?logo=github&logoColor=white)](https://github.com/s-sahoo/duo/tree/ch-1)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Sf7R-dqdR6gq-H8nyZ9E3ZkyvqMTqcwq?usp=sharing)
[![YouTube](https://img.shields.io/badge/YouTube-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/FCO-nnqHOqQ?si=4eGnj5zbRgyCYWwI)
[![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](http://s-sahoo.github.io/duo)
[![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2506.10892)
[![deploy](https://img.shields.io/badge/🤗-Huggingface-blue)](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875)

**Unlocks few-step generation in discrete diffusion-LLMs via the underlying Gaussian diffusion.**

<div align="center">
  <img src="https://github.com/s-sahoo/duo/blob/gh-pages/static/images/duo_schematic.png" width="60%">
</div>

## [Chapter II: Ψ-Samplers and Efficient Curriculum (ICLR 2026)](https://arxiv.org/abs/2602.21185)
By  [Justin Deschenaux](https://jdeschena.com), [Caglar Gulcehre](https://www.caglar.ai),
[Subham Sekhar Sahoo](https://s-sahoo.github.io)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1uFSzrfG0KXhGcohRIfWIM2Y7V9Q7cQNA?usp=sharing)
[![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](http://s-sahoo.github.io/duo-ch2)
[![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2602.21185)
<!-- [![deploy](https://img.shields.io/badge/🤗-Huggingface-blue)](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875) -->

**Uniform-state beats Masked diffusion on text and image generation!**
<div align="center">
  <img src="https://github.com/s-sahoo/duo-ch2/blob/gh-pages/static/images/psi-samplers-figure.png" width="90%">
</div>

This repository contains the code for the two papers in the Diffusion Duality series. It includes:
- **Duo / $\text{Duo}^\text{++}$** sampling (ancestral, ReMDM, $\Psi$-samplers, greedy-tail) — [Sampling & Eval](#sampling--eval)
- Original and efficient curriculum training strategies — [Training](#training)
- Discrete Consistency Distillation (DCD) — [Distillation](#discrete-consistency-distillation)
- Baselines (AR, MDLM, SEDD, D3PM) — [Baselines](#baselines)

[Getting Started](#getting-started) | [Checkpoints](#checkpoints) | [Citation](#acknowledgements--citation)

# Getting Started
<a name="getting_started"></a>

To get started, create a conda environment containing the required dependencies.

```bash
conda create -n duo python=3.12
conda activate duo
conda install nvidia/label/cuda-12.4.0::cuda-toolkit
pip install -r requirements.txt
pip install flash_attn==2.7.4.post1
```

# Checkpoints
<a name="checkpoints"></a>

* **Duo** (Language Modeling): Trained on OpenWebText for `1M` training steps (distilled / base):
  * [Huggingface](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875)🤗.
  * [Google Drive folder](https://drive.google.com/drive/folders/1JpqFM8XRvifwIkjWPfMyuDvu41r1yk0t?usp=share_link) as the HF checkpoints can't be finetuned.
* **Duo** (Image Modeling): Trained on CIFAR-10
  * [Huggingface (contains the raw checkpoints)](https://huggingface.co/jdeschena/duo2-cifar10)
* **Baselines** (SEDD, MDLM, AR): Trained on OpenWebText
  * [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing) — download `ar.ckpt`, `mdlm.ckpt`, `sedd.ckpt`.

# Training
<a name="training"></a>

This repo implements the original Duo curriculum, as well as the fast $\text{Duo}^\text{++}$ curriculum. By default, the training scripts use the original curriculum. To enable the efficient curriculum, simply replace `algo.curriculum.mode=simple` by `algo.curriculum.mode=poly9` (see comments in each training script).

To train $\text{Duo}^\text{++}$, use the following scripts:
* LM1B
  * w/ sentencepacking (same as in D3PM)
    * Training script: [`scripts/train_lm1b_duo_sentencepacking.sh`](./scripts/train_lm1b_duo_sentencepacking.sh)
    * [Wandb run](https://api.wandb.ai/links/kuleshov-group/huwt0ek3) 
  * w/o sentencepacking (same as in MDLM, SEDD)
    * Training script: [`scripts/train_lm1b_duo.sh`](./scripts/train_lm1b_duo.sh)
    * [Wandb run](https://api.wandb.ai/links/sahoo-diffusion/lkv5z3tm)
    
* OWT: [`scripts/train_owt_duo.sh`](./scripts/train_owt_duo.sh).
* CIFAR-10:
  * Duo: [`scripts/train_cifar10_duo_cosine.sh`](./scripts/train_cifar10_duo_cosine.sh)
  * MDLM: [`scripts/train_cifar10_mdlm_cosine.sh`](./scripts/train_cifar10_mdlm_cosine.sh)
  * Both scripts default to a cosine noise schedule. To use log-linear instead, set `noise=log-linear`.

**Notes:**
* Run `mkdir watch_folder` to create a directory to store slurm logs,
  and then run any script in [`scripts/`](scripts) as a slurm job: `sbatch scripts/ABC_XYZ.sh`
* Control the batch size per GPU using the argument `loader.batch_size`. If `loader.batch_size * num_gpus < loader.global_batch_size`, PyTorch Lightning resorts to gradient accumulation. 

# Discrete Consistency Distillation
<a name="distillation"></a>

To distill a model using the Discrete Consistency Distillation (`Alg. 1` in the [Duo paper](https://arxiv.org/abs/2506.10892)), use [`scripts/distil_owt.sh`](scripts/distil_owt.sh).


# Sampling & Eval
<a name="sampling"></a>

## Likelihood
To compute test perplexity on the validation set of OWT use [`scripts/eval_owt_duo.sh`](scripts/eval_owt_duo.sh) and for zero shot perplexities use [`scripts/zero_shot_duo.sh`](scripts/zero_shot_duo.sh).

## Sampling
You can sample with ancestral sampling using the scripts in [`scripts/gen_ppl_*.sh`](scripts/). To sample with the PC samplers such as ReMDM and our $\Psi$-samplers, use the scripts in [`scripts/psi_samplers`](scripts/psi_samplers). This directory contains examples for sampling text and images.

To use the "Greedy-tail sampler" (equivalent to nucleus sampling in AR models; see `Sec. 4.2` in the paper), set `sampling.noise_removal=greedy`. Using the default `sampling.noise_removal=ancestral` will produce more diverse samples (higher entropy) but with worse generative perplexity.

To sample from a HuggingFace checkpoint (text only), run the following command:
```bash
python main.py \
  mode=sample_eval \
  loader.batch_size=2 \
  loader.eval_batch_size=8 \
  data=openwebtext-split \
  algo=duo_base \
  algo.backbone=hf_dit \
  eval.checkpoint_path=s-sahoo/duo-distilled \
  sampling.steps=8 \
  sampling.num_sample_batches=1 \
  sampling.noise_removal=greedy \
  +wandb.offline=true 
```

To use the example scripts with raw checkpoints (see [Checkpoints](#checkpoints)), download them and set the checkpoint path in the script.

# Baselines
<a name="baselines"></a>
Download the baseline checkpoints (see [Checkpoints](#checkpoints)) and specify the paths appropriately in the respective shell scripts:
* [`scripts/eval_owt_*.sh`](scripts/) for computing validation perplexity on OWT.
* [`scripts/gen_ppl_*.sh`](scripts/) for generating text samples and evaluating them.
* [`scripts/zero_shot_*.sh`](scripts/) for computing zero shot perplexities.
* [`scripts/train_*.sh`](scripts/) for training the models.

# Acknowledgements & Citation
This repository was built off of [MDLM's Github repository](https://github.com/kuleshov-group/mdlm). Cite our papers using:
```
@inproceedings{
    sahoo2025the,
    title={The Diffusion Duality},
    author={Subham Sekhar Sahoo and Justin Deschenaux and Aaron Gokaslan and Guanghan Wang and Justin T Chiu and Volodymyr Kuleshov},
    booktitle={Forty-second International Conference on Machine Learning},
    year={2025},
    url={https://openreview.net/forum?id=9P9Y8FOSOk}
}

@inproceedings{
    deschenaux2026the,
    title={The Diffusion Duality, Chapter {II}: \${\textbackslash}Psi\$-Samplers and Efficient Curriculum},
    author={Justin Deschenaux and Caglar Gulcehre and Subham Sekhar Sahoo},
    booktitle={The Fourteenth International Conference on Learning Representations},
    year={2026},
    url={https://openreview.net/forum?id=RSIoYWIzaP}
}
```


================================================
FILE: algo.py
================================================
import os
import collections
import copy
import pickle
from typing import Optional

import fsspec
import numpy as np
import torch
import torch.nn.functional as F

import trainer_base
import utils


class AR(trainer_base.TrainerBase):
  def __init__(self, config, tokenizer):
    vocab_size = tokenizer.vocab_size
    if (not hasattr(tokenizer, 'mask_token')
        or tokenizer.mask_token is None):
      self.mask_index = vocab_size
      vocab_size += 1
    else:
      self.mask_index = tokenizer.mask_token_id
    super().__init__(config, tokenizer,
                     vocab_size=vocab_size)
    self.save_hyperparameters()
    self._validate_configuration()

  def _validate_configuration(self):
    super()._validate_configuration()
    assert not self.config.algo.time_conditioning
    assert self.config.prior.type == 'none'

  def _process_model_input(self, x0, valid_tokens):
    input_tokens = x0[:, :-1]
    output_tokens = x0[:, 1:]
    valid_tokens = valid_tokens[:, 1:]
    return input_tokens, output_tokens, valid_tokens

  def nll(self, input_tokens, output_tokens,
          current_accumulation_step):
    del current_accumulation_step
    output = self.backbone(input_tokens, None)
    output[:, :, self.mask_index] = self.neg_infinity
    output = output.log_softmax(-1)
    return - output.gather(
      -1, output_tokens[:, :, None])[:, :, 0]

  def generate_samples(self, num_samples, **kwargs):
    # precompute token buffer
    num_pred_tokens = self.num_tokens - 1
    x = torch.zeros(
      (num_samples, num_pred_tokens + 1),
      dtype=torch.long,
      device=self.device)
    x[:, 0] = self.tokenizer.bos_token_id
    # precompute noise
    noise = (torch.distributions.Gumbel(0, 1)
             .sample((num_samples, num_pred_tokens, self.vocab_size))
             .to(self.device))
    if self.config.sampling.use_float64:
      noise = noise.to(torch.float64)
    for i in range(num_pred_tokens):
      output = self.backbone(x[:, :i + 1], None)
      output[:, :, self.mask_index] = self.neg_infinity
      output = output.log_softmax(-1)
      y = (output[:, -1, :] + noise[:, i, :]).argmax(-1)
      x[:, i + 1] = y
    return x

  def _process_sigma(self, sigma):
    del sigma
    return None


class MDLM(trainer_base.AbsorbingState):
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self._validate_configuration()

  def _validate_configuration(self):
    assert self.sampler != 'ancestral', \
      'sampling.predictor=ancestral is not desirable because ' \
      'it is slow. Please set sampling.predictor=ancestral_cache'

  def _process_model_output(self, model_output, xt, sigma):
    del sigma
    model_output[:, :, self.mask_index] += self.neg_infinity
    
    # Normalize the model_output such that x.exp() is
    # a probability distribution over vocab_size.
    model_output = model_output - torch.logsumexp(
      model_output, dim=-1, keepdim=True)
    # Apply updates directly in the logits matrix.
    # For the logits of the unmasked tokens, set all values
    # to -infinity except for the indices corresponding to
    # the unmasked tokens.
    unmasked_indices = (xt != self.mask_index)
    model_output[unmasked_indices] = self.neg_infinity
    model_output[unmasked_indices, xt[unmasked_indices]] = 0
    return model_output

  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
                    dalpha_t, low_var=False):
    del xt
    log_p_theta = torch.gather(
      input=log_x_theta,
      dim=-1,
      index=x0[:, :, None]).squeeze(-1)
    return log_p_theta * dalpha_t / (1 - alpha_t)

  def _get_score(self, x, sigma):
    model_output = self.forward(x, sigma)
    # score(x, t) = p_t(y) / p_t(x)
    # => log score(x, t) = log p_t(y) - log p_t(x)
    
    # case 1: x = masked
    #   (i) y = unmasked
    #     log score(x, t) = log p_\theta(x)|_y + log k
    #     where k = exp(- sigma) / (1 - exp(- sigma))
    #   (ii) y = masked
    #     log score(x, t) = 0

    # case 2: x = unmasked
    #   (i) y != masked, y != x
    #     log score(x_i, t) = - inf
    #   (ii) y = x 
    #     log score(x_i, t) = 0
    #   (iii) y = masked token
    #     log score(x_i, t) = - log k
    #     where k = exp(- sigma) / (1 - exp(- sigma))
    
    log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
    assert log_k.ndim == 1
    
    masked_score = model_output + log_k[:, None, None]
    masked_score[:, :, self.mask_index] = 0

    unmasked_score = self.neg_infinity * torch.ones_like(
      model_output)
    unmasked_score = torch.scatter(
      unmasked_score,
      -1,
      x[..., None],
      torch.zeros_like(unmasked_score[..., :1]))
    unmasked_score[:, :, self.mask_index] = - (
      log_k[:, None] * torch.ones_like(x))
    
    masked_indices = (x == self.mask_index).to(
      model_output.dtype)[:, :, None]
    model_output = (
      masked_score * masked_indices
      + unmasked_score * (1 - masked_indices))
    return model_output.exp()


class D3PMAbsorb(trainer_base.AbsorbingState):
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self._validate_configuration()

  def _validate_configuration(self):
    super()._validate_configuration()
    assert self.noise.type == 'log-linear'
    assert self.parameterization == 'mean'

  def _process_model_output(self, model_output, xt, sigma):
    del xt 
    del sigma
    if self.subs_masking:
      model_output[:, :, self.mask_index] += self.neg_infinity
    return model_output.log_softmax(dim=-1)

  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
                    dalpha_t, low_var=False):
    del dalpha_t
    assert not low_var
    dt = 1 / self.T
    t = 1 - alpha_t  # Only valid for log-linear schedule.
    t = t.clamp(0., 1.0 - 1e-4)
    alpha_t = alpha_t + torch.zeros_like(xt)
    alpha_s = t - dt + torch.zeros_like(xt)
    assert alpha_s.shape == xt.shape
    assert alpha_t.shape == xt.shape
    log_x_theta_at_x0 = torch.gather(
      log_x_theta, -1, x0[:, :, None]).squeeze(-1)
    log_x_theta_at_m = log_x_theta[:, :, self.mask_index]
    x_theta_at_m = log_x_theta_at_m.exp()
    
    term_1_coef = dt / t
    term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
    term_1_log_dr = log_x_theta_at_x0
    
    term_2_coef = 1 - dt / t
    term_2_log_nr = term_1_log_nr
    term_2_log_dr = torch.log(
      alpha_s * x_theta_at_m / (t - dt) + 1)
    L_vb_masked = (
      term_1_coef * (term_1_log_nr - term_1_log_dr)
      + term_2_coef * (term_2_log_nr - term_2_log_dr))

    diffusion_loss = self.T * L_vb_masked * (xt == self.mask_index)
    return self._reconstruction_loss(x0) + diffusion_loss


class SEDDAbsorb(trainer_base.AbsorbingState):
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self._validate_configuration()

  def _validate_configuration(self):
    super()._validate_configuration()
    assert self.config.sampling.predictor == 'analytic'

  def _get_score(self, x, sigma):
    return self.forward(x, sigma).exp()

  def _process_model_output(self, model_output, xt, sigma):
    esigm1_log = torch.where(
      sigma < 0.5,
      torch.expm1(sigma),
      sigma.exp() - 1).log().to(model_output.dtype)
    # logits shape
    # (batch_size, context_length, vocab_size)
    model_output = (model_output
                    - esigm1_log[:, None, None]
                    - np.log(model_output.shape[-1] - 1))
    # The below scatter operation sets the log score
    # for the input word to 0.
    model_output = torch.scatter(
      model_output, -1, xt[..., None],
      torch.zeros_like(model_output[..., :1]))
    return model_output

  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
                    dalpha_t, low_var=False):
    """Computes the SEDD loss for the Absorbing State Diffusion.

    Args:
      log_x_theta: float torch.Tensor with shape (batch_size,
          context_length, vocab_size),
          log score, output of the denoising network.
      xt: int torch.Tensor with shape (batch_size,
          context_length), input.
      x0: int torch.Tensor with shape (batch_size,
          context_length), input.
      alpha_t: float torch.Tensor with shape (batch_size, 1),
          signal level.
      alpha_t: float torch.Tensor with shape (batch_size, 1),
          signal level.
      dalpha_t: float or float torch.Tensor with shape (batch_size, 1),
          time derivative of signal level.
      low_var: bool, low variance loss during training.
    
    Returns:
      loss with shape (batch_size, context_length).
    """
    assert not low_var
    masked_indices = xt == self.mask_index
    sigma = self._sigma_from_alphat(alpha_t)
    dsigma = - dalpha_t / alpha_t

    expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
    q_ratio = 1 / expsig_minus_1[masked_indices]

    words_that_were_masked = x0[masked_indices]

    neg_term = q_ratio * torch.gather(
      log_x_theta[masked_indices],
      -1,
      words_that_were_masked[..., None]).squeeze(-1)
    score = log_x_theta[masked_indices].exp()
    if self.mask_index == self.vocab_size - 1:
      pos_term = score[:, :-1].sum(dim=-1)
    else:
      pos_term = score[:, : self.mask_index].sum(
        dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
    const = q_ratio * (q_ratio.log() - 1)

    entropy = torch.zeros(* xt.shape, device=xt.device)
    entropy[masked_indices] += pos_term - neg_term + const
    return dsigma * entropy


class DUO_BASE(trainer_base.UniformState):
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self._validate_configuration()

  def on_save_checkpoint(self, checkpoint):
    checkpoint['state_dict'] = collections.OrderedDict(
      (k, v) for k, v in checkpoint['state_dict'].items()
      if not k.startswith('teacher'))
    super().on_save_checkpoint(checkpoint)

  def on_load_checkpoint(self, checkpoint):
    checkpoint['state_dict'] = collections.OrderedDict(
      (k, v) for k, v in checkpoint['state_dict'].items()
      if not k.startswith('teacher'))
    super().on_load_checkpoint(checkpoint)

  def _process_model_output(self, model_output, xt, sigma):
    del xt, sigma
    return model_output.log_softmax(dim=-1)

  def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):
    """Computes the posterior / approximate posterior.

    Args:
      x: Either clean input `x0` (one-hot),
        or model's predicted `x_theta` of shape (B, L, V).
      xt: The noisy latent (as indices) of shape (B, L).
      alpha_s: Noise level at s of shape (B, [L | 1], 1).
      alpha_t: Noise level at t of shape (B, [L | 1], 1).

    Returns:
      Posterior / approximate posterior of shape (B, L, V).
    """
    if self.config.sampling.use_float64:
      x0 = x0.to(torch.float64)
    if alpha_s.ndim == 2:
      alpha_s = alpha_s.unsqueeze(-1)
    if alpha_t.ndim == 2:
      alpha_t = alpha_t.unsqueeze(-1)
    alpha_ts = alpha_t / alpha_s
    d_alpha = alpha_s - alpha_t
    xt_one_hot = F.one_hot(xt, self.vocab_size).to(
      self.dtype).to(self.device)
    return (
      (alpha_t * self.vocab_size * x0 * xt_one_hot + (
        alpha_ts - alpha_t) * xt_one_hot + d_alpha * x0 + (
          1 - alpha_ts) * (1 - alpha_s) / self.vocab_size) / (
            alpha_t * self.vocab_size * torch.gather(
              x0, -1, xt[..., None]) + (1 - alpha_t)))

  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
                    dalpha_t, low_var=False):
    assert alpha_t.ndim == 2
    assert x0.ndim == 2
    assert xt.ndim == 2
    assert not torch.is_tensor(dalpha_t) or dalpha_t.ndim == 2
    x_reconst = log_x_theta.exp()
    x_bar_theta = self.vocab_size * alpha_t[
        :, :, None] * x_reconst + 1 - alpha_t[:, :, None]
    coeff = dalpha_t / (self.vocab_size * alpha_t)
    x_eq_xt = (x0 == xt).float()
    x_neq_xt = 1 - x_eq_xt
    xbar_xt = (1 - alpha_t) + self.vocab_size * alpha_t * x_eq_xt
    xbar_theta_xt = torch.gather(
      x_bar_theta, -1, xt.unsqueeze(-1)).squeeze(-1)
    xbar_theta_x = torch.gather(
      x_bar_theta, -1, x0.unsqueeze(-1)).squeeze(-1)
    term1 = self.vocab_size * (1 / xbar_xt
                                - 1 / xbar_theta_xt)
    
    const = (1 - alpha_t) / (self.vocab_size * alpha_t
                             + 1 - alpha_t)
    term2_coefs = x_eq_xt * const + x_neq_xt
    term2_offset = ((self.vocab_size - 1) * const * x_eq_xt
                    - (1 / const) * x_neq_xt) * const.log()
    term2_theta = - term2_coefs * (
      x_bar_theta.log().sum(-1)
      - self.vocab_size * xbar_theta_xt.log())
    term2_theta = (
      term2_theta
      - self.vocab_size * alpha_t / (1 - alpha_t) * (
        xbar_theta_x.log() - xbar_theta_xt.log()) * x_neq_xt)
    term2 = term2_theta + term2_offset
    diffusion_loss = coeff * (term1 - term2)
    assert diffusion_loss.ndim == 2
    return diffusion_loss


class Integral(torch.autograd.Function):
  """
  torch module calculating UDLM's p_t 
  """

  @staticmethod
  def forward(ctx, gamma_t, data):
    gamma_max = data['gamma_max']
    gamma_min = data['gamma_min']
    if (gamma_t.max() > gamma_max) or (
      gamma_t.min() < gamma_min):
      print('max:{} {}'.format(gamma_t.max(), gamma_max))
      print('min:{} {}'.format(gamma_t.min(), gamma_min))
      gamma_t = torch.clip(gamma_t, gamma_min, gamma_max)
    indices = torch.round(
      (data['num_points'] - 1) * (gamma_t - gamma_min) / (
          gamma_max - gamma_min)).long()
    grad_pt = data['grad_pt']
    ctx.grad_pt = grad_pt[indices]
    
    pt = data['pt'][indices]
    assert pt.shape == gamma_t.shape
    return pt

  @staticmethod
  def backward(ctx, grad_output):
    return ctx.grad_pt * grad_output, None


class DUO(DUO_BASE):
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self.gamma_min = self.config.algo.curriculum.gamma_min
    self.gamma_max = self.config.algo.curriculum.gamma_max
    self.gumbel_tau_log10_start = \
          self.config.algo.curriculum.gumbel_tau_log10_start
    self.gumbel_tau_log10_end = \
            self.config.algo.curriculum.gumbel_tau_log10_end
    self.curriculum_start = self.config.algo.curriculum.start
    self.curriculum_end = self.config.algo.curriculum.end
    self.loss_type = self.config.algo.loss_type
    self._initialize_curriculum_coefficients()
    self._validate_configuration()
    
  def _initialize_curriculum_coefficients(self):
    if self.config.algo.curriculum.mode in {'simple', 
      'efficient_cached'}:
      self._init_curriculum_cached()
    elif self.config.algo.curriculum.mode == 'series':
      self._init_curriculum_series()
    elif self.config.algo.curriculum.mode in {'sigmoid',
      'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7',
      'poly9'}:
      self._init_curriculum_approx()
    else:
      raise ValueError(self.config.algo.curriculum.mode)

  def _init_curriculum_cached(self):
    fpath = self.config.algo.curriculum.integral_cache_path
    with fsspec.open(fpath, 'rb') as f:
      self.integral_cache = pickle.load(f)
    self.integral_cache['pt'] = torch.from_numpy(
      self.integral_cache['pt'])
    self.integral_cache['grad_pt'] = torch.from_numpy(
      self.integral_cache['grad_pt'])

  def _init_curriculum_series(self):
    m, i = utils.compute_duo_series_coefficients(
      self.config.algo.curriculum.n_series_terms, 
      self.vocab_size)

    self.register_buffer('coefficients_m', m, 
                         persistent=False)
    self.register_buffer('coefficients_i', i,
                         persistent=False)
    self.register_buffer('power_arange',
      torch.arange(self.config.algo.curriculum.n_series_terms,
        dtype=torch.float64)[None], persistent=False)

  def _init_curriculum_approx(self):
    fname = f'{self.config.algo.curriculum.mode}.npy'
    fpath = os.path.join(self.config.algo.curriculum.cache_dir, 
                         fname)
    if not os.path.exists(fpath):
      # Compute the coefficients on the fly
      coefficients, _, _, _ = utils.compute_duo_operator_approx(
        num_coefficients=self.config.algo.curriculum.n_series_terms,
        vocab_size=self.vocab_size,
        gamma_min=self.gamma_min,
        gamma_max=self.gamma_max,
        fct_name=self.config.algo.curriculum.mode)
      # Tuples are for torch compile, tuples are immutable
      coefficients = tuple(coefficients)
      parent_dir = os.path.dirname(fpath)
      os.makedirs(parent_dir, exist_ok=True)
      np.save(fpath, coefficients)
    else:
      coefficients = tuple(np.load(fpath).tolist())
    mode = self.config.algo.curriculum.mode
    if mode == 'sigmoid':
      fn = utils.duo_to_alpha_dalpha_sigmoid
    elif mode == 'sigmoid-edge-corrected':
      fn = utils.duo_t_to_alpha_dalpha_sigm_corrected
    elif mode in ('poly3', 'poly5', 'poly7', 'poly9'):
      fn = utils.duo_to_alpha_dalpha_poly
    else:
      raise ValueError(mode)
    fn = torch.compile(fn)
    self._t_to_alpha_dalpha_compiled = \
      lambda t: fn(t, *coefficients)

  def to(self, *args, **kwargs):
    self = super().to(*args, **kwargs)
    self.integral_cache['pt'] = self.integral_cache[
      'pt'].to(*args, **kwargs)
    self.integral_cache['grad_pt'] = self.integral_cache[
      'grad_pt'].to(*args, **kwargs)
    return self

  def cuda(self, device=None):
    self = super().cuda(device=device)
    if hasattr(self, 'integral_cache'):
      self.integral_cache['pt'] = self.integral_cache[
        'pt'].cuda(device=device)
      self.integral_cache['grad_pt'] = self.integral_cache[
        'grad_pt'].cuda(device=device)
    return self

  def cpu(self):
    self = super().cpu()
    if hasattr(self, 'integral_cache'):
      self.integral_cache['pt'] = self.integral_cache[
        'pt'].cpu()
      self.integral_cache['grad_pt'] = self.integral_cache[
        'grad_pt'].cpu()
    return self

  def to(self, *args, **kwargs):
    self = super().to(*args, **kwargs)
    if hasattr(self, 'integral_cache'):
      self.integral_cache['pt'] = self.integral_cache[
        'pt'].to(*args, **kwargs)
      self.integral_cache['grad_pt'] = self.integral_cache[
        'grad_pt'].to(*args, **kwargs)
    return self

  def _compute_gumbel_tau_inverse(self):
    start = self.gumbel_tau_log10_start
    end = self.gumbel_tau_log10_end
    delta = end - start
    if self.global_step < self.curriculum_start:
      tau = start
    elif self.global_step < self.curriculum_end:
      frac = (self.global_step - self.curriculum_start) / (
        self.curriculum_end - self.curriculum_start)
      tau = start + frac * delta
    else:
      tau = -10
    return 10 ** (-tau)

  def training_step(self, batch, batch_idx):
    self.log(name='gumbel_tau_log10',
             value=1 / self._compute_gumbel_tau_inverse(),
             on_step=True,
             on_epoch=False,
             sync_dist=True)
    return super().training_step(batch, batch_idx)

  def _gamma_to_alpha_dalpha(self, gamma_t, t):
    if self.config.algo.curriculum.mode in ('simple', 
      'efficient_cached'):
      return self._gamma_to_alpha_dalpha_cached(gamma_t)
    elif self.config.algo.curriculum.mode == 'series':
      return utils.compute_duo_gamma_to_alpha_dalpha_series(
        gamma_t, self.coefficients_m, self.coefficients_i,
        self.power_arange, self.vocab_size, self.gamma_min,
        self.gamma_max)
    elif self.config.algo.curriculum.mode in ('sigmoid',
      'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7',
      'poly9'):
      return self._t_to_alpha_dalpha_compiled(t)
    else:
      raise ValueError(self.config.algo.curriculum.mode)

  def _gamma_to_alphat_integral(self, gamma_t):
    integral = Integral.apply(gamma_t, self.integral_cache)
    return (self.vocab_size * integral - 1) / (
      self.vocab_size - 1)

  def _gamma_to_alpha_dalpha_cached(self, gamma_t):
    gamma_t_prime = self.gamma_max - self.gamma_min
    usdm_alpha_t = DUO._gamma_to_alphat_integral(self, gamma_t)
    T = 1000
    usdm_dalpha_t = gamma_t_prime * T * (
      DUO._gamma_to_alphat_integral(self, gamma_t + 1 / T) 
      - usdm_alpha_t)
    return usdm_alpha_t, usdm_dalpha_t

  def _prior_loss(self):
    alpha_1 = self._gamma_to_alphat_integral(
      torch.tensor(self.gamma_max))
    loss = ((alpha_1 + (1 - alpha_1) / self.vocab_size) \
           * torch.log((self.vocab_size - 1) * alpha_1 + 1) \
           + (1 - 1 / self.vocab_size) * (1 - alpha_1) \
           * torch.log(1 - alpha_1))
    return loss.item()

  def _q_xt_gaussian(self, x, gamma_t):
    """Computes the noisy sample xt."""
    assert gamma_t.ndim == 1
    assert x.ndim == 3
    gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1)
    alpha_t = torch.sigmoid(-gamma_t).sqrt()
    sigma_t = torch.sigmoid(gamma_t).sqrt()
    epsilon = torch.randn(x.shape, dtype=torch.float32,
                          device=self.device)
    return alpha_t * x + sigma_t * epsilon

  def nll(self, x0, output_tokens,
          current_accumulation_step=None, train_mode=False):
    use_true_nll = (self.global_step > self.curriculum_end
                    or not train_mode)
    if use_true_nll:
      return super().nll(x0, output_tokens,
                         current_accumulation_step)
    del output_tokens
    t = self._sample_t(x0.shape[0], current_accumulation_step)
    gamma_t = self.gamma_min + t * (self.gamma_max
                                    - self.gamma_min)
    usdm_alpha_t, usdm_dalpha_t = \
      self._gamma_to_alpha_dalpha(gamma_t, t)

    usdm_alpha_t = usdm_alpha_t.unsqueeze(-1)
    assert usdm_alpha_t.ndim == 2
    usdm_dalpha_t = usdm_dalpha_t.unsqueeze(-1)
    sigma = self._sigma_from_alphat(usdm_alpha_t)
    # Default Duo curriculum
    if self.config.algo.curriculum.mode == 'simple':
      x0_one_hot = F.one_hot(x0, self.vocab_size)
      xt = self._q_xt_gaussian(x0_one_hot, gamma_t)
      xt = xt * self._compute_gumbel_tau_inverse()
      xt_usdm = xt.argmax(-1)
      log_x_theta = self.forward(xt, sigma=sigma)
    else:  # Efficient variant
      softmax_approx, topk_indices, xt_usdm = \
        utils.sample_tempered_softmax_topk(
        extra_index=x0,
        alpha=torch.sigmoid(-gamma_t).sqrt(),
        sigma=torch.sigmoid(gamma_t).sqrt(),
        l=x0.shape[1],
        k=self.config.algo.curriculum.top_k,
        vocab_size=self.vocab_size,
        inverse_temperature=self._compute_gumbel_tau_inverse())
      log_x_theta = self.forward(topk_indices, sigma=sigma, 
                                 weights=softmax_approx)

    return self.nll_per_token(log_x_theta=log_x_theta,
                              xt=xt_usdm,
                              x0=x0,
                              alpha_t=usdm_alpha_t,
                              dalpha_t=usdm_dalpha_t,
                              low_var=False)


class Distillation(DUO):
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self.update_teacher_every = config.algo.update_teacher_every
    self.save_hyperparameters()
    self.teacher = None
    self.teacher_ema = config.algo.teacher_ema
    self.linear_growth_dt = config.algo.linear_growth_dt
    self.linear_growth_min = config.algo.linear_growth_min
    self.linear_growth_max = config.algo.linear_growth_max

  def _validate_configuration(self):
    assert os.path.exists(
      self.config.algo.integral_cache_path), (
        'The integral cache (Eq. 10 in the paper) for '
        f'the {self.config.data.tokenizer_name_or_path} '
        ' tokenizer doesnt exist at '
        f'{self.config.algo.integral_cache_path}. '
        'Please generate it by running the utils.py script, '
        'and ensure the correct path is specified using the '
        'algo.integral_cache_path flag.')
    assert self.loss_type in {
      'kl-fwd', 'kl-bwd', 'posterior', 'kl-posterior'}

  def _maybe_update_teacher_weights(self):
    if self.global_step % self.update_teacher_every != 0:
      return
    if self.teacher_ema:
      self.ema.copy_to(self.teacher.parameters())
    else:
      for better_param, current_param in zip(
        self.backbone.parameters(), self.teacher.parameters()):
        if current_param.requires_grad:
          current_param.data.copy_(better_param.data)

  @torch.no_grad()
  def _teacher_logits(self, xt, sigma):
    if self.teacher is None:
      self.teacher = copy.deepcopy(self.backbone)
    self._maybe_update_teacher_weights()

    sigma = self._process_sigma(sigma)
    with torch.amp.autocast('cuda', dtype=torch.float32):
      model_output = self.teacher(xt, sigma)
    logits = self._process_model_output(
      model_output=model_output, xt=xt, sigma=sigma)
    return logits.detach()

  def _sample_trajectory(self, x0, gamma_t, gamma_s):
    """Computes the noisy sample xt."""
    assert gamma_t.ndim == 1
    assert gamma_s.ndim == 1
    assert x0.ndim == 2
    x0 = F.one_hot(x0, self.vocab_size).to(
      self.dtype).to(self.device)
    gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1)
    alpha_t = torch.sigmoid(-gamma_t).sqrt()
    sigma_t = torch.sigmoid(gamma_t).sqrt()

    gamma_s = gamma_s.unsqueeze(-1).unsqueeze(-1)
    alpha_s = torch.sigmoid(-gamma_s).sqrt()
    sigma_s = torch.sigmoid(gamma_s).sqrt()
    
    epsilon = torch.randn(x0.shape, dtype=torch.float32,
                          device=self.device)
    xt = alpha_t * x0 + sigma_t * epsilon
    xs = alpha_s * x0 + sigma_s * epsilon
    return xt, xs

  def _compute_dt(self):
    if self.linear_growth_dt:
      scale = self.global_step / self.trainer.max_steps
      return self.linear_growth_min + scale * (
        self.linear_growth_max -  self.linear_growth_min)
    n = self.global_step // self.update_teacher_every
    return 2 ** n / self.T

  def nll(self, x0, output_tokens,
          current_accumulation_step=None, train_mode=None):
    del output_tokens, train_mode
    t = self._sample_t(x0.shape[0], current_accumulation_step)
    dt = self._compute_dt()
    t = torch.clip(t + dt, 0, 1)

    gamma_t = self.gamma_min + t * (self.gamma_max
                                    - self.gamma_min)
    gamma_s = self.gamma_min + (
      t - dt) * (self.gamma_max - self.gamma_min)

    usdm_alpha_t = self._gamma_to_alphat_integral(gamma_t)
    usdm_alpha_t = usdm_alpha_t.unsqueeze(-1)
    assert usdm_alpha_t.ndim == 2
    usdm_alpha_s = self._gamma_to_alphat_integral(gamma_s)
    usdm_alpha_s = usdm_alpha_s.unsqueeze(-1)
    assert usdm_alpha_s.ndim == 2

    xt, xs = self._sample_trajectory(x0, gamma_t, gamma_s)
    xt_discrete = xt.argmax(-1)
    xs_discrete = xs.argmax(-1)
    log_x_theta_student = self.forward(
      xt_discrete, sigma=self._sigma_from_alphat(usdm_alpha_t))
    log_x_theta_teacher = self._teacher_logits(
      xs_discrete, sigma=self._sigma_from_alphat(usdm_alpha_s))
    if self.config.training.loss_precision == 'float64':
      log_x_theta_student = log_x_theta_student.to(torch.float64)
      log_x_theta_teacher = log_x_theta_teacher.to(torch.float64)
    if self.loss_type == 'kl-fwd':
      return (log_x_theta_teacher.exp() * (
        log_x_theta_teacher - log_x_theta_student)).sum(-1)
    elif self.loss_type == 'kl-bwd':
      return (log_x_theta_student.exp() * (
        log_x_theta_student - log_x_theta_teacher)).sum(-1)
    
  def training_step(self, batch, batch_idx):
    self.log(name='dt',
             value=self._compute_dt(),
             on_step=True,
             on_epoch=False,
             sync_dist=True)
    return super().training_step(batch, batch_idx)


================================================
FILE: configs/algo/ar.yaml
================================================
name: ar
backbone: dit
parameterization: ar
time_conditioning: False
causal_attention: True
# Irrelevant flags
T: 0
subs_masking: False
ignore_bos: False


================================================
FILE: configs/algo/d3pm.yaml
================================================
name: d3pm
backbone: dit  # dit / dimamba
parameterization: mean
time_conditioning: True
T: 1000 
subs_masking: False  # True / False
causal_attention: False
ignore_bos: False
loss_type: elbo  # elbo, low_var


================================================
FILE: configs/algo/distillation.yaml
================================================
name: distillation
backbone: dit  # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
subs_masking: False
causal_attention: False
gumbel_tau_log10_start: -1
gumbel_tau_log10_end: -1
curriculum_start: -1
curriculum_end: -1

integral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl
loss_type: kl-bwd  # kl-fwd, kl-bwd, posterior
update_teacher_every: 10_000
ignore_bos: False
T: 64
gamma_min: -4
gamma_max: -1
teacher_ema: False
posterior_loss_weight: 0.0
linear_growth_dt: False
linear_growth_min: 0.001
linear_growth_max: 0.25

================================================
FILE: configs/algo/duo.yaml
================================================
name: duo
backbone: dit  # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
T: 0  # 0 (continuous time) / 1000 
subs_masking: False
causal_attention: False
loss_type: elbo
ignore_bos: False

# Curriculum arguments
curriculum:
  # Simple is the original duo curriculum, poly9 is the one we use in duo2 for speed
  mode: simple  # simple / efficient_cached / series / sigmoid / sigmoid-edge-corrected / poly3 / poly5 / poly7 / poly9
  # Arguments for the simple & efficient variant
  gumbel_tau_log10_start: -1.0
  gumbel_tau_log10_end: -2.0
  start: 100_000
  end: 200_000
  gamma_min: -3.5
  gamma_max: -1.75
  # Argument for the simple and efficient_cached variants
  integral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl
  # Arguments for the efficient variant only
  top_k: -1
  cache_dir: './transform_approx_coefficients'
  n_series_terms: 150


================================================
FILE: configs/algo/duo_base.yaml
================================================
name: duo_base
backbone: dit  # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
T: 0  # 0 (continuous time) / 1000 
subs_masking: False
causal_attention: False
ignore_bos: False
loss_type: elbo  # elbo, low_var


================================================
FILE: configs/algo/mdlm.yaml
================================================
name: mdlm
backbone: dit  # dit / dimamba / hf_dit
parameterization: subs
time_conditioning: False
T: 0  # 0 (continuous time) / 1000 
subs_masking: False
causal_attention: False
ignore_bos: False
loss_type: elbo


================================================
FILE: configs/algo/ot-finetune.yaml
================================================
name: ot-finetune
backbone: dit  # dit / dimamba / hf_dit
parameterization: mean
time_conditioning: True
T: 0  # 0 (continuous time) / 1000 
subs_masking: False
causal_attention: False
ignore_bos: False
delta_ts: 0.5

================================================
FILE: configs/algo/sedd.yaml
================================================
name: sedd
backbone: dit  # dit / dimamba
parameterization: score 
time_conditioning: True
T: 0  # 0 (continuous time) / 1000 
subs_masking: False
causal_attention: False
ignore_bos: False
loss_type: elbo  # elbo, low_var


================================================
FILE: configs/callbacks/checkpoint_every_n_steps.yaml
================================================
checkpoint_every_n_steps:
  _target_: lightning.pytorch.callbacks.ModelCheckpoint
  save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps
  save_last: True # save model as ${save_dir}/checkpoints/last.ckpt
  dirpath: ${checkpointing.save_dir}/checkpoints
  verbose: True
  auto_insert_metric_name: False
  every_n_train_steps: 500


================================================
FILE: configs/callbacks/checkpoint_monitor.yaml
================================================
checkpoint_monitor:
  _target_: lightning.pytorch.callbacks.ModelCheckpoint
  monitor: val/nll # name of the logged metric which determines when model is improving
  mode: min # can be "max" or "min"
  save_top_k: 1 # save k best models (determined by above metric)
  save_last: False # True = additionally always save model from last epoch
  dirpath: ${checkpointing.save_dir}/checkpoints
  filename: best
  auto_insert_metric_name: False
  verbose: True


================================================
FILE: configs/callbacks/grad_record.yaml
================================================
grad_record:
  _target_: utils.GradientInspectionCallback
  num_grads_log: 4

================================================
FILE: configs/callbacks/learning_rate_monitor.yaml
================================================
learning_rate_monitor:
  _target_: lightning.pytorch.callbacks.LearningRateMonitor
  logging_interval: step


================================================
FILE: configs/config.yaml
================================================
defaults:
  - _self_
  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]
  - /data: openwebtext
  - /model: small
  - /strategy: ddp
  - /noise: log-linear
  - /lr_scheduler: constant_warmup
  - /prior: none
  - /algo: duo_base

mode: train  # train / ppl_eval / sample_eval

seed: 1

loader:
  global_batch_size: 512
  eval_global_batch_size: ${.global_batch_size}
  # Note: batch_size and eval_batch_size are **per machine**
  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}
  num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"}
  pin_memory: True

sampling:
  predictor: ancestral  # ancestral_cache (only for MDLM), ancestral, analytic, psi (for psi-samplers)
  steps: 1000
  noise_removal: ancestral  # 'ancestral', 'greedy', 'none'
  use_float64: True
  p_nucleus: 1.0
  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches
  num_sample_log: 2
  semi_ar: False
  stride_length: 1
  num_strides: 1
  guid_weight: null  # null -> no classifier-free guidance 
  psi:
    time_profile: linear  # linear / linear-constant-linear-alpha / linear-constant-linear-alpha-inv
    # Note: kappa = 1 -> pure posterior, kappa = 0 -> pure PC
    # modes: pure-posterior / pure-pc / constant-eta / constant-remdm-eta / max-capped-eta / max-rescale-eta
    high_mode: pure-posterior
    middle_mode: pure-posterior
    low_mode: pure-posterior
    # Fraction of [0, 1] spent in the high / middle modes (the rest being in low mode)
    # Example: high_frac=0.2, middle_frac=0.6. Then,
    #  - For t in [1.0, 0.8] -> use high_mode
    #  - For t in [0.8, 0.2] -> use middle_mode
    #  - For t in [0.2, 0.0] -> use low_mode
    high_frac: 0.2
    middle_frac: 0.6

training:
  ema: 0.9999
  antithetic_sampling: True
  importance_sampling: False
  sampling_eps: 1e-3
  change_of_variables: False
  loss_precision: 'bf16'  # bf16, float32, float64
  finetune_path: ''
  class_dropout_p: 0.1  # only used with class-conditional datasets (eg cifar10)

eval:
  checkpoint_path: ''  # Used to evaluate a checkpoint after training.
  disable_ema: False
  compute_generative_perplexity: False
  perplexity_batch_size: 8
  compute_perplexity_on_sanity: False
  gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf
  generate_samples: True
  generated_samples_path: ${cwd:}/samples.json

optim:
  weight_decay: 0
  lr: 3e-4
  beta1: 0.9
  beta2: 0.999
  eps: 1e-8

trainer:
  _target_: lightning.Trainer
  accelerator: cuda
  num_nodes: 1
  devices: ${device_count:}
  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}
  gradient_clip_val: 1.0
  precision: 'bf16'
  num_sanity_val_steps: 2
  max_steps: 1_000_000
  log_every_n_steps: 100
  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run
  limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run
  val_check_interval: 5000
  check_val_every_n_epoch: null  # Disable end-of-epoch validation. Instead, validate every trainer.val_check_interval steps.

wandb:
  project: duo
  group: null
  job_type: null
  name: null
  id: ${.name}_${seed}
  tags:
    - ${noise.type}
    - ${data.train}
    - ${data.valid}
    - ${algo.name}

hydra:
  run:
    dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}
  job:
    chdir: true

checkpointing:
  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is
  save_dir: ${cwd:}
  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`
  resume_from_ckpt: true
  resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt


================================================
FILE: configs/data/ag_news.yaml
================================================
train: ag_news
valid: ag_news
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/cifar10.yaml
================================================
train: cifar10
valid: cifar10
modality: image
tokenizer_name_or_path: cifar10
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
num_classes: 10  # 10 / null
streaming: False
size: 1024
length: 3072
wrap: False
insert_train_eos: False
insert_valid_eos: False
insert_train_spacial: False

================================================
FILE: configs/data/fineweb-edu.yaml
================================================
train: HuggingFaceFW/fineweb-edu
valid: openwebtext-valid  #wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: True
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/lambada.yaml
================================================
train: lambada
valid: lambada
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/lm1b-gpt2.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/lm1b-streaming.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: bert-base-uncased
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: False
streaming: True
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/lm1b-wrap.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: bert-base-uncased
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/lm1b.yaml
================================================
train: lm1b
valid: lm1b
modality: text
tokenizer_name_or_path: bert-base-uncased
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: False
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/openwebtext-split.yaml
================================================
train: openwebtext-train
valid: openwebtext-valid
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/openwebtext-streaming.yaml
================================================
train: openwebtext
valid: wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /tmp/data
wrap: True
streaming: True
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/openwebtext.yaml
================================================
train: openwebtext
valid: wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/ptb.yaml
================================================
train: ptb
valid: ptb
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/scientific_papers_arxiv.yaml
================================================
train: scientific_papers_arxiv
valid: scientific_papers_arxiv
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/scientific_papers_pubmed.yaml
================================================
train: scientific_papers_pubmed
valid: scientific_papers_pubmed
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/synthetic.yaml
================================================
train: synthetic
valid: synthetic
modality: text
tokenizer_name_or_path: synthetic
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: True
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/text8-crop.yaml
================================================
# TODO: When using this dataset, set model.length = 256 to match D3PM setup
train: text8-crop
valid: text8
modality: text
tokenizer_name_or_path: text8
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/text8.yaml
================================================
# TODO: When using this dataset, set model.length = 256 to match D3PM setup
train: text8
valid: text8
modality: text
tokenizer_name_or_path: text8
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/wikitext103.yaml
================================================
train: wikitext103
valid: wikitext103
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/data/wikitext2.yaml
================================================
train: wikitext2
valid: wikitext2
modality: text
tokenizer_name_or_path: gpt2
cache_dir: /share/kuleshov/ssahoo/textdiffusion/data
wrap: True
streaming: False
insert_train_eos: True
insert_valid_eos: True

================================================
FILE: configs/lr_scheduler/constant_warmup.yaml
================================================
_target_: transformers.get_constant_schedule_with_warmup
num_warmup_steps: 2500

================================================
FILE: configs/lr_scheduler/cosine_decay_warmup.yaml
================================================
_target_: utils.CosineDecayWarmupLRScheduler
t_in_epochs: False
t_initial: ${eval:${trainer.max_steps}-${.warmup_t}}
warmup_prefix: True
warmup_lr_init: 1e-6
warmup_t: ${eval:0.1*${trainer.max_steps}}
lr_min: 1e-6

================================================
FILE: configs/lr_scheduler/step_scheduler.yaml
================================================
_target_: torch.optim.lr_scheduler.LambdaLR
lr_lambda:
  _target_: utils.LRHalveScheduler
  warmup_steps: 500
  n_halve_steps: 10_000

================================================
FILE: configs/model/medium.yaml
================================================
name: medium
type: ddit
hidden_size: 1024
cond_dim: 128
length: 1024
n_blocks: 24
n_heads: 16
scale_by_sigma: True
dropout: 0.1
tie_word_embeddings: False
vocab_lookup: True

================================================
FILE: configs/model/small.yaml
================================================
name: small
type: ddit
hidden_size: 768
cond_dim: 128
length: 1024
n_blocks: 12
n_heads: 12
scale_by_sigma: True
dropout: 0.1
tie_word_embeddings: False
vocab_lookup: True

================================================
FILE: configs/model/tiny-dimamba.yaml
================================================
name: tiny
type: dimamba
hidden_size: 512
cond_dim: 128
length: 1024
n_blocks: 14
n_heads: 8
scale_by_sigma: True
dropout: 0.1
temb_strategy: adaln 
tie_word_embeddings: False

================================================
FILE: configs/model/tiny.yaml
================================================
name: tiny
type: ddit
hidden_size: 256
cond_dim: 128
length: 1024
n_blocks: 8
n_heads: 8
scale_by_sigma: True
dropout: 0.1
tie_word_embeddings: False
vocab_lookup: True

================================================
FILE: configs/model/unet.yaml
================================================
name: unet
type: unet
ch: 128 
num_res_blocks: 2
num_scales: 4
ch_mult: [1, 2, 2, 2]
input_channels: 3
output_channels: -1 # determined by vocab_size
scale_count_to_put_attn: 1 # at 16 res
data_min_max: [0, 255] # No need currently
dropout: 0.1
skip_rescale: True
time_conditioning: True # Whether to add in time embeddings
time_scale_factor: 1000 
time_embed_dim: ${.ch}
fix_logistic: False
size: ${data.size}
cond_dim: ${.ch}
length: ${data.length}

================================================
FILE: configs/noise/cosine.yaml
================================================
type: cosine
eps: 1e-3

================================================
FILE: configs/noise/log-linear.yaml
================================================
type: log-linear
eps: 1e-3

================================================
FILE: configs/prior/none.yaml
================================================
type: none
latent_width: 0
latent_height: 0

================================================
FILE: configs/strategy/ddp.yaml
================================================
_target_: lightning.pytorch.strategies.DDPStrategy
find_unused_parameters: false


================================================
FILE: configs/strategy/fsdp.yaml
================================================
# TODO(yair): Currenly not compatible with grad clipping
_target_: lightning.pytorch.strategies.FSDPStrategy
sharding_strategy: SHARD_GRAD_OP


================================================
FILE: dataloader.py
================================================
import functools
import itertools
import json
import math
import os
import re
import shutil
import typing
import urllib
import zipfile
from typing import Optional

import datasets
import einops
import fsspec
import numpy as np
import requests
import tokenizers
import torch
import torchvision
from torchvision import transforms as th_transforms

import transformers

import utils

LOGGER = utils.get_logger(__name__)



class RawPixelsVisionTokenizer:
  def __init__(self, vocab_size, image_size,
               add_mask_token=True, add_special_tokens=True):
    self.pad_token_id = None
    self.pad_token = None
    if add_mask_token:
      self.mask_token = vocab_size
      self.mask_token_id = vocab_size
      self.vocab_size = vocab_size + 1  # mask token
    else:
      self.vocab_size = vocab_size
    if add_special_tokens:
      self.bos_token_id = vocab_size
      self.bos_token = vocab_size
      self.eos_token_id = vocab_size + 1
      self.eos_token = vocab_size + 1
      # mask token, bos_token, eos_token
      self.vocab_size = self.vocab_size + 2  
    else:
      self.vocab_size = self.vocab_size
    self.image_size = image_size

  def __call__(self, x):
    return x

  def batch_decode(self, x):
    x = einops.rearrange(x, 'b (c h w) -> b c h w', c=3,
                         h=self.image_size)
    x = x.to(dtype=torch.uint8)
    return x

  def decode(self, x):
    x = einops.rearrange(x, '(c h w) -> h w c', c=3,
                         h=self.image_size)
    x = x.to(dtype=torch.uint8)
    return x
  
  def __len__(self):
    return self.vocab_size
  

class DiscreteCIFAR10(torch.utils.data.Dataset):
  def __init__(self, cache_dir, train):
    self._dataset = torchvision.datasets.CIFAR10(
      root=cache_dir, train=train, download=True)

    transforms = []
    if train:
      transforms += [th_transforms.RandomHorizontalFlip()]
      
    transforms += [th_transforms.Lambda(
          lambda x: torch.from_numpy(np.array(x))),
                  th_transforms.Lambda(
          lambda x: einops.rearrange(x, "h w c -> (c h w)")),]

    self.transform = th_transforms.Compose(transforms)

  def __len__(self):
    return len(self._dataset)

  def __getitem__(self, index):
    img, labels = self._dataset[index]

    img = self.transform(img)
    attention_mask = torch.ones_like(img)
    return {'input_ids': img.to(torch.long), 'labels': labels,
            'attention_mask': attention_mask}


def wt_detokenizer(string):
  # contractions
  string = string.replace("s '", "s'")
  string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
  # number separators
  string = string.replace(" @-@ ", "-")
  string = string.replace(" @,@ ", ",")
  string = string.replace(" @.@ ", ".")
  # punctuation
  string = string.replace(" : ", ": ")
  string = string.replace(" ; ", "; ")
  string = string.replace(" . ", ". ")
  string = string.replace(" ! ", "! ")
  string = string.replace(" ? ", "? ")
  string = string.replace(" , ", ", ")
  # double brackets
  string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
  string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
  string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
  string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
  string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
  # miscellaneous
  string = string.replace("= = = =", "====")
  string = string.replace("= = =", "===")
  string = string.replace("= =", "==")
  string = string.replace(" " + chr(176) + " ", chr(176))
  string = string.replace(" \n", "\n")
  string = string.replace("\n ", "\n")
  string = string.replace(" N ", " 1 ")
  string = string.replace(" 's", "'s")
  return string

def ptb_detokenizer(x):
  x = x.replace(" 's", "'s")
  x = x.replace("s ' ", "s' ")
  x = x.replace(" n't", "n't")
  x = x.replace(" \n ", "\n")
  x = x.replace("\\/", "/")
  for _ in range(10):
      x = x.replace(" N ", " 1 ")
  x = x.replace("$ 1", "$1")
  x = x.replace("# 1", "#1")
  x = x.replace("<unk>", "?")
  return x


def lm1b_detokenizer(x):
  x = x.replace('http : / / ', 'http://')
  x = x.replace('https : / / ', 'https://')
  x = re.sub(r' \'(\w+)', r"'\1", x)
  x = re.sub(r' (\w+) \. ', r' \1. ', x)
  x = re.sub(r' (\w+) \.$', r' \1.', x)
  x = x.replace(' ? ', '? ')
  x = re.sub(r' \?$', '?', x)
  x = x.replace(' ! ', '! ')
  x = re.sub(r' \!$', '!', x)
  x = x.replace(' , ', ', ')
  x = x.replace(' : ', ': ')
  x = x.replace(' ; ', '; ')
  x = x.replace(' / ', '/')
  x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x)
  x = re.sub(r'\' ([^\']+) \'', r"'\1'", x)
  x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x)
  x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x)
  x = x.replace('$ ', '$')
  x = x.replace('£ ', '£')
  return x


def lambada_detokenizer(text):
  text = text.replace("“", '"')
  text = text.replace("”", '"')
  return '\n'+text.strip()


def scientific_papers_detokenizer(x):
  x = wt_detokenizer(x)
  x = lm1b_detokenizer(x)
  return x


class SyntheticTokenizer(
  transformers.PreTrainedTokenizer):
  
  def __init__(
    self,
    vocab_size,
    bos_token="[BOS]",
    eos_token="[EOS]",
    sep_token=None,
    cls_token=None,
    pad_token=None,
    mask_token=None,
    unk_token=None,
    **kwargs):
    
    self.tokens = []
    
    for i in range (vocab_size - 2):
      # appending space for readability
      self.tokens.append(str(i) + " ")
    
    self._vocab_str_to_int = {
      '[BOS]': vocab_size - 2,
      '[EOS]': vocab_size - 1,
      ** {ch: i for i, ch in enumerate(self.tokens)}}
    
    self._vocab_int_to_str = {
      v: k for k, v in self._vocab_str_to_int.items()}
    
    super().__init__(
      bos_token=bos_token,
      eos_token=eos_token,
      sep_token=sep_token,
      cls_token=cls_token,
      pad_token=pad_token,
      mask_token=mask_token,
      unk_token=unk_token,
      **kwargs)

  @property
  def vocab_size(self) -> int:
    return len(self._vocab_str_to_int)

  def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
    return list(text.lower())

  def _convert_token_to_id(self, token: str) -> int:
    return self._vocab_str_to_int.get(
      token, self._vocab_str_to_int['[UNK]'])

  def _convert_id_to_token(self, index: int) -> str:
    return self._vocab_int_to_str[index]

  def convert_tokens_to_string(self, tokens):
    return ''.join(tokens)

  def get_vocab(self) -> typing.Dict[str, int]:
    return self._vocab_str_to_int


def _generate_synthetic_data(dataset_size, 
                             seq_len, vocab_size):
  dataset = np.zeros((dataset_size, seq_len), dtype=int)
  # tokens representing sequence boundary
  dataset[:, 0] = vocab_size - 2  # bos
  dataset[:, -1] = vocab_size - 1  # eos

  for i in range(dataset_size):
    # sample from 0, 1, ..., vocab_size - 3
    temp = np.random.randint(vocab_size - 2)
    for j in reversed(range(1, seq_len - 1)):
      dataset[i, j] = temp
      if temp != 0:
        temp = temp // 4
      else:
        temp = np.random.randint(vocab_size - 2)

  return dataset


def generate_synthetic_dataset(train_dataset_size, 
                               validation_dataset_size, 
                               seq_len, vocab_size):
  np.random.seed(42)
  train_data = torch.from_numpy(
    _generate_synthetic_data(train_dataset_size, 
                             seq_len, vocab_size))
  train_dataset = datasets.Dataset.from_dict({
    'input_ids': train_data, 
    'attention_mask': torch.ones_like(train_data),
  })
  train_dataset.set_format(type='torch')

  np.random.seed(41)
  validation_data = torch.from_numpy(
    _generate_synthetic_data(validation_dataset_size, 
                             seq_len, vocab_size))
  validation_dataset = datasets.Dataset.from_dict({
    'input_ids': validation_data, 
    'attention_mask': torch.ones_like(validation_data),
  })
  validation_dataset.set_format(type='torch')

  return {
    'train': train_dataset,
    'validation': validation_dataset,
  }


class Text8Tokenizer(transformers.PreTrainedTokenizer):
  def __init__(
    self,
    bos_token='[BOS]',
    eos_token='[EOS]',
    sep_token='[SEP]',
    cls_token='[CLS]',
    pad_token='[PAD]',
    mask_token='[MASK]',
    unk_token='[UNK]',
    **kwargs):
    self.characters = list('abcdefghijklmnopqrstuvwxyz ')
    self._vocab_str_to_int = {
      '[CLS]': 0,
      '[SEP]': 1,
      '[BOS]': 2,
      '[EOS]': 3,
      '[MASK]': 4,
      '[PAD]': 5,
      '[RESERVED]': 6,
      '[UNK]': 7,
      ** {ch: i + 8 for i, ch in enumerate(self.characters)}}
    self._vocab_int_to_str = {
      v: k for k, v in self._vocab_str_to_int.items()}
    super().__init__(
      bos_token=bos_token,
      eos_token=eos_token,
      sep_token=sep_token,
      cls_token=cls_token,
      pad_token=pad_token,
      mask_token=mask_token,
      unk_token=unk_token,
      **kwargs)

  @property
  def vocab_size(self) -> int:
    return len(self._vocab_str_to_int)

  def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
    return list(text.lower())

  def _convert_token_to_id(self, token: str) -> int:
    return self._vocab_str_to_int.get(
      token, self._vocab_str_to_int['[UNK]'])

  def _convert_id_to_token(self, index: int) -> str:
    return self._vocab_int_to_str[index]

  def convert_tokens_to_string(self, tokens):
    return ''.join(tokens)

  def get_vocab(self) -> typing.Dict[str, int]:
    return self._vocab_str_to_int


def get_lambada_test_dataset():
    url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl"

    def read_jsonl_to_list(url):
      response = requests.get(url, stream=True)
      data_list = []

      # Process each line in the response content
      for line in response.iter_lines(decode_unicode=True):
        if line:
          data = json.loads(line)
          data_list.append(data)

      return data_list

    lambada_data = read_jsonl_to_list(url)
    dataset = datasets.Dataset.from_list(lambada_data)
    return dataset


def get_text8_dataset(cache_dir, max_seq_length=256,
                      drop_last=True, crop_train=False):
  """Adapted from:
    https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344

    Args:
      cache_dir: str, path to cache directory.
      max_seq_length: int, maximum length of sequences.
          (default: 256, as in D3PM codebase.)
      drop_last: bool, whether to drop the last incomplete
          batch. (default: True, as in D3PM codebase.)
      crop_train: bool, whether to subsample contiguous
          subsequences from training example. serves to
          make sure transformer models with absolute position
          embeddings do not have incorrect position-wise
          marginals. (default: False, but necessary to match D3PM AR)

    Returns:
      dataset: dataset.DatasetDict, with keys 'train',
          'valid', 'test'.
  """
  url = 'http://mattmahoney.net/dc/text8.zip'
  if not crop_train:
    cache_dir = f'{cache_dir}/text8'
  else:
    cache_dir = f'{cache_dir}/text8-crop-train'
  split_names = ['train', 'validation', 'test']
  if not all([
    utils.fsspec_exists(os.path.join(cache_dir, split))
    for split in split_names
  ]):
    # Check if raw data exists
    raw_cache_dir = os.path.join(cache_dir, 'raw_data')
    if not all([
      utils.fsspec_exists(
        os.path.join(raw_cache_dir, f'text8.{split}.txt'))
      for split in split_names
    ]):
      if not utils.fsspec_exists(
        os.path.join(raw_cache_dir, 'text8.zip')):
        utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)
        LOGGER.info('Downloading text8 from URL {}.'.format(url))
        with (urllib.request.urlopen(url) as in_stream,
              open(os.path.join(raw_cache_dir, 'text8.zip'),
                   'wb') as out_file):
          shutil.copyfileobj(in_stream, out_file)

      with fsspec.open(
        os.path.join(raw_cache_dir, 'text8.zip'),
        'rb') as f:
        rawdata = zipfile.ZipFile(f).read(
          'text8').decode('utf-8')

      # Splits taken from D3PM codebase
      splits = {
        'train': rawdata[:90000000],
        'validation': rawdata[90000000: 95000000],
        'test': rawdata[95000000:],
      }

      for split, data in splits.items():
        _path = os.path.join(raw_cache_dir,
                             f'text8.{split}.txt')
        with fsspec.open(_path, 'w') as f:
          f.write(data)
    else:
      splits = {}
      for split in split_names:
        _path = os.path.join(raw_cache_dir,
                             f'text8.{split}.txt')
        with fsspec.open(_path, 'r') as f:
          splits[split] = f.read()

    # Chunk and save as datasets.DatasetDict
    def chunks(lst, n):
      """Yield successive n-sized chunks from lst."""
      for i in range(0, len(lst), n):
        yield lst[i:i + n]

    dataset_dict = {}
    for k, v in splits.items():
      if k == 'train' and crop_train == True:
        chunk_size = 2 * max_seq_length
      else:
        chunk_size = max_seq_length
      text = list(chunks(v, chunk_size))
      if drop_last and len(text[-1]) < chunk_size:
        text = text[:-1]
      dataset_dict[k] = datasets.Dataset.from_dict({'text': text})
    dataset = datasets.DatasetDict(dataset_dict)
    dataset.save_to_disk(cache_dir)
  else:
    dataset = datasets.load_from_disk(cache_dir)

  return dataset


def _group_texts(examples, block_size, bos, eos):
  # Concatenate all texts.
  concatenated_examples = list(itertools.chain(* examples['input_ids']))
  total_length = len(concatenated_examples)
  # TODO(yair): look into not dropping the remainder but rather padding it.
  # We drop the small remainder, and if the total_length < block_size - 2
  # we exclude this batch and return an empty dict.
  # We could add padding if the model supported it instead of
  # this drop, you can customize this part to your needs.
  new_block_size = block_size - 2  # [BOS] and [EOS] to be added
  total_length = (total_length // new_block_size) * new_block_size
  # Split by chunks of max_len.
  result = {}
  _values = []
  _attn_masks = []
  for i in range(0, total_length, new_block_size):
    _values.append(
      [bos]
      + concatenated_examples[i : i + new_block_size]
      + [eos])
    _attn_masks.append(torch.ones(block_size))
  result['input_ids'] = _values
  result['attention_mask'] = _attn_masks
  return result


def get_dataset(dataset_name,
                tokenizer,
                wrap,
                mode,
                cache_dir,
                insert_eos=True,
                block_size=1024,
                num_proc=len(os.sched_getaffinity(0)),
                streaming=False,
                revision : Optional[str]=None):
  if dataset_name == 'cifar10':
    assert mode in ('train', 'validation')
    return DiscreteCIFAR10(cache_dir=cache_dir, 
                           train=mode=='train')
  eos_tag = ''
  if not insert_eos:
    eos_tag = '_eosFalse'
  if wrap:
    filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped{eos_tag}.dat'
  else:
    filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped{eos_tag}.dat'
  _path = os.path.join(cache_dir, filename)
  
  if utils.fsspec_exists(_path):
    LOGGER.info(f'Loading data from: {_path}')
    return datasets.load_from_disk(_path).with_format('torch')
  LOGGER.info(f'Generating new data at: {_path}')
  LOGGER.info(f'{streaming=}')  

  crop_train = dataset_name == 'text8-crop'
  if mode == 'train' and crop_train:
    # double block size for sub-sampling
    block_size *= 2
  
  if dataset_name == 'wikitext103':
    dataset = datasets.load_dataset(
      'wikitext',
      name='wikitext-103-raw-v1',
      cache_dir=cache_dir,
      revision=revision)
  elif dataset_name == 'wikitext2':
    dataset = datasets.load_dataset(
      'wikitext',
      name='wikitext-2-raw-v1',
      cache_dir=cache_dir,
      revision=revision)
  elif dataset_name == 'ptb':
    dataset = datasets.load_dataset(
      'ptb_text_only',
      cache_dir=cache_dir,
      revision=revision)
  elif dataset_name == 'lambada':
    dataset = get_lambada_test_dataset()
  elif dataset_name == 'text8':
    assert wrap
    assert revision is None
    dataset = get_text8_dataset(
      cache_dir, max_seq_length=block_size)
  elif dataset_name == 'text8-crop':
    assert revision is None
    dataset = get_text8_dataset(
      cache_dir, max_seq_length=block_size, crop_train=True)
  elif dataset_name == 'openwebtext-train':
    dataset = datasets.load_dataset(
      'openwebtext',
      split='train[:-100000]',
      cache_dir=cache_dir,
      revision=revision,
      streaming=False,
      num_proc=num_proc)
  elif dataset_name == 'openwebtext-valid':
    dataset = datasets.load_dataset(
      'openwebtext',
      split='train[-100000:]',
      cache_dir=cache_dir,
      revision=revision,
      streaming=False,
      num_proc=num_proc)
  elif dataset_name == 'scientific_papers_arxiv':
    dataset = datasets.load_dataset(
      'scientific_papers', 'arxiv',
      cache_dir=cache_dir,
      streaming=streaming,
      revision=revision)
  elif dataset_name == 'scientific_papers_pubmed':
    dataset = datasets.load_dataset(
      'scientific_papers', 'pubmed',
      cache_dir=cache_dir,
      streaming=streaming,
      revision=revision)
  elif dataset_name == 'ag_news':
    dataset = datasets.load_dataset(
      'ag_news',
      cache_dir=cache_dir,
      streaming=streaming,
      revision=revision)
  elif dataset_name == 'synthetic':
    assert streaming
    assert wrap  # i.e., no pad tokens
    dataset = generate_synthetic_dataset(
      train_dataset_size=100000,
      validation_dataset_size=1024,
      seq_len=32,
      vocab_size=256,
    )
  else:
    dataset = datasets.load_dataset(
      dataset_name,
      cache_dir=cache_dir,
      streaming=streaming,
      revision=revision)

  if dataset_name in ['lambada', 'openwebtext-train',
                      'openwebtext-valid']:
    data = dataset
  else:
    data = dataset[mode]
    if dataset_name == 'synthetic':
      # already tokenized, no further actions required
      return data

  if dataset_name.startswith('wikitext'):
    detokenizer = wt_detokenizer
  elif dataset_name == 'ptb':
    detokenizer = ptb_detokenizer
  elif dataset_name == 'lm1b':
    detokenizer = lm1b_detokenizer
  elif dataset_name == 'lambada':
    detokenizer = lambada_detokenizer
  elif dataset_name.startswith('scientific_papers'):
    detokenizer = scientific_papers_detokenizer
  else:
    detokenizer = None

  def _apply_detokenizer(detokenizer):
    def detok(text):
      for i, t in enumerate(text, 0):
        text[i] = detokenizer(t)
      return text
    return detok
  
  EOS = tokenizer.encode(tokenizer.eos_token)[0]
  BOS = tokenizer.encode(tokenizer.bos_token)[0]

  def preprocess_and_tokenize(example):
    if dataset_name == 'ptb':
      text = example['sentence']
    elif 'scientific_papers' in dataset_name:
      text = example['article']
    else:
      text = example['text']
    
    if detokenizer is not None:
      text = _apply_detokenizer(detokenizer)(text)

    tokenizer.padding_side = 'right'
    tokenizer.truncation_side = 'right'

    if wrap:
      tokens = tokenizer(text,
                         add_special_tokens=False,
                         return_attention_mask=False,
                         return_token_type_ids=False)
      if insert_eos:
        tokens = {'input_ids':
                  [t + [EOS] for t in tokens['input_ids']]}
      # Still missing BOS, but will be added in group_texts
    else:
      tokens = tokenizer(text,
                         max_length=block_size,
                         padding='max_length',
                         truncation=True,
                         add_special_tokens=True,
                         return_attention_mask=True,
                         return_token_type_ids=True)
    return tokens

  if streaming:
    tokenized_dataset = data.map(
      preprocess_and_tokenize,
      batched=True)
  else:
    tokenized_dataset = data.map(
      preprocess_and_tokenize,
      batched=True,
      num_proc=num_proc,
      load_from_cache_file=True,
      desc='Tokenizing')
  if dataset_name == 'ptb':
    tokenized_dataset = tokenized_dataset.remove_columns(
      'sentence')
  elif 'scientific_papers' in dataset_name:
    tokenized_dataset = tokenized_dataset.remove_columns([
      'article', 'abstract', 'section_names'])
  elif dataset_name == 'ag_news':
    tokenized_dataset = tokenized_dataset.remove_columns(
      ['text', 'label'])
  else:
    tokenized_dataset = tokenized_dataset.remove_columns(
      'text')

  if not wrap:
    if not streaming:
      tokenized_dataset.save_to_disk(_path)
    return tokenized_dataset.with_format('torch')

  group_texts = functools.partial(
    _group_texts, block_size=block_size, bos=BOS, eos=EOS)
  if streaming:
    chunked_dataset = tokenized_dataset.map(
      group_texts,
      batched=True)
  else:
    chunked_dataset = tokenized_dataset.map(
      group_texts,
      batched=True,
      num_proc=num_proc,
      load_from_cache_file=True,
      desc='Grouping')
    chunked_dataset.save_to_disk(_path)
  chunked_dataset = chunked_dataset.with_format('torch')
  return chunked_dataset


def get_tokenizer(config):
  if config.data.tokenizer_name_or_path == 'text8':
    tokenizer = Text8Tokenizer()
  elif config.data.tokenizer_name_or_path == 'bert-base-uncased':
    tokenizer = transformers.BertTokenizer.\
      from_pretrained('bert-base-uncased')
  elif config.data.tokenizer_name_or_path == 'synthetic':
    tokenizer = SyntheticTokenizer(vocab_size=256)
  elif config.data.tokenizer_name_or_path == 'cifar10':
    return RawPixelsVisionTokenizer(
      vocab_size=256, image_size=32, add_special_tokens=False, 
      add_mask_token='mdlm' in config.algo.name)
  else:
    tokenizer = transformers.AutoTokenizer.from_pretrained(
      config.data.tokenizer_name_or_path)

  if (isinstance(tokenizer, transformers.GPT2TokenizerFast)
      or isinstance(tokenizer, transformers.GPT2Tokenizer)):
    tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
      (tokenizer.bos_token, tokenizer.bos_token_id),
      (tokenizer.eos_token, tokenizer.eos_token_id))

  # For wrapped batches:
  #  [BOS] sent1 [EOS] sent2-fragment [EOS]
  #  [BOS] sent2-fragment [EOS] sent3 [EOS]
  if tokenizer.bos_token is None:
    if tokenizer.cls_token is None:
      raise AttributeError(
        'Tokenizer must have a bos_token or '
        f'cls_token: {tokenizer}')
    tokenizer.bos_token = tokenizer.cls_token
  if tokenizer.eos_token is None:
    if tokenizer.sep_token is None:
      raise AttributeError(
        'Tokenizer must have a eos_token '
        f'or sep_token: {tokenizer}')
    tokenizer.eos_token = tokenizer.sep_token
  if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

  return tokenizer
    

def get_dataloaders(config, tokenizer, skip_train=False,
                    skip_valid=False, valid_seed=None):
  num_gpus = torch.cuda.device_count()
  assert (config.loader.global_batch_size
          == (config.loader.batch_size
              * config.trainer.num_nodes
              * num_gpus
              * config.trainer.accumulate_grad_batches))
  if config.loader.global_batch_size % (
    num_gpus * config.trainer.accumulate_grad_batches) != 0:
    raise ValueError(
      f'Train Batch Size {config.training.batch_size}'
      f'not divisible by {num_gpus} gpus with accumulation '
      f'{config.trainer.accumulate_grad_batches}.')
  if config.loader.eval_global_batch_size % num_gpus != 0:
    raise ValueError(
      f'Eval Batch Size for {config.eval.batch_size} '
      f'not divisible by {num_gpus}.')
  if skip_train:
    train_set = None
  else:
    train_set = get_dataset(
      config.data.train,
      tokenizer,
      mode='train',
      wrap=config.data.wrap,
      insert_eos=config.data.insert_train_eos,
      cache_dir=config.data.cache_dir,
      block_size=config.model.length,
      streaming=config.data.streaming,
      num_proc=config.loader.num_workers,
      revision=config.data.get("train_revision", None))
  
  if config.data.valid in ['text8', 'lm1b', 'ag_news']:
    validation_split = 'test'
  else:
    validation_split = 'validation'
  if skip_valid:
    valid_set = None
  else:
    valid_set = get_dataset(
      config.data.valid,
      tokenizer,
      wrap=config.data.wrap,
      mode=validation_split,
      cache_dir=config.data.cache_dir,
      insert_eos=config.data.insert_valid_eos,
      block_size=config.model.length,
      streaming=config.data.streaming,
      num_proc=config.loader.num_workers,
      revision=config.data.get("valid_revision", None))

  if skip_train:
    train_loader = None
  else:
    train_loader = torch.utils.data.DataLoader(
      train_set,
      batch_size=config.loader.batch_size,
      num_workers=config.loader.num_workers,
      pin_memory=config.loader.pin_memory,
      shuffle=not config.data.streaming,
      persistent_workers=True)
    train_loader.tokenizer = tokenizer
  if skip_valid:
    valid_loader = None
  else:
    if valid_seed is None:
      shuffle_valid = False
      generator = None
    else:
      shuffle_valid = True
      generator = torch.Generator().manual_seed(valid_seed)
    valid_loader = torch.utils.data.DataLoader(
      valid_set,
      batch_size=config.loader.eval_batch_size,
      num_workers=config.loader.num_workers,
      pin_memory=config.loader.pin_memory,
      shuffle=shuffle_valid,
      generator=generator)
    # Will be used in generative perplexity calculation
    valid_loader.tokenizer = tokenizer

  return train_loader, valid_loader


# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py


class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):

  def __init__(self, *args, generator=None, **kwargs):
    # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
    # which should be reproducible if pl.seed_everything was called beforehand.
    # This means that changing the seed of the experiment will also change the
    # sampling order.
    if generator is None:
      seed = int(torch.empty((), dtype=torch.int64).random_().item())
      generator = torch.Generator().manual_seed(seed)
    kwargs.pop('shuffle', None)
    super().__init__(*args, generator=generator, **kwargs)
    self.counter = 0
    self.restarting = False

  def state_dict(self):
    return {'random_state': self.generator.get_state(),
            'counter': self.counter}

  def load_state_dict(self, state_dict):
    self.generator.set_state(state_dict.get('random_state'))
    self.counter = state_dict['counter']
    # self.start_counter = self.counter
    self.restarting = True

  # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
  # epoch, and subsequent epoch will have very few batches.

  def __iter__(self) -> typing.Iterator[int]:
    n = len(self.data_source)

    self.state = self.generator.get_state()
    indices = torch.randperm(n, generator=self.generator).tolist()

    if not self.restarting:
      self.counter = 0
    else:
      indices = indices[self.counter:]
      self.restarting = False

    for index in indices:
      self.counter += 1
      yield index

    self.counter = 0


class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.counter = 0
    self.restarting = False

  def state_dict(self):
    return {'epoch': self.epoch, 'counter': self.counter}

  def load_state_dict(self, state_dict):
    self.epoch = state_dict['epoch']
    self.counter = state_dict['counter']
    self.restarting = True

  # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
  # epoch, and subsequent epoch will have very few batches.
  def __iter__(self):
    if self.shuffle:
      # deterministically shuffle based on epoch and seed
      g = torch.Generator()
      g.manual_seed(self.seed + self.epoch)
      indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
    else:
      indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

    if not self.drop_last:
      # add extra samples to make it evenly divisible
      padding_size = self.total_size - len(indices)
      if padding_size <= len(indices):
        indices += indices[:padding_size]
      else:
        indices += (indices * math.ceil(
          padding_size / len(indices)))[:padding_size]
    else:
      # remove tail of data to make it evenly divisible.
      indices = indices[:self.total_size]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank:self.total_size:self.num_replicas]
    assert len(indices) == self.num_samples

    if not self.restarting:
      self.counter = 0
    else:
      indices = indices[self.counter:]
      self.restarting = False

    for index in indices:
      self.counter += 1
      yield index

    self.counter = 0


================================================
FILE: discrete_diffusion_harness.py
================================================
import torch
from omegaconf import OmegaConf

from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.__main__ import cli_evaluate
from lm_eval.api.instance import Instance
from datasets import Dataset
from tqdm import tqdm
import numpy as np
import algo
import dataloader

"""
# Instructions on running the eval

## What is the script doing?
The script evaluates (or approximates) the log-likelihood of
prefix + suffix. The script will load a checkpoint, and 
depending on the config, use the corresponding class 
(e.g. mdlm, duo, ar, etc).

- For MCQ, the log-likelihood of all continuations given the 
  same prefix is evaluated. The most likely continuation is 
  selected as the "correct" answer according to the model.

- For lambada_openai, the model is correct if the true 
  continuation is generated as the argmax. For diffusion, we 
  inject noise in the continuation, and check whether the true 
  answer computed in a single forward pass is the most likely. 
  This is naturally favoring AR models, as they run one forward 
  pass per token, while diffusion use a single pass for all 
  tokens for simplicity. Therefore, I usually only compare on 
  MCQ since it is more fair.

To run the script, you need to install the lm-eval-harness package:
```
pip install git+https://github.com/EleutherAI/lm-evaluation-harness
```

## Tasks that we tested with
**MCQ**: arc_easy, arc_challenge, hellaswag, winogrande, 
         boolq, openbookqa, race, social_iqa, mathqa, piqa
**Likelihood based**: lambada_openai

## Batch size that fits at the small scale
    - boolq -> 64
    - openbookqa -> 256
    - race -> 32
    - social_iqa -> 512
    - winogrande -> 64
    - mathqa -> 64
    - lambada_openai -> 64
    - arc_easy -> 256
    - arc_challenge -> 256
    - hellaswag -> 64
    - piqa -> 64

## Important flags
  --trust_remote_code -> some datasets execute code when loading 
                         from huggingface. Without this flag, 
                         the script might crash.
  --batch_size        -> max. num elements to use in parallel 
                         to eval the likelihood (for diffusion). 
                         For simplicity, inputs are NOT padded 
                         and batched for AR, though it should 
                         be fairly easy to add.
  --tasks             -> one or multiple comma-separated tasks 
                         to evaluate on.
  --model_args        -> string (without spaces) that contains 
                         the arguments to pass to the evaluator 
                         (stuff like checkpoints path, number 
                         of MC samples to evaluate the 
                         likelihood, path to the sentencepiece 
                         tokenizer, etc).
  --output_path       -> path to a json file where the evaluation 
                         results will be saved, instead of 
                         only being printed in the terminal 
                         (they will always be printed).
  --limit             -> debugging flag; limit the number of 
                         examples to a fixed amount, instead 
                         of using the whole dataset


## Example commands


### Run a single task with an AR model
python discrete_diffusion_harness.py \
    --tasks arc_easy \
    --batch_size 256 \
    --model dlm \
    --model_args checkpoint_path=/home/username/baselines/ar/1M.ckpt \
    --output_path ./harness_results/ar/1M/arc_easy.ckpt

### Run a single task with an MDLM model, and 2048 MC samples to approximate the likelihood
python discrete_diffusion_harness.py \
    --tasks arc_easy \
    --model dlm \
    --batch_size 256 \
    --model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \
    --output_path ./harness_results/mdlm/1M/arc_easy.ckpt

### Debug with 20 examples only
python discrete_diffusion_harness.py \
    --tasks arc_easy \
    --model dlm \
    --batch_size 256 \
    --model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \
    --limit 20 \
    --output_path ./harness_results/mdlm/1M/arc_easy.ckpt


### Run the benchmarks from "The Diffusion Duality (Chapter 2)":
    (Arc-e, Arc-c, HSwag, WinoG, PIQA, MathQA, OQA)
-> just change the checkpoint to evaluate mdlm, ar, or duo

for task_config in "arc_easy 256" "arc_challenge 256" "hellaswag 64" "winogrande 64" "piqa 64" "mathqa 64" "openbookqa 256"; do
  task=$(echo $task_config | cut -d' ' -f1)
  batch_size=$(echo $task_config | cut -d' ' -f2)
  python discrete_diffusion_harness.py \
    --batch_size $batch_size \
    --tasks $task \
    --model dlm \
    --model_args checkpoint_path=/path/to/checkpoint.ckpt \
    --output_path ./harness_results/duo/$task.json
done

"""


def requests_to_dataset(config, requests, tokenizer, num_proc):
  def _tokenize(e):
    eos_idx = tokenizer.eos_token_id
    bos_idx = tokenizer.bos_token_id
    prefix_tokens = tokenizer(e['prefix'], 
                              return_attention_mask=False, 
                              add_special_tokens=False
                              )['input_ids']
    target_tokens = tokenizer(e['target'], 
                              return_attention_mask=False, 
                              add_special_tokens=False
                              )['input_ids']
    prefix_tokens = [bos_idx] + prefix_tokens
    target_tokens = target_tokens + [eos_idx]
    
    return {
        'prefix_text': e['prefix'],
        'target_text': e['target'],
        'prefix': prefix_tokens,
        'target': target_tokens,
    }
  ds = []
  ds = [{'prefix': req.args[0], 'target': req.args[1]} 
        for req in requests]
  ds = Dataset.from_list(ds)
  ds = ds.map(_tokenize, num_proc=num_proc)
  ds = ds.with_format('torch')
  seq_lenths = [len(x['prefix']) + len(x['target']) 
                for x in ds]
  
  num_larger = len([x for x in seq_lenths 
                    if x > config.model.length])
  if num_larger > 0:
    print(f'\033[91mThere are some examples that are longer '
          f'than the context length, they will be ignored '
          f'during evaluation. Number of such sequences: '
          f'{num_larger}\033[0m')

  return ds


def _eval_suffix_nll_generators(config, module, prefix, 
  suffix, batch_size, num_samples, loss_avg_mode):
  device = module.device
  assert num_samples % batch_size == 0
  full_sentence = torch.cat([prefix, suffix], dim=-1
                  ).repeat(batch_size, 1).to(module.device)
  all_ts = module._sample_t(num_samples, accum_step=None)
  for idx in range(0, num_samples, batch_size):
    t = all_ts[idx:idx+batch_size].unsqueeze(-1)
    dalpha_t, alpha_t = module.noise(t)
    alpha_t = alpha_t.to(device)
    sigma = module._sigma_from_alphat(alpha_t)
    x0 = full_sentence.to(device)
    # Inject noise
    xt = module.q_xt(full_sentence, alpha_t).to(device)
    if loss_avg_mode == 'full':
      pass  # nothing to do
    elif loss_avg_mode == 'suffix':
      xt[:, :len(prefix)] = prefix
      # recompute alpha_t based on number of masked tokens, 
      #   for conditioning of the backbone
      alpha_t = (xt == x0).float().mean(dim=1)[:, None]
      t = module.noise.get_t_for_alpha(alpha_t)
      # We need to get dalpha_t for the loss:
      alpha_t = alpha_t.to(device)
      sigma = module._sigma_from_alphat(alpha_t)
    else:
      raise ValueError(loss_avg_mode)

    yield xt, x0, t, sigma, alpha_t, dalpha_t
    

def eval_suffix_nll(config, module, prefix, suffix, batch_size, 
                    num_samples, loss_avg_mode):
  if config.algo.name in ('mdlm', 'duo', 'duo_base', 
                          'distillation', 'ot-finetune'):
    return eval_suffix_nll_diffusion(config, module, prefix, 
      suffix, batch_size, num_samples, loss_avg_mode)
  elif config.algo.name == 'ar':
    return eval_suffix_nll_ar(config, module, prefix, 
      suffix, batch_size, num_samples, loss_avg_mode)
  else:
    raise ValueError(config.algo.name)


def eval_suffix_nll_ar(config, module, prefix, suffix,
                    batch_size, num_samples, loss_avg_mode):
  x_cat = torch.cat([prefix, suffix[:-1]], dim=-1)
  x_cat = x_cat.reshape(1, -1).to(module.device)

  with torch.amp.autocast('cuda', dtype=torch.float32):
    out = module.backbone(x_cat, sigma=None)

  out[:, :, module.mask_index] = module.neg_infinity
  suffix_out = out[:, len(prefix) - 1:, :]
  suffix_logits = torch.log_softmax(suffix_out, dim=-1)
  index = suffix[None, :, None].to(module.device)
  nll = torch.gather(-suffix_logits, dim=-1, 
                     index=index).mean()
  return float(nll.cpu())

def eval_suffix_nll_diffusion(config, module, prefix, suffix, 
  batch_size, num_samples, loss_avg_mode):
  all_losses = []
  generator =  _eval_suffix_nll_generators(config, module, 
    prefix, suffix, batch_size, num_samples, loss_avg_mode)
  for xt, x0, t, sigma, alpha_t, dalpha_t in generator:
    log_x_theta = module(xt, sigma, labels=None)
    token_nll = module.nll_per_token(log_x_theta, xt, x0, 
                                     alpha_t, dalpha_t)
    if loss_avg_mode == 'full':
      loss = float(token_nll.mean())
    elif loss_avg_mode == 'suffix':
      loss = float(token_nll[:, len(prefix):].mean())
    all_losses.append(loss)
  return float(np.mean(all_losses))


@register_model("dlm")
class DiscreteDiffusionHarness(LM):
  def __init__(self, pretrained="NONE", max_length=1024,
    num_mc_samples=1024, batch_size=64, device="cuda",
    checkpoint_path=None, num_proc=8, loss_avg_mode='full',
    *args, **kwargs):
    super().__init__()
    # Whether to use the full sequence, or suffix only to 
    #  approximate the NLL. Full should be the correct way.
    assert loss_avg_mode in ('full', 'suffix')
    ckpt = torch.load(checkpoint_path, map_location='cpu', 
                      weights_only=False)
    config = ckpt['hyper_parameters']['config']
    # Backfill missing keys into legacy checkpoints
    if not hasattr(config.training, 'class_dropout_p'):
      OmegaConf.set_struct(config, False)
      config.training.class_dropout_p = 0.0
      OmegaConf.set_struct(config, True)
    self.tokenizer = dataloader.get_tokenizer(config)
    if config.algo.name == 'mdlm':
      self.model = algo.MDLM(config, self.tokenizer)
    elif config.algo.name in ('duo', 'duo_base', 
                              'distillation', 'ot-finetune'):
      self.model = algo.DUO_BASE(config, self.tokenizer)
    elif config.algo.name == 'ar':
      self.model = algo.AR(config, self.tokenizer)
    else:
      raise ValueError(f'Implement for {config.algo.name}')
    self.config = config
    self.num_proc = num_proc
    self.num_mc_samples = num_mc_samples
    self.batch_size = int(batch_size)
    self.device = device
    self.loss_avg_mode = loss_avg_mode

    self.model.load_state_dict(ckpt['state_dict'])
    self.model.to(device)
    self.model.eval()

  def suffix_greedy_prediction(self, prefix, target):
    if self.config.algo.name == 'mdlm':
      return self._suffix_greedy_prediction_mdlm(prefix, 
                                                 target)
    elif self.config.algo.name in ('duo', 'duo_base', 
                              'distillation', 'ot-finetune'):
      return self._suffix_greedy_prediction_duo_base(prefix, 
                                                     target)
    elif self.config.algo.name == 'ar':
      return self._suffix_greedy_prediction_ar(prefix, target)
    else:
      raise ValueError(self.config.algo.name)
    
  def _suffix_greedy_prediction_ar(self, prefix, target):
    x_cat = torch.cat([prefix, target[:-1]], 
      dim=-1).reshape(1, -1).to(self.device)

    # Follows generate_samples in AR (algo.py)
    out = self.model.backbone(x_cat, sigma=None)
    out[:, :, self.model.mask_index] = self.model.neg_infinity
    out = out.log_softmax(-1)
    preds_suffix = out[:, len(prefix) - 1:, :]
    greedy_preds = preds_suffix.argmax(-1).flatten()
    return (greedy_preds.cpu() == target).all().item()


  def _suffix_greedy_prediction_mdlm(self, prefix, target):
    mask_idx = self.model.mask_index
    eos_idx = self.tokenizer.eos_token_id
    # Note: because of the preprocessing, we know that the 
    #  last token is an eos token.
    noisy_target = [mask_idx] * (len(target) - 1) + [eos_idx]
    noisy_target = torch.tensor(noisy_target, 
                                device=self.device)
    prefix = prefix.to(self.device)
    seq = torch.concatenate([prefix, noisy_target], 
                            dim=-1).reshape(1, -1)
    sigma = torch.zeros(size=(seq.shape[0], 1), 
                        device=self.device)
    logits = self.model(seq, sigma, labels=None)
    assert logits.shape[0] == 1
    suffix_logits = logits[0, len(prefix):]
    target_preds = suffix_logits.argmax(-1).cpu()
    correct = target_preds == target
    correct = correct.all()
    return bool(correct)
  
  def _suffix_greedy_prediction_duo_base(self, prefix, target):
    noisy_suffix = torch.randint(
      0, self.model.vocab_size, size=target.shape, 
      dtype=target.dtype, device=self.device)

    prefix = prefix.to(self.device)
    # shape: (1, len)
    noisy_seq = torch.concatenate([prefix, noisy_suffix])[None, :]
    # Set the EOS token at the end
    noisy_seq[0, -1] = target[-1]
    clean_seq = torch.concatenate([prefix, target.to(self.device)])[None, :]
    
    alpha_t = (noisy_seq == clean_seq).float().mean(dim=1)[:, None]
    sigma = self.model._sigma_from_alphat(alpha_t)
    t = self.model.noise.get_t_for_alpha(alpha_t)
    logits = self.model(noisy_seq, sigma, labels=None)
    assert logits.shape[0] == 1
    suffix_logits = logits[0, len(prefix):]
    target_preds = suffix_logits.argmax(-1).cpu()

    correct = target_preds == target
    correct = correct.all()
    return bool(correct)
  
  @torch.no_grad()
  def loglikelihood(self, requests: list[Instance]) \
                                -> list[tuple[float, bool]]:
    dataset = requests_to_dataset(self.config, requests, 
                                  self.tokenizer, self.num_proc)
    out = []
    for elem in tqdm(dataset, 'Computing likelihood...'):
      prefix = elem['prefix']
      target = elem['target']
      
      if len(prefix) + len(target) > self.model.config.model.length:
        # If the request is too long, skip it.
        ll = 0.0
        is_target_greedy_dec = False
        out.append((ll, is_target_greedy_dec))
        print("SKIPPING")
        continue

      ll = -eval_suffix_nll(self.config, self.model, prefix, 
        target, self.batch_size, self.num_mc_samples, 
        self.loss_avg_mode)
      is_target_greedy_dec = self.suffix_greedy_prediction(
        prefix, target)
      out.append((ll, bool(is_target_greedy_dec)))
    return out

  def loglikelihood_rolling(
        self, requests: list[Instance]
    ) -> list[tuple[float, bool]]:
    raise NotImplementedError
  
  def generate_until(self, context, max_length, stop, 
                     **generation_kwargs):
    raise NotImplementedError


if __name__ == "__main__":
    cli_evaluate()

================================================
FILE: main.py
================================================
import json
import os

import fsspec
import hydra
import lightning as L
from lightning.fabric import Fabric
import omegaconf
import rich.syntax
import rich.tree
import torch
from torch.utils.data.distributed import DistributedSampler
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from tqdm import tqdm, trange

import algo
import dataloader
import utils

omegaconf.OmegaConf.register_new_resolver(
  'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
  'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
  'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
  'div_up', lambda x, y: (x + y - 1) // y)


def _load_from_checkpoint(diffusion_model, config, tokenizer):
  if 'hf' in config.algo.backbone:
    return diffusion_model(
      config, tokenizer=tokenizer).to('cuda')
  
  return diffusion_model.load_from_checkpoint(
    config.eval.checkpoint_path,
    tokenizer=tokenizer,
    config=config)


@L.pytorch.utilities.rank_zero_only
def _print_config(
  config: omegaconf.DictConfig,
  resolve: bool = True,
  save_cfg: bool = True) -> None:
  """Prints content of DictConfig using Rich library and its tree structure.
  
  Args:
    config (DictConfig): Configuration composed by Hydra.
    resolve (bool): Whether to resolve reference fields of DictConfig.
    save_cfg (bool): Whether to save the configuration tree to a file.
  """

  style = 'dim'
  tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)

  fields = config.keys()
  for field in fields:
    branch = tree.add(field, style=style, guide_style=style)

    config_section = config.get(field)
    branch_content = str(config_section)
    if isinstance(config_section, omegaconf.DictConfig):
      branch_content = omegaconf.OmegaConf.to_yaml(
        config_section, resolve=resolve)

    branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
  rich.print(tree)
  if save_cfg:
    with fsspec.open(
      '{}/config_tree.txt'.format(
        config.checkpointing.save_dir), 'w') as fp:
      rich.print(tree, file=fp)


@L.pytorch.utilities.rank_zero_only
def _print_batch(config, train_ds, valid_ds, tokenizer, k=64):
  for dl_type, dl in [
    ('train', train_ds), ('valid', valid_ds)]:
    print(f'Printing {dl_type} dataloader batch.')
    batch = next(iter(dl))
    print('Batch input_ids.shape', batch['input_ids'].shape)
    if config.data.modality == 'text':
      first = batch['input_ids'][0, :k]
      last = batch['input_ids'][0, -k:]
      print(f'First {k} tokens:', tokenizer.decode(first))
      print('ids:', first)
      print(f'Last {k} tokens:', tokenizer.decode(last))
      print('ids:', last)


def _generate_samples(diffusion_model, config, logger,
                      tokenizer):
  logger.info('Starting Sample Eval.')
  model = _load_from_checkpoint(
    diffusion_model=diffusion_model,
    config=config,
    tokenizer=tokenizer)
  model.metrics.gen_ppl.reset()
  model.metrics.sample_entropy.reset()
  if config.eval.disable_ema:
    logger.info('Disabling EMA.')
    model.ema = None
  stride_length = config.sampling.stride_length
  num_strides = config.sampling.num_strides
  all_samples = []
  for _ in trange(config.sampling.num_sample_batches):
    if config.sampling.semi_ar:
      _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(
        stride_length=stride_length,
        num_strides=num_strides,
        dt=1 / config.sampling.steps)
      text_samples = intermediate_samples[-1]
      # Note: Samples generated using semi-ar method
      # need to to be processed before computing generative perplexity
      # since these samples contain numerous <|endoftext|> tokens
      # and diffusion.compute_generative_perplexity() discards
      # any text after the first EOS token.
    else:
      samples = model.restore_model_and_sample(
        num_steps=config.sampling.steps)
      model.metrics.record_entropy(samples)
      text_samples = model.tokenizer.batch_decode(samples)
      model.metrics.record_generative_perplexity(
        text_samples, config.model.length, model.device)
      all_samples.extend(list(text_samples))
  generative_ppl = 0.
  entropy = 0.
  if not config.sampling.semi_ar:
    generative_ppl = model.metrics.gen_ppl.compute().item()
    entropy = model.metrics.sample_entropy.compute().item()
    logger.info(f'Generative perplexity: {generative_ppl}')
    logger.info(f'Sample entropy: {entropy}')
  samples_path = config.eval.generated_samples_path
  with fsspec.open(samples_path, 'w') as f:
    json.dump({'generative_ppl': generative_ppl,
               'entropy': entropy,
               'generated_seqs': all_samples}, f, indent=4)
  logger.info(f'Samples saved at: {samples_path}',)

def _eval_ppl(diffusion_model, config, logger, tokenizer):
  logger.info('Starting Perplexity Eval.')

  model = _load_from_checkpoint(
    diffusion_model=diffusion_model,
    config=config,
    tokenizer=tokenizer)
  if config.eval.disable_ema:
    logger.info('Disabling EMA.')
    model.ema = None

  wandb_logger = None
  if config.get('wandb', None) is not None:
    wandb_logger = L.pytorch.loggers.WandbLogger(
      config=omegaconf.OmegaConf.to_object(config),
      ** config.wandb)
  callbacks = []
  if 'callbacks' in config:
    for _, callback in config.callbacks.items():
      callbacks.append(hydra.utils.instantiate(callback))
  trainer = hydra.utils.instantiate(
    config.trainer,
    default_root_dir=os.getcwd(),
    callbacks=callbacks,
    strategy=hydra.utils.instantiate(config.strategy),
    logger=wandb_logger)
  _, valid_ds = dataloader.get_dataloaders(
    config, tokenizer, skip_train=True, valid_seed=config.seed)
  trainer.validate(model, valid_ds)


def _train(diffusion_model, config, logger, tokenizer):
  logger.info('Starting Training.')
  wandb_logger = None
  if config.get('wandb', None) is not None:
    wandb_logger = L.pytorch.loggers.WandbLogger(
      config=omegaconf.OmegaConf.to_object(config),
      **config.wandb)

  if (config.checkpointing.resume_from_ckpt
      and config.checkpointing.resume_ckpt_path is not None
      and utils.fsspec_exists(
        config.checkpointing.resume_ckpt_path)):
    ckpt_path = config.checkpointing.resume_ckpt_path
  else:
    ckpt_path = None

  # Lightning callbacks
  callbacks = []
  if 'callbacks' in config:
    for _, callback in config.callbacks.items():
      callbacks.append(hydra.utils.instantiate(callback))

  train_ds, valid_ds = dataloader.get_dataloaders(
    config, tokenizer)
  _print_batch(config, train_ds, valid_ds, tokenizer)

  if config.training.finetune_path != '':
    assert utils.fsspec_exists(config.training.finetune_path)
    model = diffusion_model.load_from_checkpoint(
      config.training.finetune_path,
      tokenizer=tokenizer,
      config=config)
  else:
    model = diffusion_model(config, tokenizer=valid_ds.tokenizer)

  trainer = hydra.utils.instantiate(
    config.trainer,
    default_root_dir=os.getcwd(),
    callbacks=callbacks,
    strategy=hydra.utils.instantiate(config.strategy),
    logger=wandb_logger)
  trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)


def _eval_fid(diffusion_model, config, logger, tokenizer):
  logger.info('Preparing data and model for FID eval.')
  fabric = Fabric(accelerator=config.trainer.accelerator,
                  devices=config.trainer.devices,
                  num_nodes=config.trainer.num_nodes)
  
  fabric.launch()
  seed = config.seed + fabric.global_rank
  L.seed_everything(seed)
  print(f'(Rank {fabric.global_rank}): seed: {seed}')
  model = _load_from_checkpoint(
    diffusion_model=diffusion_model,
    config=config,
    tokenizer=tokenizer)
  model.to(fabric.device)
  if config.eval.disable_ema:
    logger.info('Disabling EMA.')
    model.ema = None
  model._eval_mode()

  assert config.data.train == 'cifar10', \
                    'FID eval only implemented for CIFAR-10'

  # Like in flow matching papers: FID against train
  loader, _ = dataloader.get_dataloaders(config, 
    tokenizer=tokenizer, skip_valid=True)

  sampler = DistributedSampler(
    loader.dataset,
    num_replicas=fabric.world_size,
    rank=fabric.global_rank,
    shuffle=False)
  
  loader = torch.utils.data.DataLoader(
    loader.dataset,
    batch_size=config.loader.eval_batch_size,
    sampler=sampler,
    num_workers=loader.num_workers if hasattr(loader, 'num_workers') else 0,
    pin_memory=getattr(loader, 'pin_memory', False))
  
  # Check each GPU must generate the same number of images
  assert len(loader) == len(loader.dataset) // loader.batch_size // fabric.world_size, \
     f'{len(loader)=}, {len(loader.dataset)=}, {loader.batch_size=}, {fabric.world_size=}'

  fid_calculator = FrechetInceptionDistance(
    normalize=False).to(fabric.device)
  is_calculator = InceptionScore(
    normalize=False).to(fabric.device)
  
  desc = f'(Rank {fabric.global_rank}) Sampling...'
  for batch in tqdm(loader, desc=desc):
    real_samples = batch['input_ids']
    # Generate images with labels matching the true data
    labels = batch['labels']

    gen_samples = model.generate_samples(
      num_samples=real_samples.shape[0], 
      num_steps=config.sampling.steps,
      labels=labels)
    # Reshape 1D seq -> 2D image
    gen_samples = model.tokenizer.batch_decode(gen_samples)
    real_samples = model.tokenizer.batch_decode(
      real_samples).to(fabric.device)

    fid_calculator.update(gen_samples, real=False)
    fid_calculator.update(real_samples, real=True)
    is_calculator.update(gen_samples)
  
  fabric.barrier()
  logger.info('Done sampling. Computing FID & IS...')
  fid = fid_calculator.compute()
  incep_score = is_calculator.compute()

  if fabric.global_rank == 0:
    logger.info(f'FID: {fid}')
    logger.info(f'IS: {incep_score}')
  fabric.barrier()


@hydra.main(version_base=None, config_path='configs',
            config_name='config')
def main(config):
  """Main entry point for training."""
  L.seed_everything(config.seed)
  _print_config(config, resolve=True, save_cfg=True)
  
  logger = utils.get_logger(__name__)
  tokenizer = dataloader.get_tokenizer(config)
  if config.algo.name == 'ar':
    diffusion_model = algo.AR
  elif config.algo.name == 'mdlm':
    diffusion_model = algo.MDLM
  elif config.algo.name == 'duo_base':
    diffusion_model = algo.DUO_BASE
  elif config.algo.name == 'd3pm':
    diffusion_model = algo.D3PMAbsorb
  elif config.algo.name == 'sedd':
    diffusion_model = algo.SEDDAbsorb
  elif config.algo.name == 'duo':
    diffusion_model = algo.DUO
  elif config.algo.name == 'distillation':
    diffusion_model = algo.Distillation
  elif config.algo.name == 'ot-finetune':
    diffusion_model = algo.OptimalTransportFinetune
  else:
    raise ValueError(
      f'Invalid algorithm name: {config.algo.name}')
  kwargs = {'diffusion_model': diffusion_model,
            'config': config,
            'tokenizer': tokenizer,
            'logger': logger}
  if config.mode == 'sample_eval':
    _generate_samples(**kwargs)
  elif config.mode == 'ppl_eval':
    _eval_ppl(**kwargs)
  elif config.mode == 'fid_eval':
    _eval_fid(**kwargs)
  else:
    _train(**kwargs)


if __name__ == '__main__':
  main()

================================================
FILE: metrics.py
================================================
import math
import os
import typing

import torch
import torch.nn.functional as F
import torchmetrics
import transformers

LOG2 = math.log(2)


class NLL(torchmetrics.aggregation.MeanMetric):
  def update(self,
             value:typing.Union[float, torch.Tensor],
             weight:typing.Union[float, torch.Tensor]=1.0) -> None:
    """Update state with data.

    Args:
      value: Either a float or tensor containing data.
        Additional tensor dimensions will be flattened
      weight: Either a float or tensor containing weights
        for calculating the average. Shape of weight should
        be able to broadcast with the shape of `value`.
        Default to `1.0` corresponding to simple harmonic
        average.
    """
    # broadcast weight to value shape
    if not isinstance(value, torch.Tensor):
        value = torch.as_tensor(value, dtype=self.dtype,
                                device=self.device)
    if (weight is not None
        and not isinstance(weight, torch.Tensor)):
      weight = torch.as_tensor(weight,
                               dtype=self.dtype,
                               device=self.device)
    weight = torch.broadcast_to(weight, value.shape)
    value, weight = self._cast_and_nan_check_input(value,
                                                   weight)

    if value.numel() == 0:
      return
    self.mean_value += value.sum()
    self.weight += weight.sum()


class BPD(NLL):
  def compute(self) -> torch.Tensor:
    """Computes the bits per dimension.

    Returns:
      bpd
    """
    return self.mean_value / self.weight / LOG2


class Perplexity(NLL):
  def compute(self) -> torch.Tensor:
    """Computes the Perplexity.

    Returns:
     Perplexity
    """
    return torch.exp(self.mean_value / self.weight)


class Metrics:
  def __init__(self, gen_ppl_eval_model_name_or_path=None,
               eval_ppl_batch_size=None) -> None:
    metrics = torchmetrics.MetricCollection({
        'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity()})
    metrics.set_dtype(torch.float64)
    self.train_nlls = metrics.clone(prefix='train/')
    self.train_aux = BPD()
    self.valid_nlls = metrics.clone(prefix='val/')
    self.valid_aux = BPD()
    self.gen_ppl = Perplexity()
    self.sample_entropy = torchmetrics.aggregation.MeanMetric()
    self.eval_ppl_batch_size = eval_ppl_batch_size
    self.gen_ppl_eval_model_name_or_path = gen_ppl_eval_model_name_or_path
    self.tokenizer = transformers.AutoTokenizer.\
      from_pretrained(gen_ppl_eval_model_name_or_path)
    if self.tokenizer.pad_token is None:
      self.tokenizer.pad_token = self.tokenizer.eos_token
      self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

  def to(self, *args, **kwargs):
    self.gen_ppl = self.gen_ppl.to(*args, **kwargs)
    self.sample_entropy = self.sample_entropy.to(*args, **kwargs)
    self.train_nlls = self.train_nlls.to(*args, **kwargs)
    self.train_aux = self.train_aux.to(*args, **kwargs)
    self.valid_nlls = self.valid_nlls.to(*args, **kwargs)
    self.valid_aux = self.valid_aux.to(*args, **kwargs)

  def reset(self):
    self.gen_ppl.reset()
    self.sample_entropy.reset()
    self.train_nlls.reset()
    self.train_aux.reset()
    self.valid_nlls.reset()
    self.valid_aux.reset()

  def update_train(self, nll, aux_loss, num_tokens):
    self.train_nlls.update(nll, num_tokens)
    self.train_aux.update(aux_loss, num_tokens)

  def update_valid(self, nll, aux_loss, num_tokens):
    self.valid_nlls.update(nll, num_tokens)
    self.valid_aux.update(aux_loss, num_tokens)


  @torch.no_grad()
  def _eval_retokenize(self, text_samples, max_length,
                       device):
    """Retokenizes samples for the eval model.
    
    Args:
        text_samples: List of sentences generated by the model.
    Returns:
        samples: Samples re-tokenized for the eval model
        attn_mask: Attention mask for the eval model
        eval_context_size: Size of the context for the eval model
    """
    if 'llama2' in self.gen_ppl_eval_model_name_or_path:
      tokenizer_kwargs = {
        'text_samples': text_samples,
        'return_tensors': 'pt',
        'return_token_type_ids': False,
        'return_attention_mask': True,
        'truncation': True,
        'padding': True,
        'max_length': max_length,
      }
      eval_context_size = 4096
    else:
      tokenizer_kwargs = {
        'return_tensors': 'pt',
        'return_token_type_ids': False,
        'return_attention_mask': True,
        'truncation': True,
        'padding': True,
        'max_length': max_length,
      }
      eval_context_size = 1024
    samples = self.tokenizer(text_samples,
                             **tokenizer_kwargs)
    attn_mask = samples['attention_mask']
    samples = samples['input_ids']
    if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
      attn_mask = attn_mask.to(device)
      samples = samples.to(device)      
    return samples, attn_mask, eval_context_size

  @torch.no_grad()
  def record_entropy(self, tokens):
    for sample in tokens:
      _, counts = torch.unique(
        sample, return_counts=True, sorted=False)
      entropy = torch.special.entr(
        counts.float() / counts.sum()).sum().item()
      self.sample_entropy.update(entropy)

  @torch.no_grad()
  def record_generative_perplexity(
    self,
    text_samples: typing.List[str],
    max_length: int,
    retokenize: bool = True,
    device='cuda') -> None:
    
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    eval_model = transformers.AutoModelForCausalLM.from_pretrained(
      self.gen_ppl_eval_model_name_or_path).eval()
    if 'llama2' not in self.gen_ppl_eval_model_name_or_path:
      eval_model = eval_model.to(device)
    # Re-tokenize using eval model's tokenizer
    if retokenize:
      (samples, attn_mask,
       eval_context_size) = self._eval_retokenize(
         text_samples, max_length=max_length, device=device)
    else:
      samples = text_samples
      attn_mask = torch.ones(samples.shape).to(device)
      eval_context_size = samples.shape[-1]
    batch_size = min(self.eval_ppl_batch_size,
                     samples.shape[0])
    num_batches = samples.shape[0] // batch_size
    for i in range(num_batches):
      _samples = torch.split(
        samples[i * batch_size: (i + 1) * batch_size],
        eval_context_size,
        dim=-1)
      _attn_mask = torch.split(
        attn_mask[i * batch_size: (i + 1) * batch_size],
        eval_context_size,
        dim=-1)
      for (sample_chunk, attn_mask_chunk) in zip(_samples,
                                                 _attn_mask):
        logits = eval_model(sample_chunk,
                            attention_mask=attn_mask_chunk)
        logits = logits[0].transpose(-1, -2)
        nlls = F.cross_entropy(logits[..., :-1],
                               sample_chunk[..., 1:],
                               reduction='none')
        first_eos = (
          sample_chunk
          == self.tokenizer.eos_token_id).cumsum(-1) == 1
        token_mask = sample_chunk != self.tokenizer.eos_token_id
        valid_tokens = first_eos[..., 1:] + token_mask[..., 1:]
        self.gen_ppl.update(nlls * valid_tokens, valid_tokens)


================================================
FILE: models/__init__.py
================================================
from . import dit
from . import ema
from . import unet

================================================
FILE: models/dit.py
================================================
import math
import typing

import einops
import flash_attn
import flash_attn.layers.rotary
import huggingface_hub
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F

# Flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)


def bias_dropout_add_scale(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float,
    training: bool) -> torch.Tensor:
  if bias is not None:
    out = scale * F.dropout(x + bias, p=prob, training=training)
  else:
    out = scale * F.dropout(x, p=prob, training=training)

  if residual is not None:
    out = residual + out
  return out


def get_bias_dropout_add_scale(training):
  def _bias_dropout_add(x, bias, scale, residual, prob):
    return bias_dropout_add_scale(
      x, bias, scale, residual, prob, training)

  return _bias_dropout_add


# function overload
def modulate(x: torch.Tensor,
             shift: torch.Tensor,
             scale: torch.Tensor) -> torch.Tensor:
  return x * (1 + scale) + shift


@torch.jit.script
def bias_dropout_add_scale_fused_train(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float) -> torch.Tensor:
  return bias_dropout_add_scale(
    x, bias, scale, residual, prob, True)


@torch.jit.script
def bias_dropout_add_scale_fused_inference(
    x: torch.Tensor,
    bias: typing.Optional[torch.Tensor],
    scale: torch.Tensor,
    residual: typing.Optional[torch.Tensor],
    prob: float) -> torch.Tensor:
  return bias_dropout_add_scale(
    x, bias, scale, residual, prob, False)


@torch.jit.script
def modulate_fused(x: torch.Tensor,
                   shift: torch.Tensor,
                   scale: torch.Tensor) -> torch.Tensor:
  return modulate(x, shift, scale)


class Rotary(torch.nn.Module):
  def __init__(self, dim, base=10_000):
    super().__init__()
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    self.register_buffer('inv_freq', inv_freq)
    self.seq_len_cached = None
    self.cos_cached = None
    self.sin_cached = None

  def forward(self, x, seq_dim=1):
    seq_len = x.shape[seq_dim]
    if seq_len != self.seq_len_cached:
      self.seq_len_cached = seq_len
      t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
      freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
      emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
      # dims are: batch, seq_len, qkv, head, dim
      self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
      self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
      # This makes the transformation on v an identity.
      self.cos_cached[:,:,2,:,:].fill_(1.)
      self.sin_cached[:,:,2,:,:].fill_(0.)

    return self.cos_cached, self.sin_cached


def rotate_half(x):
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
  return torch.cat((-x2, x1), dim=-1)


def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
  with torch.amp.autocast('cuda', enabled=False):
    cos, sin = rotary_cos_sin
    cos = cos.to(qkv.dtype)
    sin = sin.to(qkv.dtype)
    cos = cos[0,:,0,0,:cos.shape[-1]//2]
    sin = sin[0,:,0,0,:sin.shape[-1]//2]
    q, k, v = qkv.chunk(3, dim=2)
    q = flash_attn.layers.rotary.apply_rotary_emb_torch(
      q.squeeze(dim=2), cos, sin)
    k = flash_attn.layers.rotary.apply_rotary_emb_torch(
      k.squeeze(dim=2), cos, sin)
    v = v.squeeze(dim=2)
  return q, k, v


def apply_rotary_pos_emb(qkv, cos, sin):
  cos = cos[0,:,0,0,:cos.shape[-1]//2]
  sin = sin[0,:,0,0,:sin.shape[-1]//2]
  return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)


def regular_attention_multi_headed(q, k, v):
  # Assuming qkv is a tensor with shape [batch, seq_len, 3, num_heads, head_dim]
  # where the 3 represents Q, K, V packed in that order
  attention_output = F.scaled_dot_product_attention(
    query=q.transpose(1, 2),
    key=k.transpose(1, 2),
    value=v.transpose(1, 2),
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False)
  # [batch_size, seq_len, num_heads, head_dim]
  attention_output = attention_output.transpose(1, 2)
  return einops.rearrange(attention_output, 'b s h d -> b s (h d)')


#################################################################################
#                                  Layers                                       #
#################################################################################
class LayerNorm(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.weight = nn.Parameter(torch.ones([dim]))
    self.dim = dim
  def forward(self, x):
    with torch.amp.autocast('cuda', enabled=False):
      x = F.layer_norm(x.float(), [self.dim])
    return x * self.weight[None, None, :]


def residual_linear(x, W, x_skip, residual_scale):
  """x_skip + residual_scale * W @ x"""
  dim_out, dim_in = W.shape[0], W.shape[1]
  return torch.addmm(
    x_skip.view(-1, dim_out),
    x.view(-1, dim_in),
    W.T,
    alpha=residual_scale).view(*x.shape[:-1], dim_out)


#################################################################################
#               Embedding Layers for Timesteps and Class Labels                 #
#################################################################################
class TimestepEmbedder(nn.Module):
  """
  Embeds scalar timesteps into vector representations.
  """
  def __init__(self, hidden_size, frequency_embedding_size=256):
    super().__init__()
    self.mlp = nn.Sequential(
      nn.Linear(frequency_embedding_size, hidden_size, bias=True),
      nn.SiLU(),
      nn.Linear(hidden_size, hidden_size, bias=True))
    self.frequency_embedding_size = frequency_embedding_size

  @staticmethod
  def timestep_embedding(t, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param t: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an (N, D) Tensor of positional embeddings.
    """
    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
    half = dim // 2
    freqs = torch.exp(
      - math.log(max_period)
      * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
      / half)
    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
      embedding = torch.cat(
        [embedding,
         torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

  def forward(self, t):
    t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
    t_emb = self.mlp(t_freq)
    return t_emb


class LabelEmbedder(nn.Module):
  """Embeds class labels into vector representations.
  
  Also handles label dropout for classifier-free guidance.
  """
  def __init__(self, num_classes, cond_size):
    super().__init__()
    self.embedding_table = nn.Embedding(num_classes + 1, cond_size)
    self.num_classes = num_classes

    # TODO think of initializing with 0.02 std deviation like in original DiT paper

  def forward(self, labels):
    embeddings = self.embedding_table(labels)
    return embeddings
    

#################################################################################
#                                 Core Model                                    #
#################################################################################

class DDiTBlockCausal(nn.Module):
  def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1):
    super().__init__()
    self.n_heads = n_heads

    self.norm1 = LayerNorm(dim)
    self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
    self.attn_out = nn.Linear(dim, dim, bias=False)
    self.dropout1 = nn.Dropout(dropout)

    self.norm2 = LayerNorm(dim)
    self.mlp = nn.Sequential(
      nn.Linear(dim, mlp_ratio * dim, bias=True),
      nn.GELU(approximate='tanh'),
      nn.Linear(mlp_ratio * dim, dim, bias=True))
    self.dropout2 = nn.Dropout(dropout)
    self.dropout = dropout

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return bias_dropout_add_scale_fused_inference

  def forward(self, x, rotary_cos_sin, **kwargs):
    del kwargs
    batch_size, seq_len = x.shape[0], x.shape[1]

    bias_dropout_scale_fn = self._get_bias_dropout_scale()

    # attention operation
    x_skip = x
    x = self.norm1(x)

    qkv = self.attn_qkv(x)
    qkv = einops.rearrange(
      qkv,
      'b s (three h d) -> b s three h d',
      three=3,
      h=self.n_heads)
    with torch.amp.autocast('cuda', enabled=False):
      cos, sin = rotary_cos_sin
      qkv = apply_rotary_pos_emb(
        qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
      )
    qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...')
    cu_seqlens = torch.arange(
      0, (batch_size + 1) * seq_len,
      step=seq_len, dtype=torch.int32, device=qkv.device)
    x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(
      qkv, cu_seqlens, seq_len, 0.0, causal=True)

    x = einops.rearrange(x, '(b s) h d -> b s (h d)',
                         b=batch_size)

    scale = torch.ones(1, device=x.device, dtype=x.dtype)
    x = bias_dropout_scale_fn(
      self.attn_out(x), None, scale, x_skip, self.dropout)

    # mlp operation
    x = bias_dropout_scale_fn(
      self.mlp(self.norm2(x)), None, scale, x, self.dropout)
    return x



class DDiTBlock(nn.Module):
  def __init__(self, dim, n_heads, adaLN,
               cond_dim=None, mlp_ratio=4,
               dropout=0.1):
    super().__init__()
    self.n_heads = n_heads
    self.adaLN = adaLN

    self.norm1 = LayerNorm(dim)
    self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
    self.attn_out = nn.Linear(dim, dim, bias=False)
    self.dropout1 = nn.Dropout(dropout)

    self.norm2 = LayerNorm(dim)
    self.mlp = nn.Sequential(
      nn.Linear(dim, mlp_ratio * dim, bias=True),
      nn.GELU(approximate='tanh'),
      nn.Linear(mlp_ratio * dim, dim, bias=True))
    self.dropout2 = nn.Dropout(dropout)
    self.dropout = dropout

    if self.adaLN:
      self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim)
      self.adaLN_modulation.weight.data.zero_()
      self.adaLN_modulation.bias.data.zero_()


  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return bias_dropout_add_scale_fused_inference


  def forward(self, x, rotary_cos_sin, c=None):

    bias_dropout_scale_fn = self._get_bias_dropout_scale()

    x_skip = x
    x = self.norm1(x)

    if self.adaLN:
      # self.adaLN_modulation(c): (128, 1536)
      # self.adaLN_modulation(c)[:, None]: (128, 1, 1536)
      # "" .chunk(6, dim=2) returns 6 tuples of shapes (128, 1, 256)
      (shift_msa, scale_msa, gate_msa, shift_mlp,
       scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
      x = modulate_fused(x, shift_msa, scale_msa)

    qkv = einops.rearrange(
      self.attn_qkv(x),
      'b s (three h d) -> b s three h d',
      three=3,
      h=self.n_heads)
    q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)
    
    x = regular_attention_multi_headed(q, k, v)

    if self.adaLN:
      x = bias_dropout_scale_fn(self.attn_out(x),
                                None,
                                gate_msa,
                                x_skip,
                                self.dropout)
      x = bias_dropout_scale_fn(
        self.mlp(modulate_fused(
          self.norm2(x), shift_mlp, scale_mlp)),
        None, gate_mlp, x, self.dropout)
    else:
      scale = torch.ones(1, device=x.device, dtype=x.dtype)
      x = bias_dropout_scale_fn(
        self.attn_out(x), None, scale, x_skip, self.dropout)
      x = bias_dropout_scale_fn(
        self.mlp(self.norm2(x)), None, scale, x, self.dropout)
    return x


class EmbeddingLayer(nn.Module):
  def __init__(self, dim, vocab_dim):
    super().__init__()
    self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))
    torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))

  def forward(self, x, weights=None):
    if weights is not None:
      bs, seq_len, k = x.shape
      flat_x = x.reshape(-1, k)
      flat_w = weights.reshape(-1, k).float()
      bag = F.embedding_bag(flat_x, self.embedding.float(),
                            per_sample_weights=flat_w,
                            mode='sum')
      return bag.view(bs, seq_len, -1)
    elif x.ndim == 2:
      return self.embedding[x]
    assert x.ndim == 3
    return torch.einsum(
      "blv,ve->ble",
      torch.nn.functional.softmax(x, dim=-1).float(),
      self.embedding.float()).to(x.dtype)


class DDiTFinalLayer(nn.Module):
  def __init__(self, hidden_size, out_channels, cond_dim,
               adaLN):
    super().__init__()
    self.norm_final = LayerNorm(hidden_size)
    self.linear = nn.Linear(hidden_size, out_channels)
    self.linear.weight.data.zero_()
    self.linear.bias.data.zero_()
    self.adaLN = adaLN
    if self.adaLN:
      self.adaLN_modulation = nn.Linear(cond_dim,
                                        2 * hidden_size,
                                        bias=True)
      self.adaLN_modulation.weight.data.zero_()
      self.adaLN_modulation.bias.data.zero_()

  def forward(self, x, c):
    x = self.norm_final(x)
    if self.adaLN:
      shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
      x = modulate_fused(x, shift, scale)
    x = self.linear(x)
    return x


class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
  def __init__(self, config, vocab_size: int):
    super().__init__()
    if type(config) == dict:
      config = omegaconf.OmegaConf.create(config)
    self.causal = config.algo.causal_attention
    self.adaLN = not self.causal
    self.config = config
    self.vocab_size = vocab_size
    dim = config.model.hidden_size
    cond_dim = config.model.cond_dim
    self.vocab_embed = EmbeddingLayer(dim, vocab_size)
    if not self.causal:
      self.sigma_map = TimestepEmbedder(cond_dim)
    self.rotary_emb = Rotary(dim // config.model.n_heads)

    blocks = []
    for _ in range(config.model.n_blocks):
      if self.causal:
        block = DDiTBlockCausal(
          dim=dim,
          n_heads=config.model.n_heads,
          dropout=config.model.dropout)
      else:
        block = DDiTBlock(
          dim=dim,
          n_heads=config.model.n_heads,
          cond_dim=cond_dim,
          adaLN=self.adaLN,
          dropout=config.model.dropout)
      blocks.append(block)
    self.blocks = nn.ModuleList(blocks)

    self.output_layer = DDiTFinalLayer(
      hidden_size=dim,
      out_channels=vocab_size,
      cond_dim=cond_dim,
      adaLN=self.adaLN)
    self.scale_by_sigma = config.model.scale_by_sigma

  def _get_bias_dropout_scale(self):
    if self.training:
      return bias_dropout_add_scale_fused_train
    else:
      return  bias_dropout_add_scale_fused_inference

  def forward(self, x, sigma, class_cond=None, weights=None):
    assert class_cond is None, 'Not implemented for DiT'
    x = self.vocab_embed(x, weights)
    if self.causal:
      t_cond = None
    else:
      t_cond = F.silu(self.sigma_map(sigma))

    rotary_cos_sin = self.rotary_emb(x)
    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
      for i in range(len(self.blocks)):
        x = self.blocks[i](x, rotary_cos_sin, c=t_cond)
      x = self.output_layer(x, c=t_cond)

    return x


================================================
FILE: models/ema.py
================================================
import torch


class ExponentialMovingAverage:
  """
  Maintains (exponential) moving average of a set of parameters.
  """

  def __init__(self, parameters, decay, use_num_updates=True):
    """
    Args:
        parameters: Iterable of `torch.nn.Parameter`; usually the result of
            `model.parameters()`.
        decay: The exponential decay.
        use_num_updates: Whether to use number of updates when computing
            averages.
    """
    if decay < 0.0 or decay > 1.0:
      raise ValueError('Decay must be between 0 and 1')
    self.decay = decay
    self.num_updates = 0 if use_num_updates else None
    self.shadow_params = [p.clone().detach()
                          for p in parameters if p.requires_grad]
    self.collected_params = []

  def move_shadow_params_to_device(self, device):
    self.shadow_params = [i.to(device) for i in self.shadow_params]

  def update(self, parameters):
    """
    Update currently maintained parameters.

    Call this every time the parameters are updated, such as the result of
    the `optimizer.step()` call.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; usually the same set of
            parameters used to initialize this object.
    """
    decay = self.decay
    if self.num_updates is not None:
      self.num_updates += 1
      decay = min(decay, (1 + self.num_updates) /
                  (10 + self.num_updates))
    one_minus_decay = 1.0 - decay
    with torch.no_grad():
      parameters = [p for p in parameters if p.requires_grad]
      for s_param, param in zip(self.shadow_params, parameters):
        s_param.sub_(one_minus_decay * (s_param - param))

  def copy_to(self, parameters):
    """
    Copy current parameters into given collection of parameters.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored moving averages.
    """
    parameters = [p for p in parameters if p.requires_grad]
    for s_param, param in zip(self.shadow_params, parameters):
      if param.requires_grad:
        param.data.copy_(s_param.data)

  def store(self, parameters):
    """
    Save the current parameters for restoring later.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
    """
    self.collected_params = [param.clone() for param in parameters]

  def restore(self, parameters):
    """
    Restore the parameters stored with the `store` method.
    Useful to validate the model with EMA parameters without affecting the
    original optimization process. Store the parameters before the
    `copy_to` method. After validation (or model saving), use this to
    restore the former parameters.

    Args:
        parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
    """
    for c_param, param in zip(self.collected_params, parameters):
      param.data.copy_(c_param.data)

  def state_dict(self):
    return dict(decay=self.decay,
                num_updates=self.num_updates,
                shadow_params=self.shadow_params)

  def load_state_dict(self, state_dict):
    self.decay = state_dict['decay']
    self.num_updates = state_dict['num_updates']
    self.shadow_params = state_dict['shadow_params']


================================================
FILE: models/unet.py
================================================
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import numpy as np
import omegaconf

import transformers
from einops import rearrange
from .dit import LabelEmbedder, EmbeddingLayer


# From https://github.com/yang-song/score_sde_pytorch/ which is from
#  https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
def transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
  assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
  half_dim = embedding_dim // 2
  # magic number 10000 is from transformers
  emb = math.log(max_positions) / (half_dim - 1)
  # emb = math.log(2.) / (half_dim - 1)
  emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
  # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
  # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
  emb = timesteps.float()[:, None] * emb[None, :]
  emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  if embedding_dim % 2 == 1:  # zero pad
    emb = F.pad(emb, (0, 1), mode='constant')
  assert emb.shape == (timesteps.shape[0], embedding_dim)
  return emb


# Code modified from https://github.com/yang-song/score_sde_pytorch
def variance_scaling(scale, mode, distribution,
                     in_axis=1, out_axis=0,
                     dtype=torch.float32,
                     device='cpu'):
  """Ported from JAX. """

  def _compute_fans(shape, in_axis=1, out_axis=0):
    receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
    fan_in = shape[in_axis] * receptive_field_size
    fan_out = shape[out_axis] * receptive_field_size
    return fan_in, fan_out

  def init(shape, dtype=dtype, device=device):
    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
    if mode == "fan_in":
      denominator = fan_in
    elif mode == "fan_out":
      denominator = fan_out
    elif mode == "fan_avg":
      denominator = (fan_in + fan_out) / 2
    else:
      raise ValueError(
        "invalid mode for variance scaling initializer: {}".format(mode))
    variance = scale / denominator
    if distribution == "normal":
      return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
    elif distribution == "uniform":
      return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
    else:
      raise ValueError("invalid distribution for variance scaling initializer")

  return init


def default_init(scale=1.):
  """The same initialization used in DDPM."""
  scale = 1e-10 if scale == 0 else scale
  return variance_scaling(scale, 'fan_avg', 'uniform')


class NiN(nn.Module):
  def __init__(self, in_ch, out_ch, init_scale=0.1):
    super().__init__()
    self.W = nn.Parameter(default_init(scale=init_scale)((in_ch, out_ch)), requires_grad=True)
    self.b = nn.Parameter(torch.zeros(out_ch), requires_grad=True)

  def forward(self, x, #  ["batch", "in_ch", "H", "W"]
    ):

    x = x.permute(0, 2, 3, 1)
    # x (batch, H, W, in_ch)
    y = torch.einsum('bhwi,ik->bhwk', x, self.W) + self.b
    # y (batch, H, W, out_ch)
    return y.permute(0, 3, 1, 2)

class AttnBlock(nn.Module):
  """Channel-wise self-attention block."""
  def __init__(self, channels, skip_rescale=True):
    super().__init__()
    self.skip_rescale = skip_rescale
    self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels//4, 32),
        num_channels=channels, eps=1e-6)
    self.NIN_0 = NiN(channels, channels)
    self.NIN_1 = NiN(channels, channels)
    self.NIN_2 = NiN(channels, channels)
    self.NIN_3 = NiN(channels, channels, init_scale=0.)

  def forward(self, x, # ["batch", "channels", "H", "W"]
    ):

    B, C, H, W = x.shape
    h = self.GroupNorm_0(x)
    q = self.NIN_0(h)
    k = self.NIN_1(h)
    v = self.NIN_2(h)

    w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
    w = torch.reshape(w, (B, H, W, H * W))
    w = F.softmax(w, dim=-1)
    w = torch.reshape(w, (B, H, W, H, W))
    h = torch.einsum('bhwij,bcij->bchw', w, v)
    h = self.NIN_3(h)

    if self.skip_rescale:
        return (x + h) / np.sqrt(2.)
    else:
        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_rescale=True):
        super().__init__()

        self.in_ch = in_ch
        self.out_ch = out_ch

        self.skip_rescale = skip_rescale

        self.act = nn.functional.silu
        self.groupnorm0 = nn.GroupNorm(
            num_groups=min(in_ch // 4, 32),
            num_channels=in_ch, eps=1e-6
        )
        self.conv0 = nn.Conv2d(
            in_ch, out_ch, kernel_size=3, padding=1
        )

        if temb_dim is not None:
            self.dense0 = nn.Linear(temb_dim, out_ch)
            nn.init.zeros_(self.dense0.bias)


        self.groupnorm1 = nn.GroupNorm(
            num_groups=min(out_ch // 4, 32),
            num_channels=out_ch, eps=1e-6
        )
        self.dropout0 = nn.Dropout(dropout)

        self.conv1 = nn.Conv2d(
            out_ch, out_ch, kernel_size=3, padding=1
        )
        if out_ch != in_ch:
            self.nin = NiN(in_ch, out_ch)

    def forward(self, x, # ["batch", "in_ch", "H", "W"]
                temb=None, #  ["batch", "temb_dim"]
        ):

        assert x.shape[1] == self.in_ch

        h = self.groupnorm0(x)
        h = self.act(h)
        h = self.conv0(h)

        if temb is not None:
            h += self.dense0(self.act(temb))[:, :, None, None]

        h = self.groupnorm1(h)
        h = self.act(h)
        h = self.dropout0(h)
        h = self.conv1(h)
        if h.shape[1] != self.in_ch:
            x = self.nin(x)

        assert x.shape == h.shape

        if self.skip_rescale:
            return (x + h) / np.sqrt(2.)
        else:
            return x + h

class Downsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, 
            stride=2, padding=0)

    def forward(self, x, # ["batch", "ch", "inH", "inW"]
        ):
        B, C, H, W = x.shape
        x = nn.functional.pad(x, (0, 1, 0, 1))
        x= self.conv(x)

        assert x.shape == (B, C, H // 2, W // 2)
        return x

class Upsample(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

  def forward(self, x, # ["batch", "ch", "inH", "inW"]
    ):
    B, C, H, W = x.shape
    h = F.interpolate(x, (H*2, W*2), mode='nearest')
    h = self.conv(h)

    assert h.shape == (B, C, H*2, W*2)
    return h


class UNet(nn.Module):
    def __init__(self, config, vocab_size=None):
        super().__init__()
        if type(config) == dict:
            config = omegaconf.OmegaConf.create(config)
        assert config.model.name == 'unet'
        self.ch = config.model.ch
        self.num_res_blocks = config.model.num_res_blocks
        self.num_scales = config.model.num_scales
        self.ch_mult = config.model.ch_mult
        assert self.num_scales == len(self.ch_mult)
        self.input_channels = config.model.input_channels
        self.output_channels = 2 * config.model.input_channels
        self.scale_count_to_put_attn = config.model.scale_count_to_put_attn
        self.data_min_max = [0, vocab_size] # config.model.data_min_max # tuple of min and max value of input so it can be rescaled to [-1, 1]
        self.dropout = config.model.dropout
        self.skip_rescale = config.model.skip_rescale
        self.time_conditioning = config.model.time_conditioning # Whether to add in time embeddings
        self.time_scale_factor = config.model.time_scale_factor # scale to make the range of times be 0 to 1000
        self.time_embed_dim = config.model.time_embed_dim
        self.vocab_size = vocab_size

        self.size = config.model.size
        self.length = config.model.length

        # truncated logistic 
        self.fix_logistic = config.model.fix_logistic

        self.act = nn.functional.silu

        if self.time_conditioning:
            self.temb_modules = []
            self.temb_modules.append(nn.Linear(self.time_embed_dim, self.time_embed_dim*4))
            nn.init.zeros_(self.temb_modules[-1].bias)
            self.temb_modules.append(nn.Linear(self.time_embed_dim*4, self.time_embed_dim*4))
            nn.init.zeros_(self.temb_modules[-1].bias)
            self.temb_modules = nn.ModuleList(self.temb_modules)

        self.expanded_time_dim = 4 * self.time_embed_dim if self.time_conditioning else None

        self.input_conv = nn.Conv2d(
            in_channels=self.input_channels, out_channels=self.ch,
            kernel_size=3, padding=1
        )

        h_cs = [self.ch]
        in_ch = self.ch

        # Downsampling
        self.downsampling_modules = []

        for scale_count in range(self.num_scales):
            for res_count in range(self.num_res_blocks):
                out_ch = self.ch * self.ch_mult[scale_count]
                self.downsampling_modules.append(
                    ResBlock(in_ch, out_ch, temb_dim=self.expanded_time_dim,
                        dropout=self.dropout, skip_rescale=self.skip_rescale)
                )
                in_ch = out_ch
                h_cs.append(in_ch)
                if scale_count == self.scale_count_to_put_attn:
                    self.downsampling_modules.append(
                        AttnBlock(in_ch, skip_rescale=self.skip_rescale)
                    )

            if scale_count != self.num_scales - 1:
                self.downsampling_modules.append(Downsample(in_ch))
                h_cs.append(in_ch)

        self.downsampling_modules = nn.ModuleList(self.downsampling_modules)

        # Middle
        self.middle_modules = []

        self.middle_modules.append(
            ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim,
                dropout=self.dropout, skip_rescale=self.skip_rescale)
        )
        self.middle_modules.append(
            AttnBlock(in_ch, skip_rescale=self.skip_rescale)
        )
        self.middle_modules.append(
            ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim,
                dropout=self.dropout, skip_rescale=self.skip_rescale)
        )
        self.middle_modules = nn.ModuleList(self.middle_modules)

        # Upsampling
        self.upsampling_modules = []

        for scale_count in reversed(range(self.num_scales)):
            for res_count in range(self.num_res_blocks+1):
                out_ch = self.ch * self.ch_mult[scale_count]
                self.upsampling_modules.append(
                    ResBlock(in_ch + h_cs.pop(), 
                        out_ch,
                        temb_dim=self.expanded_time_dim,
                        dropout=self.dropout,
                        skip_rescale=self.skip_rescale
                    )
                )
                in_ch = out_ch

                if scale_count == self.scale_count_to_put_attn:
                    self.upsampling_modules.append(
                        AttnBlock(in_ch, skip_rescale=self.skip_rescale)
                    )
            if scale_count != 0:
                self.upsampling_modules.append(Upsample(in_ch))

        self.upsampling_modules = nn.ModuleList(self.upsampling_modules)

        assert len(h_cs) == 0

        # output
        self.output_modules = []
        
        self.output_modules.append(
            nn.GroupNorm(min(in_ch//4, 32), in_ch, eps=1e-6)
        )

        self.output_modules.append(
           nn.Conv2d(in_ch, self.output_channels, kernel_size=3, padding=1)
        )
        self.output_modules = nn.ModuleList(self.output_modules)

        self.cond_map = LabelEmbedder(
            config.data.num_classes,
            self.time_embed_dim*4)

    def _center_data(self, x):
        out = (x - self.data_min_max[0]) / (self.data_min_max[1] - self.data_min_max[0]) # [0, 1]
        return 2 * out - 1 # to put it in [-1, 1]

    def _time_embedding(self, timesteps):
        if self.time_conditioning:
            temb = transformer_timestep_embedding(
                timesteps * self.time_scale_factor, self.time_embed_dim
            )
            temb = self.temb_modules[0](temb)
            temb = self.temb_modules[1](self.act(temb))
        else:
            temb = None

        return temb

    def _do_input_conv(self, h):
        h = self.input_conv(h)
        hs = [h]
        return h, hs

    def _do_downsampling(self, h, hs, temb):
        m_idx = 0
        for scale_count in range(self.num_scales):
            for res_count in range(self.num_res_blocks):
                h = self.downsampling_modules[m_idx](h, temb)
                m_idx += 1
                if scale_count == self.scale_count_to_put_attn:
                    h = self.downsampling_modules[m_idx](h)
                    m_idx += 1
                hs.append(h)

            if scale_count != self.num_scales - 1:
                h = self.downsampling_modules[m_idx](h)
                hs.append(h)
                m_idx += 1

        assert m_idx == len(self.downsampling_modules)

        return h, hs

    def _do_middle(self, h, temb):
        m_idx = 0
        h = self.middle_modules[m_idx](h, temb)
        m_idx += 1
        h = self.middle_modules[m_idx](h)
        m_idx += 1
        h = self.middle_modules[m_idx](h, temb)
        m_idx += 1

        assert m_idx == len(self.middle_modules)

        return h

    def _do_upsampling(self, h, hs, temb):
        m_idx = 0
        for scale_count in reversed(range(self.num_scales)):
            for res_count in range(self.num_res_blocks+1):
                h = self.upsampling_modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
                m_idx += 1

                if scale_count == self.scale_count_to_put_attn:
                    h = self.upsampling_modules[m_idx](h)
                    m_idx += 1

            if scale_count != 0:
                h = self.upsampling_modules[m_idx](h)
                m_idx += 1

        assert len(hs) == 0
        assert m_idx == len(self.upsampling_modules)

        return h

    def _do_output(self, h):

        h = self.output_modules[0](h)
        h = self.act(h)
        h = self.output_modules[1](h)

        return h

    def _logistic_output_res(self,
        h, #  ["B", "twoC", "H", "W"]
        centered_x_in, # ["B", "C", "H", "W"]
    ):
        B, twoC, H, W = h.shape
        C = twoC//2
        h[:, 0:C, :, :] = torch.tanh(centered_x_in + h[:, 0:C, :, :])
        return h

    def _log_minus_exp(self, a, b, eps=1e-6):
        """ 
            Compute log (exp(a) - exp(b)) for (b<a)
            From https://arxiv.org/pdf/2107.03006.pdf
        """
        return a + torch.log1p(-torch.exp(b-a) + eps)

    
    def _truncated_logistic_output(self, net_out):
        B, D = net_out.shape[0], self.length
        C = 3
        S = self.vocab_size

        # Truncated logistic output from https://arxiv.org/pdf/2107.03006.pdf
        mu = net_out[:, 0:C, :, :].unsqueeze(-1)
        log_scale = net_out[:, C:, :, :].unsqueeze(-1)

        inv_scale = torch.exp(- (log_scale - 2))

        bin_width = 2. / S
        bin_centers = torch.linspace(start=-1. + bin_width/2,
            end=1. - bin_width/2,
            steps=S,
            device='cuda').view(1, 1, 1, 1, S)

        sig_in_left = (bin_centers - bin_width/2 - mu) * inv_scale
        bin_left_logcdf = F.logsigmoid(sig_in_left)
        sig_in_right = (bin_centers + bin_width/2 - mu) * inv_scale
        bin_right_logcdf = F.logsigmoid(sig_in_right)

        logits_1 = self._log_minus_exp(bin_right_logcdf, bin_left_logcdf)
        logits_2 = self._log_minus_exp(-sig_in_left + bin_left_logcdf, -sig_in_right + bin_right_logcdf)
        if self.fix_logistic:
            logits = torch.min(logits_1, logits_2)
        else:
            logits = logits_1

        logits = logits.view(B,D,S)

        return logits


    def forward(self,
        x, # ["B", "C", "H", "W"]
        sigma=None, 
        class_cond=None,
        weights=None
    ):
        assert weights == None
        img_size = int(np.sqrt(self.size))

        h = rearrange(x, "b (c h w) -> b c h w", h=img_size, 
                      w=img_size, c=3)
        h = self._center_data(h)
        centered_x_in = h

        temb = self._time_embedding(sigma)
        if class_cond is not None:
            if self.cond_map is None:
                raise ValueError("Conditioning variable provided, "
                                "but Model was not initialized "
                                "with condition embedding layer.")
            else:
                assert class_cond.shape == (x.shape[0],)
                temb = temb + self.cond_map(class_cond)

        h, hs = self._do_input_conv(h)

        h, hs = self._do_downsampling(h, hs, temb)

        h = self._do_middle(h, temb)

        h = self._do_upsampling(h, hs, temb)

        h = self._do_output(h)

        # h (B, 2*C, H, W)
        h = self._logistic_output_res(h, centered_x_in)
        h = self._truncated_logistic_output(h) # (B, D, S)

        return h

================================================
FILE: models/unit_test_attention.py
================================================
import unittest

import torch

# from flash_attn import flash_attention
import torch.nn.functional as F


def attention_inner_heads_flash(qkv, num_heads):
    """Computes attention with heads inside of qkv in the channel dimension using FlashAttention.

    Args:
        qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:
            H = number of heads,
            C = number of channels per head.
        num_heads: number of heads.

    Returns:
        Attention output of shape (B, H*C, T).
    """

    # bs, width, length = qkv.shape
    # ch = width // (3 * num_heads)

    # # Split into (q, k, v) of shape (B, H*C, T).
    # q, k, v = qkv.chunk(3, dim=1)

    # # Rescale q and k. This makes them contiguous in memory.
    # scale = ch ** (-1 / 4)  # scale with 4th root = scaling output by sqrt
    # q = q * scale
    # k = k * scale

    # # Reshape q, k, v to (B, H, T, C) for FlashAttention
    # q = q.view(bs, num_heads, ch, length).permute(0, 1, 3, 2)  # (B, H, T, C)
    # k = k.view(bs, num_heads, ch, length).permute(0, 1, 3, 2)  # (B, H, T, C)
    # v = v.view(bs, num_heads, ch, length).permute(0, 1, 3, 2)  # (B, H, T, C)

    # # Compute attention using FlashAttention
    # out = flash_attention(q, k, v)  # (B, H, T, C)

    # # Reshape back to (B, H*C, T)
    # out = out.permute(0, 1, 3, 2).reshape(bs, num_heads * ch, length)
    # return out
    bs, width, length = qkv.shape
    ch = width // (3 * num_heads)

    # Split into (q, k, v) and reshape directly to (B, H, T, C)
    q, k, v = qkv.chunk(3, dim=1)
    q = q.view(bs, num_heads, ch, length).transpose(2, 3)  # (B, H, T, C)
    k = k.view(bs, num_heads, ch, length).transpose(2, 3)  # (B, H, T, C)
    v = v.view(bs, num_heads, ch, length).transpose(2, 3)  # (B, H, T, C)

    # Compute scaled dot-product attention
    out = F.scaled_dot_product_attention(q, k, v)  # (B, H, T, C)

    # Reshape back to (B, H*C, T) in one step
    out = out.transpose(2, 3).reshape(bs, num_heads * ch, length)
    return out



class TestAttentionInnerHeadsFlash(unittest.TestCase):
    def setUp(self):
        # Set up common test variables
        self.batch_size = 2
        self.num_heads = 4
        self.seq_len = 8
        self.channels_per_head = 16
        self.qkv_dim = 3 * self.num_heads * self.channels_per_head

        # Create a random qkv tensor
        self.qkv = torch.randn(self.batch_size, self.qkv_dim, self.seq_len, dtype=torch.float32)

    def attention_inner_heads_old(self, qkv, num_heads):
        """Original implementation of attention for reference."""
        bs, width, length = qkv.shape
        ch = width // (3 * num_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = ch ** (-1 / 4)
        q = q * scale
        k = k * scale
        q = q.view(bs * num_heads, ch, length)
        k = k.view(bs * num_heads, ch, length)
        v = v.reshape(bs * num_heads, ch, length)
        weight = torch.einsum("bct,bcs->bts", q, k)
        weight = torch.softmax(weight.float(), dim=-1).to(weight.dtype)
        out = torch.einsum("bts,bcs->bct", weight, v)
        return out.reshape(bs, num_heads * ch, length)

    def test_attention_inner_heads_flash(self):
        # Compute the output using the old and flash attention implementations
        output_old = self.attention_inner_heads_old(self.qkv, self.num_heads)
        output_flash = attention_inner_heads_flash(self.qkv, self.num_heads)

        # Verify the shapes are the same
        self.assertEqual(output_old.shape, output_flash.shape)

        # Verify that the outputs are close (numerically similar)
        self.assertTrue(torch.allclose(output_old, output_flash, atol=1e-5))

if __name__ == "__main__":
    unittest.main()

================================================
FILE: requirements.txt
================================================
# conda install nvidia/label/cuda-12.4.0::cuda-toolkit
datasets==2.15.0
einops==0.7.0
fsspec
git-lfs==1.6
h5py==3.10.0
hydra-core==1.3.2
ipdb==0.13.13
lightning==2.2.1
notebook==7.1.1
nvitop==1.3.2
omegaconf==2.3.0
packaging==23.2
pandas==2.2.1
rich==13.7.1
seaborn==0.13.2
scikit-learn==1.4.0
transformers==4.38.2
triton==2.2.0
torch==2.3.1
torchaudio==2.3.1
torchmetrics==1.6.1
torchvision==0.18.1
wandb
timm
ocifs
hf_transfer
huggingface-hub
# Install flash attention only after installing the above modules via pip install -r requirements.txt
# flash_attn==2.7.4.post1


================================================
FILE: scripts/distil_owt.sh
================================================
#!/bin/bash
#SBATCH -J posterior                 # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=64000                  # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov          # Request partition
#SBATCH --constraint="[a5000|a6000|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                 # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

export HYDRA_FULL_ERROR=1
finetune_path=/path/to/duo.ckpt

srun python -u -m main \
  mode=train \
  loader.batch_size=2 \
  loader.eval_batch_size=2 \
  data=openwebtext-split \
  model=small \
  algo=distillation \
  training.finetune_path=$finetune_path \
  sampling.num_sample_batches=10 \
  sampling.steps=32 \
  eval.compute_generative_perplexity=True \
  algo.T=512 \
  lr_scheduler.num_warmup_steps=500 \
  trainer.val_check_interval=1000 \
  trainer.max_steps=50000 \
  loader.global_batch_size=128 \
  training.ema=0.999 \
  algo.update_teacher_every=10000 \
  optim.lr=6e-5 \
  trainer.limit_val_batches=8 \
  algo.teacher_ema=False \
  algo.linear_growth_dt=false \
  +wandb.offline=true


================================================
FILE: scripts/eval_lm1b_duo.sh
================================================
#!/bin/bash
#SBATCH -J eval_mdlm                  # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=100000                  # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=gpu               # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-W2ZcFy-small-conjugate-lm1b-wrap-1M-gmin-350-gmax-175-3/checkpoints/7-100000.ckpt

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=ppl_eval \
  loader.batch_size=64 \
  loader.eval_batch_size=64 \
  data=lm1b-wrap \
  model=small \
  model.length=128 \
  algo=duo_base \
  eval.checkpoint_path=$checkpoint_path \
  sampling.num_sample_batches=0 \
  +wandb.offline=true

================================================
FILE: scripts/eval_owt_ar.sh
================================================
#!/bin/bash
#SBATCH -J eval_ar                # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=100000                  # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov          # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=ppl_eval \
  loader.batch_size=16 \
  loader.eval_batch_size=16 \
  data=openwebtext-split \
  algo=ar \
  model.length=1024 \
  eval.checkpoint_path=$checkpoint_path \
  +wandb.offline=true

================================================
FILE: scripts/eval_owt_duo.sh
================================================
#!/bin/bash
#SBATCH -J owt_duo_anneal                    # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=100000                  # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov          # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-VlCQLK-small-conjugate-OWT-anneal/checkpoints/last.ckpt

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=ppl_eval \
  loader.batch_size=8 \
  loader.eval_batch_size=8 \
  data=openwebtext-split \
  model=small \
  algo=duo_base \
  eval.checkpoint_path=$checkpoint_path \
  sampling.num_sample_batches=0 \
  +wandb.offline=true

================================================
FILE: scripts/eval_owt_mdlm.sh
================================================
#!/bin/bash
#SBATCH -J eval_mdlm                  # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=100000                  # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=gpu               # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=ppl_eval \
  loader.batch_size=16 \
  loader.eval_batch_size=16 \
  data=openwebtext-split \
  model=small \
  algo=mdlm \
  eval.checkpoint_path=$checkpoint_path \
  sampling.num_sample_batches=0 \
  +wandb.offline=true

================================================
FILE: scripts/eval_owt_sedd.sh
================================================
#!/bin/bash
#SBATCH -J eval_sedd              # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=100000                  # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov          # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=ppl_eval \
  loader.batch_size=16 \
  loader.eval_batch_size=16 \
  data=openwebtext-split \
  model=small \
  algo=sedd \
  eval.checkpoint_path=$checkpoint_path \
  sampling.num_sample_batches=0 \
  +wandb.offline=true

================================================
FILE: scripts/fid_cifar10_duo_ancestral_cosine.sh
================================================
export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1
python -u -m main \
    mode=fid_eval \
    sampling.steps=64 \
    sampling.guid_weight=1.0 \
    data=cifar10 \
    data.cache_dir=<YOUR-CACHE-PATH> \
    model=unet \
    noise=cosine \
    algo=duo_base \
    algo.backbone=unet \
    trainer.num_nodes=1 \
    loader.eval_batch_size=500 \
    eval.checkpoint_path=<PATH-TO-THE-MDLM-CHECKPOINT>


================================================
FILE: scripts/fid_cifar10_duo_base_ancestral_cosine.sh
================================================
export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1
python -u -m main \
    mode=fid_eval \
    sampling.steps=64 \
    sampling.guid_weight=1.0 \
    data=cifar10 \
    data.cache_dir=<YOUR-CACHE-PATH> \
    model=unet \
    noise=cosine \
    algo=duo_base \
    algo.backbone=unet \
    trainer.num_nodes=1 \
    loader.eval_batch_size=500 \
    eval.checkpoint_path=<PATH-TO-THE-MDLM-CHECKPOINT>


================================================
FILE: scripts/fid_cifar10_mdlm_ancestral_cosine.sh
================================================
python -u -m main \
    mode=fid_eval \
    sampling.steps=64 \
    sampling.guid_weight=1.0 \
    sampling.predictor=ancestral_cache \
    data=cifar10 \
    data.cache_dir=<YOUR-CACHE-PATH> \
    model=unet \
    noise=cosine \
    algo=mdlm \
    algo.backbone=unet \
    loader.eval_batch_size=500 \
    eval.checkpoint_path=<PATH-TO-THE-MDLM-CHECKPOINT>


================================================
FILE: scripts/gen_ppl_lm1b_ar.sh
================================================
#!/bin/bash
#SBATCH -J sample_ar                  # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=32000                   # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=gpu               # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/lm1b-ar/last_copy.ckpt

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=sample_eval \
  loader.batch_size=2 \
  loader.eval_batch_size=64 \
  data=lm1b-wrap \
  algo=ar \
  model=small \
  model.length=128 \
  eval.checkpoint_path=$checkpoint_path \
  sampling.num_sample_batches=15 \
  +wandb.offline=true


================================================
FILE: scripts/gen_ppl_lm1b_duo.sh
================================================
#!/bin/bash
#SBATCH -J sample_ar                  # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=32000                   # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=gpu               # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil-kl-bwd-32/checkpoints/0-1000.ckpt
steps=32

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=sample_eval \
  loader.batch_size=2 \
  loader.eval_batch_size=64 \
  data=lm1b-wrap \
  algo=duo_base \
  model=small \
  model.length=128 \
  eval.checkpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil7-kl-bwd/checkpoints/last.ckpt \
  sampling.num_sample_batches=15 \
  sampling.steps=$steps \
  +wandb.offline=true \
  sampling.noise_removal=greedy


================================================
FILE: scripts/gen_ppl_owt_ar.sh
================================================
#!/bin/bash
#SBATCH -J sample_ar                  # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=16000                   # server memory requested (per node)
#SBATCH -t 24:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov,gpu          # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints
seed=1

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=sample_eval \
  seed=$seed \
  loader.batch_size=2 \
  loader.eval_batch_size=8 \
  data=openwebtext-split \
  algo=ar \
  model=small \
  eval.checkpoint_path=$checkpoint_path/last.ckpt \
  sampling.num_sample_batches=100 \
  +wandb.offline=true \
  eval.generated_samples_path=$checkpoint_path/$seed-ckpt-last.json

================================================
FILE: scripts/gen_ppl_owt_duo.sh
================================================
#!/bin/bash
#SBATCH -J an_owt_duo                    # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=16000                   # server memory requested (per node)
#SBATCH -t 24:00:00                   # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov,gpu      # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

export HYDRA_FULL_ERROR=1

while [[ "$#" -gt 0 ]]; do
    case $1 in
        --steps) steps="$2"; shift ;;
        --seed) seed="$2"; shift ;;
        --ckpt) ckpt="$2"; shift ;;
        *) echo "Unknown parameter: $1"; exit 1 ;;
    esac
    shift
done

checkpoint_path=/share/kuleshov/ssahoo/flow-ode/distil-distil-vjrpZb-distillation-OWT/checkpoints
ckpt=0-50000

steps=${steps:-32}
seed=${seed:-1}
echo "  Steps: $steps"
echo "  Seed: $seed"
echo "  ckpt: $ckpt"

srun python -u -m main \
  mode=sample_eval \
  seed=$seed \
  loader.batch_size=2 \
  loader.eval_batch_size=8 \
  data=openwebtext-split \
  algo=duo_base \
  model=small \
  eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \
  sampling.num_sample_batches=100 \
  sampling.steps=$steps \
  +wandb.offline=true \
  eval.generated_samples_path=$checkpoint_path/samples_ancestral/$seed-$steps-ckpt-$ckpt.json

================================================
FILE: scripts/gen_ppl_owt_mdlm.sh
================================================
#!/bin/bash
#SBATCH -J sample_mdlm                # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=64000                   # server memory requested (per node)
#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov          # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt

export HYDRA_FULL_ERROR=1

srun python -u -m main \
  mode=sample_eval \
  loader.batch_size=16 \
  loader.eval_batch_size=16 \
  data=openwebtext-split \
  model=small \
  algo=mdlm \
  eval.checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt \
  sampling.num_sample_batches=4 \
  +wandb.offline=true

================================================
FILE: scripts/gen_ppl_owt_sedd.sh
================================================
#!/bin/bash
#SBATCH -J sedd_samples               # Job name
#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)
#SBATCH -N 1                          # Total number of nodes requested
#SBATCH --get-user-env                # retrieve the users login environment
#SBATCH --mem=16000                   # server memory requested (per node)
#SBATCH -t 24:00:00                   # Time limit (hh:mm:ss)
#SBATCH --partition=kuleshov,gpu      # Request partition
#SBATCH --constraint="[a5000|a6000|a100|3090]"
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1                  # Type/number of GPUs needed
#SBATCH --open-mode=append            # Do not overwrite logs
#SBATCH --requeue                     # Requeue upon preemption

while [[ "$#" -gt 0 ]]; do
    case $1 in
        --steps) steps="$2"; shift ;;
        --seed) seed="$2"; shift ;;
        *) echo "Unknown parameter: $1"; exit 1 ;;
    esac
    shift
done

steps=${steps:-32}
seed=${seed:-1}

echo "  Steps: $steps"
echo "  Seed: $seed"

export HYDRA_FULL_ERROR=1
checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints
ckpt=last

srun python -u -m main \
  mode=sample_eval \
  seed=$seed \
  loader.batch_size=2 \
  loader.eval_batch_size=8 \
  data=openwebtext-split \
  algo=sedd \
  model=small \
  eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \
  sampling.num_sample_batches=0 \
  sampling.num_sample_batches=100 \
  sampling.steps=$steps \
  sampling.predictor=analytic \
  eval.generated_samples_path=$checkpoint_path/$seed-$steps-ckpt-$ckpt.json \
  +wandb.offline=true

================================================
FILE: scripts/psi_samplers/cifar10/duo_constant_remdm.sh
================================================
# DUO psi-sampler with constant-remdm-eta mode (ReMDM loop)

NUM_STEPS=256
ETA=0.01
NOISE=cosine
CHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>
DATA_CACHE_DIR=<YOUR-CACHE-PATH>
EVAL_BATCH_SIZE=500

python -u -m main \
    mode=fid_eval \
    data=cifar10 \
    data.cache_dir=$DATA_CACHE_DIR \
    model=unet \
    algo=duo_base \
    algo.backbone=unet \
    noise=$NOISE \
    sampling.predictor=psi \
    sampling.steps=$NUM_STEPS \
    sampling.guid_weight=1.0 \
    eval.checkpoint_path=$CHECKPOINT_PATH \
    loader.eval_batch_size=$EVAL_BATCH_SIZE \
    sampling.psi.time_profile=linear-constant-linear-0.9-inv \
    sampling.psi.high_mode=pure-posterior \
    sampling.psi.middle_mode=constant-remdm-$ETA \
    sampling.psi.low_mode=pure-posterior \
    sampling.psi.high_frac=0.45 \
    sampling.psi.middle_frac=0.5


================================================
FILE: scripts/psi_samplers/cifar10/duo_max_capped_remdm.sh
================================================
# DUO psi-sampler with max-capped-eta mode (ReMDM cap)

NUM_STEPS=256
ETA=0.005
NOISE=cosine
CHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>
DATA_CACHE_DIR=<YOUR-CACHE-PATH>
EVAL_BATCH_SIZE=500

python -u -m main \
    mode=fid_eval \
    data=cifar10 \
    data.cache_dir=$DATA_CACHE_DIR \
    model=unet \
    algo=duo_base \
    algo.backbone=unet \
    noise=$NOISE \
    sampling.predictor=psi \
    sampling.steps=$NUM_STEPS \
    sampling.guid_weight=1.0 \
    eval.checkpoint_path=$CHECKPOINT_PATH \
    loader.eval_batch_size=$EVAL_BATCH_SIZE \
    sampling.psi.time_profile=linear \
    sampling.psi.high_mode=max-capped-$ETA \
    sampling.psi.middle_mode=max-capped-$ETA \
    sampling.psi.low_mode=max-capped-$ETA \
    sampling.psi.high_frac=0.0 \
    sampling.psi.middle_frac=0.0


================================================
FILE: scripts/psi_samplers/cifar10/duo_max_rescale_eta.sh
================================================
# DUO psi-sampler with max-rescale-eta mode — CIFAR-10 FID eval

NUM_STEPS=256
ETA=0.01
NOISE=cosine
CHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>
DATA_CACHE_DIR=<YOUR-CACHE-PATH>
EVAL_BATCH_SIZE=500

python -u -m main \
    mode=fid_eval \
    data=cifar10 \
    data.cache_dir=$DATA_CACHE_DIR \
    model=unet \
    algo=duo_base \
    algo.backbone=unet \
    noise=$NOISE \
    sampling.predictor=psi \
    sampling.steps=$NUM_STEPS \
    sampling.guid_weight=1.0 \
    eval.checkpoint_path=$CHECKPOINT_PATH \
    loader.eval_batch_size=$EVAL_BATCH_SIZE \
    sampling.psi.time_profile=linear \
    sampling.psi.high_mode=max-rescale-$ETA \
    sampling.psi.middle_mode=max-rescale-$ETA \
    sampling.psi.low_mode=max-rescale-$ETA \
    sampling.psi.high_frac=0.0 \
    sampling.psi.middle_frac=0.0


================================================
FILE: scripts/psi_samplers/cifar10/duo_psi_pc.sh
================================================
# DUO psi-sampler with constant kappa in pc phase
# Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC
#   t in [1.0, 0.5] -> pure-posterior
#   t in [0.5, 0.1] -> constant-kappa
#   t in [0.1, 0.0] -> pure-posterior

NUM_STEPS=256
KAPPA=0.95
NOISE=cosine
HIGH_FRAC=0.5
MIDDLE_FRAC=0.4
CHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>
DATA_CACHE_DIR=<YOUR-CACHE-PATH>
EVAL_BATCH_SIZE=500

python -u -m main \
    mode=fid_eval \
    data=cifar10 \
    data.cache_dir=$DATA_CACHE_DIR \
    model=unet \
    algo=duo_base \
    algo.backbone=unet \
    noise=$NOISE \
    sampling.predictor=psi \
    sampling.steps=$NUM_STEPS \
    sampling.guid_weight=1.0 \
    eval.checkpoint_path=$CHECKPOINT_PATH \
    loader.eval_batch_size=$EVAL_BATCH_SIZE \
    sampling.psi.time_profile=linear \
    sampling.psi.high_mode=pure-posterior \
    sampling.psi.middle_mode=constant-$KAPPA \
    sampling.psi.low_mode=pure-posterior \
    sampling.psi.high_frac=$HIGH_FRAC \
    sampling.psi.middle_frac=$MIDDLE_FRAC


================================================
FILE: scripts/psi_samplers/cifar10/mdlm_constant_remdm.sh
================================================
# MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop) — CIFAR-10 FID eval

NUM_STEPS=256
ETA=0.01
NOISE=cosine
CHECKPOINT_PATH=<PATH-TO-MDLM-CHECKPOINT>
DATA_CACHE_DIR=<YOUR-CACHE-PATH>
EVAL_BATCH_SIZE=500

python -u -m main \
    mode=fid_eval \
    data=cifar10 \
    data.cache_dir=$DATA_CACHE_DIR \
    model=unet \
    algo=mdlm \
    algo.backbone=unet \
    noise=$NOISE \
    sampling.predictor=psi \
    sampling.steps=$NUM_STEPS \
    sampling.guid_weight=1.0 \
    eval.checkpoint_path=$CHECKPOINT_PATH \
    loader.eval_batch_size=$EVAL_BATCH_SIZE \
    sampling.psi.time_profile=linear-constant-linear-0.9-inv \
    sampling.psi.high_mode=pure-posterior \
    sampling.psi.middle_mode=constant-remdm-$ETA \
    sampling.psi.low_mode=pure-posterior \
    sampling.psi.high_frac=0.45 \
    sampling.psi.middle_frac=0.5


================================================
FILE: scripts/psi_samplers/cifar10/mdlm_max_capped_remdm.sh
================================================
# MDLM psi-sampler with max-capped-eta mode (ReMDM cap)

NUM_STEPS=256
ETA=0.005
NOISE=cosine
CHECKPOINT_PATH=<PATH-TO-MDLM-CHECKPOINT>
DATA_CACHE_DIR=<YOUR-CACHE-PATH>
EVAL_BATCH_SIZE=500

python -u -m main \
    mode=fid_eval \
    data=cifar10 \
    data.cache_dir=$DATA_CACHE_DIR \
    model=unet \
    algo=mdlm \
    algo.backbone=unet \
    noise=$NOISE \
    sampling.predictor=psi \
    sampling.steps=$NUM_STEPS \
    sampling.guid_weight=1.0 \
    eval.checkpoint_path=$CHECKPOINT_PATH \
    loader.eval_batch_size=$EVAL_BATCH_SIZE \
    sampling.psi.time_profile=linear \
    sampling.psi.high_mode=max-capped-$ETA \
    sampling.psi.middle_mode=max-capped-$ETA \
    sampling.psi.low_mode=max-capped-$ETA \
    sampling.psi.high_frac=0.0 \
    sampling.psi.middle_frac=0.0


================================================
FILE: scripts/psi_samplers/cifar10/mdlm_max_rescale_eta.s
Download .txt
gitextract_9im7g5qx/

├── .gitignore
├── LICENSE
├── README.md
├── algo.py
├── configs/
│   ├── algo/
│   │   ├── ar.yaml
│   │   ├── d3pm.yaml
│   │   ├── distillation.yaml
│   │   ├── duo.yaml
│   │   ├── duo_base.yaml
│   │   ├── mdlm.yaml
│   │   ├── ot-finetune.yaml
│   │   └── sedd.yaml
│   ├── callbacks/
│   │   ├── checkpoint_every_n_steps.yaml
│   │   ├── checkpoint_monitor.yaml
│   │   ├── grad_record.yaml
│   │   └── learning_rate_monitor.yaml
│   ├── config.yaml
│   ├── data/
│   │   ├── ag_news.yaml
│   │   ├── cifar10.yaml
│   │   ├── fineweb-edu.yaml
│   │   ├── lambada.yaml
│   │   ├── lm1b-gpt2.yaml
│   │   ├── lm1b-streaming.yaml
│   │   ├── lm1b-wrap.yaml
│   │   ├── lm1b.yaml
│   │   ├── openwebtext-split.yaml
│   │   ├── openwebtext-streaming.yaml
│   │   ├── openwebtext.yaml
│   │   ├── ptb.yaml
│   │   ├── scientific_papers_arxiv.yaml
│   │   ├── scientific_papers_pubmed.yaml
│   │   ├── synthetic.yaml
│   │   ├── text8-crop.yaml
│   │   ├── text8.yaml
│   │   ├── wikitext103.yaml
│   │   └── wikitext2.yaml
│   ├── lr_scheduler/
│   │   ├── constant_warmup.yaml
│   │   ├── cosine_decay_warmup.yaml
│   │   └── step_scheduler.yaml
│   ├── model/
│   │   ├── medium.yaml
│   │   ├── small.yaml
│   │   ├── tiny-dimamba.yaml
│   │   ├── tiny.yaml
│   │   └── unet.yaml
│   ├── noise/
│   │   ├── cosine.yaml
│   │   └── log-linear.yaml
│   ├── prior/
│   │   └── none.yaml
│   └── strategy/
│       ├── ddp.yaml
│       └── fsdp.yaml
├── dataloader.py
├── discrete_diffusion_harness.py
├── integral/
│   ├── bert-base-uncased.pkl
│   └── gpt2.pkl
├── main.py
├── metrics.py
├── models/
│   ├── __init__.py
│   ├── dit.py
│   ├── ema.py
│   ├── unet.py
│   └── unit_test_attention.py
├── requirements.txt
├── scripts/
│   ├── distil_owt.sh
│   ├── eval_lm1b_duo.sh
│   ├── eval_owt_ar.sh
│   ├── eval_owt_duo.sh
│   ├── eval_owt_mdlm.sh
│   ├── eval_owt_sedd.sh
│   ├── fid_cifar10_duo_ancestral_cosine.sh
│   ├── fid_cifar10_duo_base_ancestral_cosine.sh
│   ├── fid_cifar10_mdlm_ancestral_cosine.sh
│   ├── gen_ppl_lm1b_ar.sh
│   ├── gen_ppl_lm1b_duo.sh
│   ├── gen_ppl_owt_ar.sh
│   ├── gen_ppl_owt_duo.sh
│   ├── gen_ppl_owt_mdlm.sh
│   ├── gen_ppl_owt_sedd.sh
│   ├── psi_samplers/
│   │   ├── cifar10/
│   │   │   ├── duo_constant_remdm.sh
│   │   │   ├── duo_max_capped_remdm.sh
│   │   │   ├── duo_max_rescale_eta.sh
│   │   │   ├── duo_psi_pc.sh
│   │   │   ├── mdlm_constant_remdm.sh
│   │   │   ├── mdlm_max_capped_remdm.sh
│   │   │   ├── mdlm_max_rescale_eta.sh
│   │   │   └── mdlm_psi_pc.sh
│   │   └── owt/
│   │       ├── duo_loop_remdm.sh
│   │       ├── duo_max_capped_remdm.sh
│   │       ├── duo_max_rescale_eta.sh
│   │       ├── mdlm_loop_remdm.sh
│   │       ├── mdlm_max_capped_remdm.sh
│   │       └── mdlm_max_rescale_eta.sh
│   ├── train_cifar10_duo_base_cosine.sh
│   ├── train_cifar10_duo_cosine.sh
│   ├── train_cifar10_mdlm_cosine.sh
│   ├── train_lm1b_ar.sh
│   ├── train_lm1b_ar_sentencepacking.sh
│   ├── train_lm1b_d3pm.sh
│   ├── train_lm1b_duo.sh
│   ├── train_lm1b_duo_sentencepacking.sh
│   ├── train_lm1b_mdlm.sh
│   ├── train_lm1b_mdlm_sentencepacking.sh
│   ├── train_owt_duo.sh
│   ├── train_owt_duo_finetune.sh
│   ├── train_owt_mdlm.sh
│   ├── train_owt_sedd.sh
│   ├── zero_shot_ar.sh
│   ├── zero_shot_duo.sh
│   ├── zero_shot_mdlm.sh
│   └── zero_shot_sedd.sh
├── trainer_base.py
└── utils.py
Download .txt
SYMBOL INDEX (348 symbols across 11 files)

FILE: algo.py
  class AR (line 16) | class AR(trainer_base.TrainerBase):
    method __init__ (line 17) | def __init__(self, config, tokenizer):
    method _validate_configuration (line 30) | def _validate_configuration(self):
    method _process_model_input (line 35) | def _process_model_input(self, x0, valid_tokens):
    method nll (line 41) | def nll(self, input_tokens, output_tokens,
    method generate_samples (line 50) | def generate_samples(self, num_samples, **kwargs):
    method _process_sigma (line 72) | def _process_sigma(self, sigma):
  class MDLM (line 77) | class MDLM(trainer_base.AbsorbingState):
    method __init__ (line 78) | def __init__(self, config, tokenizer):
    method _validate_configuration (line 82) | def _validate_configuration(self):
    method _process_model_output (line 87) | def _process_model_output(self, model_output, xt, sigma):
    method nll_per_token (line 104) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
    method _get_score (line 113) | def _get_score(self, x, sigma):
  class D3PMAbsorb (line 158) | class D3PMAbsorb(trainer_base.AbsorbingState):
    method __init__ (line 159) | def __init__(self, config, tokenizer):
    method _validate_configuration (line 163) | def _validate_configuration(self):
    method _process_model_output (line 168) | def _process_model_output(self, model_output, xt, sigma):
    method nll_per_token (line 175) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
  class SEDDAbsorb (line 207) | class SEDDAbsorb(trainer_base.AbsorbingState):
    method __init__ (line 208) | def __init__(self, config, tokenizer):
    method _validate_configuration (line 212) | def _validate_configuration(self):
    method _get_score (line 216) | def _get_score(self, x, sigma):
    method _process_model_output (line 219) | def _process_model_output(self, model_output, xt, sigma):
    method nll_per_token (line 236) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
  class DUO_BASE (line 286) | class DUO_BASE(trainer_base.UniformState):
    method __init__ (line 287) | def __init__(self, config, tokenizer):
    method on_save_checkpoint (line 291) | def on_save_checkpoint(self, checkpoint):
    method on_load_checkpoint (line 297) | def on_load_checkpoint(self, checkpoint):
    method _process_model_output (line 303) | def _process_model_output(self, model_output, xt, sigma):
    method _posterior_from_x0 (line 307) | def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):
    method nll_per_token (line 337) | def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
  class Integral (line 375) | class Integral(torch.autograd.Function):
    method forward (line 381) | def forward(ctx, gamma_t, data):
    method backward (line 400) | def backward(ctx, grad_output):
  class DUO (line 404) | class DUO(DUO_BASE):
    method __init__ (line 405) | def __init__(self, config, tokenizer):
    method _initialize_curriculum_coefficients (line 419) | def _initialize_curriculum_coefficients(self):
    method _init_curriculum_cached (line 432) | def _init_curriculum_cached(self):
    method _init_curriculum_series (line 441) | def _init_curriculum_series(self):
    method _init_curriculum_approx (line 454) | def _init_curriculum_approx(self):
    method to (line 486) | def to(self, *args, **kwargs):
    method cuda (line 494) | def cuda(self, device=None):
    method cpu (line 503) | def cpu(self):
    method to (line 512) | def to(self, *args, **kwargs):
    method _compute_gumbel_tau_inverse (line 521) | def _compute_gumbel_tau_inverse(self):
    method training_step (line 535) | def training_step(self, batch, batch_idx):
    method _gamma_to_alpha_dalpha (line 543) | def _gamma_to_alpha_dalpha(self, gamma_t, t):
    method _gamma_to_alphat_integral (line 559) | def _gamma_to_alphat_integral(self, gamma_t):
    method _gamma_to_alpha_dalpha_cached (line 564) | def _gamma_to_alpha_dalpha_cached(self, gamma_t):
    method _prior_loss (line 573) | def _prior_loss(self):
    method _q_xt_gaussian (line 582) | def _q_xt_gaussian(self, x, gamma_t):
    method nll (line 593) | def nll(self, x0, output_tokens,
  class Distillation (line 639) | class Distillation(DUO):
    method __init__ (line 640) | def __init__(self, config, tokenizer):
    method _validate_configuration (line 650) | def _validate_configuration(self):
    method _maybe_update_teacher_weights (line 663) | def _maybe_update_teacher_weights(self):
    method _teacher_logits (line 675) | def _teacher_logits(self, xt, sigma):
    method _sample_trajectory (line 687) | def _sample_trajectory(self, x0, gamma_t, gamma_s):
    method _compute_dt (line 708) | def _compute_dt(self):
    method nll (line 716) | def nll(self, x0, output_tokens,
    method training_step (line 752) | def training_step(self, batch, batch_idx):

FILE: dataloader.py
  class RawPixelsVisionTokenizer (line 31) | class RawPixelsVisionTokenizer:
    method __init__ (line 32) | def __init__(self, vocab_size, image_size,
    method __call__ (line 53) | def __call__(self, x):
    method batch_decode (line 56) | def batch_decode(self, x):
    method decode (line 62) | def decode(self, x):
    method __len__ (line 68) | def __len__(self):
  class DiscreteCIFAR10 (line 72) | class DiscreteCIFAR10(torch.utils.data.Dataset):
    method __init__ (line 73) | def __init__(self, cache_dir, train):
    method __len__ (line 88) | def __len__(self):
    method __getitem__ (line 91) | def __getitem__(self, index):
  function wt_detokenizer (line 100) | def wt_detokenizer(string):
  function ptb_detokenizer (line 132) | def ptb_detokenizer(x):
  function lm1b_detokenizer (line 146) | def lm1b_detokenizer(x):
  function lambada_detokenizer (line 169) | def lambada_detokenizer(text):
  function scientific_papers_detokenizer (line 175) | def scientific_papers_detokenizer(x):
  class SyntheticTokenizer (line 181) | class SyntheticTokenizer(
    method __init__ (line 184) | def __init__(
    method vocab_size (line 221) | def vocab_size(self) -> int:
    method _tokenize (line 224) | def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
    method _convert_token_to_id (line 227) | def _convert_token_to_id(self, token: str) -> int:
    method _convert_id_to_token (line 231) | def _convert_id_to_token(self, index: int) -> str:
    method convert_tokens_to_string (line 234) | def convert_tokens_to_string(self, tokens):
    method get_vocab (line 237) | def get_vocab(self) -> typing.Dict[str, int]:
  function _generate_synthetic_data (line 241) | def _generate_synthetic_data(dataset_size,
  function generate_synthetic_dataset (line 261) | def generate_synthetic_dataset(train_dataset_size,
  class Text8Tokenizer (line 290) | class Text8Tokenizer(transformers.PreTrainedTokenizer):
    method __init__ (line 291) | def __init__(
    method vocab_size (line 325) | def vocab_size(self) -> int:
    method _tokenize (line 328) | def _tokenize(self, text: str, **kwargs) -> typing.List[str]:
    method _convert_token_to_id (line 331) | def _convert_token_to_id(self, token: str) -> int:
    method _convert_id_to_token (line 335) | def _convert_id_to_token(self, index: int) -> str:
    method convert_tokens_to_string (line 338) | def convert_tokens_to_string(self, tokens):
    method get_vocab (line 341) | def get_vocab(self) -> typing.Dict[str, int]:
  function get_lambada_test_dataset (line 345) | def get_lambada_test_dataset():
  function get_text8_dataset (line 365) | def get_text8_dataset(cache_dir, max_seq_length=256,
  function _group_texts (line 462) | def _group_texts(examples, block_size, bos, eos):
  function get_dataset (line 488) | def get_dataset(dataset_name,
  function get_tokenizer (line 712) | def get_tokenizer(config):
  function get_dataloaders (line 755) | def get_dataloaders(config, tokenizer, skip_train=False,
  class RandomFaultTolerantSampler (line 843) | class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
    method __init__ (line 845) | def __init__(self, *args, generator=None, **kwargs):
    method state_dict (line 858) | def state_dict(self):
    method load_state_dict (line 862) | def load_state_dict(self, state_dict):
    method __iter__ (line 871) | def __iter__(self) -> typing.Iterator[int]:
  class FaultTolerantDistributedSampler (line 890) | class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
    method __init__ (line 892) | def __init__(self, *args, **kwargs):
    method state_dict (line 897) | def state_dict(self):
    method load_state_dict (line 900) | def load_state_dict(self, state_dict):
    method __iter__ (line 907) | def __iter__(self):

FILE: discrete_diffusion_harness.py
  function requests_to_dataset (line 132) | def requests_to_dataset(config, requests, tokenizer, num_proc):
  function _eval_suffix_nll_generators (line 173) | def _eval_suffix_nll_generators(config, module, prefix,
  function eval_suffix_nll (line 205) | def eval_suffix_nll(config, module, prefix, suffix, batch_size,
  function eval_suffix_nll_ar (line 218) | def eval_suffix_nll_ar(config, module, prefix, suffix,
  function eval_suffix_nll_diffusion (line 234) | def eval_suffix_nll_diffusion(config, module, prefix, suffix,
  class DiscreteDiffusionHarness (line 252) | class DiscreteDiffusionHarness(LM):
    method __init__ (line 253) | def __init__(self, pretrained="NONE", max_length=1024,
    method suffix_greedy_prediction (line 290) | def suffix_greedy_prediction(self, prefix, target):
    method _suffix_greedy_prediction_ar (line 303) | def _suffix_greedy_prediction_ar(self, prefix, target):
    method _suffix_greedy_prediction_mdlm (line 316) | def _suffix_greedy_prediction_mdlm(self, prefix, target):
    method _suffix_greedy_prediction_duo_base (line 337) | def _suffix_greedy_prediction_duo_base(self, prefix, target):
    method loglikelihood (line 362) | def loglikelihood(self, requests: list[Instance]) \
    method loglikelihood_rolling (line 387) | def loglikelihood_rolling(
    method generate_until (line 392) | def generate_until(self, context, max_length, stop,

FILE: main.py
  function _load_from_checkpoint (line 31) | def _load_from_checkpoint(diffusion_model, config, tokenizer):
  function _print_config (line 43) | def _print_config(
  function _print_batch (line 78) | def _print_batch(config, train_ds, valid_ds, tokenizer, k=64):
  function _generate_samples (line 93) | def _generate_samples(diffusion_model, config, logger,
  function _eval_ppl (line 142) | def _eval_ppl(diffusion_model, config, logger, tokenizer):
  function _train (line 173) | def _train(diffusion_model, config, logger, tokenizer):
  function _eval_fid (line 217) | def _eval_fid(diffusion_model, config, logger, tokenizer):
  function main (line 298) | def main(config):

FILE: metrics.py
  class NLL (line 13) | class NLL(torchmetrics.aggregation.MeanMetric):
    method update (line 14) | def update(self,
  class BPD (line 47) | class BPD(NLL):
    method compute (line 48) | def compute(self) -> torch.Tensor:
  class Perplexity (line 57) | class Perplexity(NLL):
    method compute (line 58) | def compute(self) -> torch.Tensor:
  class Metrics (line 67) | class Metrics:
    method __init__ (line 68) | def __init__(self, gen_ppl_eval_model_name_or_path=None,
    method to (line 87) | def to(self, *args, **kwargs):
    method reset (line 95) | def reset(self):
    method update_train (line 103) | def update_train(self, nll, aux_loss, num_tokens):
    method update_valid (line 107) | def update_valid(self, nll, aux_loss, num_tokens):
    method _eval_retokenize (line 113) | def _eval_retokenize(self, text_samples, max_length,
    method record_entropy (line 155) | def record_entropy(self, tokens):
    method record_generative_perplexity (line 164) | def record_generative_perplexity(

FILE: models/dit.py
  function bias_dropout_add_scale (line 20) | def bias_dropout_add_scale(
  function get_bias_dropout_add_scale (line 37) | def get_bias_dropout_add_scale(training):
  function modulate (line 46) | def modulate(x: torch.Tensor,
  function bias_dropout_add_scale_fused_train (line 53) | def bias_dropout_add_scale_fused_train(
  function bias_dropout_add_scale_fused_inference (line 64) | def bias_dropout_add_scale_fused_inference(
  function modulate_fused (line 75) | def modulate_fused(x: torch.Tensor,
  class Rotary (line 81) | class Rotary(torch.nn.Module):
    method __init__ (line 82) | def __init__(self, dim, base=10_000):
    method forward (line 90) | def forward(self, x, seq_dim=1):
  function rotate_half (line 107) | def rotate_half(x):
  function split_and_apply_rotary_pos_emb (line 112) | def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):
  function apply_rotary_pos_emb (line 128) | def apply_rotary_pos_emb(qkv, cos, sin):
  function regular_attention_multi_headed (line 134) | def regular_attention_multi_headed(q, k, v):
  class LayerNorm (line 152) | class LayerNorm(nn.Module):
    method __init__ (line 153) | def __init__(self, dim):
    method forward (line 157) | def forward(self, x):
  function residual_linear (line 163) | def residual_linear(x, W, x_skip, residual_scale):
  class TimestepEmbedder (line 176) | class TimestepEmbedder(nn.Module):
    method __init__ (line 180) | def __init__(self, hidden_size, frequency_embedding_size=256):
    method timestep_embedding (line 189) | def timestep_embedding(t, dim, max_period=10000):
    method forward (line 212) | def forward(self, t):
  class LabelEmbedder (line 218) | class LabelEmbedder(nn.Module):
    method __init__ (line 223) | def __init__(self, num_classes, cond_size):
    method forward (line 230) | def forward(self, labels):
  class DDiTBlockCausal (line 239) | class DDiTBlockCausal(nn.Module):
    method __init__ (line 240) | def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1):
    method _get_bias_dropout_scale (line 257) | def _get_bias_dropout_scale(self):
    method forward (line 263) | def forward(self, x, rotary_cos_sin, **kwargs):
  class DDiTBlock (line 305) | class DDiTBlock(nn.Module):
    method __init__ (line 306) | def __init__(self, dim, n_heads, adaLN,
    method _get_bias_dropout_scale (line 332) | def _get_bias_dropout_scale(self):
    method forward (line 339) | def forward(self, x, rotary_cos_sin, c=None):
  class EmbeddingLayer (line 382) | class EmbeddingLayer(nn.Module):
    method __init__ (line 383) | def __init__(self, dim, vocab_dim):
    method forward (line 388) | def forward(self, x, weights=None):
  class DDiTFinalLayer (line 406) | class DDiTFinalLayer(nn.Module):
    method __init__ (line 407) | def __init__(self, hidden_size, out_channels, cond_dim,
    method forward (line 422) | def forward(self, x, c):
  class DIT (line 431) | class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):
    method __init__ (line 432) | def __init__(self, config, vocab_size: int):
    method _get_bias_dropout_scale (line 471) | def _get_bias_dropout_scale(self):
    method forward (line 477) | def forward(self, x, sigma, class_cond=None, weights=None):

FILE: models/ema.py
  class ExponentialMovingAverage (line 4) | class ExponentialMovingAverage:
    method __init__ (line 9) | def __init__(self, parameters, decay, use_num_updates=True):
    method move_shadow_params_to_device (line 26) | def move_shadow_params_to_device(self, device):
    method update (line 29) | def update(self, parameters):
    method copy_to (line 51) | def copy_to(self, parameters):
    method store (line 64) | def store(self, parameters):
    method restore (line 74) | def restore(self, parameters):
    method state_dict (line 89) | def state_dict(self):
    method load_state_dict (line 94) | def load_state_dict(self, state_dict):

FILE: models/unet.py
  function transformer_timestep_embedding (line 15) | def transformer_timestep_embedding(timesteps, embedding_dim, max_positio...
  function variance_scaling (line 33) | def variance_scaling(scale, mode, distribution,
  function default_init (line 67) | def default_init(scale=1.):
  class NiN (line 73) | class NiN(nn.Module):
    method __init__ (line 74) | def __init__(self, in_ch, out_ch, init_scale=0.1):
    method forward (line 79) | def forward(self, x, #  ["batch", "in_ch", "H", "W"]
  class AttnBlock (line 88) | class AttnBlock(nn.Module):
    method __init__ (line 90) | def __init__(self, channels, skip_rescale=True):
    method forward (line 100) | def forward(self, x, # ["batch", "channels", "H", "W"]
  class ResBlock (line 122) | class ResBlock(nn.Module):
    method __init__ (line 123) | def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_res...
    method forward (line 157) | def forward(self, x, # ["batch", "in_ch", "H", "W"]
  class Downsample (line 184) | class Downsample(nn.Module):
    method __init__ (line 185) | def __init__(self, channels):
    method forward (line 190) | def forward(self, x, # ["batch", "ch", "inH", "inW"]
  class Upsample (line 199) | class Upsample(nn.Module):
    method __init__ (line 200) | def __init__(self, channels):
    method forward (line 204) | def forward(self, x, # ["batch", "ch", "inH", "inW"]
  class UNet (line 214) | class UNet(nn.Module):
    method __init__ (line 215) | def __init__(self, config, vocab_size=None):
    method _center_data (line 344) | def _center_data(self, x):
    method _time_embedding (line 348) | def _time_embedding(self, timesteps):
    method _do_input_conv (line 360) | def _do_input_conv(self, h):
    method _do_downsampling (line 365) | def _do_downsampling(self, h, hs, temb):
    method _do_middle (line 385) | def _do_middle(self, h, temb):
    method _do_upsampling (line 398) | def _do_upsampling(self, h, hs, temb):
    method _do_output (line 418) | def _do_output(self, h):
    method _logistic_output_res (line 426) | def _logistic_output_res(self,
    method _log_minus_exp (line 435) | def _log_minus_exp(self, a, b, eps=1e-6):
    method _truncated_logistic_output (line 443) | def _truncated_logistic_output(self, net_out):
    method forward (line 477) | def forward(self,

FILE: models/unit_test_attention.py
  function attention_inner_heads_flash (line 9) | def attention_inner_heads_flash(qkv, num_heads):
  class TestAttentionInnerHeadsFlash (line 62) | class TestAttentionInnerHeadsFlash(unittest.TestCase):
    method setUp (line 63) | def setUp(self):
    method attention_inner_heads_old (line 74) | def attention_inner_heads_old(self, qkv, num_heads):
    method test_attention_inner_heads_flash (line 90) | def test_attention_inner_heads_flash(self):

FILE: trainer_base.py
  class Loss (line 18) | class Loss:
  class LogLinear (line 25) | class LogLinear(torch.nn.Module):
    method __init__ (line 26) | def __init__(self, eps):
    method forward (line 30) | def forward(self, t):
    method get_t_for_alpha (line 36) | def get_t_for_alpha(self, alpha_t):
  class Cosine (line 40) | class Cosine(torch.nn.Module):
    method __init__ (line 41) | def __init__(self, eps):
    method forward (line 46) | def forward(self, t):
    method get_t_for_alpha (line 52) | def get_t_for_alpha(self, alpha_t):
  function sample_categorical (line 62) | def sample_categorical(categorical_probs):
  function _unsqueeze (line 69) | def _unsqueeze(x, reference):
  class TrainerBase (line 75) | class TrainerBase(L.LightningModule):
    method __init__ (line 76) | def __init__(
    method _validate_configuration (line 147) | def _validate_configuration(self):
    method to (line 161) | def to(self, *args, **kwargs):
    method q_xt (line 166) | def q_xt(self, x, alpha_t):
    method _get_parameters (line 169) | def _get_parameters(self):
    method _eval_mode (line 173) | def _eval_mode(self):
    method _train_mode (line 180) | def _train_mode(self):
    method on_load_checkpoint (line 186) | def on_load_checkpoint(self, checkpoint):
    method on_save_checkpoint (line 197) | def on_save_checkpoint(self, checkpoint):
    method on_train_start (line 236) | def on_train_start(self):
    method optimizer_step (line 273) | def optimizer_step(self, *args, **kwargs):
    method _process_sigma (line 278) | def _process_sigma(self, sigma):
    method _process_model_output (line 281) | def _process_model_output(self, model_output, xt, sigma):
    method forward (line 284) | def forward(self, xt, sigma, labels=None, weights=None,
    method on_train_epoch_start (line 297) | def on_train_epoch_start(self):
    method training_step (line 302) | def training_step(self, batch, batch_idx):
    method on_train_epoch_end (line 319) | def on_train_epoch_end(self):
    method on_validation_epoch_start (line 324) | def on_validation_epoch_start(self):
    method validation_step (line 330) | def validation_step(self, batch, batch_idx):
    method on_validation_epoch_end (line 339) | def on_validation_epoch_end(self):
    method configure_optimizers (line 381) | def configure_optimizers(self):
    method generate_samples (line 398) | def generate_samples(self, num_samples, num_steps, eps):
    method restore_model_and_sample (line 401) | def restore_model_and_sample(self, num_steps, eps=1e-5):
    method _process_model_input (line 412) | def _process_model_input(self, x0, valid_tokens):
    method nll (line 415) | def nll(self, input_tokens, labels, output_tokens,
    method _loss (line 419) | def _loss(self, x0, labels, valid_tokens,
  class Diffusion (line 442) | class Diffusion(TrainerBase):
    method _validate_configuration (line 443) | def _validate_configuration(self):
    method _process_model_input (line 452) | def _process_model_input(self, x0, valid_tokens):
    method _process_sigma (line 455) | def _process_sigma(self, sigma):
    method _sample_t (line 465) | def _sample_t(self, n, accum_step):
    method _sigma_from_alphat (line 484) | def _sigma_from_alphat(self, alpha_t):
    method _reconstruction_loss (line 487) | def _reconstruction_loss(self, x0):
    method nll_per_token (line 496) | def nll_per_token(self, model_output, xt, x0, alpha_t,
    method nll (line 500) | def nll(self, x0, labels, output_tokens,
    method _get_score (line 540) | def _get_score(self, **kwargs):
    method _denoiser_update (line 544) | def _denoiser_update(self, x, t):
    method _analytic_update (line 547) | def _analytic_update(self, x, t, dt):
    method _posterior_from_x0 (line 550) | def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):
    method _forward_process (line 555) | def _forward_process(self, q_x0, alpha_s):
    method _get_ancestral_posterior (line 560) | def _get_ancestral_posterior(self, xt, sigma, labels,
    method _get_posterior_from_xt (line 584) | def _get_posterior_from_xt(self, xt, sigma, labels, alpha_s,
    method _get_guided_posterior_from_xt (line 602) | def _get_guided_posterior_from_xt(self, xt, sigma, labels,
    method _ancestral_update (line 633) | def _ancestral_update(self, x, t, labels, dt, p_x0=None,
    method _psi_update (line 650) | def _psi_update(self, x, t, labels, dt, kappa, p_x0=None,
    method _get_sampling_time_profile (line 674) | def _get_sampling_time_profile(self, eps, num_steps):
    method _mode_to_psi_kappas (line 694) | def _mode_to_psi_kappas(self, mode, timesteps):
    method _get_kappas (line 727) | def _get_kappas(self, timesteps):
    method generate_samples (line 743) | def generate_samples(self, num_samples, labels=None,
    method _semi_ar_sampler (line 800) | def _semi_ar_sampler(
    method restore_model_and_semi_ar_sample (line 840) | def restore_model_and_semi_ar_sample(
  class AbsorbingState (line 856) | class AbsorbingState(Diffusion):
    method __init__ (line 857) | def __init__(self, config, tokenizer):
    method _validate_configuration (line 875) | def _validate_configuration(self):
    method q_xt (line 886) | def q_xt(self, x, alpha_t):
    method prior_sample (line 901) | def prior_sample(self, *batch_dims):
    method _posterior_from_x0 (line 905) | def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):
    method _forward_process (line 919) | def _forward_process(self, x0, alpha_s):
    method _staggered_score (line 924) | def _staggered_score(self, score, dsigma):
    method _analytic_update (line 931) | def _analytic_update(self, x, t, dt):
    method _denoiser_update (line 942) | def _denoiser_update(self, x, t):
    method _transp_transition (line 953) | def _transp_transition(self, i, sigma):
  class UniformState (line 963) | class UniformState(Diffusion):
    method _validate_configuration (line 964) | def _validate_configuration(self):
    method _forward_process (line 971) | def _forward_process(self, x0, alpha_s):
    method q_xt (line 975) | def q_xt(self, x, alpha_t):
    method prior_sample (line 993) | def prior_sample(self, *batch_dims):

FILE: utils.py
  function count_parameters (line 24) | def count_parameters(model):
  function fsspec_exists (line 29) | def fsspec_exists(filename):
  function fsspec_listdir (line 35) | def fsspec_listdir(dirname):
  function fsspec_mkdirs (line 41) | def fsspec_mkdirs(dirname, exist_ok=True):
  function print_nans (line 47) | def print_nans(tensor, name):
  class LRHalveScheduler (line 52) | class LRHalveScheduler:
    method __init__ (line 53) | def __init__(self, warmup_steps, n_halve_steps):
    method __call__ (line 57) | def __call__(self, current_step):
  class CosineDecayWarmupLRScheduler (line 64) | class CosineDecayWarmupLRScheduler(
    method __init__ (line 74) | def __init__(self, *args, **kwargs):
    method step (line 79) | def step(self, epoch=None):
  class LoggingContext (line 97) | class LoggingContext:
    method __init__ (line 99) | def __init__(self, logger, level=None, handler=None, close=True):
    method __enter__ (line 105) | def __enter__(self):
    method __exit__ (line 112) | def __exit__(self, et, ev, tb):
  class GradientInspectionCallback (line 121) | class GradientInspectionCallback(lightning.Callback):
    method __init__ (line 122) | def __init__(self, num_grads_log):
    method on_before_optimizer_step (line 125) | def on_before_optimizer_step(self, trainer, pl_module, optimizer):
  function get_logger (line 158) | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
  function top_k_top_p_filtering (line 177) | def top_k_top_p_filtering(
  function _discrete_prob_map (line 244) | def _discrete_prob_map(gamma_t, N=10):
  function _discrete_prob_grad (line 253) | def _discrete_prob_grad(gamma_t, N=10):
  function _cache_prob_usdm_in_partition (line 263) | def _cache_prob_usdm_in_partition(
  function test_cache_prob_usdm_in_partition (line 303) | def test_cache_prob_usdm_in_partition(
  function compute_duo_series_coefficients (line 340) | def compute_duo_series_coefficients(num_coefficients,
  function compute_duo_gamma_to_alpha_dalpha_series (line 373) | def compute_duo_gamma_to_alpha_dalpha_series(
  function duo_t_to_alpha_dalpha_sigm_corrected (line 411) | def duo_t_to_alpha_dalpha_sigm_corrected(
  function duo_to_alpha_dalpha_sigmoid (line 432) | def duo_to_alpha_dalpha_sigmoid(t: torch.Tensor, a: float,
  function duo_to_alpha_dalpha_poly (line 440) | def duo_to_alpha_dalpha_poly(t: torch.Tensor,
  function compute_duo_operator_approx (line 453) | def compute_duo_operator_approx(num_coefficients, vocab_size,
  function _sample_k_int (line 511) | def _sample_k_int(bs: int, l: int, k: int, max_value: int,
  function _sample_topk_gaussian (line 529) | def _sample_topk_gaussian(N: int,
  function _sample_topk_and_extra (line 559) | def _sample_topk_and_extra(N: int, alpha: torch.Tensor,
  function _log_mean_exp_trunc_normal (line 575) | def _log_mean_exp_trunc_normal(c: torch.Tensor,
  function sample_tempered_softmax_topk (line 591) | def sample_tempered_softmax_topk(
Condensed preview — 110 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (293K chars).
[
  {
    "path": ".gitignore",
    "chars": 3167,
    "preview": ".DS_Store\n\n.hf_cache\ntest/\noutputs/\nwandb/\nwatch_folder/\nnotes.md\ngrid_search.sh\n*.ipynb\n\n# Byte-compiled / optimized / "
  },
  {
    "path": "LICENSE",
    "chars": 11349,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 8653,
    "preview": "<div align=\"center\">\n\n# The Diffusion Duality Series\n\n</div>\n\n## [Chapter I (ICML 2025)](https://arxiv.org/abs/2506.1089"
  },
  {
    "path": "algo.py",
    "chars": 27474,
    "preview": "import os\nimport collections\nimport copy\nimport pickle\nfrom typing import Optional\n\nimport fsspec\nimport numpy as np\nimp"
  },
  {
    "path": "configs/algo/ar.yaml",
    "chars": 154,
    "preview": "name: ar\nbackbone: dit\nparameterization: ar\ntime_conditioning: False\ncausal_attention: True\n# Irrelevant flags\nT: 0\nsubs"
  },
  {
    "path": "configs/algo/d3pm.yaml",
    "chars": 209,
    "preview": "name: d3pm\nbackbone: dit  # dit / dimamba\nparameterization: mean\ntime_conditioning: True\nT: 1000 \nsubs_masking: False  #"
  },
  {
    "path": "configs/algo/distillation.yaml",
    "chars": 575,
    "preview": "name: distillation\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nsubs_masking: "
  },
  {
    "path": "configs/algo/duo.yaml",
    "chars": 902,
    "preview": "name: duo\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nT: 0  # 0 (continuous t"
  },
  {
    "path": "configs/algo/duo_base.yaml",
    "chars": 233,
    "preview": "name: duo_base\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nT: 0  # 0 (continu"
  },
  {
    "path": "configs/algo/mdlm.yaml",
    "chars": 213,
    "preview": "name: mdlm\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: subs\ntime_conditioning: False\nT: 0  # 0 (continuous"
  },
  {
    "path": "configs/algo/ot-finetune.yaml",
    "chars": 216,
    "preview": "name: ot-finetune\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nT: 0  # 0 (cont"
  },
  {
    "path": "configs/algo/sedd.yaml",
    "chars": 222,
    "preview": "name: sedd\nbackbone: dit  # dit / dimamba\nparameterization: score \ntime_conditioning: True\nT: 0  # 0 (continuous time) /"
  },
  {
    "path": "configs/callbacks/checkpoint_every_n_steps.yaml",
    "chars": 381,
    "preview": "checkpoint_every_n_steps:\n  _target_: lightning.pytorch.callbacks.ModelCheckpoint\n  save_top_k: -1 # Do not save any \"be"
  },
  {
    "path": "configs/callbacks/checkpoint_monitor.yaml",
    "chars": 456,
    "preview": "checkpoint_monitor:\n  _target_: lightning.pytorch.callbacks.ModelCheckpoint\n  monitor: val/nll # name of the logged metr"
  },
  {
    "path": "configs/callbacks/grad_record.yaml",
    "chars": 76,
    "preview": "grad_record:\n  _target_: utils.GradientInspectionCallback\n  num_grads_log: 4"
  },
  {
    "path": "configs/callbacks/learning_rate_monitor.yaml",
    "chars": 108,
    "preview": "learning_rate_monitor:\n  _target_: lightning.pytorch.callbacks.LearningRateMonitor\n  logging_interval: step\n"
  },
  {
    "path": "configs/config.yaml",
    "chars": 3884,
    "preview": "defaults:\n  - _self_\n  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]\n  - /data: op"
  },
  {
    "path": "configs/data/ag_news.yaml",
    "chars": 200,
    "preview": "train: ag_news\nvalid: ag_news\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusio"
  },
  {
    "path": "configs/data/cifar10.yaml",
    "chars": 288,
    "preview": "train: cifar10\nvalid: cifar10\nmodality: image\ntokenizer_name_or_path: cifar10\ncache_dir: /share/kuleshov/ssahoo/textdiff"
  },
  {
    "path": "configs/data/fineweb-edu.yaml",
    "chars": 241,
    "preview": "train: HuggingFaceFW/fineweb-edu\nvalid: openwebtext-valid  #wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncach"
  },
  {
    "path": "configs/data/lambada.yaml",
    "chars": 200,
    "preview": "train: lambada\nvalid: lambada\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusio"
  },
  {
    "path": "configs/data/lm1b-gpt2.yaml",
    "chars": 194,
    "preview": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data"
  },
  {
    "path": "configs/data/lm1b-streaming.yaml",
    "chars": 207,
    "preview": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: bert-base-uncased\ncache_dir: /share/kuleshov/ssahoo/textd"
  },
  {
    "path": "configs/data/lm1b-wrap.yaml",
    "chars": 207,
    "preview": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: bert-base-uncased\ncache_dir: /share/kuleshov/ssahoo/textd"
  },
  {
    "path": "configs/data/lm1b.yaml",
    "chars": 208,
    "preview": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: bert-base-uncased\ncache_dir: /share/kuleshov/ssahoo/textd"
  },
  {
    "path": "configs/data/openwebtext-split.yaml",
    "chars": 220,
    "preview": "train: openwebtext-train\nvalid: openwebtext-valid\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov"
  },
  {
    "path": "configs/data/openwebtext-streaming.yaml",
    "chars": 175,
    "preview": "train: openwebtext\nvalid: wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /tmp/data\nwrap: True\nstream"
  },
  {
    "path": "configs/data/openwebtext.yaml",
    "chars": 208,
    "preview": "train: openwebtext\nvalid: wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/text"
  },
  {
    "path": "configs/data/ptb.yaml",
    "chars": 192,
    "preview": "train: ptb\nvalid: ptb\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nw"
  },
  {
    "path": "configs/data/scientific_papers_arxiv.yaml",
    "chars": 232,
    "preview": "train: scientific_papers_arxiv\nvalid: scientific_papers_arxiv\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /sh"
  },
  {
    "path": "configs/data/scientific_papers_pubmed.yaml",
    "chars": 234,
    "preview": "train: scientific_papers_pubmed\nvalid: scientific_papers_pubmed\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /"
  },
  {
    "path": "configs/data/synthetic.yaml",
    "chars": 208,
    "preview": "train: synthetic\nvalid: synthetic\nmodality: text\ntokenizer_name_or_path: synthetic\ncache_dir: /share/kuleshov/ssahoo/tex"
  },
  {
    "path": "configs/data/text8-crop.yaml",
    "chars": 278,
    "preview": "# TODO: When using this dataset, set model.length = 256 to match D3PM setup\ntrain: text8-crop\nvalid: text8\nmodality: tex"
  },
  {
    "path": "configs/data/text8.yaml",
    "chars": 273,
    "preview": "# TODO: When using this dataset, set model.length = 256 to match D3PM setup\ntrain: text8\nvalid: text8\nmodality: text\ntok"
  },
  {
    "path": "configs/data/wikitext103.yaml",
    "chars": 208,
    "preview": "train: wikitext103\nvalid: wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/text"
  },
  {
    "path": "configs/data/wikitext2.yaml",
    "chars": 204,
    "preview": "train: wikitext2\nvalid: wikitext2\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiff"
  },
  {
    "path": "configs/lr_scheduler/constant_warmup.yaml",
    "chars": 79,
    "preview": "_target_: transformers.get_constant_schedule_with_warmup\nnum_warmup_steps: 2500"
  },
  {
    "path": "configs/lr_scheduler/cosine_decay_warmup.yaml",
    "chars": 213,
    "preview": "_target_: utils.CosineDecayWarmupLRScheduler\nt_in_epochs: False\nt_initial: ${eval:${trainer.max_steps}-${.warmup_t}}\nwar"
  },
  {
    "path": "configs/lr_scheduler/step_scheduler.yaml",
    "chars": 133,
    "preview": "_target_: torch.optim.lr_scheduler.LambdaLR\nlr_lambda:\n  _target_: utils.LRHalveScheduler\n  warmup_steps: 500\n  n_halve_"
  },
  {
    "path": "configs/model/medium.yaml",
    "chars": 173,
    "preview": "name: medium\ntype: ddit\nhidden_size: 1024\ncond_dim: 128\nlength: 1024\nn_blocks: 24\nn_heads: 16\nscale_by_sigma: True\ndropo"
  },
  {
    "path": "configs/model/small.yaml",
    "chars": 171,
    "preview": "name: small\ntype: ddit\nhidden_size: 768\ncond_dim: 128\nlength: 1024\nn_blocks: 12\nn_heads: 12\nscale_by_sigma: True\ndropout"
  },
  {
    "path": "configs/model/tiny-dimamba.yaml",
    "chars": 175,
    "preview": "name: tiny\ntype: dimamba\nhidden_size: 512\ncond_dim: 128\nlength: 1024\nn_blocks: 14\nn_heads: 8\nscale_by_sigma: True\ndropou"
  },
  {
    "path": "configs/model/tiny.yaml",
    "chars": 168,
    "preview": "name: tiny\ntype: ddit\nhidden_size: 256\ncond_dim: 128\nlength: 1024\nn_blocks: 8\nn_heads: 8\nscale_by_sigma: True\ndropout: 0"
  },
  {
    "path": "configs/model/unet.yaml",
    "chars": 450,
    "preview": "name: unet\ntype: unet\nch: 128 \nnum_res_blocks: 2\nnum_scales: 4\nch_mult: [1, 2, 2, 2]\ninput_channels: 3\noutput_channels: "
  },
  {
    "path": "configs/noise/cosine.yaml",
    "chars": 22,
    "preview": "type: cosine\neps: 1e-3"
  },
  {
    "path": "configs/noise/log-linear.yaml",
    "chars": 26,
    "preview": "type: log-linear\neps: 1e-3"
  },
  {
    "path": "configs/prior/none.yaml",
    "chars": 43,
    "preview": "type: none\nlatent_width: 0\nlatent_height: 0"
  },
  {
    "path": "configs/strategy/ddp.yaml",
    "chars": 81,
    "preview": "_target_: lightning.pytorch.strategies.DDPStrategy\nfind_unused_parameters: false\n"
  },
  {
    "path": "configs/strategy/fsdp.yaml",
    "chars": 142,
    "preview": "# TODO(yair): Currenly not compatible with grad clipping\n_target_: lightning.pytorch.strategies.FSDPStrategy\nsharding_st"
  },
  {
    "path": "dataloader.py",
    "chars": 29152,
    "preview": "import functools\nimport itertools\nimport json\nimport math\nimport os\nimport re\nimport shutil\nimport typing\nimport urllib\n"
  },
  {
    "path": "discrete_diffusion_harness.py",
    "chars": 15120,
    "preview": "import torch\nfrom omegaconf import OmegaConf\n\nfrom lm_eval.api.model import LM\nfrom lm_eval.api.registry import register"
  },
  {
    "path": "main.py",
    "chars": 11259,
    "preview": "import json\nimport os\n\nimport fsspec\nimport hydra\nimport lightning as L\nfrom lightning.fabric import Fabric\nimport omega"
  },
  {
    "path": "metrics.py",
    "chars": 7221,
    "preview": "import math\nimport os\nimport typing\n\nimport torch\nimport torch.nn.functional as F\nimport torchmetrics\nimport transformer"
  },
  {
    "path": "models/__init__.py",
    "chars": 54,
    "preview": "from . import dit\nfrom . import ema\nfrom . import unet"
  },
  {
    "path": "models/dit.py",
    "chars": 15825,
    "preview": "import math\nimport typing\n\nimport einops\nimport flash_attn\nimport flash_attn.layers.rotary\nimport huggingface_hub\nimport"
  },
  {
    "path": "models/ema.py",
    "chars": 3320,
    "preview": "import torch\n\n\nclass ExponentialMovingAverage:\n  \"\"\"\n  Maintains (exponential) moving average of a set of parameters.\n  "
  },
  {
    "path": "models/unet.py",
    "chars": 17152,
    "preview": "import torch\nimport torch.nn as nn\nimport math\nimport torch.nn.functional as F\nimport numpy as np\nimport omegaconf\n\nimpo"
  },
  {
    "path": "models/unit_test_attention.py",
    "chars": 3719,
    "preview": "import unittest\n\nimport torch\n\n# from flash_attn import flash_attention\nimport torch.nn.functional as F\n\n\ndef attention_"
  },
  {
    "path": "requirements.txt",
    "chars": 573,
    "preview": "# conda install nvidia/label/cuda-12.4.0::cuda-toolkit\ndatasets==2.15.0\neinops==0.7.0\nfsspec\ngit-lfs==1.6\nh5py==3.10.0\nh"
  },
  {
    "path": "scripts/distil_owt.sh",
    "chars": 1427,
    "preview": "#!/bin/bash\n#SBATCH -J posterior                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)"
  },
  {
    "path": "scripts/eval_lm1b_duo.sh",
    "chars": 1171,
    "preview": "#!/bin/bash\n#SBATCH -J eval_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/eval_owt_ar.sh",
    "chars": 1158,
    "preview": "#!/bin/bash\n#SBATCH -J eval_ar                # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#S"
  },
  {
    "path": "scripts/eval_owt_duo.sh",
    "chars": 1137,
    "preview": "#!/bin/bash\n#SBATCH -J owt_duo_anneal                    # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (ou"
  },
  {
    "path": "scripts/eval_owt_mdlm.sh",
    "chars": 1078,
    "preview": "#!/bin/bash\n#SBATCH -J eval_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/eval_owt_sedd.sh",
    "chars": 1187,
    "preview": "#!/bin/bash\n#SBATCH -J eval_sedd              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#S"
  },
  {
    "path": "scripts/fid_cifar10_duo_ancestral_cosine.sh",
    "chars": 390,
    "preview": "export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1\npython -u -m main \\\n    mode=fid_eval \\\n    sampling.steps=64 \\\n    sampling.g"
  },
  {
    "path": "scripts/fid_cifar10_duo_base_ancestral_cosine.sh",
    "chars": 390,
    "preview": "export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1\npython -u -m main \\\n    mode=fid_eval \\\n    sampling.steps=64 \\\n    sampling.g"
  },
  {
    "path": "scripts/fid_cifar10_mdlm_ancestral_cosine.sh",
    "chars": 359,
    "preview": "python -u -m main \\\n    mode=fid_eval \\\n    sampling.steps=64 \\\n    sampling.guid_weight=1.0 \\\n    sampling.predictor=an"
  },
  {
    "path": "scripts/gen_ppl_lm1b_ar.sh",
    "chars": 1106,
    "preview": "#!/bin/bash\n#SBATCH -J sample_ar                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/gen_ppl_lm1b_duo.sh",
    "chars": 1260,
    "preview": "#!/bin/bash\n#SBATCH -J sample_ar                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/gen_ppl_owt_ar.sh",
    "chars": 1288,
    "preview": "#!/bin/bash\n#SBATCH -J sample_ar                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/gen_ppl_owt_duo.sh",
    "chars": 1613,
    "preview": "#!/bin/bash\n#SBATCH -J an_owt_duo                    # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & "
  },
  {
    "path": "scripts/gen_ppl_owt_mdlm.sh",
    "chars": 1111,
    "preview": "#!/bin/bash\n#SBATCH -J sample_mdlm                # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/gen_ppl_owt_sedd.sh",
    "chars": 1662,
    "preview": "#!/bin/bash\n#SBATCH -J sedd_samples               # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_constant_remdm.sh",
    "chars": 819,
    "preview": "# DUO psi-sampler with constant-remdm-eta mode (ReMDM loop)\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-T"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_max_capped_remdm.sh",
    "chars": 788,
    "preview": "# DUO psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.005\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-DU"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_max_rescale_eta.sh",
    "chars": 799,
    "preview": "# DUO psi-sampler with max-rescale-eta mode — CIFAR-10 FID eval\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCHECKPOINT_PATH=<PA"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_psi_pc.sh",
    "chars": 1013,
    "preview": "# DUO psi-sampler with constant kappa in pc phase\n# Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_constant_remdm.sh",
    "chars": 837,
    "preview": "# MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop) — CIFAR-10 FID eval\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCH"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_max_capped_remdm.sh",
    "chars": 786,
    "preview": "# MDLM psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.005\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-M"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_max_rescale_eta.sh",
    "chars": 777,
    "preview": "# MDLM psi-sampler with max-rescale-eta mode\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-MDLM-CHECKPOI"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_psi_pc.sh",
    "chars": 977,
    "preview": "# MDLM psi-sampler with constant kappa during pc phase\n# Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pu"
  },
  {
    "path": "scripts/psi_samplers/owt/duo_loop_remdm.sh",
    "chars": 872,
    "preview": "# DUO psi-sampler with constant-remdm-eta mode (ReMDM loop)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.95\nNOISE=log-linear\nCHEC"
  },
  {
    "path": "scripts/psi_samplers/owt/duo_max_capped_remdm.sh",
    "chars": 839,
    "preview": "# DUO psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHECKPOINT"
  },
  {
    "path": "scripts/psi_samplers/owt/duo_max_rescale_eta.sh",
    "chars": 847,
    "preview": "# DUO psi-sampler with max-rescale-eta mode (ReMDM rescale)\n\nNUM_STEPS=256\nETA=0.05\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHECK"
  },
  {
    "path": "scripts/psi_samplers/owt/mdlm_loop_remdm.sh",
    "chars": 869,
    "preview": "# MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.95\nNOISE=log-linear\nCHE"
  },
  {
    "path": "scripts/psi_samplers/owt/mdlm_max_capped_remdm.sh",
    "chars": 836,
    "preview": "# MDLM psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHECKPOIN"
  },
  {
    "path": "scripts/psi_samplers/owt/mdlm_max_rescale_eta.sh",
    "chars": 844,
    "preview": "# MDLM psi-sampler with max-rescale-eta mode (ReMDM rescale)\n\nNUM_STEPS=256\nETA=0.05\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHEC"
  },
  {
    "path": "scripts/train_cifar10_duo_base_cosine.sh",
    "chars": 632,
    "preview": "python -u -m main \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    algo=duo_base \\\n    a"
  },
  {
    "path": "scripts/train_cifar10_duo_cosine.sh",
    "chars": 632,
    "preview": "python -u -m main \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    algo=duo_base \\\n    a"
  },
  {
    "path": "scripts/train_cifar10_mdlm_cosine.sh",
    "chars": 620,
    "preview": "python -u -m main \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    algo=mdlm \\\n    algo."
  },
  {
    "path": "scripts/train_lm1b_ar.sh",
    "chars": 1057,
    "preview": "#!/bin/bash\n#SBATCH -J train_ar_lm1b              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_lm1b_ar_sentencepacking.sh",
    "chars": 1069,
    "preview": "#!/bin/bash\n#SBATCH -J train_ar_lm1b              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_lm1b_d3pm.sh",
    "chars": 1101,
    "preview": "#!/bin/bash\n#SBATCH -J train_d3pm                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_lm1b_duo.sh",
    "chars": 1379,
    "preview": "#!/bin/bash\n#SBATCH -J duo-lm1b                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_lm1b_duo_sentencepacking.sh",
    "chars": 1384,
    "preview": "#!/bin/bash\n#SBATCH -J duo-lm1b                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_lm1b_mdlm.sh",
    "chars": 1056,
    "preview": "#!/bin/bash\n#SBATCH -J lm1b_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_lm1b_mdlm_sentencepacking.sh",
    "chars": 1072,
    "preview": "#!/bin/bash\n#SBATCH -J lm1b_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_owt_duo.sh",
    "chars": 1393,
    "preview": "#!/bin/bash\n#SBATCH -J duo-lm1b                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_owt_duo_finetune.sh",
    "chars": 1411,
    "preview": "#!/bin/bash\n#SBATCH -J duo-base                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_owt_mdlm.sh",
    "chars": 1134,
    "preview": "#!/bin/bash\n#SBATCH -J train_mdlm                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/train_owt_sedd.sh",
    "chars": 1143,
    "preview": "#!/bin/bash\n#SBATCH -J train_sedd                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j exp"
  },
  {
    "path": "scripts/zero_shot_ar.sh",
    "chars": 1455,
    "preview": "#!/bin/bash\n#SBATCH -J zeroshot_ar                # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/zero_shot_duo.sh",
    "chars": 1486,
    "preview": "#!/bin/bash\n#SBATCH -J zeroshot_duo               # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "scripts/zero_shot_mdlm.sh",
    "chars": 1492,
    "preview": "#!/bin/bash\n#SBATCH -J zeroshot_mdlm_noeos              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out"
  },
  {
    "path": "scripts/zero_shot_sedd.sh",
    "chars": 1513,
    "preview": "#!/bin/bash\n#SBATCH -J zeroshot_sedd              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err"
  },
  {
    "path": "trainer_base.py",
    "chars": 35968,
    "preview": "import itertools\nfrom dataclasses import dataclass\n\nimport hydra.utils\nimport lightning as L\nimport numpy as np\nimport t"
  },
  {
    "path": "utils.py",
    "chars": 23672,
    "preview": "\"\"\"Console logger utilities.\n\nCopied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py\nCo"
  }
]

// ... and 2 more files (download for full content)

About this extraction

This page contains the full source code of the s-sahoo/duo GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 110 files (270.1 KB), approximately 76.6k tokens, and a symbol index with 348 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!