Repository: s-sahoo/duo Branch: main Commit: 492505208b36 Files: 110 Total size: 270.1 KB Directory structure: gitextract_9im7g5qx/ ├── .gitignore ├── LICENSE ├── README.md ├── algo.py ├── configs/ │ ├── algo/ │ │ ├── ar.yaml │ │ ├── d3pm.yaml │ │ ├── distillation.yaml │ │ ├── duo.yaml │ │ ├── duo_base.yaml │ │ ├── mdlm.yaml │ │ ├── ot-finetune.yaml │ │ └── sedd.yaml │ ├── callbacks/ │ │ ├── checkpoint_every_n_steps.yaml │ │ ├── checkpoint_monitor.yaml │ │ ├── grad_record.yaml │ │ └── learning_rate_monitor.yaml │ ├── config.yaml │ ├── data/ │ │ ├── ag_news.yaml │ │ ├── cifar10.yaml │ │ ├── fineweb-edu.yaml │ │ ├── lambada.yaml │ │ ├── lm1b-gpt2.yaml │ │ ├── lm1b-streaming.yaml │ │ ├── lm1b-wrap.yaml │ │ ├── lm1b.yaml │ │ ├── openwebtext-split.yaml │ │ ├── openwebtext-streaming.yaml │ │ ├── openwebtext.yaml │ │ ├── ptb.yaml │ │ ├── scientific_papers_arxiv.yaml │ │ ├── scientific_papers_pubmed.yaml │ │ ├── synthetic.yaml │ │ ├── text8-crop.yaml │ │ ├── text8.yaml │ │ ├── wikitext103.yaml │ │ └── wikitext2.yaml │ ├── lr_scheduler/ │ │ ├── constant_warmup.yaml │ │ ├── cosine_decay_warmup.yaml │ │ └── step_scheduler.yaml │ ├── model/ │ │ ├── medium.yaml │ │ ├── small.yaml │ │ ├── tiny-dimamba.yaml │ │ ├── tiny.yaml │ │ └── unet.yaml │ ├── noise/ │ │ ├── cosine.yaml │ │ └── log-linear.yaml │ ├── prior/ │ │ └── none.yaml │ └── strategy/ │ ├── ddp.yaml │ └── fsdp.yaml ├── dataloader.py ├── discrete_diffusion_harness.py ├── integral/ │ ├── bert-base-uncased.pkl │ └── gpt2.pkl ├── main.py ├── metrics.py ├── models/ │ ├── __init__.py │ ├── dit.py │ ├── ema.py │ ├── unet.py │ └── unit_test_attention.py ├── requirements.txt ├── scripts/ │ ├── distil_owt.sh │ ├── eval_lm1b_duo.sh │ ├── eval_owt_ar.sh │ ├── eval_owt_duo.sh │ ├── eval_owt_mdlm.sh │ ├── eval_owt_sedd.sh │ ├── fid_cifar10_duo_ancestral_cosine.sh │ ├── fid_cifar10_duo_base_ancestral_cosine.sh │ ├── fid_cifar10_mdlm_ancestral_cosine.sh │ ├── gen_ppl_lm1b_ar.sh │ ├── gen_ppl_lm1b_duo.sh │ ├── gen_ppl_owt_ar.sh │ ├── gen_ppl_owt_duo.sh │ ├── gen_ppl_owt_mdlm.sh │ ├── gen_ppl_owt_sedd.sh │ ├── psi_samplers/ │ │ ├── cifar10/ │ │ │ ├── duo_constant_remdm.sh │ │ │ ├── duo_max_capped_remdm.sh │ │ │ ├── duo_max_rescale_eta.sh │ │ │ ├── duo_psi_pc.sh │ │ │ ├── mdlm_constant_remdm.sh │ │ │ ├── mdlm_max_capped_remdm.sh │ │ │ ├── mdlm_max_rescale_eta.sh │ │ │ └── mdlm_psi_pc.sh │ │ └── owt/ │ │ ├── duo_loop_remdm.sh │ │ ├── duo_max_capped_remdm.sh │ │ ├── duo_max_rescale_eta.sh │ │ ├── mdlm_loop_remdm.sh │ │ ├── mdlm_max_capped_remdm.sh │ │ └── mdlm_max_rescale_eta.sh │ ├── train_cifar10_duo_base_cosine.sh │ ├── train_cifar10_duo_cosine.sh │ ├── train_cifar10_mdlm_cosine.sh │ ├── train_lm1b_ar.sh │ ├── train_lm1b_ar_sentencepacking.sh │ ├── train_lm1b_d3pm.sh │ ├── train_lm1b_duo.sh │ ├── train_lm1b_duo_sentencepacking.sh │ ├── train_lm1b_mdlm.sh │ ├── train_lm1b_mdlm_sentencepacking.sh │ ├── train_owt_duo.sh │ ├── train_owt_duo_finetune.sh │ ├── train_owt_mdlm.sh │ ├── train_owt_sedd.sh │ ├── zero_shot_ar.sh │ ├── zero_shot_duo.sh │ ├── zero_shot_mdlm.sh │ └── zero_shot_sedd.sh ├── trainer_base.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .DS_Store .hf_cache test/ outputs/ wandb/ watch_folder/ notes.md grid_search.sh *.ipynb # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2025 Subham Sekhar Sahoo Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
# The Diffusion Duality Series
## [Chapter I (ICML 2025)](https://arxiv.org/abs/2506.10892) By [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Justin Deschenaux](https://jdeschena.com), [Aaron Gokaslan](https://skylion007.github.io), [Guanghan Wang](https://tech.cornell.edu/people/guanghan-wang/), [Justin Chiu](https://justinchiu.netlify.app), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/) [![GitHub](https://img.shields.io/badge/GitHub-Repo-181717?logo=github&logoColor=white)](https://github.com/s-sahoo/duo/tree/ch-1) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Sf7R-dqdR6gq-H8nyZ9E3ZkyvqMTqcwq?usp=sharing) [![YouTube](https://img.shields.io/badge/YouTube-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/FCO-nnqHOqQ?si=4eGnj5zbRgyCYWwI) [![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](http://s-sahoo.github.io/duo) [![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2506.10892) [![deploy](https://img.shields.io/badge/🤗-Huggingface-blue)](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875) **Unlocks few-step generation in discrete diffusion-LLMs via the underlying Gaussian diffusion.**
## [Chapter II: Ψ-Samplers and Efficient Curriculum (ICLR 2026)](https://arxiv.org/abs/2602.21185) By [Justin Deschenaux](https://jdeschena.com), [Caglar Gulcehre](https://www.caglar.ai), [Subham Sekhar Sahoo](https://s-sahoo.github.io) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1uFSzrfG0KXhGcohRIfWIM2Y7V9Q7cQNA?usp=sharing) [![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](http://s-sahoo.github.io/duo-ch2) [![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2602.21185) **Uniform-state beats Masked diffusion on text and image generation!**
This repository contains the code for the two papers in the Diffusion Duality series. It includes: - **Duo / $\text{Duo}^\text{++}$** sampling (ancestral, ReMDM, $\Psi$-samplers, greedy-tail) — [Sampling & Eval](#sampling--eval) - Original and efficient curriculum training strategies — [Training](#training) - Discrete Consistency Distillation (DCD) — [Distillation](#discrete-consistency-distillation) - Baselines (AR, MDLM, SEDD, D3PM) — [Baselines](#baselines) [Getting Started](#getting-started) | [Checkpoints](#checkpoints) | [Citation](#acknowledgements--citation) # Getting Started To get started, create a conda environment containing the required dependencies. ```bash conda create -n duo python=3.12 conda activate duo conda install nvidia/label/cuda-12.4.0::cuda-toolkit pip install -r requirements.txt pip install flash_attn==2.7.4.post1 ``` # Checkpoints * **Duo** (Language Modeling): Trained on OpenWebText for `1M` training steps (distilled / base): * [Huggingface](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875)🤗. * [Google Drive folder](https://drive.google.com/drive/folders/1JpqFM8XRvifwIkjWPfMyuDvu41r1yk0t?usp=share_link) as the HF checkpoints can't be finetuned. * **Duo** (Image Modeling): Trained on CIFAR-10 * [Huggingface (contains the raw checkpoints)](https://huggingface.co/jdeschena/duo2-cifar10) * **Baselines** (SEDD, MDLM, AR): Trained on OpenWebText * [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing) — download `ar.ckpt`, `mdlm.ckpt`, `sedd.ckpt`. # Training This repo implements the original Duo curriculum, as well as the fast $\text{Duo}^\text{++}$ curriculum. By default, the training scripts use the original curriculum. To enable the efficient curriculum, simply replace `algo.curriculum.mode=simple` by `algo.curriculum.mode=poly9` (see comments in each training script). To train $\text{Duo}^\text{++}$, use the following scripts: * LM1B * w/ sentencepacking (same as in D3PM) * Training script: [`scripts/train_lm1b_duo_sentencepacking.sh`](./scripts/train_lm1b_duo_sentencepacking.sh) * [Wandb run](https://api.wandb.ai/links/kuleshov-group/huwt0ek3) * w/o sentencepacking (same as in MDLM, SEDD) * Training script: [`scripts/train_lm1b_duo.sh`](./scripts/train_lm1b_duo.sh) * [Wandb run](https://api.wandb.ai/links/sahoo-diffusion/lkv5z3tm) * OWT: [`scripts/train_owt_duo.sh`](./scripts/train_owt_duo.sh). * CIFAR-10: * Duo: [`scripts/train_cifar10_duo_cosine.sh`](./scripts/train_cifar10_duo_cosine.sh) * MDLM: [`scripts/train_cifar10_mdlm_cosine.sh`](./scripts/train_cifar10_mdlm_cosine.sh) * Both scripts default to a cosine noise schedule. To use log-linear instead, set `noise=log-linear`. **Notes:** * Run `mkdir watch_folder` to create a directory to store slurm logs, and then run any script in [`scripts/`](scripts) as a slurm job: `sbatch scripts/ABC_XYZ.sh` * Control the batch size per GPU using the argument `loader.batch_size`. If `loader.batch_size * num_gpus < loader.global_batch_size`, PyTorch Lightning resorts to gradient accumulation. # Discrete Consistency Distillation To distill a model using the Discrete Consistency Distillation (`Alg. 1` in the [Duo paper](https://arxiv.org/abs/2506.10892)), use [`scripts/distil_owt.sh`](scripts/distil_owt.sh). # Sampling & Eval ## Likelihood To compute test perplexity on the validation set of OWT use [`scripts/eval_owt_duo.sh`](scripts/eval_owt_duo.sh) and for zero shot perplexities use [`scripts/zero_shot_duo.sh`](scripts/zero_shot_duo.sh). ## Sampling You can sample with ancestral sampling using the scripts in [`scripts/gen_ppl_*.sh`](scripts/). To sample with the PC samplers such as ReMDM and our $\Psi$-samplers, use the scripts in [`scripts/psi_samplers`](scripts/psi_samplers). This directory contains examples for sampling text and images. To use the "Greedy-tail sampler" (equivalent to nucleus sampling in AR models; see `Sec. 4.2` in the paper), set `sampling.noise_removal=greedy`. Using the default `sampling.noise_removal=ancestral` will produce more diverse samples (higher entropy) but with worse generative perplexity. To sample from a HuggingFace checkpoint (text only), run the following command: ```bash python main.py \ mode=sample_eval \ loader.batch_size=2 \ loader.eval_batch_size=8 \ data=openwebtext-split \ algo=duo_base \ algo.backbone=hf_dit \ eval.checkpoint_path=s-sahoo/duo-distilled \ sampling.steps=8 \ sampling.num_sample_batches=1 \ sampling.noise_removal=greedy \ +wandb.offline=true ``` To use the example scripts with raw checkpoints (see [Checkpoints](#checkpoints)), download them and set the checkpoint path in the script. # Baselines Download the baseline checkpoints (see [Checkpoints](#checkpoints)) and specify the paths appropriately in the respective shell scripts: * [`scripts/eval_owt_*.sh`](scripts/) for computing validation perplexity on OWT. * [`scripts/gen_ppl_*.sh`](scripts/) for generating text samples and evaluating them. * [`scripts/zero_shot_*.sh`](scripts/) for computing zero shot perplexities. * [`scripts/train_*.sh`](scripts/) for training the models. # Acknowledgements & Citation This repository was built off of [MDLM's Github repository](https://github.com/kuleshov-group/mdlm). Cite our papers using: ``` @inproceedings{ sahoo2025the, title={The Diffusion Duality}, author={Subham Sekhar Sahoo and Justin Deschenaux and Aaron Gokaslan and Guanghan Wang and Justin T Chiu and Volodymyr Kuleshov}, booktitle={Forty-second International Conference on Machine Learning}, year={2025}, url={https://openreview.net/forum?id=9P9Y8FOSOk} } @inproceedings{ deschenaux2026the, title={The Diffusion Duality, Chapter {II}: \${\textbackslash}Psi\$-Samplers and Efficient Curriculum}, author={Justin Deschenaux and Caglar Gulcehre and Subham Sekhar Sahoo}, booktitle={The Fourteenth International Conference on Learning Representations}, year={2026}, url={https://openreview.net/forum?id=RSIoYWIzaP} } ``` ================================================ FILE: algo.py ================================================ import os import collections import copy import pickle from typing import Optional import fsspec import numpy as np import torch import torch.nn.functional as F import trainer_base import utils class AR(trainer_base.TrainerBase): def __init__(self, config, tokenizer): vocab_size = tokenizer.vocab_size if (not hasattr(tokenizer, 'mask_token') or tokenizer.mask_token is None): self.mask_index = vocab_size vocab_size += 1 else: self.mask_index = tokenizer.mask_token_id super().__init__(config, tokenizer, vocab_size=vocab_size) self.save_hyperparameters() self._validate_configuration() def _validate_configuration(self): super()._validate_configuration() assert not self.config.algo.time_conditioning assert self.config.prior.type == 'none' def _process_model_input(self, x0, valid_tokens): input_tokens = x0[:, :-1] output_tokens = x0[:, 1:] valid_tokens = valid_tokens[:, 1:] return input_tokens, output_tokens, valid_tokens def nll(self, input_tokens, output_tokens, current_accumulation_step): del current_accumulation_step output = self.backbone(input_tokens, None) output[:, :, self.mask_index] = self.neg_infinity output = output.log_softmax(-1) return - output.gather( -1, output_tokens[:, :, None])[:, :, 0] def generate_samples(self, num_samples, **kwargs): # precompute token buffer num_pred_tokens = self.num_tokens - 1 x = torch.zeros( (num_samples, num_pred_tokens + 1), dtype=torch.long, device=self.device) x[:, 0] = self.tokenizer.bos_token_id # precompute noise noise = (torch.distributions.Gumbel(0, 1) .sample((num_samples, num_pred_tokens, self.vocab_size)) .to(self.device)) if self.config.sampling.use_float64: noise = noise.to(torch.float64) for i in range(num_pred_tokens): output = self.backbone(x[:, :i + 1], None) output[:, :, self.mask_index] = self.neg_infinity output = output.log_softmax(-1) y = (output[:, -1, :] + noise[:, i, :]).argmax(-1) x[:, i + 1] = y return x def _process_sigma(self, sigma): del sigma return None class MDLM(trainer_base.AbsorbingState): def __init__(self, config, tokenizer): super().__init__(config, tokenizer) self._validate_configuration() def _validate_configuration(self): assert self.sampler != 'ancestral', \ 'sampling.predictor=ancestral is not desirable because ' \ 'it is slow. Please set sampling.predictor=ancestral_cache' def _process_model_output(self, model_output, xt, sigma): del sigma model_output[:, :, self.mask_index] += self.neg_infinity # Normalize the model_output such that x.exp() is # a probability distribution over vocab_size. model_output = model_output - torch.logsumexp( model_output, dim=-1, keepdim=True) # Apply updates directly in the logits matrix. # For the logits of the unmasked tokens, set all values # to -infinity except for the indices corresponding to # the unmasked tokens. unmasked_indices = (xt != self.mask_index) model_output[unmasked_indices] = self.neg_infinity model_output[unmasked_indices, xt[unmasked_indices]] = 0 return model_output def nll_per_token(self, log_x_theta, xt, x0, alpha_t, dalpha_t, low_var=False): del xt log_p_theta = torch.gather( input=log_x_theta, dim=-1, index=x0[:, :, None]).squeeze(-1) return log_p_theta * dalpha_t / (1 - alpha_t) def _get_score(self, x, sigma): model_output = self.forward(x, sigma) # score(x, t) = p_t(y) / p_t(x) # => log score(x, t) = log p_t(y) - log p_t(x) # case 1: x = masked # (i) y = unmasked # log score(x, t) = log p_\theta(x)|_y + log k # where k = exp(- sigma) / (1 - exp(- sigma)) # (ii) y = masked # log score(x, t) = 0 # case 2: x = unmasked # (i) y != masked, y != x # log score(x_i, t) = - inf # (ii) y = x # log score(x_i, t) = 0 # (iii) y = masked token # log score(x_i, t) = - log k # where k = exp(- sigma) / (1 - exp(- sigma)) log_k = - torch.log(torch.expm1(sigma)).squeeze(-1) assert log_k.ndim == 1 masked_score = model_output + log_k[:, None, None] masked_score[:, :, self.mask_index] = 0 unmasked_score = self.neg_infinity * torch.ones_like( model_output) unmasked_score = torch.scatter( unmasked_score, -1, x[..., None], torch.zeros_like(unmasked_score[..., :1])) unmasked_score[:, :, self.mask_index] = - ( log_k[:, None] * torch.ones_like(x)) masked_indices = (x == self.mask_index).to( model_output.dtype)[:, :, None] model_output = ( masked_score * masked_indices + unmasked_score * (1 - masked_indices)) return model_output.exp() class D3PMAbsorb(trainer_base.AbsorbingState): def __init__(self, config, tokenizer): super().__init__(config, tokenizer) self._validate_configuration() def _validate_configuration(self): super()._validate_configuration() assert self.noise.type == 'log-linear' assert self.parameterization == 'mean' def _process_model_output(self, model_output, xt, sigma): del xt del sigma if self.subs_masking: model_output[:, :, self.mask_index] += self.neg_infinity return model_output.log_softmax(dim=-1) def nll_per_token(self, log_x_theta, xt, x0, alpha_t, dalpha_t, low_var=False): del dalpha_t assert not low_var dt = 1 / self.T t = 1 - alpha_t # Only valid for log-linear schedule. t = t.clamp(0., 1.0 - 1e-4) alpha_t = alpha_t + torch.zeros_like(xt) alpha_s = t - dt + torch.zeros_like(xt) assert alpha_s.shape == xt.shape assert alpha_t.shape == xt.shape log_x_theta_at_x0 = torch.gather( log_x_theta, -1, x0[:, :, None]).squeeze(-1) log_x_theta_at_m = log_x_theta[:, :, self.mask_index] x_theta_at_m = log_x_theta_at_m.exp() term_1_coef = dt / t term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1) term_1_log_dr = log_x_theta_at_x0 term_2_coef = 1 - dt / t term_2_log_nr = term_1_log_nr term_2_log_dr = torch.log( alpha_s * x_theta_at_m / (t - dt) + 1) L_vb_masked = ( term_1_coef * (term_1_log_nr - term_1_log_dr) + term_2_coef * (term_2_log_nr - term_2_log_dr)) diffusion_loss = self.T * L_vb_masked * (xt == self.mask_index) return self._reconstruction_loss(x0) + diffusion_loss class SEDDAbsorb(trainer_base.AbsorbingState): def __init__(self, config, tokenizer): super().__init__(config, tokenizer) self._validate_configuration() def _validate_configuration(self): super()._validate_configuration() assert self.config.sampling.predictor == 'analytic' def _get_score(self, x, sigma): return self.forward(x, sigma).exp() def _process_model_output(self, model_output, xt, sigma): esigm1_log = torch.where( sigma < 0.5, torch.expm1(sigma), sigma.exp() - 1).log().to(model_output.dtype) # logits shape # (batch_size, context_length, vocab_size) model_output = (model_output - esigm1_log[:, None, None] - np.log(model_output.shape[-1] - 1)) # The below scatter operation sets the log score # for the input word to 0. model_output = torch.scatter( model_output, -1, xt[..., None], torch.zeros_like(model_output[..., :1])) return model_output def nll_per_token(self, log_x_theta, xt, x0, alpha_t, dalpha_t, low_var=False): """Computes the SEDD loss for the Absorbing State Diffusion. Args: log_x_theta: float torch.Tensor with shape (batch_size, context_length, vocab_size), log score, output of the denoising network. xt: int torch.Tensor with shape (batch_size, context_length), input. x0: int torch.Tensor with shape (batch_size, context_length), input. alpha_t: float torch.Tensor with shape (batch_size, 1), signal level. alpha_t: float torch.Tensor with shape (batch_size, 1), signal level. dalpha_t: float or float torch.Tensor with shape (batch_size, 1), time derivative of signal level. low_var: bool, low variance loss during training. Returns: loss with shape (batch_size, context_length). """ assert not low_var masked_indices = xt == self.mask_index sigma = self._sigma_from_alphat(alpha_t) dsigma = - dalpha_t / alpha_t expsig_minus_1 = torch.expm1(sigma).expand_as(xt) q_ratio = 1 / expsig_minus_1[masked_indices] words_that_were_masked = x0[masked_indices] neg_term = q_ratio * torch.gather( log_x_theta[masked_indices], -1, words_that_were_masked[..., None]).squeeze(-1) score = log_x_theta[masked_indices].exp() if self.mask_index == self.vocab_size - 1: pos_term = score[:, :-1].sum(dim=-1) else: pos_term = score[:, : self.mask_index].sum( dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1) const = q_ratio * (q_ratio.log() - 1) entropy = torch.zeros(* xt.shape, device=xt.device) entropy[masked_indices] += pos_term - neg_term + const return dsigma * entropy class DUO_BASE(trainer_base.UniformState): def __init__(self, config, tokenizer): super().__init__(config, tokenizer) self._validate_configuration() def on_save_checkpoint(self, checkpoint): checkpoint['state_dict'] = collections.OrderedDict( (k, v) for k, v in checkpoint['state_dict'].items() if not k.startswith('teacher')) super().on_save_checkpoint(checkpoint) def on_load_checkpoint(self, checkpoint): checkpoint['state_dict'] = collections.OrderedDict( (k, v) for k, v in checkpoint['state_dict'].items() if not k.startswith('teacher')) super().on_load_checkpoint(checkpoint) def _process_model_output(self, model_output, xt, sigma): del xt, sigma return model_output.log_softmax(dim=-1) def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t): """Computes the posterior / approximate posterior. Args: x: Either clean input `x0` (one-hot), or model's predicted `x_theta` of shape (B, L, V). xt: The noisy latent (as indices) of shape (B, L). alpha_s: Noise level at s of shape (B, [L | 1], 1). alpha_t: Noise level at t of shape (B, [L | 1], 1). Returns: Posterior / approximate posterior of shape (B, L, V). """ if self.config.sampling.use_float64: x0 = x0.to(torch.float64) if alpha_s.ndim == 2: alpha_s = alpha_s.unsqueeze(-1) if alpha_t.ndim == 2: alpha_t = alpha_t.unsqueeze(-1) alpha_ts = alpha_t / alpha_s d_alpha = alpha_s - alpha_t xt_one_hot = F.one_hot(xt, self.vocab_size).to( self.dtype).to(self.device) return ( (alpha_t * self.vocab_size * x0 * xt_one_hot + ( alpha_ts - alpha_t) * xt_one_hot + d_alpha * x0 + ( 1 - alpha_ts) * (1 - alpha_s) / self.vocab_size) / ( alpha_t * self.vocab_size * torch.gather( x0, -1, xt[..., None]) + (1 - alpha_t))) def nll_per_token(self, log_x_theta, xt, x0, alpha_t, dalpha_t, low_var=False): assert alpha_t.ndim == 2 assert x0.ndim == 2 assert xt.ndim == 2 assert not torch.is_tensor(dalpha_t) or dalpha_t.ndim == 2 x_reconst = log_x_theta.exp() x_bar_theta = self.vocab_size * alpha_t[ :, :, None] * x_reconst + 1 - alpha_t[:, :, None] coeff = dalpha_t / (self.vocab_size * alpha_t) x_eq_xt = (x0 == xt).float() x_neq_xt = 1 - x_eq_xt xbar_xt = (1 - alpha_t) + self.vocab_size * alpha_t * x_eq_xt xbar_theta_xt = torch.gather( x_bar_theta, -1, xt.unsqueeze(-1)).squeeze(-1) xbar_theta_x = torch.gather( x_bar_theta, -1, x0.unsqueeze(-1)).squeeze(-1) term1 = self.vocab_size * (1 / xbar_xt - 1 / xbar_theta_xt) const = (1 - alpha_t) / (self.vocab_size * alpha_t + 1 - alpha_t) term2_coefs = x_eq_xt * const + x_neq_xt term2_offset = ((self.vocab_size - 1) * const * x_eq_xt - (1 / const) * x_neq_xt) * const.log() term2_theta = - term2_coefs * ( x_bar_theta.log().sum(-1) - self.vocab_size * xbar_theta_xt.log()) term2_theta = ( term2_theta - self.vocab_size * alpha_t / (1 - alpha_t) * ( xbar_theta_x.log() - xbar_theta_xt.log()) * x_neq_xt) term2 = term2_theta + term2_offset diffusion_loss = coeff * (term1 - term2) assert diffusion_loss.ndim == 2 return diffusion_loss class Integral(torch.autograd.Function): """ torch module calculating UDLM's p_t """ @staticmethod def forward(ctx, gamma_t, data): gamma_max = data['gamma_max'] gamma_min = data['gamma_min'] if (gamma_t.max() > gamma_max) or ( gamma_t.min() < gamma_min): print('max:{} {}'.format(gamma_t.max(), gamma_max)) print('min:{} {}'.format(gamma_t.min(), gamma_min)) gamma_t = torch.clip(gamma_t, gamma_min, gamma_max) indices = torch.round( (data['num_points'] - 1) * (gamma_t - gamma_min) / ( gamma_max - gamma_min)).long() grad_pt = data['grad_pt'] ctx.grad_pt = grad_pt[indices] pt = data['pt'][indices] assert pt.shape == gamma_t.shape return pt @staticmethod def backward(ctx, grad_output): return ctx.grad_pt * grad_output, None class DUO(DUO_BASE): def __init__(self, config, tokenizer): super().__init__(config, tokenizer) self.gamma_min = self.config.algo.curriculum.gamma_min self.gamma_max = self.config.algo.curriculum.gamma_max self.gumbel_tau_log10_start = \ self.config.algo.curriculum.gumbel_tau_log10_start self.gumbel_tau_log10_end = \ self.config.algo.curriculum.gumbel_tau_log10_end self.curriculum_start = self.config.algo.curriculum.start self.curriculum_end = self.config.algo.curriculum.end self.loss_type = self.config.algo.loss_type self._initialize_curriculum_coefficients() self._validate_configuration() def _initialize_curriculum_coefficients(self): if self.config.algo.curriculum.mode in {'simple', 'efficient_cached'}: self._init_curriculum_cached() elif self.config.algo.curriculum.mode == 'series': self._init_curriculum_series() elif self.config.algo.curriculum.mode in {'sigmoid', 'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7', 'poly9'}: self._init_curriculum_approx() else: raise ValueError(self.config.algo.curriculum.mode) def _init_curriculum_cached(self): fpath = self.config.algo.curriculum.integral_cache_path with fsspec.open(fpath, 'rb') as f: self.integral_cache = pickle.load(f) self.integral_cache['pt'] = torch.from_numpy( self.integral_cache['pt']) self.integral_cache['grad_pt'] = torch.from_numpy( self.integral_cache['grad_pt']) def _init_curriculum_series(self): m, i = utils.compute_duo_series_coefficients( self.config.algo.curriculum.n_series_terms, self.vocab_size) self.register_buffer('coefficients_m', m, persistent=False) self.register_buffer('coefficients_i', i, persistent=False) self.register_buffer('power_arange', torch.arange(self.config.algo.curriculum.n_series_terms, dtype=torch.float64)[None], persistent=False) def _init_curriculum_approx(self): fname = f'{self.config.algo.curriculum.mode}.npy' fpath = os.path.join(self.config.algo.curriculum.cache_dir, fname) if not os.path.exists(fpath): # Compute the coefficients on the fly coefficients, _, _, _ = utils.compute_duo_operator_approx( num_coefficients=self.config.algo.curriculum.n_series_terms, vocab_size=self.vocab_size, gamma_min=self.gamma_min, gamma_max=self.gamma_max, fct_name=self.config.algo.curriculum.mode) # Tuples are for torch compile, tuples are immutable coefficients = tuple(coefficients) parent_dir = os.path.dirname(fpath) os.makedirs(parent_dir, exist_ok=True) np.save(fpath, coefficients) else: coefficients = tuple(np.load(fpath).tolist()) mode = self.config.algo.curriculum.mode if mode == 'sigmoid': fn = utils.duo_to_alpha_dalpha_sigmoid elif mode == 'sigmoid-edge-corrected': fn = utils.duo_t_to_alpha_dalpha_sigm_corrected elif mode in ('poly3', 'poly5', 'poly7', 'poly9'): fn = utils.duo_to_alpha_dalpha_poly else: raise ValueError(mode) fn = torch.compile(fn) self._t_to_alpha_dalpha_compiled = \ lambda t: fn(t, *coefficients) def to(self, *args, **kwargs): self = super().to(*args, **kwargs) self.integral_cache['pt'] = self.integral_cache[ 'pt'].to(*args, **kwargs) self.integral_cache['grad_pt'] = self.integral_cache[ 'grad_pt'].to(*args, **kwargs) return self def cuda(self, device=None): self = super().cuda(device=device) if hasattr(self, 'integral_cache'): self.integral_cache['pt'] = self.integral_cache[ 'pt'].cuda(device=device) self.integral_cache['grad_pt'] = self.integral_cache[ 'grad_pt'].cuda(device=device) return self def cpu(self): self = super().cpu() if hasattr(self, 'integral_cache'): self.integral_cache['pt'] = self.integral_cache[ 'pt'].cpu() self.integral_cache['grad_pt'] = self.integral_cache[ 'grad_pt'].cpu() return self def to(self, *args, **kwargs): self = super().to(*args, **kwargs) if hasattr(self, 'integral_cache'): self.integral_cache['pt'] = self.integral_cache[ 'pt'].to(*args, **kwargs) self.integral_cache['grad_pt'] = self.integral_cache[ 'grad_pt'].to(*args, **kwargs) return self def _compute_gumbel_tau_inverse(self): start = self.gumbel_tau_log10_start end = self.gumbel_tau_log10_end delta = end - start if self.global_step < self.curriculum_start: tau = start elif self.global_step < self.curriculum_end: frac = (self.global_step - self.curriculum_start) / ( self.curriculum_end - self.curriculum_start) tau = start + frac * delta else: tau = -10 return 10 ** (-tau) def training_step(self, batch, batch_idx): self.log(name='gumbel_tau_log10', value=1 / self._compute_gumbel_tau_inverse(), on_step=True, on_epoch=False, sync_dist=True) return super().training_step(batch, batch_idx) def _gamma_to_alpha_dalpha(self, gamma_t, t): if self.config.algo.curriculum.mode in ('simple', 'efficient_cached'): return self._gamma_to_alpha_dalpha_cached(gamma_t) elif self.config.algo.curriculum.mode == 'series': return utils.compute_duo_gamma_to_alpha_dalpha_series( gamma_t, self.coefficients_m, self.coefficients_i, self.power_arange, self.vocab_size, self.gamma_min, self.gamma_max) elif self.config.algo.curriculum.mode in ('sigmoid', 'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7', 'poly9'): return self._t_to_alpha_dalpha_compiled(t) else: raise ValueError(self.config.algo.curriculum.mode) def _gamma_to_alphat_integral(self, gamma_t): integral = Integral.apply(gamma_t, self.integral_cache) return (self.vocab_size * integral - 1) / ( self.vocab_size - 1) def _gamma_to_alpha_dalpha_cached(self, gamma_t): gamma_t_prime = self.gamma_max - self.gamma_min usdm_alpha_t = DUO._gamma_to_alphat_integral(self, gamma_t) T = 1000 usdm_dalpha_t = gamma_t_prime * T * ( DUO._gamma_to_alphat_integral(self, gamma_t + 1 / T) - usdm_alpha_t) return usdm_alpha_t, usdm_dalpha_t def _prior_loss(self): alpha_1 = self._gamma_to_alphat_integral( torch.tensor(self.gamma_max)) loss = ((alpha_1 + (1 - alpha_1) / self.vocab_size) \ * torch.log((self.vocab_size - 1) * alpha_1 + 1) \ + (1 - 1 / self.vocab_size) * (1 - alpha_1) \ * torch.log(1 - alpha_1)) return loss.item() def _q_xt_gaussian(self, x, gamma_t): """Computes the noisy sample xt.""" assert gamma_t.ndim == 1 assert x.ndim == 3 gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1) alpha_t = torch.sigmoid(-gamma_t).sqrt() sigma_t = torch.sigmoid(gamma_t).sqrt() epsilon = torch.randn(x.shape, dtype=torch.float32, device=self.device) return alpha_t * x + sigma_t * epsilon def nll(self, x0, output_tokens, current_accumulation_step=None, train_mode=False): use_true_nll = (self.global_step > self.curriculum_end or not train_mode) if use_true_nll: return super().nll(x0, output_tokens, current_accumulation_step) del output_tokens t = self._sample_t(x0.shape[0], current_accumulation_step) gamma_t = self.gamma_min + t * (self.gamma_max - self.gamma_min) usdm_alpha_t, usdm_dalpha_t = \ self._gamma_to_alpha_dalpha(gamma_t, t) usdm_alpha_t = usdm_alpha_t.unsqueeze(-1) assert usdm_alpha_t.ndim == 2 usdm_dalpha_t = usdm_dalpha_t.unsqueeze(-1) sigma = self._sigma_from_alphat(usdm_alpha_t) # Default Duo curriculum if self.config.algo.curriculum.mode == 'simple': x0_one_hot = F.one_hot(x0, self.vocab_size) xt = self._q_xt_gaussian(x0_one_hot, gamma_t) xt = xt * self._compute_gumbel_tau_inverse() xt_usdm = xt.argmax(-1) log_x_theta = self.forward(xt, sigma=sigma) else: # Efficient variant softmax_approx, topk_indices, xt_usdm = \ utils.sample_tempered_softmax_topk( extra_index=x0, alpha=torch.sigmoid(-gamma_t).sqrt(), sigma=torch.sigmoid(gamma_t).sqrt(), l=x0.shape[1], k=self.config.algo.curriculum.top_k, vocab_size=self.vocab_size, inverse_temperature=self._compute_gumbel_tau_inverse()) log_x_theta = self.forward(topk_indices, sigma=sigma, weights=softmax_approx) return self.nll_per_token(log_x_theta=log_x_theta, xt=xt_usdm, x0=x0, alpha_t=usdm_alpha_t, dalpha_t=usdm_dalpha_t, low_var=False) class Distillation(DUO): def __init__(self, config, tokenizer): super().__init__(config, tokenizer) self.update_teacher_every = config.algo.update_teacher_every self.save_hyperparameters() self.teacher = None self.teacher_ema = config.algo.teacher_ema self.linear_growth_dt = config.algo.linear_growth_dt self.linear_growth_min = config.algo.linear_growth_min self.linear_growth_max = config.algo.linear_growth_max def _validate_configuration(self): assert os.path.exists( self.config.algo.integral_cache_path), ( 'The integral cache (Eq. 10 in the paper) for ' f'the {self.config.data.tokenizer_name_or_path} ' ' tokenizer doesnt exist at ' f'{self.config.algo.integral_cache_path}. ' 'Please generate it by running the utils.py script, ' 'and ensure the correct path is specified using the ' 'algo.integral_cache_path flag.') assert self.loss_type in { 'kl-fwd', 'kl-bwd', 'posterior', 'kl-posterior'} def _maybe_update_teacher_weights(self): if self.global_step % self.update_teacher_every != 0: return if self.teacher_ema: self.ema.copy_to(self.teacher.parameters()) else: for better_param, current_param in zip( self.backbone.parameters(), self.teacher.parameters()): if current_param.requires_grad: current_param.data.copy_(better_param.data) @torch.no_grad() def _teacher_logits(self, xt, sigma): if self.teacher is None: self.teacher = copy.deepcopy(self.backbone) self._maybe_update_teacher_weights() sigma = self._process_sigma(sigma) with torch.amp.autocast('cuda', dtype=torch.float32): model_output = self.teacher(xt, sigma) logits = self._process_model_output( model_output=model_output, xt=xt, sigma=sigma) return logits.detach() def _sample_trajectory(self, x0, gamma_t, gamma_s): """Computes the noisy sample xt.""" assert gamma_t.ndim == 1 assert gamma_s.ndim == 1 assert x0.ndim == 2 x0 = F.one_hot(x0, self.vocab_size).to( self.dtype).to(self.device) gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1) alpha_t = torch.sigmoid(-gamma_t).sqrt() sigma_t = torch.sigmoid(gamma_t).sqrt() gamma_s = gamma_s.unsqueeze(-1).unsqueeze(-1) alpha_s = torch.sigmoid(-gamma_s).sqrt() sigma_s = torch.sigmoid(gamma_s).sqrt() epsilon = torch.randn(x0.shape, dtype=torch.float32, device=self.device) xt = alpha_t * x0 + sigma_t * epsilon xs = alpha_s * x0 + sigma_s * epsilon return xt, xs def _compute_dt(self): if self.linear_growth_dt: scale = self.global_step / self.trainer.max_steps return self.linear_growth_min + scale * ( self.linear_growth_max - self.linear_growth_min) n = self.global_step // self.update_teacher_every return 2 ** n / self.T def nll(self, x0, output_tokens, current_accumulation_step=None, train_mode=None): del output_tokens, train_mode t = self._sample_t(x0.shape[0], current_accumulation_step) dt = self._compute_dt() t = torch.clip(t + dt, 0, 1) gamma_t = self.gamma_min + t * (self.gamma_max - self.gamma_min) gamma_s = self.gamma_min + ( t - dt) * (self.gamma_max - self.gamma_min) usdm_alpha_t = self._gamma_to_alphat_integral(gamma_t) usdm_alpha_t = usdm_alpha_t.unsqueeze(-1) assert usdm_alpha_t.ndim == 2 usdm_alpha_s = self._gamma_to_alphat_integral(gamma_s) usdm_alpha_s = usdm_alpha_s.unsqueeze(-1) assert usdm_alpha_s.ndim == 2 xt, xs = self._sample_trajectory(x0, gamma_t, gamma_s) xt_discrete = xt.argmax(-1) xs_discrete = xs.argmax(-1) log_x_theta_student = self.forward( xt_discrete, sigma=self._sigma_from_alphat(usdm_alpha_t)) log_x_theta_teacher = self._teacher_logits( xs_discrete, sigma=self._sigma_from_alphat(usdm_alpha_s)) if self.config.training.loss_precision == 'float64': log_x_theta_student = log_x_theta_student.to(torch.float64) log_x_theta_teacher = log_x_theta_teacher.to(torch.float64) if self.loss_type == 'kl-fwd': return (log_x_theta_teacher.exp() * ( log_x_theta_teacher - log_x_theta_student)).sum(-1) elif self.loss_type == 'kl-bwd': return (log_x_theta_student.exp() * ( log_x_theta_student - log_x_theta_teacher)).sum(-1) def training_step(self, batch, batch_idx): self.log(name='dt', value=self._compute_dt(), on_step=True, on_epoch=False, sync_dist=True) return super().training_step(batch, batch_idx) ================================================ FILE: configs/algo/ar.yaml ================================================ name: ar backbone: dit parameterization: ar time_conditioning: False causal_attention: True # Irrelevant flags T: 0 subs_masking: False ignore_bos: False ================================================ FILE: configs/algo/d3pm.yaml ================================================ name: d3pm backbone: dit # dit / dimamba parameterization: mean time_conditioning: True T: 1000 subs_masking: False # True / False causal_attention: False ignore_bos: False loss_type: elbo # elbo, low_var ================================================ FILE: configs/algo/distillation.yaml ================================================ name: distillation backbone: dit # dit / dimamba / hf_dit parameterization: mean time_conditioning: True subs_masking: False causal_attention: False gumbel_tau_log10_start: -1 gumbel_tau_log10_end: -1 curriculum_start: -1 curriculum_end: -1 integral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl loss_type: kl-bwd # kl-fwd, kl-bwd, posterior update_teacher_every: 10_000 ignore_bos: False T: 64 gamma_min: -4 gamma_max: -1 teacher_ema: False posterior_loss_weight: 0.0 linear_growth_dt: False linear_growth_min: 0.001 linear_growth_max: 0.25 ================================================ FILE: configs/algo/duo.yaml ================================================ name: duo backbone: dit # dit / dimamba / hf_dit parameterization: mean time_conditioning: True T: 0 # 0 (continuous time) / 1000 subs_masking: False causal_attention: False loss_type: elbo ignore_bos: False # Curriculum arguments curriculum: # Simple is the original duo curriculum, poly9 is the one we use in duo2 for speed mode: simple # simple / efficient_cached / series / sigmoid / sigmoid-edge-corrected / poly3 / poly5 / poly7 / poly9 # Arguments for the simple & efficient variant gumbel_tau_log10_start: -1.0 gumbel_tau_log10_end: -2.0 start: 100_000 end: 200_000 gamma_min: -3.5 gamma_max: -1.75 # Argument for the simple and efficient_cached variants integral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl # Arguments for the efficient variant only top_k: -1 cache_dir: './transform_approx_coefficients' n_series_terms: 150 ================================================ FILE: configs/algo/duo_base.yaml ================================================ name: duo_base backbone: dit # dit / dimamba / hf_dit parameterization: mean time_conditioning: True T: 0 # 0 (continuous time) / 1000 subs_masking: False causal_attention: False ignore_bos: False loss_type: elbo # elbo, low_var ================================================ FILE: configs/algo/mdlm.yaml ================================================ name: mdlm backbone: dit # dit / dimamba / hf_dit parameterization: subs time_conditioning: False T: 0 # 0 (continuous time) / 1000 subs_masking: False causal_attention: False ignore_bos: False loss_type: elbo ================================================ FILE: configs/algo/ot-finetune.yaml ================================================ name: ot-finetune backbone: dit # dit / dimamba / hf_dit parameterization: mean time_conditioning: True T: 0 # 0 (continuous time) / 1000 subs_masking: False causal_attention: False ignore_bos: False delta_ts: 0.5 ================================================ FILE: configs/algo/sedd.yaml ================================================ name: sedd backbone: dit # dit / dimamba parameterization: score time_conditioning: True T: 0 # 0 (continuous time) / 1000 subs_masking: False causal_attention: False ignore_bos: False loss_type: elbo # elbo, low_var ================================================ FILE: configs/callbacks/checkpoint_every_n_steps.yaml ================================================ checkpoint_every_n_steps: _target_: lightning.pytorch.callbacks.ModelCheckpoint save_top_k: -1 # Do not save any "best" models; this callback is being used to save every n train steps save_last: True # save model as ${save_dir}/checkpoints/last.ckpt dirpath: ${checkpointing.save_dir}/checkpoints verbose: True auto_insert_metric_name: False every_n_train_steps: 500 ================================================ FILE: configs/callbacks/checkpoint_monitor.yaml ================================================ checkpoint_monitor: _target_: lightning.pytorch.callbacks.ModelCheckpoint monitor: val/nll # name of the logged metric which determines when model is improving mode: min # can be "max" or "min" save_top_k: 1 # save k best models (determined by above metric) save_last: False # True = additionally always save model from last epoch dirpath: ${checkpointing.save_dir}/checkpoints filename: best auto_insert_metric_name: False verbose: True ================================================ FILE: configs/callbacks/grad_record.yaml ================================================ grad_record: _target_: utils.GradientInspectionCallback num_grads_log: 4 ================================================ FILE: configs/callbacks/learning_rate_monitor.yaml ================================================ learning_rate_monitor: _target_: lightning.pytorch.callbacks.LearningRateMonitor logging_interval: step ================================================ FILE: configs/config.yaml ================================================ defaults: - _self_ - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor] - /data: openwebtext - /model: small - /strategy: ddp - /noise: log-linear - /lr_scheduler: constant_warmup - /prior: none - /algo: duo_base mode: train # train / ppl_eval / sample_eval seed: 1 loader: global_batch_size: 512 eval_global_batch_size: ${.global_batch_size} # Note: batch_size and eval_batch_size are **per machine** batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}} num_workers: ${eval:"len(__import__('os').sched_getaffinity(0))"} pin_memory: True sampling: predictor: ancestral # ancestral_cache (only for MDLM), ancestral, analytic, psi (for psi-samplers) steps: 1000 noise_removal: ancestral # 'ancestral', 'greedy', 'none' use_float64: True p_nucleus: 1.0 num_sample_batches: 2 # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches num_sample_log: 2 semi_ar: False stride_length: 1 num_strides: 1 guid_weight: null # null -> no classifier-free guidance psi: time_profile: linear # linear / linear-constant-linear-alpha / linear-constant-linear-alpha-inv # Note: kappa = 1 -> pure posterior, kappa = 0 -> pure PC # modes: pure-posterior / pure-pc / constant-eta / constant-remdm-eta / max-capped-eta / max-rescale-eta high_mode: pure-posterior middle_mode: pure-posterior low_mode: pure-posterior # Fraction of [0, 1] spent in the high / middle modes (the rest being in low mode) # Example: high_frac=0.2, middle_frac=0.6. Then, # - For t in [1.0, 0.8] -> use high_mode # - For t in [0.8, 0.2] -> use middle_mode # - For t in [0.2, 0.0] -> use low_mode high_frac: 0.2 middle_frac: 0.6 training: ema: 0.9999 antithetic_sampling: True importance_sampling: False sampling_eps: 1e-3 change_of_variables: False loss_precision: 'bf16' # bf16, float32, float64 finetune_path: '' class_dropout_p: 0.1 # only used with class-conditional datasets (eg cifar10) eval: checkpoint_path: '' # Used to evaluate a checkpoint after training. disable_ema: False compute_generative_perplexity: False perplexity_batch_size: 8 compute_perplexity_on_sanity: False gen_ppl_eval_model_name_or_path: gpt2-large # gpt2-large, meta-llama/Llama-2-7b-hf generate_samples: True generated_samples_path: ${cwd:}/samples.json optim: weight_decay: 0 lr: 3e-4 beta1: 0.9 beta2: 0.999 eps: 1e-8 trainer: _target_: lightning.Trainer accelerator: cuda num_nodes: 1 devices: ${device_count:} accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}} gradient_clip_val: 1.0 precision: 'bf16' num_sanity_val_steps: 2 max_steps: 1_000_000 log_every_n_steps: 100 limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run limit_val_batches: 1.0 # validate on full dataset, can be used to toggle quick run val_check_interval: 5000 check_val_every_n_epoch: null # Disable end-of-epoch validation. Instead, validate every trainer.val_check_interval steps. wandb: project: duo group: null job_type: null name: null id: ${.name}_${seed} tags: - ${noise.type} - ${data.train} - ${data.valid} - ${algo.name} hydra: run: dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S} job: chdir: true checkpointing: # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is save_dir: ${cwd:} # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath` resume_from_ckpt: true resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt ================================================ FILE: configs/data/ag_news.yaml ================================================ train: ag_news valid: ag_news modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/cifar10.yaml ================================================ train: cifar10 valid: cifar10 modality: image tokenizer_name_or_path: cifar10 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data num_classes: 10 # 10 / null streaming: False size: 1024 length: 3072 wrap: False insert_train_eos: False insert_valid_eos: False insert_train_spacial: False ================================================ FILE: configs/data/fineweb-edu.yaml ================================================ train: HuggingFaceFW/fineweb-edu valid: openwebtext-valid #wikitext103 modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: True insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/lambada.yaml ================================================ train: lambada valid: lambada modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/lm1b-gpt2.yaml ================================================ train: lm1b valid: lm1b modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/lm1b-streaming.yaml ================================================ train: lm1b valid: lm1b modality: text tokenizer_name_or_path: bert-base-uncased cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: False streaming: True insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/lm1b-wrap.yaml ================================================ train: lm1b valid: lm1b modality: text tokenizer_name_or_path: bert-base-uncased cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/lm1b.yaml ================================================ train: lm1b valid: lm1b modality: text tokenizer_name_or_path: bert-base-uncased cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: False streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/openwebtext-split.yaml ================================================ train: openwebtext-train valid: openwebtext-valid modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/openwebtext-streaming.yaml ================================================ train: openwebtext valid: wikitext103 modality: text tokenizer_name_or_path: gpt2 cache_dir: /tmp/data wrap: True streaming: True insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/openwebtext.yaml ================================================ train: openwebtext valid: wikitext103 modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/ptb.yaml ================================================ train: ptb valid: ptb modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/scientific_papers_arxiv.yaml ================================================ train: scientific_papers_arxiv valid: scientific_papers_arxiv modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/scientific_papers_pubmed.yaml ================================================ train: scientific_papers_pubmed valid: scientific_papers_pubmed modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/synthetic.yaml ================================================ train: synthetic valid: synthetic modality: text tokenizer_name_or_path: synthetic cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: True insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/text8-crop.yaml ================================================ # TODO: When using this dataset, set model.length = 256 to match D3PM setup train: text8-crop valid: text8 modality: text tokenizer_name_or_path: text8 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/text8.yaml ================================================ # TODO: When using this dataset, set model.length = 256 to match D3PM setup train: text8 valid: text8 modality: text tokenizer_name_or_path: text8 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/wikitext103.yaml ================================================ train: wikitext103 valid: wikitext103 modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/data/wikitext2.yaml ================================================ train: wikitext2 valid: wikitext2 modality: text tokenizer_name_or_path: gpt2 cache_dir: /share/kuleshov/ssahoo/textdiffusion/data wrap: True streaming: False insert_train_eos: True insert_valid_eos: True ================================================ FILE: configs/lr_scheduler/constant_warmup.yaml ================================================ _target_: transformers.get_constant_schedule_with_warmup num_warmup_steps: 2500 ================================================ FILE: configs/lr_scheduler/cosine_decay_warmup.yaml ================================================ _target_: utils.CosineDecayWarmupLRScheduler t_in_epochs: False t_initial: ${eval:${trainer.max_steps}-${.warmup_t}} warmup_prefix: True warmup_lr_init: 1e-6 warmup_t: ${eval:0.1*${trainer.max_steps}} lr_min: 1e-6 ================================================ FILE: configs/lr_scheduler/step_scheduler.yaml ================================================ _target_: torch.optim.lr_scheduler.LambdaLR lr_lambda: _target_: utils.LRHalveScheduler warmup_steps: 500 n_halve_steps: 10_000 ================================================ FILE: configs/model/medium.yaml ================================================ name: medium type: ddit hidden_size: 1024 cond_dim: 128 length: 1024 n_blocks: 24 n_heads: 16 scale_by_sigma: True dropout: 0.1 tie_word_embeddings: False vocab_lookup: True ================================================ FILE: configs/model/small.yaml ================================================ name: small type: ddit hidden_size: 768 cond_dim: 128 length: 1024 n_blocks: 12 n_heads: 12 scale_by_sigma: True dropout: 0.1 tie_word_embeddings: False vocab_lookup: True ================================================ FILE: configs/model/tiny-dimamba.yaml ================================================ name: tiny type: dimamba hidden_size: 512 cond_dim: 128 length: 1024 n_blocks: 14 n_heads: 8 scale_by_sigma: True dropout: 0.1 temb_strategy: adaln tie_word_embeddings: False ================================================ FILE: configs/model/tiny.yaml ================================================ name: tiny type: ddit hidden_size: 256 cond_dim: 128 length: 1024 n_blocks: 8 n_heads: 8 scale_by_sigma: True dropout: 0.1 tie_word_embeddings: False vocab_lookup: True ================================================ FILE: configs/model/unet.yaml ================================================ name: unet type: unet ch: 128 num_res_blocks: 2 num_scales: 4 ch_mult: [1, 2, 2, 2] input_channels: 3 output_channels: -1 # determined by vocab_size scale_count_to_put_attn: 1 # at 16 res data_min_max: [0, 255] # No need currently dropout: 0.1 skip_rescale: True time_conditioning: True # Whether to add in time embeddings time_scale_factor: 1000 time_embed_dim: ${.ch} fix_logistic: False size: ${data.size} cond_dim: ${.ch} length: ${data.length} ================================================ FILE: configs/noise/cosine.yaml ================================================ type: cosine eps: 1e-3 ================================================ FILE: configs/noise/log-linear.yaml ================================================ type: log-linear eps: 1e-3 ================================================ FILE: configs/prior/none.yaml ================================================ type: none latent_width: 0 latent_height: 0 ================================================ FILE: configs/strategy/ddp.yaml ================================================ _target_: lightning.pytorch.strategies.DDPStrategy find_unused_parameters: false ================================================ FILE: configs/strategy/fsdp.yaml ================================================ # TODO(yair): Currenly not compatible with grad clipping _target_: lightning.pytorch.strategies.FSDPStrategy sharding_strategy: SHARD_GRAD_OP ================================================ FILE: dataloader.py ================================================ import functools import itertools import json import math import os import re import shutil import typing import urllib import zipfile from typing import Optional import datasets import einops import fsspec import numpy as np import requests import tokenizers import torch import torchvision from torchvision import transforms as th_transforms import transformers import utils LOGGER = utils.get_logger(__name__) class RawPixelsVisionTokenizer: def __init__(self, vocab_size, image_size, add_mask_token=True, add_special_tokens=True): self.pad_token_id = None self.pad_token = None if add_mask_token: self.mask_token = vocab_size self.mask_token_id = vocab_size self.vocab_size = vocab_size + 1 # mask token else: self.vocab_size = vocab_size if add_special_tokens: self.bos_token_id = vocab_size self.bos_token = vocab_size self.eos_token_id = vocab_size + 1 self.eos_token = vocab_size + 1 # mask token, bos_token, eos_token self.vocab_size = self.vocab_size + 2 else: self.vocab_size = self.vocab_size self.image_size = image_size def __call__(self, x): return x def batch_decode(self, x): x = einops.rearrange(x, 'b (c h w) -> b c h w', c=3, h=self.image_size) x = x.to(dtype=torch.uint8) return x def decode(self, x): x = einops.rearrange(x, '(c h w) -> h w c', c=3, h=self.image_size) x = x.to(dtype=torch.uint8) return x def __len__(self): return self.vocab_size class DiscreteCIFAR10(torch.utils.data.Dataset): def __init__(self, cache_dir, train): self._dataset = torchvision.datasets.CIFAR10( root=cache_dir, train=train, download=True) transforms = [] if train: transforms += [th_transforms.RandomHorizontalFlip()] transforms += [th_transforms.Lambda( lambda x: torch.from_numpy(np.array(x))), th_transforms.Lambda( lambda x: einops.rearrange(x, "h w c -> (c h w)")),] self.transform = th_transforms.Compose(transforms) def __len__(self): return len(self._dataset) def __getitem__(self, index): img, labels = self._dataset[index] img = self.transform(img) attention_mask = torch.ones_like(img) return {'input_ids': img.to(torch.long), 'labels': labels, 'attention_mask': attention_mask} def wt_detokenizer(string): # contractions string = string.replace("s '", "s'") string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) # number separators string = string.replace(" @-@ ", "-") string = string.replace(" @,@ ", ",") string = string.replace(" @.@ ", ".") # punctuation string = string.replace(" : ", ": ") string = string.replace(" ; ", "; ") string = string.replace(" . ", ". ") string = string.replace(" ! ", "! ") string = string.replace(" ? ", "? ") string = string.replace(" , ", ", ") # double brackets string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) # miscellaneous string = string.replace("= = = =", "====") string = string.replace("= = =", "===") string = string.replace("= =", "==") string = string.replace(" " + chr(176) + " ", chr(176)) string = string.replace(" \n", "\n") string = string.replace("\n ", "\n") string = string.replace(" N ", " 1 ") string = string.replace(" 's", "'s") return string def ptb_detokenizer(x): x = x.replace(" 's", "'s") x = x.replace("s ' ", "s' ") x = x.replace(" n't", "n't") x = x.replace(" \n ", "\n") x = x.replace("\\/", "/") for _ in range(10): x = x.replace(" N ", " 1 ") x = x.replace("$ 1", "$1") x = x.replace("# 1", "#1") x = x.replace("", "?") return x def lm1b_detokenizer(x): x = x.replace('http : / / ', 'http://') x = x.replace('https : / / ', 'https://') x = re.sub(r' \'(\w+)', r"'\1", x) x = re.sub(r' (\w+) \. ', r' \1. ', x) x = re.sub(r' (\w+) \.$', r' \1.', x) x = x.replace(' ? ', '? ') x = re.sub(r' \?$', '?', x) x = x.replace(' ! ', '! ') x = re.sub(r' \!$', '!', x) x = x.replace(' , ', ', ') x = x.replace(' : ', ': ') x = x.replace(' ; ', '; ') x = x.replace(' / ', '/') x = re.sub(r'\" ([^\"]+) \"', r'"\1"', x) x = re.sub(r'\' ([^\']+) \'', r"'\1'", x) x = re.sub(r'\( ([^\(\)]+) \)', r"(\1)", x) x = re.sub(r'\[ ([^\[\]]+) \]', r"[\1]", x) x = x.replace('$ ', '$') x = x.replace('£ ', '£') return x def lambada_detokenizer(text): text = text.replace("“", '"') text = text.replace("”", '"') return '\n'+text.strip() def scientific_papers_detokenizer(x): x = wt_detokenizer(x) x = lm1b_detokenizer(x) return x class SyntheticTokenizer( transformers.PreTrainedTokenizer): def __init__( self, vocab_size, bos_token="[BOS]", eos_token="[EOS]", sep_token=None, cls_token=None, pad_token=None, mask_token=None, unk_token=None, **kwargs): self.tokens = [] for i in range (vocab_size - 2): # appending space for readability self.tokens.append(str(i) + " ") self._vocab_str_to_int = { '[BOS]': vocab_size - 2, '[EOS]': vocab_size - 1, ** {ch: i for i, ch in enumerate(self.tokens)}} self._vocab_int_to_str = { v: k for k, v in self._vocab_str_to_int.items()} super().__init__( bos_token=bos_token, eos_token=eos_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, unk_token=unk_token, **kwargs) @property def vocab_size(self) -> int: return len(self._vocab_str_to_int) def _tokenize(self, text: str, **kwargs) -> typing.List[str]: return list(text.lower()) def _convert_token_to_id(self, token: str) -> int: return self._vocab_str_to_int.get( token, self._vocab_str_to_int['[UNK]']) def _convert_id_to_token(self, index: int) -> str: return self._vocab_int_to_str[index] def convert_tokens_to_string(self, tokens): return ''.join(tokens) def get_vocab(self) -> typing.Dict[str, int]: return self._vocab_str_to_int def _generate_synthetic_data(dataset_size, seq_len, vocab_size): dataset = np.zeros((dataset_size, seq_len), dtype=int) # tokens representing sequence boundary dataset[:, 0] = vocab_size - 2 # bos dataset[:, -1] = vocab_size - 1 # eos for i in range(dataset_size): # sample from 0, 1, ..., vocab_size - 3 temp = np.random.randint(vocab_size - 2) for j in reversed(range(1, seq_len - 1)): dataset[i, j] = temp if temp != 0: temp = temp // 4 else: temp = np.random.randint(vocab_size - 2) return dataset def generate_synthetic_dataset(train_dataset_size, validation_dataset_size, seq_len, vocab_size): np.random.seed(42) train_data = torch.from_numpy( _generate_synthetic_data(train_dataset_size, seq_len, vocab_size)) train_dataset = datasets.Dataset.from_dict({ 'input_ids': train_data, 'attention_mask': torch.ones_like(train_data), }) train_dataset.set_format(type='torch') np.random.seed(41) validation_data = torch.from_numpy( _generate_synthetic_data(validation_dataset_size, seq_len, vocab_size)) validation_dataset = datasets.Dataset.from_dict({ 'input_ids': validation_data, 'attention_mask': torch.ones_like(validation_data), }) validation_dataset.set_format(type='torch') return { 'train': train_dataset, 'validation': validation_dataset, } class Text8Tokenizer(transformers.PreTrainedTokenizer): def __init__( self, bos_token='[BOS]', eos_token='[EOS]', sep_token='[SEP]', cls_token='[CLS]', pad_token='[PAD]', mask_token='[MASK]', unk_token='[UNK]', **kwargs): self.characters = list('abcdefghijklmnopqrstuvwxyz ') self._vocab_str_to_int = { '[CLS]': 0, '[SEP]': 1, '[BOS]': 2, '[EOS]': 3, '[MASK]': 4, '[PAD]': 5, '[RESERVED]': 6, '[UNK]': 7, ** {ch: i + 8 for i, ch in enumerate(self.characters)}} self._vocab_int_to_str = { v: k for k, v in self._vocab_str_to_int.items()} super().__init__( bos_token=bos_token, eos_token=eos_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, unk_token=unk_token, **kwargs) @property def vocab_size(self) -> int: return len(self._vocab_str_to_int) def _tokenize(self, text: str, **kwargs) -> typing.List[str]: return list(text.lower()) def _convert_token_to_id(self, token: str) -> int: return self._vocab_str_to_int.get( token, self._vocab_str_to_int['[UNK]']) def _convert_id_to_token(self, index: int) -> str: return self._vocab_int_to_str[index] def convert_tokens_to_string(self, tokens): return ''.join(tokens) def get_vocab(self) -> typing.Dict[str, int]: return self._vocab_str_to_int def get_lambada_test_dataset(): url = "https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl" def read_jsonl_to_list(url): response = requests.get(url, stream=True) data_list = [] # Process each line in the response content for line in response.iter_lines(decode_unicode=True): if line: data = json.loads(line) data_list.append(data) return data_list lambada_data = read_jsonl_to_list(url) dataset = datasets.Dataset.from_list(lambada_data) return dataset def get_text8_dataset(cache_dir, max_seq_length=256, drop_last=True, crop_train=False): """Adapted from: https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344 Args: cache_dir: str, path to cache directory. max_seq_length: int, maximum length of sequences. (default: 256, as in D3PM codebase.) drop_last: bool, whether to drop the last incomplete batch. (default: True, as in D3PM codebase.) crop_train: bool, whether to subsample contiguous subsequences from training example. serves to make sure transformer models with absolute position embeddings do not have incorrect position-wise marginals. (default: False, but necessary to match D3PM AR) Returns: dataset: dataset.DatasetDict, with keys 'train', 'valid', 'test'. """ url = 'http://mattmahoney.net/dc/text8.zip' if not crop_train: cache_dir = f'{cache_dir}/text8' else: cache_dir = f'{cache_dir}/text8-crop-train' split_names = ['train', 'validation', 'test'] if not all([ utils.fsspec_exists(os.path.join(cache_dir, split)) for split in split_names ]): # Check if raw data exists raw_cache_dir = os.path.join(cache_dir, 'raw_data') if not all([ utils.fsspec_exists( os.path.join(raw_cache_dir, f'text8.{split}.txt')) for split in split_names ]): if not utils.fsspec_exists( os.path.join(raw_cache_dir, 'text8.zip')): utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True) LOGGER.info('Downloading text8 from URL {}.'.format(url)) with (urllib.request.urlopen(url) as in_stream, open(os.path.join(raw_cache_dir, 'text8.zip'), 'wb') as out_file): shutil.copyfileobj(in_stream, out_file) with fsspec.open( os.path.join(raw_cache_dir, 'text8.zip'), 'rb') as f: rawdata = zipfile.ZipFile(f).read( 'text8').decode('utf-8') # Splits taken from D3PM codebase splits = { 'train': rawdata[:90000000], 'validation': rawdata[90000000: 95000000], 'test': rawdata[95000000:], } for split, data in splits.items(): _path = os.path.join(raw_cache_dir, f'text8.{split}.txt') with fsspec.open(_path, 'w') as f: f.write(data) else: splits = {} for split in split_names: _path = os.path.join(raw_cache_dir, f'text8.{split}.txt') with fsspec.open(_path, 'r') as f: splits[split] = f.read() # Chunk and save as datasets.DatasetDict def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i:i + n] dataset_dict = {} for k, v in splits.items(): if k == 'train' and crop_train == True: chunk_size = 2 * max_seq_length else: chunk_size = max_seq_length text = list(chunks(v, chunk_size)) if drop_last and len(text[-1]) < chunk_size: text = text[:-1] dataset_dict[k] = datasets.Dataset.from_dict({'text': text}) dataset = datasets.DatasetDict(dataset_dict) dataset.save_to_disk(cache_dir) else: dataset = datasets.load_from_disk(cache_dir) return dataset def _group_texts(examples, block_size, bos, eos): # Concatenate all texts. concatenated_examples = list(itertools.chain(* examples['input_ids'])) total_length = len(concatenated_examples) # TODO(yair): look into not dropping the remainder but rather padding it. # We drop the small remainder, and if the total_length < block_size - 2 # we exclude this batch and return an empty dict. # We could add padding if the model supported it instead of # this drop, you can customize this part to your needs. new_block_size = block_size - 2 # [BOS] and [EOS] to be added total_length = (total_length // new_block_size) * new_block_size # Split by chunks of max_len. result = {} _values = [] _attn_masks = [] for i in range(0, total_length, new_block_size): _values.append( [bos] + concatenated_examples[i : i + new_block_size] + [eos]) _attn_masks.append(torch.ones(block_size)) result['input_ids'] = _values result['attention_mask'] = _attn_masks return result def get_dataset(dataset_name, tokenizer, wrap, mode, cache_dir, insert_eos=True, block_size=1024, num_proc=len(os.sched_getaffinity(0)), streaming=False, revision : Optional[str]=None): if dataset_name == 'cifar10': assert mode in ('train', 'validation') return DiscreteCIFAR10(cache_dir=cache_dir, train=mode=='train') eos_tag = '' if not insert_eos: eos_tag = '_eosFalse' if wrap: filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped{eos_tag}.dat' else: filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped{eos_tag}.dat' _path = os.path.join(cache_dir, filename) if utils.fsspec_exists(_path): LOGGER.info(f'Loading data from: {_path}') return datasets.load_from_disk(_path).with_format('torch') LOGGER.info(f'Generating new data at: {_path}') LOGGER.info(f'{streaming=}') crop_train = dataset_name == 'text8-crop' if mode == 'train' and crop_train: # double block size for sub-sampling block_size *= 2 if dataset_name == 'wikitext103': dataset = datasets.load_dataset( 'wikitext', name='wikitext-103-raw-v1', cache_dir=cache_dir, revision=revision) elif dataset_name == 'wikitext2': dataset = datasets.load_dataset( 'wikitext', name='wikitext-2-raw-v1', cache_dir=cache_dir, revision=revision) elif dataset_name == 'ptb': dataset = datasets.load_dataset( 'ptb_text_only', cache_dir=cache_dir, revision=revision) elif dataset_name == 'lambada': dataset = get_lambada_test_dataset() elif dataset_name == 'text8': assert wrap assert revision is None dataset = get_text8_dataset( cache_dir, max_seq_length=block_size) elif dataset_name == 'text8-crop': assert revision is None dataset = get_text8_dataset( cache_dir, max_seq_length=block_size, crop_train=True) elif dataset_name == 'openwebtext-train': dataset = datasets.load_dataset( 'openwebtext', split='train[:-100000]', cache_dir=cache_dir, revision=revision, streaming=False, num_proc=num_proc) elif dataset_name == 'openwebtext-valid': dataset = datasets.load_dataset( 'openwebtext', split='train[-100000:]', cache_dir=cache_dir, revision=revision, streaming=False, num_proc=num_proc) elif dataset_name == 'scientific_papers_arxiv': dataset = datasets.load_dataset( 'scientific_papers', 'arxiv', cache_dir=cache_dir, streaming=streaming, revision=revision) elif dataset_name == 'scientific_papers_pubmed': dataset = datasets.load_dataset( 'scientific_papers', 'pubmed', cache_dir=cache_dir, streaming=streaming, revision=revision) elif dataset_name == 'ag_news': dataset = datasets.load_dataset( 'ag_news', cache_dir=cache_dir, streaming=streaming, revision=revision) elif dataset_name == 'synthetic': assert streaming assert wrap # i.e., no pad tokens dataset = generate_synthetic_dataset( train_dataset_size=100000, validation_dataset_size=1024, seq_len=32, vocab_size=256, ) else: dataset = datasets.load_dataset( dataset_name, cache_dir=cache_dir, streaming=streaming, revision=revision) if dataset_name in ['lambada', 'openwebtext-train', 'openwebtext-valid']: data = dataset else: data = dataset[mode] if dataset_name == 'synthetic': # already tokenized, no further actions required return data if dataset_name.startswith('wikitext'): detokenizer = wt_detokenizer elif dataset_name == 'ptb': detokenizer = ptb_detokenizer elif dataset_name == 'lm1b': detokenizer = lm1b_detokenizer elif dataset_name == 'lambada': detokenizer = lambada_detokenizer elif dataset_name.startswith('scientific_papers'): detokenizer = scientific_papers_detokenizer else: detokenizer = None def _apply_detokenizer(detokenizer): def detok(text): for i, t in enumerate(text, 0): text[i] = detokenizer(t) return text return detok EOS = tokenizer.encode(tokenizer.eos_token)[0] BOS = tokenizer.encode(tokenizer.bos_token)[0] def preprocess_and_tokenize(example): if dataset_name == 'ptb': text = example['sentence'] elif 'scientific_papers' in dataset_name: text = example['article'] else: text = example['text'] if detokenizer is not None: text = _apply_detokenizer(detokenizer)(text) tokenizer.padding_side = 'right' tokenizer.truncation_side = 'right' if wrap: tokens = tokenizer(text, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False) if insert_eos: tokens = {'input_ids': [t + [EOS] for t in tokens['input_ids']]} # Still missing BOS, but will be added in group_texts else: tokens = tokenizer(text, max_length=block_size, padding='max_length', truncation=True, add_special_tokens=True, return_attention_mask=True, return_token_type_ids=True) return tokens if streaming: tokenized_dataset = data.map( preprocess_and_tokenize, batched=True) else: tokenized_dataset = data.map( preprocess_and_tokenize, batched=True, num_proc=num_proc, load_from_cache_file=True, desc='Tokenizing') if dataset_name == 'ptb': tokenized_dataset = tokenized_dataset.remove_columns( 'sentence') elif 'scientific_papers' in dataset_name: tokenized_dataset = tokenized_dataset.remove_columns([ 'article', 'abstract', 'section_names']) elif dataset_name == 'ag_news': tokenized_dataset = tokenized_dataset.remove_columns( ['text', 'label']) else: tokenized_dataset = tokenized_dataset.remove_columns( 'text') if not wrap: if not streaming: tokenized_dataset.save_to_disk(_path) return tokenized_dataset.with_format('torch') group_texts = functools.partial( _group_texts, block_size=block_size, bos=BOS, eos=EOS) if streaming: chunked_dataset = tokenized_dataset.map( group_texts, batched=True) else: chunked_dataset = tokenized_dataset.map( group_texts, batched=True, num_proc=num_proc, load_from_cache_file=True, desc='Grouping') chunked_dataset.save_to_disk(_path) chunked_dataset = chunked_dataset.with_format('torch') return chunked_dataset def get_tokenizer(config): if config.data.tokenizer_name_or_path == 'text8': tokenizer = Text8Tokenizer() elif config.data.tokenizer_name_or_path == 'bert-base-uncased': tokenizer = transformers.BertTokenizer.\ from_pretrained('bert-base-uncased') elif config.data.tokenizer_name_or_path == 'synthetic': tokenizer = SyntheticTokenizer(vocab_size=256) elif config.data.tokenizer_name_or_path == 'cifar10': return RawPixelsVisionTokenizer( vocab_size=256, image_size=32, add_special_tokens=False, add_mask_token='mdlm' in config.algo.name) else: tokenizer = transformers.AutoTokenizer.from_pretrained( config.data.tokenizer_name_or_path) if (isinstance(tokenizer, transformers.GPT2TokenizerFast) or isinstance(tokenizer, transformers.GPT2Tokenizer)): tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing( (tokenizer.bos_token, tokenizer.bos_token_id), (tokenizer.eos_token, tokenizer.eos_token_id)) # For wrapped batches: # [BOS] sent1 [EOS] sent2-fragment [EOS] # [BOS] sent2-fragment [EOS] sent3 [EOS] if tokenizer.bos_token is None: if tokenizer.cls_token is None: raise AttributeError( 'Tokenizer must have a bos_token or ' f'cls_token: {tokenizer}') tokenizer.bos_token = tokenizer.cls_token if tokenizer.eos_token is None: if tokenizer.sep_token is None: raise AttributeError( 'Tokenizer must have a eos_token ' f'or sep_token: {tokenizer}') tokenizer.eos_token = tokenizer.sep_token if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) return tokenizer def get_dataloaders(config, tokenizer, skip_train=False, skip_valid=False, valid_seed=None): num_gpus = torch.cuda.device_count() assert (config.loader.global_batch_size == (config.loader.batch_size * config.trainer.num_nodes * num_gpus * config.trainer.accumulate_grad_batches)) if config.loader.global_batch_size % ( num_gpus * config.trainer.accumulate_grad_batches) != 0: raise ValueError( f'Train Batch Size {config.training.batch_size}' f'not divisible by {num_gpus} gpus with accumulation ' f'{config.trainer.accumulate_grad_batches}.') if config.loader.eval_global_batch_size % num_gpus != 0: raise ValueError( f'Eval Batch Size for {config.eval.batch_size} ' f'not divisible by {num_gpus}.') if skip_train: train_set = None else: train_set = get_dataset( config.data.train, tokenizer, mode='train', wrap=config.data.wrap, insert_eos=config.data.insert_train_eos, cache_dir=config.data.cache_dir, block_size=config.model.length, streaming=config.data.streaming, num_proc=config.loader.num_workers, revision=config.data.get("train_revision", None)) if config.data.valid in ['text8', 'lm1b', 'ag_news']: validation_split = 'test' else: validation_split = 'validation' if skip_valid: valid_set = None else: valid_set = get_dataset( config.data.valid, tokenizer, wrap=config.data.wrap, mode=validation_split, cache_dir=config.data.cache_dir, insert_eos=config.data.insert_valid_eos, block_size=config.model.length, streaming=config.data.streaming, num_proc=config.loader.num_workers, revision=config.data.get("valid_revision", None)) if skip_train: train_loader = None else: train_loader = torch.utils.data.DataLoader( train_set, batch_size=config.loader.batch_size, num_workers=config.loader.num_workers, pin_memory=config.loader.pin_memory, shuffle=not config.data.streaming, persistent_workers=True) train_loader.tokenizer = tokenizer if skip_valid: valid_loader = None else: if valid_seed is None: shuffle_valid = False generator = None else: shuffle_valid = True generator = torch.Generator().manual_seed(valid_seed) valid_loader = torch.utils.data.DataLoader( valid_set, batch_size=config.loader.eval_batch_size, num_workers=config.loader.num_workers, pin_memory=config.loader.pin_memory, shuffle=shuffle_valid, generator=generator) # Will be used in generative perplexity calculation valid_loader.tokenizer = tokenizer return train_loader, valid_loader # Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py class RandomFaultTolerantSampler(torch.utils.data.RandomSampler): def __init__(self, *args, generator=None, **kwargs): # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, # which should be reproducible if pl.seed_everything was called beforehand. # This means that changing the seed of the experiment will also change the # sampling order. if generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator().manual_seed(seed) kwargs.pop('shuffle', None) super().__init__(*args, generator=generator, **kwargs) self.counter = 0 self.restarting = False def state_dict(self): return {'random_state': self.generator.get_state(), 'counter': self.counter} def load_state_dict(self, state_dict): self.generator.set_state(state_dict.get('random_state')) self.counter = state_dict['counter'] # self.start_counter = self.counter self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. def __iter__(self) -> typing.Iterator[int]: n = len(self.data_source) self.state = self.generator.get_state() indices = torch.randperm(n, generator=self.generator).tolist() if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False for index in indices: self.counter += 1 yield index self.counter = 0 class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.counter = 0 self.restarting = False def state_dict(self): return {'epoch': self.epoch, 'counter': self.counter} def load_state_dict(self, state_dict): self.epoch = state_dict['epoch'] self.counter = state_dict['counter'] self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil( padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples if not self.restarting: self.counter = 0 else: indices = indices[self.counter:] self.restarting = False for index in indices: self.counter += 1 yield index self.counter = 0 ================================================ FILE: discrete_diffusion_harness.py ================================================ import torch from omegaconf import OmegaConf from lm_eval.api.model import LM from lm_eval.api.registry import register_model from lm_eval.__main__ import cli_evaluate from lm_eval.api.instance import Instance from datasets import Dataset from tqdm import tqdm import numpy as np import algo import dataloader """ # Instructions on running the eval ## What is the script doing? The script evaluates (or approximates) the log-likelihood of prefix + suffix. The script will load a checkpoint, and depending on the config, use the corresponding class (e.g. mdlm, duo, ar, etc). - For MCQ, the log-likelihood of all continuations given the same prefix is evaluated. The most likely continuation is selected as the "correct" answer according to the model. - For lambada_openai, the model is correct if the true continuation is generated as the argmax. For diffusion, we inject noise in the continuation, and check whether the true answer computed in a single forward pass is the most likely. This is naturally favoring AR models, as they run one forward pass per token, while diffusion use a single pass for all tokens for simplicity. Therefore, I usually only compare on MCQ since it is more fair. To run the script, you need to install the lm-eval-harness package: ``` pip install git+https://github.com/EleutherAI/lm-evaluation-harness ``` ## Tasks that we tested with **MCQ**: arc_easy, arc_challenge, hellaswag, winogrande, boolq, openbookqa, race, social_iqa, mathqa, piqa **Likelihood based**: lambada_openai ## Batch size that fits at the small scale - boolq -> 64 - openbookqa -> 256 - race -> 32 - social_iqa -> 512 - winogrande -> 64 - mathqa -> 64 - lambada_openai -> 64 - arc_easy -> 256 - arc_challenge -> 256 - hellaswag -> 64 - piqa -> 64 ## Important flags --trust_remote_code -> some datasets execute code when loading from huggingface. Without this flag, the script might crash. --batch_size -> max. num elements to use in parallel to eval the likelihood (for diffusion). For simplicity, inputs are NOT padded and batched for AR, though it should be fairly easy to add. --tasks -> one or multiple comma-separated tasks to evaluate on. --model_args -> string (without spaces) that contains the arguments to pass to the evaluator (stuff like checkpoints path, number of MC samples to evaluate the likelihood, path to the sentencepiece tokenizer, etc). --output_path -> path to a json file where the evaluation results will be saved, instead of only being printed in the terminal (they will always be printed). --limit -> debugging flag; limit the number of examples to a fixed amount, instead of using the whole dataset ## Example commands ### Run a single task with an AR model python discrete_diffusion_harness.py \ --tasks arc_easy \ --batch_size 256 \ --model dlm \ --model_args checkpoint_path=/home/username/baselines/ar/1M.ckpt \ --output_path ./harness_results/ar/1M/arc_easy.ckpt ### Run a single task with an MDLM model, and 2048 MC samples to approximate the likelihood python discrete_diffusion_harness.py \ --tasks arc_easy \ --model dlm \ --batch_size 256 \ --model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \ --output_path ./harness_results/mdlm/1M/arc_easy.ckpt ### Debug with 20 examples only python discrete_diffusion_harness.py \ --tasks arc_easy \ --model dlm \ --batch_size 256 \ --model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \ --limit 20 \ --output_path ./harness_results/mdlm/1M/arc_easy.ckpt ### Run the benchmarks from "The Diffusion Duality (Chapter 2)": (Arc-e, Arc-c, HSwag, WinoG, PIQA, MathQA, OQA) -> just change the checkpoint to evaluate mdlm, ar, or duo for task_config in "arc_easy 256" "arc_challenge 256" "hellaswag 64" "winogrande 64" "piqa 64" "mathqa 64" "openbookqa 256"; do task=$(echo $task_config | cut -d' ' -f1) batch_size=$(echo $task_config | cut -d' ' -f2) python discrete_diffusion_harness.py \ --batch_size $batch_size \ --tasks $task \ --model dlm \ --model_args checkpoint_path=/path/to/checkpoint.ckpt \ --output_path ./harness_results/duo/$task.json done """ def requests_to_dataset(config, requests, tokenizer, num_proc): def _tokenize(e): eos_idx = tokenizer.eos_token_id bos_idx = tokenizer.bos_token_id prefix_tokens = tokenizer(e['prefix'], return_attention_mask=False, add_special_tokens=False )['input_ids'] target_tokens = tokenizer(e['target'], return_attention_mask=False, add_special_tokens=False )['input_ids'] prefix_tokens = [bos_idx] + prefix_tokens target_tokens = target_tokens + [eos_idx] return { 'prefix_text': e['prefix'], 'target_text': e['target'], 'prefix': prefix_tokens, 'target': target_tokens, } ds = [] ds = [{'prefix': req.args[0], 'target': req.args[1]} for req in requests] ds = Dataset.from_list(ds) ds = ds.map(_tokenize, num_proc=num_proc) ds = ds.with_format('torch') seq_lenths = [len(x['prefix']) + len(x['target']) for x in ds] num_larger = len([x for x in seq_lenths if x > config.model.length]) if num_larger > 0: print(f'\033[91mThere are some examples that are longer ' f'than the context length, they will be ignored ' f'during evaluation. Number of such sequences: ' f'{num_larger}\033[0m') return ds def _eval_suffix_nll_generators(config, module, prefix, suffix, batch_size, num_samples, loss_avg_mode): device = module.device assert num_samples % batch_size == 0 full_sentence = torch.cat([prefix, suffix], dim=-1 ).repeat(batch_size, 1).to(module.device) all_ts = module._sample_t(num_samples, accum_step=None) for idx in range(0, num_samples, batch_size): t = all_ts[idx:idx+batch_size].unsqueeze(-1) dalpha_t, alpha_t = module.noise(t) alpha_t = alpha_t.to(device) sigma = module._sigma_from_alphat(alpha_t) x0 = full_sentence.to(device) # Inject noise xt = module.q_xt(full_sentence, alpha_t).to(device) if loss_avg_mode == 'full': pass # nothing to do elif loss_avg_mode == 'suffix': xt[:, :len(prefix)] = prefix # recompute alpha_t based on number of masked tokens, # for conditioning of the backbone alpha_t = (xt == x0).float().mean(dim=1)[:, None] t = module.noise.get_t_for_alpha(alpha_t) # We need to get dalpha_t for the loss: alpha_t = alpha_t.to(device) sigma = module._sigma_from_alphat(alpha_t) else: raise ValueError(loss_avg_mode) yield xt, x0, t, sigma, alpha_t, dalpha_t def eval_suffix_nll(config, module, prefix, suffix, batch_size, num_samples, loss_avg_mode): if config.algo.name in ('mdlm', 'duo', 'duo_base', 'distillation', 'ot-finetune'): return eval_suffix_nll_diffusion(config, module, prefix, suffix, batch_size, num_samples, loss_avg_mode) elif config.algo.name == 'ar': return eval_suffix_nll_ar(config, module, prefix, suffix, batch_size, num_samples, loss_avg_mode) else: raise ValueError(config.algo.name) def eval_suffix_nll_ar(config, module, prefix, suffix, batch_size, num_samples, loss_avg_mode): x_cat = torch.cat([prefix, suffix[:-1]], dim=-1) x_cat = x_cat.reshape(1, -1).to(module.device) with torch.amp.autocast('cuda', dtype=torch.float32): out = module.backbone(x_cat, sigma=None) out[:, :, module.mask_index] = module.neg_infinity suffix_out = out[:, len(prefix) - 1:, :] suffix_logits = torch.log_softmax(suffix_out, dim=-1) index = suffix[None, :, None].to(module.device) nll = torch.gather(-suffix_logits, dim=-1, index=index).mean() return float(nll.cpu()) def eval_suffix_nll_diffusion(config, module, prefix, suffix, batch_size, num_samples, loss_avg_mode): all_losses = [] generator = _eval_suffix_nll_generators(config, module, prefix, suffix, batch_size, num_samples, loss_avg_mode) for xt, x0, t, sigma, alpha_t, dalpha_t in generator: log_x_theta = module(xt, sigma, labels=None) token_nll = module.nll_per_token(log_x_theta, xt, x0, alpha_t, dalpha_t) if loss_avg_mode == 'full': loss = float(token_nll.mean()) elif loss_avg_mode == 'suffix': loss = float(token_nll[:, len(prefix):].mean()) all_losses.append(loss) return float(np.mean(all_losses)) @register_model("dlm") class DiscreteDiffusionHarness(LM): def __init__(self, pretrained="NONE", max_length=1024, num_mc_samples=1024, batch_size=64, device="cuda", checkpoint_path=None, num_proc=8, loss_avg_mode='full', *args, **kwargs): super().__init__() # Whether to use the full sequence, or suffix only to # approximate the NLL. Full should be the correct way. assert loss_avg_mode in ('full', 'suffix') ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False) config = ckpt['hyper_parameters']['config'] # Backfill missing keys into legacy checkpoints if not hasattr(config.training, 'class_dropout_p'): OmegaConf.set_struct(config, False) config.training.class_dropout_p = 0.0 OmegaConf.set_struct(config, True) self.tokenizer = dataloader.get_tokenizer(config) if config.algo.name == 'mdlm': self.model = algo.MDLM(config, self.tokenizer) elif config.algo.name in ('duo', 'duo_base', 'distillation', 'ot-finetune'): self.model = algo.DUO_BASE(config, self.tokenizer) elif config.algo.name == 'ar': self.model = algo.AR(config, self.tokenizer) else: raise ValueError(f'Implement for {config.algo.name}') self.config = config self.num_proc = num_proc self.num_mc_samples = num_mc_samples self.batch_size = int(batch_size) self.device = device self.loss_avg_mode = loss_avg_mode self.model.load_state_dict(ckpt['state_dict']) self.model.to(device) self.model.eval() def suffix_greedy_prediction(self, prefix, target): if self.config.algo.name == 'mdlm': return self._suffix_greedy_prediction_mdlm(prefix, target) elif self.config.algo.name in ('duo', 'duo_base', 'distillation', 'ot-finetune'): return self._suffix_greedy_prediction_duo_base(prefix, target) elif self.config.algo.name == 'ar': return self._suffix_greedy_prediction_ar(prefix, target) else: raise ValueError(self.config.algo.name) def _suffix_greedy_prediction_ar(self, prefix, target): x_cat = torch.cat([prefix, target[:-1]], dim=-1).reshape(1, -1).to(self.device) # Follows generate_samples in AR (algo.py) out = self.model.backbone(x_cat, sigma=None) out[:, :, self.model.mask_index] = self.model.neg_infinity out = out.log_softmax(-1) preds_suffix = out[:, len(prefix) - 1:, :] greedy_preds = preds_suffix.argmax(-1).flatten() return (greedy_preds.cpu() == target).all().item() def _suffix_greedy_prediction_mdlm(self, prefix, target): mask_idx = self.model.mask_index eos_idx = self.tokenizer.eos_token_id # Note: because of the preprocessing, we know that the # last token is an eos token. noisy_target = [mask_idx] * (len(target) - 1) + [eos_idx] noisy_target = torch.tensor(noisy_target, device=self.device) prefix = prefix.to(self.device) seq = torch.concatenate([prefix, noisy_target], dim=-1).reshape(1, -1) sigma = torch.zeros(size=(seq.shape[0], 1), device=self.device) logits = self.model(seq, sigma, labels=None) assert logits.shape[0] == 1 suffix_logits = logits[0, len(prefix):] target_preds = suffix_logits.argmax(-1).cpu() correct = target_preds == target correct = correct.all() return bool(correct) def _suffix_greedy_prediction_duo_base(self, prefix, target): noisy_suffix = torch.randint( 0, self.model.vocab_size, size=target.shape, dtype=target.dtype, device=self.device) prefix = prefix.to(self.device) # shape: (1, len) noisy_seq = torch.concatenate([prefix, noisy_suffix])[None, :] # Set the EOS token at the end noisy_seq[0, -1] = target[-1] clean_seq = torch.concatenate([prefix, target.to(self.device)])[None, :] alpha_t = (noisy_seq == clean_seq).float().mean(dim=1)[:, None] sigma = self.model._sigma_from_alphat(alpha_t) t = self.model.noise.get_t_for_alpha(alpha_t) logits = self.model(noisy_seq, sigma, labels=None) assert logits.shape[0] == 1 suffix_logits = logits[0, len(prefix):] target_preds = suffix_logits.argmax(-1).cpu() correct = target_preds == target correct = correct.all() return bool(correct) @torch.no_grad() def loglikelihood(self, requests: list[Instance]) \ -> list[tuple[float, bool]]: dataset = requests_to_dataset(self.config, requests, self.tokenizer, self.num_proc) out = [] for elem in tqdm(dataset, 'Computing likelihood...'): prefix = elem['prefix'] target = elem['target'] if len(prefix) + len(target) > self.model.config.model.length: # If the request is too long, skip it. ll = 0.0 is_target_greedy_dec = False out.append((ll, is_target_greedy_dec)) print("SKIPPING") continue ll = -eval_suffix_nll(self.config, self.model, prefix, target, self.batch_size, self.num_mc_samples, self.loss_avg_mode) is_target_greedy_dec = self.suffix_greedy_prediction( prefix, target) out.append((ll, bool(is_target_greedy_dec))) return out def loglikelihood_rolling( self, requests: list[Instance] ) -> list[tuple[float, bool]]: raise NotImplementedError def generate_until(self, context, max_length, stop, **generation_kwargs): raise NotImplementedError if __name__ == "__main__": cli_evaluate() ================================================ FILE: main.py ================================================ import json import os import fsspec import hydra import lightning as L from lightning.fabric import Fabric import omegaconf import rich.syntax import rich.tree import torch from torch.utils.data.distributed import DistributedSampler from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.image.inception import InceptionScore from tqdm import tqdm, trange import algo import dataloader import utils omegaconf.OmegaConf.register_new_resolver( 'cwd', os.getcwd) omegaconf.OmegaConf.register_new_resolver( 'device_count', torch.cuda.device_count) omegaconf.OmegaConf.register_new_resolver( 'eval', eval) omegaconf.OmegaConf.register_new_resolver( 'div_up', lambda x, y: (x + y - 1) // y) def _load_from_checkpoint(diffusion_model, config, tokenizer): if 'hf' in config.algo.backbone: return diffusion_model( config, tokenizer=tokenizer).to('cuda') return diffusion_model.load_from_checkpoint( config.eval.checkpoint_path, tokenizer=tokenizer, config=config) @L.pytorch.utilities.rank_zero_only def _print_config( config: omegaconf.DictConfig, resolve: bool = True, save_cfg: bool = True) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. resolve (bool): Whether to resolve reference fields of DictConfig. save_cfg (bool): Whether to save the configuration tree to a file. """ style = 'dim' tree = rich.tree.Tree('CONFIG', style=style, guide_style=style) fields = config.keys() for field in fields: branch = tree.add(field, style=style, guide_style=style) config_section = config.get(field) branch_content = str(config_section) if isinstance(config_section, omegaconf.DictConfig): branch_content = omegaconf.OmegaConf.to_yaml( config_section, resolve=resolve) branch.add(rich.syntax.Syntax(branch_content, 'yaml')) rich.print(tree) if save_cfg: with fsspec.open( '{}/config_tree.txt'.format( config.checkpointing.save_dir), 'w') as fp: rich.print(tree, file=fp) @L.pytorch.utilities.rank_zero_only def _print_batch(config, train_ds, valid_ds, tokenizer, k=64): for dl_type, dl in [ ('train', train_ds), ('valid', valid_ds)]: print(f'Printing {dl_type} dataloader batch.') batch = next(iter(dl)) print('Batch input_ids.shape', batch['input_ids'].shape) if config.data.modality == 'text': first = batch['input_ids'][0, :k] last = batch['input_ids'][0, -k:] print(f'First {k} tokens:', tokenizer.decode(first)) print('ids:', first) print(f'Last {k} tokens:', tokenizer.decode(last)) print('ids:', last) def _generate_samples(diffusion_model, config, logger, tokenizer): logger.info('Starting Sample Eval.') model = _load_from_checkpoint( diffusion_model=diffusion_model, config=config, tokenizer=tokenizer) model.metrics.gen_ppl.reset() model.metrics.sample_entropy.reset() if config.eval.disable_ema: logger.info('Disabling EMA.') model.ema = None stride_length = config.sampling.stride_length num_strides = config.sampling.num_strides all_samples = [] for _ in trange(config.sampling.num_sample_batches): if config.sampling.semi_ar: _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample( stride_length=stride_length, num_strides=num_strides, dt=1 / config.sampling.steps) text_samples = intermediate_samples[-1] # Note: Samples generated using semi-ar method # need to to be processed before computing generative perplexity # since these samples contain numerous <|endoftext|> tokens # and diffusion.compute_generative_perplexity() discards # any text after the first EOS token. else: samples = model.restore_model_and_sample( num_steps=config.sampling.steps) model.metrics.record_entropy(samples) text_samples = model.tokenizer.batch_decode(samples) model.metrics.record_generative_perplexity( text_samples, config.model.length, model.device) all_samples.extend(list(text_samples)) generative_ppl = 0. entropy = 0. if not config.sampling.semi_ar: generative_ppl = model.metrics.gen_ppl.compute().item() entropy = model.metrics.sample_entropy.compute().item() logger.info(f'Generative perplexity: {generative_ppl}') logger.info(f'Sample entropy: {entropy}') samples_path = config.eval.generated_samples_path with fsspec.open(samples_path, 'w') as f: json.dump({'generative_ppl': generative_ppl, 'entropy': entropy, 'generated_seqs': all_samples}, f, indent=4) logger.info(f'Samples saved at: {samples_path}',) def _eval_ppl(diffusion_model, config, logger, tokenizer): logger.info('Starting Perplexity Eval.') model = _load_from_checkpoint( diffusion_model=diffusion_model, config=config, tokenizer=tokenizer) if config.eval.disable_ema: logger.info('Disabling EMA.') model.ema = None wandb_logger = None if config.get('wandb', None) is not None: wandb_logger = L.pytorch.loggers.WandbLogger( config=omegaconf.OmegaConf.to_object(config), ** config.wandb) callbacks = [] if 'callbacks' in config: for _, callback in config.callbacks.items(): callbacks.append(hydra.utils.instantiate(callback)) trainer = hydra.utils.instantiate( config.trainer, default_root_dir=os.getcwd(), callbacks=callbacks, strategy=hydra.utils.instantiate(config.strategy), logger=wandb_logger) _, valid_ds = dataloader.get_dataloaders( config, tokenizer, skip_train=True, valid_seed=config.seed) trainer.validate(model, valid_ds) def _train(diffusion_model, config, logger, tokenizer): logger.info('Starting Training.') wandb_logger = None if config.get('wandb', None) is not None: wandb_logger = L.pytorch.loggers.WandbLogger( config=omegaconf.OmegaConf.to_object(config), **config.wandb) if (config.checkpointing.resume_from_ckpt and config.checkpointing.resume_ckpt_path is not None and utils.fsspec_exists( config.checkpointing.resume_ckpt_path)): ckpt_path = config.checkpointing.resume_ckpt_path else: ckpt_path = None # Lightning callbacks callbacks = [] if 'callbacks' in config: for _, callback in config.callbacks.items(): callbacks.append(hydra.utils.instantiate(callback)) train_ds, valid_ds = dataloader.get_dataloaders( config, tokenizer) _print_batch(config, train_ds, valid_ds, tokenizer) if config.training.finetune_path != '': assert utils.fsspec_exists(config.training.finetune_path) model = diffusion_model.load_from_checkpoint( config.training.finetune_path, tokenizer=tokenizer, config=config) else: model = diffusion_model(config, tokenizer=valid_ds.tokenizer) trainer = hydra.utils.instantiate( config.trainer, default_root_dir=os.getcwd(), callbacks=callbacks, strategy=hydra.utils.instantiate(config.strategy), logger=wandb_logger) trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path) def _eval_fid(diffusion_model, config, logger, tokenizer): logger.info('Preparing data and model for FID eval.') fabric = Fabric(accelerator=config.trainer.accelerator, devices=config.trainer.devices, num_nodes=config.trainer.num_nodes) fabric.launch() seed = config.seed + fabric.global_rank L.seed_everything(seed) print(f'(Rank {fabric.global_rank}): seed: {seed}') model = _load_from_checkpoint( diffusion_model=diffusion_model, config=config, tokenizer=tokenizer) model.to(fabric.device) if config.eval.disable_ema: logger.info('Disabling EMA.') model.ema = None model._eval_mode() assert config.data.train == 'cifar10', \ 'FID eval only implemented for CIFAR-10' # Like in flow matching papers: FID against train loader, _ = dataloader.get_dataloaders(config, tokenizer=tokenizer, skip_valid=True) sampler = DistributedSampler( loader.dataset, num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=False) loader = torch.utils.data.DataLoader( loader.dataset, batch_size=config.loader.eval_batch_size, sampler=sampler, num_workers=loader.num_workers if hasattr(loader, 'num_workers') else 0, pin_memory=getattr(loader, 'pin_memory', False)) # Check each GPU must generate the same number of images assert len(loader) == len(loader.dataset) // loader.batch_size // fabric.world_size, \ f'{len(loader)=}, {len(loader.dataset)=}, {loader.batch_size=}, {fabric.world_size=}' fid_calculator = FrechetInceptionDistance( normalize=False).to(fabric.device) is_calculator = InceptionScore( normalize=False).to(fabric.device) desc = f'(Rank {fabric.global_rank}) Sampling...' for batch in tqdm(loader, desc=desc): real_samples = batch['input_ids'] # Generate images with labels matching the true data labels = batch['labels'] gen_samples = model.generate_samples( num_samples=real_samples.shape[0], num_steps=config.sampling.steps, labels=labels) # Reshape 1D seq -> 2D image gen_samples = model.tokenizer.batch_decode(gen_samples) real_samples = model.tokenizer.batch_decode( real_samples).to(fabric.device) fid_calculator.update(gen_samples, real=False) fid_calculator.update(real_samples, real=True) is_calculator.update(gen_samples) fabric.barrier() logger.info('Done sampling. Computing FID & IS...') fid = fid_calculator.compute() incep_score = is_calculator.compute() if fabric.global_rank == 0: logger.info(f'FID: {fid}') logger.info(f'IS: {incep_score}') fabric.barrier() @hydra.main(version_base=None, config_path='configs', config_name='config') def main(config): """Main entry point for training.""" L.seed_everything(config.seed) _print_config(config, resolve=True, save_cfg=True) logger = utils.get_logger(__name__) tokenizer = dataloader.get_tokenizer(config) if config.algo.name == 'ar': diffusion_model = algo.AR elif config.algo.name == 'mdlm': diffusion_model = algo.MDLM elif config.algo.name == 'duo_base': diffusion_model = algo.DUO_BASE elif config.algo.name == 'd3pm': diffusion_model = algo.D3PMAbsorb elif config.algo.name == 'sedd': diffusion_model = algo.SEDDAbsorb elif config.algo.name == 'duo': diffusion_model = algo.DUO elif config.algo.name == 'distillation': diffusion_model = algo.Distillation elif config.algo.name == 'ot-finetune': diffusion_model = algo.OptimalTransportFinetune else: raise ValueError( f'Invalid algorithm name: {config.algo.name}') kwargs = {'diffusion_model': diffusion_model, 'config': config, 'tokenizer': tokenizer, 'logger': logger} if config.mode == 'sample_eval': _generate_samples(**kwargs) elif config.mode == 'ppl_eval': _eval_ppl(**kwargs) elif config.mode == 'fid_eval': _eval_fid(**kwargs) else: _train(**kwargs) if __name__ == '__main__': main() ================================================ FILE: metrics.py ================================================ import math import os import typing import torch import torch.nn.functional as F import torchmetrics import transformers LOG2 = math.log(2) class NLL(torchmetrics.aggregation.MeanMetric): def update(self, value:typing.Union[float, torch.Tensor], weight:typing.Union[float, torch.Tensor]=1.0) -> None: """Update state with data. Args: value: Either a float or tensor containing data. Additional tensor dimensions will be flattened weight: Either a float or tensor containing weights for calculating the average. Shape of weight should be able to broadcast with the shape of `value`. Default to `1.0` corresponding to simple harmonic average. """ # broadcast weight to value shape if not isinstance(value, torch.Tensor): value = torch.as_tensor(value, dtype=self.dtype, device=self.device) if (weight is not None and not isinstance(weight, torch.Tensor)): weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) weight = torch.broadcast_to(weight, value.shape) value, weight = self._cast_and_nan_check_input(value, weight) if value.numel() == 0: return self.mean_value += value.sum() self.weight += weight.sum() class BPD(NLL): def compute(self) -> torch.Tensor: """Computes the bits per dimension. Returns: bpd """ return self.mean_value / self.weight / LOG2 class Perplexity(NLL): def compute(self) -> torch.Tensor: """Computes the Perplexity. Returns: Perplexity """ return torch.exp(self.mean_value / self.weight) class Metrics: def __init__(self, gen_ppl_eval_model_name_or_path=None, eval_ppl_batch_size=None) -> None: metrics = torchmetrics.MetricCollection({ 'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity()}) metrics.set_dtype(torch.float64) self.train_nlls = metrics.clone(prefix='train/') self.train_aux = BPD() self.valid_nlls = metrics.clone(prefix='val/') self.valid_aux = BPD() self.gen_ppl = Perplexity() self.sample_entropy = torchmetrics.aggregation.MeanMetric() self.eval_ppl_batch_size = eval_ppl_batch_size self.gen_ppl_eval_model_name_or_path = gen_ppl_eval_model_name_or_path self.tokenizer = transformers.AutoTokenizer.\ from_pretrained(gen_ppl_eval_model_name_or_path) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id def to(self, *args, **kwargs): self.gen_ppl = self.gen_ppl.to(*args, **kwargs) self.sample_entropy = self.sample_entropy.to(*args, **kwargs) self.train_nlls = self.train_nlls.to(*args, **kwargs) self.train_aux = self.train_aux.to(*args, **kwargs) self.valid_nlls = self.valid_nlls.to(*args, **kwargs) self.valid_aux = self.valid_aux.to(*args, **kwargs) def reset(self): self.gen_ppl.reset() self.sample_entropy.reset() self.train_nlls.reset() self.train_aux.reset() self.valid_nlls.reset() self.valid_aux.reset() def update_train(self, nll, aux_loss, num_tokens): self.train_nlls.update(nll, num_tokens) self.train_aux.update(aux_loss, num_tokens) def update_valid(self, nll, aux_loss, num_tokens): self.valid_nlls.update(nll, num_tokens) self.valid_aux.update(aux_loss, num_tokens) @torch.no_grad() def _eval_retokenize(self, text_samples, max_length, device): """Retokenizes samples for the eval model. Args: text_samples: List of sentences generated by the model. Returns: samples: Samples re-tokenized for the eval model attn_mask: Attention mask for the eval model eval_context_size: Size of the context for the eval model """ if 'llama2' in self.gen_ppl_eval_model_name_or_path: tokenizer_kwargs = { 'text_samples': text_samples, 'return_tensors': 'pt', 'return_token_type_ids': False, 'return_attention_mask': True, 'truncation': True, 'padding': True, 'max_length': max_length, } eval_context_size = 4096 else: tokenizer_kwargs = { 'return_tensors': 'pt', 'return_token_type_ids': False, 'return_attention_mask': True, 'truncation': True, 'padding': True, 'max_length': max_length, } eval_context_size = 1024 samples = self.tokenizer(text_samples, **tokenizer_kwargs) attn_mask = samples['attention_mask'] samples = samples['input_ids'] if 'llama2' not in self.gen_ppl_eval_model_name_or_path: attn_mask = attn_mask.to(device) samples = samples.to(device) return samples, attn_mask, eval_context_size @torch.no_grad() def record_entropy(self, tokens): for sample in tokens: _, counts = torch.unique( sample, return_counts=True, sorted=False) entropy = torch.special.entr( counts.float() / counts.sum()).sum().item() self.sample_entropy.update(entropy) @torch.no_grad() def record_generative_perplexity( self, text_samples: typing.List[str], max_length: int, retokenize: bool = True, device='cuda') -> None: os.environ['TOKENIZERS_PARALLELISM'] = 'false' eval_model = transformers.AutoModelForCausalLM.from_pretrained( self.gen_ppl_eval_model_name_or_path).eval() if 'llama2' not in self.gen_ppl_eval_model_name_or_path: eval_model = eval_model.to(device) # Re-tokenize using eval model's tokenizer if retokenize: (samples, attn_mask, eval_context_size) = self._eval_retokenize( text_samples, max_length=max_length, device=device) else: samples = text_samples attn_mask = torch.ones(samples.shape).to(device) eval_context_size = samples.shape[-1] batch_size = min(self.eval_ppl_batch_size, samples.shape[0]) num_batches = samples.shape[0] // batch_size for i in range(num_batches): _samples = torch.split( samples[i * batch_size: (i + 1) * batch_size], eval_context_size, dim=-1) _attn_mask = torch.split( attn_mask[i * batch_size: (i + 1) * batch_size], eval_context_size, dim=-1) for (sample_chunk, attn_mask_chunk) in zip(_samples, _attn_mask): logits = eval_model(sample_chunk, attention_mask=attn_mask_chunk) logits = logits[0].transpose(-1, -2) nlls = F.cross_entropy(logits[..., :-1], sample_chunk[..., 1:], reduction='none') first_eos = ( sample_chunk == self.tokenizer.eos_token_id).cumsum(-1) == 1 token_mask = sample_chunk != self.tokenizer.eos_token_id valid_tokens = first_eos[..., 1:] + token_mask[..., 1:] self.gen_ppl.update(nlls * valid_tokens, valid_tokens) ================================================ FILE: models/__init__.py ================================================ from . import dit from . import ema from . import unet ================================================ FILE: models/dit.py ================================================ import math import typing import einops import flash_attn import flash_attn.layers.rotary import huggingface_hub import omegaconf import torch import torch.nn as nn import torch.nn.functional as F # Flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) torch._C._jit_override_can_fuse_on_gpu(True) def bias_dropout_add_scale( x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: torch.Tensor, residual: typing.Optional[torch.Tensor], prob: float, training: bool) -> torch.Tensor: if bias is not None: out = scale * F.dropout(x + bias, p=prob, training=training) else: out = scale * F.dropout(x, p=prob, training=training) if residual is not None: out = residual + out return out def get_bias_dropout_add_scale(training): def _bias_dropout_add(x, bias, scale, residual, prob): return bias_dropout_add_scale( x, bias, scale, residual, prob, training) return _bias_dropout_add # function overload def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return x * (1 + scale) + shift @torch.jit.script def bias_dropout_add_scale_fused_train( x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: torch.Tensor, residual: typing.Optional[torch.Tensor], prob: float) -> torch.Tensor: return bias_dropout_add_scale( x, bias, scale, residual, prob, True) @torch.jit.script def bias_dropout_add_scale_fused_inference( x: torch.Tensor, bias: typing.Optional[torch.Tensor], scale: torch.Tensor, residual: typing.Optional[torch.Tensor], prob: float) -> torch.Tensor: return bias_dropout_add_scale( x, bias, scale, residual, prob, False) @torch.jit.script def modulate_fused(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return modulate(x, shift, scale) class Rotary(torch.nn.Module): def __init__(self, dim, base=10_000): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x, seq_dim=1): seq_len = x.shape[seq_dim] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone()) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # dims are: batch, seq_len, qkv, head, dim self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1) self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1) # This makes the transformation on v an identity. self.cos_cached[:,:,2,:,:].fill_(1.) self.sin_cached[:,:,2,:,:].fill_(0.) return self.cos_cached, self.sin_cached def rotate_half(x): x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin): with torch.amp.autocast('cuda', enabled=False): cos, sin = rotary_cos_sin cos = cos.to(qkv.dtype) sin = sin.to(qkv.dtype) cos = cos[0,:,0,0,:cos.shape[-1]//2] sin = sin[0,:,0,0,:sin.shape[-1]//2] q, k, v = qkv.chunk(3, dim=2) q = flash_attn.layers.rotary.apply_rotary_emb_torch( q.squeeze(dim=2), cos, sin) k = flash_attn.layers.rotary.apply_rotary_emb_torch( k.squeeze(dim=2), cos, sin) v = v.squeeze(dim=2) return q, k, v def apply_rotary_pos_emb(qkv, cos, sin): cos = cos[0,:,0,0,:cos.shape[-1]//2] sin = sin[0,:,0,0,:sin.shape[-1]//2] return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin) def regular_attention_multi_headed(q, k, v): # Assuming qkv is a tensor with shape [batch, seq_len, 3, num_heads, head_dim] # where the 3 represents Q, K, V packed in that order attention_output = F.scaled_dot_product_attention( query=q.transpose(1, 2), key=k.transpose(1, 2), value=v.transpose(1, 2), attn_mask=None, dropout_p=0.0, is_causal=False) # [batch_size, seq_len, num_heads, head_dim] attention_output = attention_output.transpose(1, 2) return einops.rearrange(attention_output, 'b s h d -> b s (h d)') ################################################################################# # Layers # ################################################################################# class LayerNorm(nn.Module): def __init__(self, dim): super().__init__() self.weight = nn.Parameter(torch.ones([dim])) self.dim = dim def forward(self, x): with torch.amp.autocast('cuda', enabled=False): x = F.layer_norm(x.float(), [self.dim]) return x * self.weight[None, None, :] def residual_linear(x, W, x_skip, residual_scale): """x_skip + residual_scale * W @ x""" dim_out, dim_in = W.shape[0], W.shape[1] return torch.addmm( x_skip.view(-1, dim_out), x.view(-1, dim_in), W.T, alpha=residual_scale).view(*x.shape[:-1], dim_out) ################################################################################# # Embedding Layers for Timesteps and Class Labels # ################################################################################# class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True)) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( - math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class LabelEmbedder(nn.Module): """Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ def __init__(self, num_classes, cond_size): super().__init__() self.embedding_table = nn.Embedding(num_classes + 1, cond_size) self.num_classes = num_classes # TODO think of initializing with 0.02 std deviation like in original DiT paper def forward(self, labels): embeddings = self.embedding_table(labels) return embeddings ################################################################################# # Core Model # ################################################################################# class DDiTBlockCausal(nn.Module): def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1): super().__init__() self.n_heads = n_heads self.norm1 = LayerNorm(dim) self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) self.attn_out = nn.Linear(dim, dim, bias=False) self.dropout1 = nn.Dropout(dropout) self.norm2 = LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_ratio * dim, bias=True), nn.GELU(approximate='tanh'), nn.Linear(mlp_ratio * dim, dim, bias=True)) self.dropout2 = nn.Dropout(dropout) self.dropout = dropout def _get_bias_dropout_scale(self): if self.training: return bias_dropout_add_scale_fused_train else: return bias_dropout_add_scale_fused_inference def forward(self, x, rotary_cos_sin, **kwargs): del kwargs batch_size, seq_len = x.shape[0], x.shape[1] bias_dropout_scale_fn = self._get_bias_dropout_scale() # attention operation x_skip = x x = self.norm1(x) qkv = self.attn_qkv(x) qkv = einops.rearrange( qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads) with torch.amp.autocast('cuda', enabled=False): cos, sin = rotary_cos_sin qkv = apply_rotary_pos_emb( qkv, cos.to(qkv.dtype), sin.to(qkv.dtype) ) qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...') cu_seqlens = torch.arange( 0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32, device=qkv.device) x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, seq_len, 0.0, causal=True) x = einops.rearrange(x, '(b s) h d -> b s (h d)', b=batch_size) scale = torch.ones(1, device=x.device, dtype=x.dtype) x = bias_dropout_scale_fn( self.attn_out(x), None, scale, x_skip, self.dropout) # mlp operation x = bias_dropout_scale_fn( self.mlp(self.norm2(x)), None, scale, x, self.dropout) return x class DDiTBlock(nn.Module): def __init__(self, dim, n_heads, adaLN, cond_dim=None, mlp_ratio=4, dropout=0.1): super().__init__() self.n_heads = n_heads self.adaLN = adaLN self.norm1 = LayerNorm(dim) self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False) self.attn_out = nn.Linear(dim, dim, bias=False) self.dropout1 = nn.Dropout(dropout) self.norm2 = LayerNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, mlp_ratio * dim, bias=True), nn.GELU(approximate='tanh'), nn.Linear(mlp_ratio * dim, dim, bias=True)) self.dropout2 = nn.Dropout(dropout) self.dropout = dropout if self.adaLN: self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim) self.adaLN_modulation.weight.data.zero_() self.adaLN_modulation.bias.data.zero_() def _get_bias_dropout_scale(self): if self.training: return bias_dropout_add_scale_fused_train else: return bias_dropout_add_scale_fused_inference def forward(self, x, rotary_cos_sin, c=None): bias_dropout_scale_fn = self._get_bias_dropout_scale() x_skip = x x = self.norm1(x) if self.adaLN: # self.adaLN_modulation(c): (128, 1536) # self.adaLN_modulation(c)[:, None]: (128, 1, 1536) # "" .chunk(6, dim=2) returns 6 tuples of shapes (128, 1, 256) (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2) x = modulate_fused(x, shift_msa, scale_msa) qkv = einops.rearrange( self.attn_qkv(x), 'b s (three h d) -> b s three h d', three=3, h=self.n_heads) q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin) x = regular_attention_multi_headed(q, k, v) if self.adaLN: x = bias_dropout_scale_fn(self.attn_out(x), None, gate_msa, x_skip, self.dropout) x = bias_dropout_scale_fn( self.mlp(modulate_fused( self.norm2(x), shift_mlp, scale_mlp)), None, gate_mlp, x, self.dropout) else: scale = torch.ones(1, device=x.device, dtype=x.dtype) x = bias_dropout_scale_fn( self.attn_out(x), None, scale, x_skip, self.dropout) x = bias_dropout_scale_fn( self.mlp(self.norm2(x)), None, scale, x, self.dropout) return x class EmbeddingLayer(nn.Module): def __init__(self, dim, vocab_dim): super().__init__() self.embedding = nn.Parameter(torch.empty((vocab_dim, dim))) torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5)) def forward(self, x, weights=None): if weights is not None: bs, seq_len, k = x.shape flat_x = x.reshape(-1, k) flat_w = weights.reshape(-1, k).float() bag = F.embedding_bag(flat_x, self.embedding.float(), per_sample_weights=flat_w, mode='sum') return bag.view(bs, seq_len, -1) elif x.ndim == 2: return self.embedding[x] assert x.ndim == 3 return torch.einsum( "blv,ve->ble", torch.nn.functional.softmax(x, dim=-1).float(), self.embedding.float()).to(x.dtype) class DDiTFinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, cond_dim, adaLN): super().__init__() self.norm_final = LayerNorm(hidden_size) self.linear = nn.Linear(hidden_size, out_channels) self.linear.weight.data.zero_() self.linear.bias.data.zero_() self.adaLN = adaLN if self.adaLN: self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True) self.adaLN_modulation.weight.data.zero_() self.adaLN_modulation.bias.data.zero_() def forward(self, x, c): x = self.norm_final(x) if self.adaLN: shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2) x = modulate_fused(x, shift, scale) x = self.linear(x) return x class DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin): def __init__(self, config, vocab_size: int): super().__init__() if type(config) == dict: config = omegaconf.OmegaConf.create(config) self.causal = config.algo.causal_attention self.adaLN = not self.causal self.config = config self.vocab_size = vocab_size dim = config.model.hidden_size cond_dim = config.model.cond_dim self.vocab_embed = EmbeddingLayer(dim, vocab_size) if not self.causal: self.sigma_map = TimestepEmbedder(cond_dim) self.rotary_emb = Rotary(dim // config.model.n_heads) blocks = [] for _ in range(config.model.n_blocks): if self.causal: block = DDiTBlockCausal( dim=dim, n_heads=config.model.n_heads, dropout=config.model.dropout) else: block = DDiTBlock( dim=dim, n_heads=config.model.n_heads, cond_dim=cond_dim, adaLN=self.adaLN, dropout=config.model.dropout) blocks.append(block) self.blocks = nn.ModuleList(blocks) self.output_layer = DDiTFinalLayer( hidden_size=dim, out_channels=vocab_size, cond_dim=cond_dim, adaLN=self.adaLN) self.scale_by_sigma = config.model.scale_by_sigma def _get_bias_dropout_scale(self): if self.training: return bias_dropout_add_scale_fused_train else: return bias_dropout_add_scale_fused_inference def forward(self, x, sigma, class_cond=None, weights=None): assert class_cond is None, 'Not implemented for DiT' x = self.vocab_embed(x, weights) if self.causal: t_cond = None else: t_cond = F.silu(self.sigma_map(sigma)) rotary_cos_sin = self.rotary_emb(x) with torch.amp.autocast('cuda', dtype=torch.bfloat16): for i in range(len(self.blocks)): x = self.blocks[i](x, rotary_cos_sin, c=t_cond) x = self.output_layer(x, c=t_cond) return x ================================================ FILE: models/ema.py ================================================ import torch class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. """ def __init__(self, parameters, decay, use_num_updates=True): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`. decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] def move_shadow_params_to_device(self, device): self.shadow_params = [i.to(device) for i in self.shadow_params] def update(self, parameters): """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) def copy_to(self, parameters): """ Copy current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) def state_dict(self): return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params) def load_state_dict(self, state_dict): self.decay = state_dict['decay'] self.num_updates = state_dict['num_updates'] self.shadow_params = state_dict['shadow_params'] ================================================ FILE: models/unet.py ================================================ import torch import torch.nn as nn import math import torch.nn.functional as F import numpy as np import omegaconf import transformers from einops import rearrange from .dit import LabelEmbedder, EmbeddingLayer # From https://github.com/yang-song/score_sde_pytorch/ which is from # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py def transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000): assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 half_dim = embedding_dim // 2 # magic number 10000 is from transformers emb = math.log(max_positions) / (half_dim - 1) # emb = math.log(2.) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1), mode='constant') assert emb.shape == (timesteps.shape[0], embedding_dim) return emb # Code modified from https://github.com/yang-song/score_sde_pytorch def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device='cpu'): """Ported from JAX. """ def _compute_fans(shape, in_axis=1, out_axis=0): receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] fan_in = shape[in_axis] * receptive_field_size fan_out = shape[out_axis] * receptive_field_size return fan_in, fan_out def init(shape, dtype=dtype, device=device): fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 else: raise ValueError( "invalid mode for variance scaling initializer: {}".format(mode)) variance = scale / denominator if distribution == "normal": return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) elif distribution == "uniform": return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) else: raise ValueError("invalid distribution for variance scaling initializer") return init def default_init(scale=1.): """The same initialization used in DDPM.""" scale = 1e-10 if scale == 0 else scale return variance_scaling(scale, 'fan_avg', 'uniform') class NiN(nn.Module): def __init__(self, in_ch, out_ch, init_scale=0.1): super().__init__() self.W = nn.Parameter(default_init(scale=init_scale)((in_ch, out_ch)), requires_grad=True) self.b = nn.Parameter(torch.zeros(out_ch), requires_grad=True) def forward(self, x, # ["batch", "in_ch", "H", "W"] ): x = x.permute(0, 2, 3, 1) # x (batch, H, W, in_ch) y = torch.einsum('bhwi,ik->bhwk', x, self.W) + self.b # y (batch, H, W, out_ch) return y.permute(0, 3, 1, 2) class AttnBlock(nn.Module): """Channel-wise self-attention block.""" def __init__(self, channels, skip_rescale=True): super().__init__() self.skip_rescale = skip_rescale self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels//4, 32), num_channels=channels, eps=1e-6) self.NIN_0 = NiN(channels, channels) self.NIN_1 = NiN(channels, channels) self.NIN_2 = NiN(channels, channels) self.NIN_3 = NiN(channels, channels, init_scale=0.) def forward(self, x, # ["batch", "channels", "H", "W"] ): B, C, H, W = x.shape h = self.GroupNorm_0(x) q = self.NIN_0(h) k = self.NIN_1(h) v = self.NIN_2(h) w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) w = torch.reshape(w, (B, H, W, H * W)) w = F.softmax(w, dim=-1) w = torch.reshape(w, (B, H, W, H, W)) h = torch.einsum('bhwij,bcij->bchw', w, v) h = self.NIN_3(h) if self.skip_rescale: return (x + h) / np.sqrt(2.) else: return x + h class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_rescale=True): super().__init__() self.in_ch = in_ch self.out_ch = out_ch self.skip_rescale = skip_rescale self.act = nn.functional.silu self.groupnorm0 = nn.GroupNorm( num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6 ) self.conv0 = nn.Conv2d( in_ch, out_ch, kernel_size=3, padding=1 ) if temb_dim is not None: self.dense0 = nn.Linear(temb_dim, out_ch) nn.init.zeros_(self.dense0.bias) self.groupnorm1 = nn.GroupNorm( num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6 ) self.dropout0 = nn.Dropout(dropout) self.conv1 = nn.Conv2d( out_ch, out_ch, kernel_size=3, padding=1 ) if out_ch != in_ch: self.nin = NiN(in_ch, out_ch) def forward(self, x, # ["batch", "in_ch", "H", "W"] temb=None, # ["batch", "temb_dim"] ): assert x.shape[1] == self.in_ch h = self.groupnorm0(x) h = self.act(h) h = self.conv0(h) if temb is not None: h += self.dense0(self.act(temb))[:, :, None, None] h = self.groupnorm1(h) h = self.act(h) h = self.dropout0(h) h = self.conv1(h) if h.shape[1] != self.in_ch: x = self.nin(x) assert x.shape == h.shape if self.skip_rescale: return (x + h) / np.sqrt(2.) else: return x + h class Downsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=0) def forward(self, x, # ["batch", "ch", "inH", "inW"] ): B, C, H, W = x.shape x = nn.functional.pad(x, (0, 1, 0, 1)) x= self.conv(x) assert x.shape == (B, C, H // 2, W // 2) return x class Upsample(nn.Module): def __init__(self, channels): super().__init__() self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) def forward(self, x, # ["batch", "ch", "inH", "inW"] ): B, C, H, W = x.shape h = F.interpolate(x, (H*2, W*2), mode='nearest') h = self.conv(h) assert h.shape == (B, C, H*2, W*2) return h class UNet(nn.Module): def __init__(self, config, vocab_size=None): super().__init__() if type(config) == dict: config = omegaconf.OmegaConf.create(config) assert config.model.name == 'unet' self.ch = config.model.ch self.num_res_blocks = config.model.num_res_blocks self.num_scales = config.model.num_scales self.ch_mult = config.model.ch_mult assert self.num_scales == len(self.ch_mult) self.input_channels = config.model.input_channels self.output_channels = 2 * config.model.input_channels self.scale_count_to_put_attn = config.model.scale_count_to_put_attn self.data_min_max = [0, vocab_size] # config.model.data_min_max # tuple of min and max value of input so it can be rescaled to [-1, 1] self.dropout = config.model.dropout self.skip_rescale = config.model.skip_rescale self.time_conditioning = config.model.time_conditioning # Whether to add in time embeddings self.time_scale_factor = config.model.time_scale_factor # scale to make the range of times be 0 to 1000 self.time_embed_dim = config.model.time_embed_dim self.vocab_size = vocab_size self.size = config.model.size self.length = config.model.length # truncated logistic self.fix_logistic = config.model.fix_logistic self.act = nn.functional.silu if self.time_conditioning: self.temb_modules = [] self.temb_modules.append(nn.Linear(self.time_embed_dim, self.time_embed_dim*4)) nn.init.zeros_(self.temb_modules[-1].bias) self.temb_modules.append(nn.Linear(self.time_embed_dim*4, self.time_embed_dim*4)) nn.init.zeros_(self.temb_modules[-1].bias) self.temb_modules = nn.ModuleList(self.temb_modules) self.expanded_time_dim = 4 * self.time_embed_dim if self.time_conditioning else None self.input_conv = nn.Conv2d( in_channels=self.input_channels, out_channels=self.ch, kernel_size=3, padding=1 ) h_cs = [self.ch] in_ch = self.ch # Downsampling self.downsampling_modules = [] for scale_count in range(self.num_scales): for res_count in range(self.num_res_blocks): out_ch = self.ch * self.ch_mult[scale_count] self.downsampling_modules.append( ResBlock(in_ch, out_ch, temb_dim=self.expanded_time_dim, dropout=self.dropout, skip_rescale=self.skip_rescale) ) in_ch = out_ch h_cs.append(in_ch) if scale_count == self.scale_count_to_put_attn: self.downsampling_modules.append( AttnBlock(in_ch, skip_rescale=self.skip_rescale) ) if scale_count != self.num_scales - 1: self.downsampling_modules.append(Downsample(in_ch)) h_cs.append(in_ch) self.downsampling_modules = nn.ModuleList(self.downsampling_modules) # Middle self.middle_modules = [] self.middle_modules.append( ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim, dropout=self.dropout, skip_rescale=self.skip_rescale) ) self.middle_modules.append( AttnBlock(in_ch, skip_rescale=self.skip_rescale) ) self.middle_modules.append( ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim, dropout=self.dropout, skip_rescale=self.skip_rescale) ) self.middle_modules = nn.ModuleList(self.middle_modules) # Upsampling self.upsampling_modules = [] for scale_count in reversed(range(self.num_scales)): for res_count in range(self.num_res_blocks+1): out_ch = self.ch * self.ch_mult[scale_count] self.upsampling_modules.append( ResBlock(in_ch + h_cs.pop(), out_ch, temb_dim=self.expanded_time_dim, dropout=self.dropout, skip_rescale=self.skip_rescale ) ) in_ch = out_ch if scale_count == self.scale_count_to_put_attn: self.upsampling_modules.append( AttnBlock(in_ch, skip_rescale=self.skip_rescale) ) if scale_count != 0: self.upsampling_modules.append(Upsample(in_ch)) self.upsampling_modules = nn.ModuleList(self.upsampling_modules) assert len(h_cs) == 0 # output self.output_modules = [] self.output_modules.append( nn.GroupNorm(min(in_ch//4, 32), in_ch, eps=1e-6) ) self.output_modules.append( nn.Conv2d(in_ch, self.output_channels, kernel_size=3, padding=1) ) self.output_modules = nn.ModuleList(self.output_modules) self.cond_map = LabelEmbedder( config.data.num_classes, self.time_embed_dim*4) def _center_data(self, x): out = (x - self.data_min_max[0]) / (self.data_min_max[1] - self.data_min_max[0]) # [0, 1] return 2 * out - 1 # to put it in [-1, 1] def _time_embedding(self, timesteps): if self.time_conditioning: temb = transformer_timestep_embedding( timesteps * self.time_scale_factor, self.time_embed_dim ) temb = self.temb_modules[0](temb) temb = self.temb_modules[1](self.act(temb)) else: temb = None return temb def _do_input_conv(self, h): h = self.input_conv(h) hs = [h] return h, hs def _do_downsampling(self, h, hs, temb): m_idx = 0 for scale_count in range(self.num_scales): for res_count in range(self.num_res_blocks): h = self.downsampling_modules[m_idx](h, temb) m_idx += 1 if scale_count == self.scale_count_to_put_attn: h = self.downsampling_modules[m_idx](h) m_idx += 1 hs.append(h) if scale_count != self.num_scales - 1: h = self.downsampling_modules[m_idx](h) hs.append(h) m_idx += 1 assert m_idx == len(self.downsampling_modules) return h, hs def _do_middle(self, h, temb): m_idx = 0 h = self.middle_modules[m_idx](h, temb) m_idx += 1 h = self.middle_modules[m_idx](h) m_idx += 1 h = self.middle_modules[m_idx](h, temb) m_idx += 1 assert m_idx == len(self.middle_modules) return h def _do_upsampling(self, h, hs, temb): m_idx = 0 for scale_count in reversed(range(self.num_scales)): for res_count in range(self.num_res_blocks+1): h = self.upsampling_modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) m_idx += 1 if scale_count == self.scale_count_to_put_attn: h = self.upsampling_modules[m_idx](h) m_idx += 1 if scale_count != 0: h = self.upsampling_modules[m_idx](h) m_idx += 1 assert len(hs) == 0 assert m_idx == len(self.upsampling_modules) return h def _do_output(self, h): h = self.output_modules[0](h) h = self.act(h) h = self.output_modules[1](h) return h def _logistic_output_res(self, h, # ["B", "twoC", "H", "W"] centered_x_in, # ["B", "C", "H", "W"] ): B, twoC, H, W = h.shape C = twoC//2 h[:, 0:C, :, :] = torch.tanh(centered_x_in + h[:, 0:C, :, :]) return h def _log_minus_exp(self, a, b, eps=1e-6): """ Compute log (exp(a) - exp(b)) for (b b c h w", h=img_size, w=img_size, c=3) h = self._center_data(h) centered_x_in = h temb = self._time_embedding(sigma) if class_cond is not None: if self.cond_map is None: raise ValueError("Conditioning variable provided, " "but Model was not initialized " "with condition embedding layer.") else: assert class_cond.shape == (x.shape[0],) temb = temb + self.cond_map(class_cond) h, hs = self._do_input_conv(h) h, hs = self._do_downsampling(h, hs, temb) h = self._do_middle(h, temb) h = self._do_upsampling(h, hs, temb) h = self._do_output(h) # h (B, 2*C, H, W) h = self._logistic_output_res(h, centered_x_in) h = self._truncated_logistic_output(h) # (B, D, S) return h ================================================ FILE: models/unit_test_attention.py ================================================ import unittest import torch # from flash_attn import flash_attention import torch.nn.functional as F def attention_inner_heads_flash(qkv, num_heads): """Computes attention with heads inside of qkv in the channel dimension using FlashAttention. Args: qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where: H = number of heads, C = number of channels per head. num_heads: number of heads. Returns: Attention output of shape (B, H*C, T). """ # bs, width, length = qkv.shape # ch = width // (3 * num_heads) # # Split into (q, k, v) of shape (B, H*C, T). # q, k, v = qkv.chunk(3, dim=1) # # Rescale q and k. This makes them contiguous in memory. # scale = ch ** (-1 / 4) # scale with 4th root = scaling output by sqrt # q = q * scale # k = k * scale # # Reshape q, k, v to (B, H, T, C) for FlashAttention # q = q.view(bs, num_heads, ch, length).permute(0, 1, 3, 2) # (B, H, T, C) # k = k.view(bs, num_heads, ch, length).permute(0, 1, 3, 2) # (B, H, T, C) # v = v.view(bs, num_heads, ch, length).permute(0, 1, 3, 2) # (B, H, T, C) # # Compute attention using FlashAttention # out = flash_attention(q, k, v) # (B, H, T, C) # # Reshape back to (B, H*C, T) # out = out.permute(0, 1, 3, 2).reshape(bs, num_heads * ch, length) # return out bs, width, length = qkv.shape ch = width // (3 * num_heads) # Split into (q, k, v) and reshape directly to (B, H, T, C) q, k, v = qkv.chunk(3, dim=1) q = q.view(bs, num_heads, ch, length).transpose(2, 3) # (B, H, T, C) k = k.view(bs, num_heads, ch, length).transpose(2, 3) # (B, H, T, C) v = v.view(bs, num_heads, ch, length).transpose(2, 3) # (B, H, T, C) # Compute scaled dot-product attention out = F.scaled_dot_product_attention(q, k, v) # (B, H, T, C) # Reshape back to (B, H*C, T) in one step out = out.transpose(2, 3).reshape(bs, num_heads * ch, length) return out class TestAttentionInnerHeadsFlash(unittest.TestCase): def setUp(self): # Set up common test variables self.batch_size = 2 self.num_heads = 4 self.seq_len = 8 self.channels_per_head = 16 self.qkv_dim = 3 * self.num_heads * self.channels_per_head # Create a random qkv tensor self.qkv = torch.randn(self.batch_size, self.qkv_dim, self.seq_len, dtype=torch.float32) def attention_inner_heads_old(self, qkv, num_heads): """Original implementation of attention for reference.""" bs, width, length = qkv.shape ch = width // (3 * num_heads) q, k, v = qkv.chunk(3, dim=1) scale = ch ** (-1 / 4) q = q * scale k = k * scale q = q.view(bs * num_heads, ch, length) k = k.view(bs * num_heads, ch, length) v = v.reshape(bs * num_heads, ch, length) weight = torch.einsum("bct,bcs->bts", q, k) weight = torch.softmax(weight.float(), dim=-1).to(weight.dtype) out = torch.einsum("bts,bcs->bct", weight, v) return out.reshape(bs, num_heads * ch, length) def test_attention_inner_heads_flash(self): # Compute the output using the old and flash attention implementations output_old = self.attention_inner_heads_old(self.qkv, self.num_heads) output_flash = attention_inner_heads_flash(self.qkv, self.num_heads) # Verify the shapes are the same self.assertEqual(output_old.shape, output_flash.shape) # Verify that the outputs are close (numerically similar) self.assertTrue(torch.allclose(output_old, output_flash, atol=1e-5)) if __name__ == "__main__": unittest.main() ================================================ FILE: requirements.txt ================================================ # conda install nvidia/label/cuda-12.4.0::cuda-toolkit datasets==2.15.0 einops==0.7.0 fsspec git-lfs==1.6 h5py==3.10.0 hydra-core==1.3.2 ipdb==0.13.13 lightning==2.2.1 notebook==7.1.1 nvitop==1.3.2 omegaconf==2.3.0 packaging==23.2 pandas==2.2.1 rich==13.7.1 seaborn==0.13.2 scikit-learn==1.4.0 transformers==4.38.2 triton==2.2.0 torch==2.3.1 torchaudio==2.3.1 torchmetrics==1.6.1 torchvision==0.18.1 wandb timm ocifs hf_transfer huggingface-hub # Install flash attention only after installing the above modules via pip install -r requirements.txt # flash_attn==2.7.4.post1 ================================================ FILE: scripts/distil_owt.sh ================================================ #!/bin/bash #SBATCH -J posterior # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=64000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption export HYDRA_FULL_ERROR=1 finetune_path=/path/to/duo.ckpt srun python -u -m main \ mode=train \ loader.batch_size=2 \ loader.eval_batch_size=2 \ data=openwebtext-split \ model=small \ algo=distillation \ training.finetune_path=$finetune_path \ sampling.num_sample_batches=10 \ sampling.steps=32 \ eval.compute_generative_perplexity=True \ algo.T=512 \ lr_scheduler.num_warmup_steps=500 \ trainer.val_check_interval=1000 \ trainer.max_steps=50000 \ loader.global_batch_size=128 \ training.ema=0.999 \ algo.update_teacher_every=10000 \ optim.lr=6e-5 \ trainer.limit_val_batches=8 \ algo.teacher_ema=False \ algo.linear_growth_dt=false \ +wandb.offline=true ================================================ FILE: scripts/eval_lm1b_duo.sh ================================================ #!/bin/bash #SBATCH -J eval_mdlm # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=100000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=4 #SBATCH --gres=gpu:4 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-W2ZcFy-small-conjugate-lm1b-wrap-1M-gmin-350-gmax-175-3/checkpoints/7-100000.ckpt export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=ppl_eval \ loader.batch_size=64 \ loader.eval_batch_size=64 \ data=lm1b-wrap \ model=small \ model.length=128 \ algo=duo_base \ eval.checkpoint_path=$checkpoint_path \ sampling.num_sample_batches=0 \ +wandb.offline=true ================================================ FILE: scripts/eval_owt_ar.sh ================================================ #!/bin/bash #SBATCH -J eval_ar # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=100000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=4 #SBATCH --gres=gpu:4 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data=openwebtext-split \ algo=ar \ model.length=1024 \ eval.checkpoint_path=$checkpoint_path \ +wandb.offline=true ================================================ FILE: scripts/eval_owt_duo.sh ================================================ #!/bin/bash #SBATCH -J owt_duo_anneal # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=100000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-VlCQLK-small-conjugate-OWT-anneal/checkpoints/last.ckpt export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=ppl_eval \ loader.batch_size=8 \ loader.eval_batch_size=8 \ data=openwebtext-split \ model=small \ algo=duo_base \ eval.checkpoint_path=$checkpoint_path \ sampling.num_sample_batches=0 \ +wandb.offline=true ================================================ FILE: scripts/eval_owt_mdlm.sh ================================================ #!/bin/bash #SBATCH -J eval_mdlm # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=100000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=4 #SBATCH --gres=gpu:4 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data=openwebtext-split \ model=small \ algo=mdlm \ eval.checkpoint_path=$checkpoint_path \ sampling.num_sample_batches=0 \ +wandb.offline=true ================================================ FILE: scripts/eval_owt_sedd.sh ================================================ #!/bin/bash #SBATCH -J eval_sedd # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=100000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=4 #SBATCH --gres=gpu:4 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data=openwebtext-split \ model=small \ algo=sedd \ eval.checkpoint_path=$checkpoint_path \ sampling.num_sample_batches=0 \ +wandb.offline=true ================================================ FILE: scripts/fid_cifar10_duo_ancestral_cosine.sh ================================================ export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 python -u -m main \ mode=fid_eval \ sampling.steps=64 \ sampling.guid_weight=1.0 \ data=cifar10 \ data.cache_dir= \ model=unet \ noise=cosine \ algo=duo_base \ algo.backbone=unet \ trainer.num_nodes=1 \ loader.eval_batch_size=500 \ eval.checkpoint_path= ================================================ FILE: scripts/fid_cifar10_duo_base_ancestral_cosine.sh ================================================ export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 python -u -m main \ mode=fid_eval \ sampling.steps=64 \ sampling.guid_weight=1.0 \ data=cifar10 \ data.cache_dir= \ model=unet \ noise=cosine \ algo=duo_base \ algo.backbone=unet \ trainer.num_nodes=1 \ loader.eval_batch_size=500 \ eval.checkpoint_path= ================================================ FILE: scripts/fid_cifar10_mdlm_ancestral_cosine.sh ================================================ python -u -m main \ mode=fid_eval \ sampling.steps=64 \ sampling.guid_weight=1.0 \ sampling.predictor=ancestral_cache \ data=cifar10 \ data.cache_dir= \ model=unet \ noise=cosine \ algo=mdlm \ algo.backbone=unet \ loader.eval_batch_size=500 \ eval.checkpoint_path= ================================================ FILE: scripts/gen_ppl_lm1b_ar.sh ================================================ #!/bin/bash #SBATCH -J sample_ar # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/lm1b-ar/last_copy.ckpt export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=sample_eval \ loader.batch_size=2 \ loader.eval_batch_size=64 \ data=lm1b-wrap \ algo=ar \ model=small \ model.length=128 \ eval.checkpoint_path=$checkpoint_path \ sampling.num_sample_batches=15 \ +wandb.offline=true ================================================ FILE: scripts/gen_ppl_lm1b_duo.sh ================================================ #!/bin/bash #SBATCH -J sample_ar # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil-kl-bwd-32/checkpoints/0-1000.ckpt steps=32 export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=sample_eval \ loader.batch_size=2 \ loader.eval_batch_size=64 \ data=lm1b-wrap \ algo=duo_base \ model=small \ model.length=128 \ eval.checkpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil7-kl-bwd/checkpoints/last.ckpt \ sampling.num_sample_batches=15 \ sampling.steps=$steps \ +wandb.offline=true \ sampling.noise_removal=greedy ================================================ FILE: scripts/gen_ppl_owt_ar.sh ================================================ #!/bin/bash #SBATCH -J sample_ar # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=16000 # server memory requested (per node) #SBATCH -t 24:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov,gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints seed=1 export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=sample_eval \ seed=$seed \ loader.batch_size=2 \ loader.eval_batch_size=8 \ data=openwebtext-split \ algo=ar \ model=small \ eval.checkpoint_path=$checkpoint_path/last.ckpt \ sampling.num_sample_batches=100 \ +wandb.offline=true \ eval.generated_samples_path=$checkpoint_path/$seed-ckpt-last.json ================================================ FILE: scripts/gen_ppl_owt_duo.sh ================================================ #!/bin/bash #SBATCH -J an_owt_duo # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=16000 # server memory requested (per node) #SBATCH -t 24:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov,gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption export HYDRA_FULL_ERROR=1 while [[ "$#" -gt 0 ]]; do case $1 in --steps) steps="$2"; shift ;; --seed) seed="$2"; shift ;; --ckpt) ckpt="$2"; shift ;; *) echo "Unknown parameter: $1"; exit 1 ;; esac shift done checkpoint_path=/share/kuleshov/ssahoo/flow-ode/distil-distil-vjrpZb-distillation-OWT/checkpoints ckpt=0-50000 steps=${steps:-32} seed=${seed:-1} echo " Steps: $steps" echo " Seed: $seed" echo " ckpt: $ckpt" srun python -u -m main \ mode=sample_eval \ seed=$seed \ loader.batch_size=2 \ loader.eval_batch_size=8 \ data=openwebtext-split \ algo=duo_base \ model=small \ eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \ sampling.num_sample_batches=100 \ sampling.steps=$steps \ +wandb.offline=true \ eval.generated_samples_path=$checkpoint_path/samples_ancestral/$seed-$steps-ckpt-$ckpt.json ================================================ FILE: scripts/gen_ppl_owt_mdlm.sh ================================================ #!/bin/bash #SBATCH -J sample_mdlm # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=64000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt export HYDRA_FULL_ERROR=1 srun python -u -m main \ mode=sample_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data=openwebtext-split \ model=small \ algo=mdlm \ eval.checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt \ sampling.num_sample_batches=4 \ +wandb.offline=true ================================================ FILE: scripts/gen_ppl_owt_sedd.sh ================================================ #!/bin/bash #SBATCH -J sedd_samples # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=16000 # server memory requested (per node) #SBATCH -t 24:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov,gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption while [[ "$#" -gt 0 ]]; do case $1 in --steps) steps="$2"; shift ;; --seed) seed="$2"; shift ;; *) echo "Unknown parameter: $1"; exit 1 ;; esac shift done steps=${steps:-32} seed=${seed:-1} echo " Steps: $steps" echo " Seed: $seed" export HYDRA_FULL_ERROR=1 checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints ckpt=last srun python -u -m main \ mode=sample_eval \ seed=$seed \ loader.batch_size=2 \ loader.eval_batch_size=8 \ data=openwebtext-split \ algo=sedd \ model=small \ eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \ sampling.num_sample_batches=0 \ sampling.num_sample_batches=100 \ sampling.steps=$steps \ sampling.predictor=analytic \ eval.generated_samples_path=$checkpoint_path/$seed-$steps-ckpt-$ckpt.json \ +wandb.offline=true ================================================ FILE: scripts/psi_samplers/cifar10/duo_constant_remdm.sh ================================================ # DUO psi-sampler with constant-remdm-eta mode (ReMDM loop) NUM_STEPS=256 ETA=0.01 NOISE=cosine CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=duo_base \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear-constant-linear-0.9-inv \ sampling.psi.high_mode=pure-posterior \ sampling.psi.middle_mode=constant-remdm-$ETA \ sampling.psi.low_mode=pure-posterior \ sampling.psi.high_frac=0.45 \ sampling.psi.middle_frac=0.5 ================================================ FILE: scripts/psi_samplers/cifar10/duo_max_capped_remdm.sh ================================================ # DUO psi-sampler with max-capped-eta mode (ReMDM cap) NUM_STEPS=256 ETA=0.005 NOISE=cosine CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=duo_base \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-capped-$ETA \ sampling.psi.middle_mode=max-capped-$ETA \ sampling.psi.low_mode=max-capped-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/psi_samplers/cifar10/duo_max_rescale_eta.sh ================================================ # DUO psi-sampler with max-rescale-eta mode — CIFAR-10 FID eval NUM_STEPS=256 ETA=0.01 NOISE=cosine CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=duo_base \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-rescale-$ETA \ sampling.psi.middle_mode=max-rescale-$ETA \ sampling.psi.low_mode=max-rescale-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/psi_samplers/cifar10/duo_psi_pc.sh ================================================ # DUO psi-sampler with constant kappa in pc phase # Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC # t in [1.0, 0.5] -> pure-posterior # t in [0.5, 0.1] -> constant-kappa # t in [0.1, 0.0] -> pure-posterior NUM_STEPS=256 KAPPA=0.95 NOISE=cosine HIGH_FRAC=0.5 MIDDLE_FRAC=0.4 CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=duo_base \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=pure-posterior \ sampling.psi.middle_mode=constant-$KAPPA \ sampling.psi.low_mode=pure-posterior \ sampling.psi.high_frac=$HIGH_FRAC \ sampling.psi.middle_frac=$MIDDLE_FRAC ================================================ FILE: scripts/psi_samplers/cifar10/mdlm_constant_remdm.sh ================================================ # MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop) — CIFAR-10 FID eval NUM_STEPS=256 ETA=0.01 NOISE=cosine CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=mdlm \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear-constant-linear-0.9-inv \ sampling.psi.high_mode=pure-posterior \ sampling.psi.middle_mode=constant-remdm-$ETA \ sampling.psi.low_mode=pure-posterior \ sampling.psi.high_frac=0.45 \ sampling.psi.middle_frac=0.5 ================================================ FILE: scripts/psi_samplers/cifar10/mdlm_max_capped_remdm.sh ================================================ # MDLM psi-sampler with max-capped-eta mode (ReMDM cap) NUM_STEPS=256 ETA=0.005 NOISE=cosine CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=mdlm \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-capped-$ETA \ sampling.psi.middle_mode=max-capped-$ETA \ sampling.psi.low_mode=max-capped-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/psi_samplers/cifar10/mdlm_max_rescale_eta.sh ================================================ # MDLM psi-sampler with max-rescale-eta mode NUM_STEPS=256 ETA=0.01 NOISE=cosine CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=mdlm \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-rescale-$ETA \ sampling.psi.middle_mode=max-rescale-$ETA \ sampling.psi.low_mode=max-rescale-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/psi_samplers/cifar10/mdlm_psi_pc.sh ================================================ # MDLM psi-sampler with constant kappa during pc phase # Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC # t in [1.0, 0.1] -> constant kappa # t in [0.1, 0.0] -> pure posterior NUM_STEPS=256 KAPPA=0.99 NOISE=cosine HIGH_FRAC=0.0 MIDDLE_FRAC=0.9 CHECKPOINT_PATH= DATA_CACHE_DIR= EVAL_BATCH_SIZE=500 python -u -m main \ mode=fid_eval \ data=cifar10 \ data.cache_dir=$DATA_CACHE_DIR \ model=unet \ algo=mdlm \ algo.backbone=unet \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.guid_weight=1.0 \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=pure-posterior \ sampling.psi.middle_mode=constant-$KAPPA \ sampling.psi.low_mode=pure-posterior \ sampling.psi.high_frac=$HIGH_FRAC \ sampling.psi.middle_frac=$MIDDLE_FRAC ================================================ FILE: scripts/psi_samplers/owt/duo_loop_remdm.sh ================================================ # DUO psi-sampler with constant-remdm-eta mode (ReMDM loop) NUM_STEPS=256 ETA=0.01 NUCLEUS_P=0.95 NOISE=log-linear CHECKPOINT_PATH=??? DATA_CACHE_DIR=??? EVAL_BATCH_SIZE=16 NUM_SAMPLE_BATCHES=32 python -u -m main \ mode=sample_eval \ data=openwebtext-split \ data.cache_dir=$DATA_CACHE_DIR \ model=small \ algo=duo_base \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.p_nucleus=$NUCLEUS_P \ sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear-constant-linear-0.9-inv \ sampling.psi.high_mode=pure-posterior \ sampling.psi.middle_mode=constant-remdm-$ETA \ sampling.psi.low_mode=pure-posterior \ sampling.psi.high_frac=0.45 \ sampling.psi.middle_frac=0.5 ================================================ FILE: scripts/psi_samplers/owt/duo_max_capped_remdm.sh ================================================ # DUO psi-sampler with max-capped-eta mode (ReMDM cap) NUM_STEPS=256 ETA=0.01 NUCLEUS_P=0.9 NOISE=log-linear CHECKPOINT_PATH=??? DATA_CACHE_DIR=??? EVAL_BATCH_SIZE=16 NUM_SAMPLE_BATCHES=32 python -u -m main \ mode=sample_eval \ data=openwebtext-split \ data.cache_dir=$DATA_CACHE_DIR \ model=small \ algo=duo_base \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.p_nucleus=$NUCLEUS_P \ sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-capped-$ETA \ sampling.psi.middle_mode=max-capped-$ETA \ sampling.psi.low_mode=max-capped-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/psi_samplers/owt/duo_max_rescale_eta.sh ================================================ # DUO psi-sampler with max-rescale-eta mode (ReMDM rescale) NUM_STEPS=256 ETA=0.05 NUCLEUS_P=0.9 NOISE=log-linear CHECKPOINT_PATH=??? DATA_CACHE_DIR=??? EVAL_BATCH_SIZE=16 NUM_SAMPLE_BATCHES=32 python -u -m main \ mode=sample_eval \ data=openwebtext-split \ data.cache_dir=$DATA_CACHE_DIR \ model=small \ algo=duo_base \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.p_nucleus=$NUCLEUS_P \ sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-rescale-$ETA \ sampling.psi.middle_mode=max-rescale-$ETA \ sampling.psi.low_mode=max-rescale-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/psi_samplers/owt/mdlm_loop_remdm.sh ================================================ # MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop) NUM_STEPS=256 ETA=0.01 NUCLEUS_P=0.95 NOISE=log-linear CHECKPOINT_PATH=??? DATA_CACHE_DIR=??? EVAL_BATCH_SIZE=16 NUM_SAMPLE_BATCHES=32 python -u -m main \ mode=sample_eval \ data=openwebtext-split \ data.cache_dir=$DATA_CACHE_DIR \ model=small \ algo=mdlm \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.p_nucleus=$NUCLEUS_P \ sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear-constant-linear-0.9-inv \ sampling.psi.high_mode=pure-posterior \ sampling.psi.middle_mode=constant-remdm-$ETA \ sampling.psi.low_mode=pure-posterior \ sampling.psi.high_frac=0.45 \ sampling.psi.middle_frac=0.5 ================================================ FILE: scripts/psi_samplers/owt/mdlm_max_capped_remdm.sh ================================================ # MDLM psi-sampler with max-capped-eta mode (ReMDM cap) NUM_STEPS=256 ETA=0.01 NUCLEUS_P=0.9 NOISE=log-linear CHECKPOINT_PATH=??? DATA_CACHE_DIR=??? EVAL_BATCH_SIZE=16 NUM_SAMPLE_BATCHES=32 python -u -m main \ mode=sample_eval \ data=openwebtext-split \ data.cache_dir=$DATA_CACHE_DIR \ model=small \ algo=mdlm \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.p_nucleus=$NUCLEUS_P \ sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-capped-$ETA \ sampling.psi.middle_mode=max-capped-$ETA \ sampling.psi.low_mode=max-capped-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/psi_samplers/owt/mdlm_max_rescale_eta.sh ================================================ # MDLM psi-sampler with max-rescale-eta mode (ReMDM rescale) NUM_STEPS=256 ETA=0.05 NUCLEUS_P=0.9 NOISE=log-linear CHECKPOINT_PATH=??? DATA_CACHE_DIR=??? EVAL_BATCH_SIZE=16 NUM_SAMPLE_BATCHES=32 python -u -m main \ mode=sample_eval \ data=openwebtext-split \ data.cache_dir=$DATA_CACHE_DIR \ model=small \ algo=mdlm \ noise=$NOISE \ sampling.predictor=psi \ sampling.steps=$NUM_STEPS \ sampling.p_nucleus=$NUCLEUS_P \ sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \ eval.checkpoint_path=$CHECKPOINT_PATH \ loader.eval_batch_size=$EVAL_BATCH_SIZE \ sampling.psi.time_profile=linear \ sampling.psi.high_mode=max-rescale-$ETA \ sampling.psi.middle_mode=max-rescale-$ETA \ sampling.psi.low_mode=max-rescale-$ETA \ sampling.psi.high_frac=0.0 \ sampling.psi.middle_frac=0.0 ================================================ FILE: scripts/train_cifar10_duo_base_cosine.sh ================================================ python -u -m main \ data=cifar10 \ data.cache_dir= \ model=unet \ algo=duo_base \ algo.backbone=unet \ noise=cosine \ loader.global_batch_size=128 \ loader.batch_size=32 \ loader.eval_batch_size=32 \ loader.num_workers=8 \ trainer.val_check_interval=2500 \ trainer.max_steps=1_500_000 \ lr_scheduler.num_warmup_steps=5000 \ eval.generate_samples=False \ optim.lr=2e-4 \ callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ wandb.name=duo_base_1_5M_d3pm_like_cosine \ hydra.run.dir=./outputs/cifar10/duo_base_1_5M_d3pm_like_cosine \ ================================================ FILE: scripts/train_cifar10_duo_cosine.sh ================================================ python -u -m main \ data=cifar10 \ data.cache_dir= \ model=unet \ algo=duo_base \ algo.backbone=unet \ noise=cosine \ loader.global_batch_size=128 \ loader.batch_size=32 \ loader.eval_batch_size=32 \ loader.num_workers=8 \ trainer.val_check_interval=2500 \ trainer.max_steps=1_500_000 \ lr_scheduler.num_warmup_steps=5000 \ eval.generate_samples=False \ optim.lr=2e-4 \ callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ wandb.name=duo_base_1_5M_d3pm_like_cosine \ hydra.run.dir=./outputs/cifar10/duo_base_1_5M_d3pm_like_cosine \ ================================================ FILE: scripts/train_cifar10_mdlm_cosine.sh ================================================ python -u -m main \ data=cifar10 \ data.cache_dir= \ model=unet \ algo=mdlm \ algo.backbone=unet \ noise=cosine \ loader.global_batch_size=128 \ loader.batch_size=32 \ loader.eval_batch_size=32 \ loader.num_workers=8 \ trainer.val_check_interval=2500 \ trainer.max_steps=1_500_000 \ lr_scheduler.num_warmup_steps=5000 \ eval.generate_samples=False \ optim.lr=2e-4 \ callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \ wandb.name=mdlm_1_5M_d3pm_like_cosine \ hydra.run.dir=./outputs/cifar10/mdlm_1_5M_d3pm_like_cosine \ ================================================ FILE: scripts/train_lm1b_ar.sh ================================================ #!/bin/bash #SBATCH -J train_ar_lm1b # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:8 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. srun python -u -m main \ loader.batch_size=64 \ loader.eval_batch_size=64 \ algo=ar \ data=lm1b \ wandb.name=ar-lm1b-small \ model=small \ model.length=128 ================================================ FILE: scripts/train_lm1b_ar_sentencepacking.sh ================================================ #!/bin/bash #SBATCH -J train_ar_lm1b # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=2 #SBATCH --gres=gpu:2 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. srun python -u -m main \ loader.batch_size=256 \ loader.eval_batch_size=256 \ algo=ar \ data=lm1b-wrap \ wandb.name=ar-lm1b-wrap-small \ model=small \ model.length=128 ================================================ FILE: scripts/train_lm1b_d3pm.sh ================================================ #!/bin/bash #SBATCH -J train_d3pm # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:8 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. srun python -u -m main \ loader.batch_size=64 \ loader.eval_batch_size=64 \ model=small \ data=lm1b \ wandb.name=d3pm-lm1b \ algo=d3pm \ model.length=128 \ eval.compute_generative_perplexity=False ================================================ FILE: scripts/train_lm1b_duo.sh ================================================ #!/bin/bash #SBATCH -J duo-lm1b # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=64000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:8 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. # Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed srun python -u -m main \ loader.batch_size=64 \ loader.eval_batch_size=64 \ data=lm1b \ wandb.name=duo-lm1b \ model=small \ algo=duo \ model.length=128 \ algo.curriculum.mode=simple \ algo.curriculum.gumbel_tau_log10_start=-3.0 \ algo.curriculum.gumbel_tau_log10_end=-3.0 \ algo.curriculum.gamma_min=-3.5 \ algo.curriculum.gamma_max=-1.75 \ algo.curriculum.start=0 \ algo.curriculum.end=500000 ================================================ FILE: scripts/train_lm1b_duo_sentencepacking.sh ================================================ #!/bin/bash #SBATCH -J duo-lm1b # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=64000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:8 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. # Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed srun python -u -m main \ loader.batch_size=64 \ loader.eval_batch_size=64 \ data=lm1b-wrap \ wandb.name=duo-lm1b \ model=small \ algo=duo \ model.length=128 \ algo.curriculum.mode=simple \ algo.curriculum.gumbel_tau_log10_start=-3.0 \ algo.curriculum.gumbel_tau_log10_end=-3.0 \ algo.curriculum.gamma_min=-3.5 \ algo.curriculum.gamma_max=-1.75 \ algo.curriculum.start=0 \ algo.curriculum.end=500000 ================================================ FILE: scripts/train_lm1b_mdlm.sh ================================================ #!/bin/bash #SBATCH -J lm1b_mdlm # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=2 #SBATCH --gres=gpu:2 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. srun python -u -m main \ loader.batch_size=64 \ loader.eval_batch_size=64 \ data=lm1b \ wandb.name=mdlm-lm1b \ model=small \ algo=mdlm \ model.length=128 ================================================ FILE: scripts/train_lm1b_mdlm_sentencepacking.sh ================================================ #!/bin/bash #SBATCH -J lm1b_mdlm # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=2 #SBATCH --gres=gpu:2 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. srun python -u -m main \ loader.batch_size=64 \ loader.eval_batch_size=64 \ data=lm1b-wrap \ wandb.name=mdlm-lm1b-wrap-small \ model=small \ algo=mdlm \ model.length=128 ================================================ FILE: scripts/train_owt_duo.sh ================================================ #!/bin/bash #SBATCH -J duo-lm1b # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=64000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:8 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. # Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed srun python -u -m main \ loader.batch_size=32 \ loader.eval_batch_size=32 \ data=openwebtext-split \ wandb.name=duo-owt \ model=small \ algo=duo \ model.length=1024 \ algo.curriculum.mode=simple \ algo.curriculum.gumbel_tau_log10_start=-3.0 \ algo.curriculum.gumbel_tau_log10_end=-3.0 \ algo.curriculum.gamma_min=-3.55 \ algo.curriculum.gamma_max=-1.85 \ algo.curriculum.start=0 \ algo.curriculum.end=500000 ================================================ FILE: scripts/train_owt_duo_finetune.sh ================================================ #!/bin/bash #SBATCH -J duo-base # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=64000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=8 #SBATCH --gres=gpu:8 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. finetune_path=/path/to/intermediate_duo_500k.ckpt # Assuming the finetune_path corresponds to the DUO model # trained for 500K steps with curriculum learning, we train the # model for 500K more steps. srun python -u -m main \ loader.batch_size=64 \ loader.eval_batch_size=64 \ data=openwebtext-split \ wandb.name=duo-owt-finetune \ model=small \ algo=duo_base \ model.length=1024 \ wandb.name=duo-base \ training.finetune_path=$finetune_path \ sampling.num_sample_batches=0 \ trainer.max_steps=500000 ================================================ FILE: scripts/train_owt_mdlm.sh ================================================ #!/bin/bash #SBATCH -J train_mdlm # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. srun python -u -m main \ loader.batch_size=2 \ loader.eval_batch_size=2 \ model=small \ data=openwebtext-split \ wandb.name=mdlm-owt \ algo=mdlm \ model.length=1024 \ eval.compute_generative_perplexity=False \ +wandb.offline=True ================================================ FILE: scripts/train_owt_sedd.sh ================================================ #!/bin/bash #SBATCH -J train_sedd # Job name #SBATCH -o watch_folder/%x_%j.out # output file (%j expands to jobID) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --constraint="gpu-mid|gpu-high" #SBATCH --ntasks-per-node=4 #SBATCH --gres=gpu:4 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon pre-emption # To enable preemption re-loading, set `hydra.run.dir` or # `checkpointing.save_dir` explicitly. srun python -u -m main \ loader.batch_size=16 \ loader.eval_batch_size=16 \ model=small \ data=openwebtext-split \ wandb.name=sedd-owt \ algo=sedd \ model.length=1024 \ eval.compute_generative_perplexity=True \ sampling.predictor=analytic ================================================ FILE: scripts/zero_shot_ar.sh ================================================ #!/bin/bash #SBATCH -J zeroshot_ar # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=gpu # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt export HYDRA_FULL_ERROR=1 datasets=("ag_news" "scientific_papers_pubmed" "scientific_papers_arxiv" "lambada" "wikitext2" "wikitext103" "ptb" "lm1b-gpt2") for data in "${datasets[@]}"; do echo "$data" srun python -u -m main \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ data="$data" \ data.insert_valid_eos=False \ algo=ar \ model.length=1024 \ eval.checkpoint_path=$checkpoint_path \ +wandb.offline=true done ================================================ FILE: scripts/zero_shot_duo.sh ================================================ #!/bin/bash #SBATCH -J zeroshot_duo # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-VlCQLK-small-conjugate-OWT-anneal/checkpoints/last.ckpt export HYDRA_FULL_ERROR=1 datasets=("ag_news" "scientific_papers_pubmed" "scientific_papers_arxiv" "lambada" "wikitext2" "wikitext103" "ptb" "lm1b-gpt2") for data in "${datasets[@]}"; do echo "$data" srun python -u -m main \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ loader.eval_global_batch_size=128 \ data="$data" \ data.insert_valid_eos=False \ model=small \ algo=duo_base \ model.length=1024 \ eval.checkpoint_path=$checkpoint_path \ sampling.num_sample_batches=0 \ +wandb.offline=true done ================================================ FILE: scripts/zero_shot_mdlm.sh ================================================ #!/bin/bash #SBATCH -J zeroshot_mdlm_noeos # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diff-clean-s-owt-no-t-mQ4fQG-param-subs_data-openwebtext-split/checkpoints/61-1000000.ckpt export HYDRA_FULL_ERROR=1 datasets=("ag_news" "scientific_papers_pubmed" "scientific_papers_arxiv" "lambada" "wikitext2" "wikitext103" "ptb" "lm1b-gpt2") for data in "${datasets[@]}"; do echo "$data" srun python -u -m main \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ loader.eval_global_batch_size=128 \ data="$data" \ data.insert_valid_eos=False \ model=small \ algo=mdlm \ model.length=1024 \ eval.checkpoint_path=$checkpoint_path \ +wandb.offline=true done ================================================ FILE: scripts/zero_shot_sedd.sh ================================================ #!/bin/bash #SBATCH -J zeroshot_sedd # Job name #SBATCH -o watch_folder/%x_%j.out # log file (out & err) #SBATCH -N 1 # Total number of nodes requested #SBATCH --get-user-env # retrieve the users login environment #SBATCH --mem=32000 # server memory requested (per node) #SBATCH -t 960:00:00 # Time limit (hh:mm:ss) #SBATCH --partition=kuleshov # Request partition #SBATCH --constraint="[a5000|a6000|a100|3090]" #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:1 # Type/number of GPUs needed #SBATCH --open-mode=append # Do not overwrite logs #SBATCH --requeue # Requeue upon preemption checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt export HYDRA_FULL_ERROR=1 datasets=("ag_news" "scientific_papers_pubmed" "scientific_papers_arxiv" "lambada" "wikitext2" "wikitext103" "ptb" "lm1b-gpt2") for data in "${datasets[@]}"; do echo "$data" srun python -u -m main \ mode=ppl_eval \ loader.batch_size=16 \ loader.eval_batch_size=16 \ loader.eval_global_batch_size=128 \ data="$data" \ data.insert_valid_eos=False \ model=small \ algo=sedd \ model.length=1024 \ eval.checkpoint_path=$checkpoint_path \ +wandb.offline=true done ================================================ FILE: trainer_base.py ================================================ import itertools from dataclasses import dataclass import hydra.utils import lightning as L import numpy as np import torch import torch.nn.functional as F import transformers import dataloader import metrics import models import utils @dataclass class Loss: loss: torch.FloatTensor nlls: torch.FloatTensor prior_loss: torch.FloatTensor num_tokens: torch.FloatTensor class LogLinear(torch.nn.Module): def __init__(self, eps): super().__init__() self.eps = eps # 1e-3 by default, to be consistent with SEDD: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion/blob/0605786da5ccb5747545e26d66fdf477187598b6/noise_lib.py#L56 def forward(self, t): t = (1 - self.eps) * t alpha_t = 1 - t dalpha_t = - torch.ones_like(alpha_t) * (1 - self.eps) return dalpha_t, alpha_t def get_t_for_alpha(self, alpha_t): return 1 - alpha_t class Cosine(torch.nn.Module): def __init__(self, eps): super().__init__() self.eps = eps self.half_pi = torch.pi / 2 def forward(self, t): t = (1 - self.eps) * t alpha_t = 1 - torch.cos(self.half_pi * (1 - t)) dalpha_t = - torch.sin(self.half_pi * (1 - t)) * self.half_pi return dalpha_t, alpha_t def get_t_for_alpha(self, alpha_t): is_tensor = torch.is_tensor(alpha_t) if not is_tensor: alpha_t = torch.tensor([alpha_t]) t = 1 - 2 / torch.pi * torch.acos(1 - alpha_t) if not is_tensor: t = t.cpu().item() return t def sample_categorical(categorical_probs): gumbel_norm = ( 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()) return (categorical_probs / gumbel_norm).argmax(dim=-1) def _unsqueeze(x, reference): return x.view( * x.shape, * ((1,) * (len(reference.shape) - len(x.shape)))) class TrainerBase(L.LightningModule): def __init__( self, config, tokenizer: transformers.PreTrainedTokenizer, vocab_size=None): super().__init__() self.save_hyperparameters() self.config = config if hasattr(self.config.algo, 'ignore_bos'): self.ignore_bos = config.algo.ignore_bos else: self.ignore_bos = False if hasattr(self.config.algo, 'loss_type'): self.loss_type = config.algo.loss_type self.tokenizer = tokenizer if vocab_size is None: self.vocab_size = len(self.tokenizer) else: self.vocab_size = vocab_size self.sampler = self.config.sampling.predictor self.antithetic_sampling = self.config.training.antithetic_sampling self.parameterization = self.config.algo.parameterization if self.config.algo.backbone == 'dit': self.backbone = models.dit.DIT( self.config, vocab_size=self.vocab_size) elif self.config.algo.backbone == 'unet': self.backbone = models.unet.UNet(self.config, self.vocab_size) elif self.config.algo.backbone == 'dimamba': self.backbone = models.dimamba.DiMamba( self.config, vocab_size=self.vocab_size, pad_token_id=self.tokenizer.pad_token_id) elif self.config.algo.backbone == 'hf_dit': self.backbone = transformers.AutoModelForMaskedLM.from_pretrained( config.eval.checkpoint_path, trust_remote_code=True) self.T = self.config.algo.T self.num_tokens = self.config.model.length self.p_nucleus = self.config.sampling.p_nucleus # Noise Schedule if config.noise.type == 'log-linear': self.noise = LogLinear(config.noise.eps) elif config.noise.type == 'cosine': self.noise = Cosine(config.noise.eps) else: raise ValueError(config.noise.type) # Class-conditional training arguments self.num_classes = config.data.get('num_classes', None) self.class_conditional = self.num_classes is not None self.class_cond_dropout = config.training.class_dropout_p self.metrics = metrics.Metrics( gen_ppl_eval_model_name_or_path=\ self.config.eval.gen_ppl_eval_model_name_or_path, eval_ppl_batch_size=\ self.config.eval.perplexity_batch_size) if self.config.training.ema > 0: self.ema = models.ema.ExponentialMovingAverage( self._get_parameters(), decay=self.config.training.ema) else: self.ema = None self.lr = self.config.optim.lr self.sampling_eps = self.config.training.sampling_eps self.time_conditioning = self.config.algo.time_conditioning self.neg_infinity = -1000000.0 self.fast_forward_epochs = None self.fast_forward_batches = None def _validate_configuration(self): assert self.config.algo.backbone in {'dit', 'hf_dit', 'unet'} if self.config.algo.parameterization == 'ar': assert not self.config.algo.time_conditioning assert self.config.prior.type == 'none' if self.time_conditioning: assert self.config.sampling.predictor != 'ancestral_cache' if self.parameterization in {'score', 'mean'}: assert self.time_conditioning if self.T > 0: assert self.parameterization != 'score' def to(self, *args, **kwargs): self = super().to(*args, **kwargs) self.metrics.to(*args, **kwargs) return self def q_xt(self, x, alpha_t): raise NotImplementedError def _get_parameters(self): return itertools.chain(self.backbone.parameters(), self.noise.parameters()) def _eval_mode(self): if self.ema: self.ema.store(self._get_parameters()) self.ema.copy_to(self._get_parameters()) self.backbone.eval() self.noise.eval() def _train_mode(self): if self.ema: self.ema.restore(self._get_parameters()) self.backbone.train() self.noise.train() def on_load_checkpoint(self, checkpoint): if self.ema: self.ema.load_state_dict(checkpoint['ema']) # Copied from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41 self.fast_forward_epochs = checkpoint['loops'][ 'fit_loop']['epoch_progress']['current']['completed'] self.fast_forward_batches = checkpoint['loops'][ 'fit_loop']['epoch_loop.batch_progress'][ 'current']['completed'] def on_save_checkpoint(self, checkpoint): if self.ema: checkpoint['ema'] = self.ema.state_dict() # Copied from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py # ['epoch_loop.batch_progress']['total']['completed'] # is 1 iteration behind, so we're using the optimizer's progress. checkpoint['loops']['fit_loop'][ 'epoch_loop.batch_progress']['total'][ 'completed'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['total'][ 'completed'] * self.trainer.accumulate_grad_batches checkpoint['loops']['fit_loop'][ 'epoch_loop.batch_progress']['current'][ 'completed'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['current'][ 'completed'] * self.trainer.accumulate_grad_batches # _batches_that_stepped tracks the number of global steps, # not the number of local steps, so we don't multiply with # self.trainer.accumulate_grad_batches here. checkpoint['loops']['fit_loop'][ 'epoch_loop.state_dict'][ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][ 'epoch_loop.automatic_optimization.optim_progress'][ 'optimizer']['step']['total']['completed'] if 'sampler' not in checkpoint.keys(): checkpoint['sampler'] = {} if hasattr(self.trainer.train_dataloader.sampler, 'state_dict'): sampler_state_dict = self.trainer.\ train_dataloader.sampler.state_dict() checkpoint['sampler'][ 'random_state'] = sampler_state_dict.get( 'random_state', None) else: checkpoint['sampler']['random_state'] = None def on_train_start(self): if self.ema: self.ema.move_shadow_params_to_device(self.device) # Adapted from: # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py distributed = ( self.trainer._accelerator_connector.use_distributed_sampler and self.trainer._accelerator_connector.is_distributed) if distributed: sampler_cls = dataloader.FaultTolerantDistributedSampler else: sampler_cls = dataloader.RandomFaultTolerantSampler updated_dls = [] for dl in self.trainer.fit_loop._combined_loader.flattened: if hasattr(dl.sampler, 'shuffle'): dl_sampler = sampler_cls( dl.dataset, shuffle=dl.sampler.shuffle) else: dl_sampler = sampler_cls(dl.dataset) if (distributed and self.fast_forward_epochs is not None and self.fast_forward_batches is not None): dl_sampler.load_state_dict({ 'epoch': self.fast_forward_epochs, 'counter': (self.fast_forward_batches * self.config.loader.batch_size)}) updated_dls.append( torch.utils.data.DataLoader( dl.dataset, batch_size=self.config.loader.batch_size, num_workers=self.config.loader.num_workers, pin_memory=self.config.loader.pin_memory, sampler=dl_sampler, shuffle=False, persistent_workers=True)) self.trainer.fit_loop._combined_loader.flattened = updated_dls def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) if self.ema: self.ema.update(self._get_parameters()) def _process_sigma(self, sigma): raise NotImplementedError def _process_model_output(self, model_output, xt, sigma): raise NotImplementedError def forward(self, xt, sigma, labels=None, weights=None, nn_input_idxs=None): if nn_input_idxs is None: nn_input_idxs = xt sigma = self._process_sigma(sigma) with torch.amp.autocast('cuda', dtype=torch.float32): model_output = self.backbone( x=nn_input_idxs, sigma=sigma, class_cond=labels, weights=weights) return self._process_model_output( model_output=model_output, xt=xt, sigma=sigma) def on_train_epoch_start(self): self.metrics.reset() assert self.metrics.train_nlls.nll.mean_value == 0 assert self.metrics.train_nlls.nll.weight == 0 def training_step(self, batch, batch_idx): current_accumulation_step = ( batch_idx % self.trainer.accumulate_grad_batches) losses = self._loss(batch['input_ids'], batch.get('labels', None), batch['attention_mask'], current_accumulation_step, train_mode=True) self.metrics.update_train(losses.nlls, losses.prior_loss, losses.num_tokens) self.log(name='trainer/loss', value=losses.loss.item(), on_step=True, on_epoch=False, sync_dist=True) return losses.loss def on_train_epoch_end(self): for k, v in self.metrics.valid_nlls.items(): self.log(name=k, value=v.compute(), on_step=False, on_epoch=True, sync_dist=True) def on_validation_epoch_start(self): self.metrics.reset() self._eval_mode() assert self.metrics.valid_nlls.nll.mean_value == 0 assert self.metrics.valid_nlls.nll.weight == 0 def validation_step(self, batch, batch_idx): del batch_idx losses = self._loss(batch['input_ids'], batch.get('labels', None), batch['attention_mask']) self.metrics.update_valid(losses.nlls, losses.prior_loss, losses.num_tokens) return losses.loss def on_validation_epoch_end(self): for k, v in self.metrics.valid_nlls.items(): self.log(name=k, value=v.compute(), on_step=False, on_epoch=True, sync_dist=True) if ((self.config.eval.compute_perplexity_on_sanity or not self.trainer.sanity_checking) and self.config.eval.generate_samples): samples, text_samples = None, None for _ in range( self.config.sampling.num_sample_batches): samples = self.generate_samples( num_samples=self.config.loader.eval_batch_size) self.metrics.record_entropy(samples) # Decode the samples to be re-tokenized by eval model text_samples = self.tokenizer.batch_decode(samples) if self.config.eval.compute_generative_perplexity: self.metrics.record_generative_perplexity( text_samples, self.num_tokens, self.device) if text_samples is not None: if self.trainer.global_rank == 0 and hasattr( self.trainer.logger, 'log_table'): # Log the last generated samples text_samples = text_samples[ : self.config.sampling.num_sample_log] self.trainer.logger.log_table( key=f'samples@global_step{self.global_step}', columns=['Generated Samples'], data=[[s] for s in text_samples]) if self.config.eval.compute_generative_perplexity: self.log('val/gen_ppl', self.metrics.gen_ppl.compute(), on_epoch=True, on_step=False, sync_dist=True) self.log('val/sample_entropy', self.metrics.sample_entropy.compute(), on_epoch=True, on_step=False, sync_dist=True) self._train_mode() def configure_optimizers(self): optimizer = torch.optim.AdamW( self._get_parameters(), lr=self.config.optim.lr, betas=(self.config.optim.beta1, self.config.optim.beta2), eps=self.config.optim.eps, weight_decay=self.config.optim.weight_decay) scheduler = hydra.utils.instantiate( self.config.lr_scheduler, optimizer=optimizer) scheduler_dict = {'scheduler': scheduler, 'interval': 'step', 'monitor': 'val/loss', 'name': 'trainer/lr'} return [optimizer], [scheduler_dict] def generate_samples(self, num_samples, num_steps, eps): raise NotImplementedError def restore_model_and_sample(self, num_steps, eps=1e-5): """Generate samples from the model.""" # Lightning auto-casting is not working in this method for some reason self._eval_mode() samples = self.generate_samples( num_samples=self.config.loader.eval_batch_size, num_steps=num_steps, eps=eps) self._train_mode() return samples def _process_model_input(self, x0, valid_tokens): raise NotImplementedError def nll(self, input_tokens, labels, output_tokens, current_accumulation_step=None, train_mode=False): raise NotImplementedError def _loss(self, x0, labels, valid_tokens, current_accumulation_step=None, train_mode=False): (input_tokens, output_tokens, valid_tokens) = self._process_model_input( x0, valid_tokens) loss = self.nll(input_tokens, labels, output_tokens, current_accumulation_step, train_mode) assert loss.ndim == 2 if self.ignore_bos: loss[:, 1:] = loss[:, 1:] valid_tokens[:, 1:] = valid_tokens[:, 1:] nlls = (loss * valid_tokens).sum() num_tokens = valid_tokens.sum() token_nll = nlls / num_tokens return Loss(loss=token_nll, nlls=nlls, prior_loss=0.0, num_tokens=num_tokens) class Diffusion(TrainerBase): def _validate_configuration(self): super()._validate_configuration() assert self.config.sampling.noise_removal in { 'none', 'ancestral', 'greedy'} assert self.loss_type in {'elbo', 'low_var'} if self.config.sampling.noise_removal == 'greedy': assert self.sampler != 'analytic' assert self.parameterization in {'mean', 'subs'} def _process_model_input(self, x0, valid_tokens): return x0, None, valid_tokens def _process_sigma(self, sigma): assert sigma.ndim == 2 sigma = sigma.mean(-1).squeeze() if sigma.ndim == 0: sigma = sigma.unsqueeze(0) if not self.time_conditioning: sigma = torch.zeros_like(sigma) assert sigma.ndim == 1, sigma.shape return sigma def _sample_t(self, n, accum_step): if accum_step is not None: # During training batch_dim = n n = self.config.loader.global_batch_size _eps_t = torch.rand(n, device=self.device) if self.antithetic_sampling: offset = torch.arange(n, device=self.device) / n _eps_t = (_eps_t / n + offset) % 1 t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps if accum_step is not None: t = t.chunk(self.trainer.num_nodes)[self.trainer.node_rank] t = t.chunk(self.trainer.num_devices)[self.trainer.local_rank] t = t.chunk(self.trainer.accumulate_grad_batches)[ accum_step] # corner case for the last datapoint t = t[:batch_dim] return t def _sigma_from_alphat(self, alpha_t): return -torch.log(alpha_t) def _reconstruction_loss(self, x0): t0 = torch.zeros(1, x0.shape[0], dtype=self.dtype, device=self.device) sigma_t0 = self._sigma_from_alphat(self.noise(t0)[1]) model_output_t0 = self.forward(x0, sigma_t0) return - torch.gather(input=model_output_t0, dim=-1, index=x0[:, :, None]).squeeze(-1) def nll_per_token(self, model_output, xt, x0, alpha_t, dalpha_t, low_var): raise NotImplementedError def nll(self, x0, labels, output_tokens, current_accumulation_step=None, train_mode=False): del output_tokens t = self._sample_t(x0.shape[0], current_accumulation_step) assert t.shape[0] == x0.shape[0] if self.T > 0: t = (t * self.T).to(torch.int) t = t / self.T # t \in {1/T, 2/T, ..., 1} t += (1 / self.T) dalpha_t, alpha_t = self.noise(t) alpha_t = alpha_t.unsqueeze(-1) dalpha_t = dalpha_t.unsqueeze(-1) assert alpha_t.ndim == 2 assert dalpha_t.ndim == 2 sigma = self._sigma_from_alphat(alpha_t) xt = self.q_xt(x0, alpha_t) # Handle class-conditional training, with class dropout if self.class_conditional: assert labels is not None rand = torch.rand(size=labels.shape, dtype=torch.float32, device=self.device) # num_classes represent the absence of class-conditioning labels = torch.where(rand < self.class_cond_dropout, self.num_classes, labels) else: assert labels is None log_x_theta = self.forward(xt, sigma=sigma, labels=labels) utils.print_nans(log_x_theta, 'model_output') return self.nll_per_token( log_x_theta=log_x_theta, xt=xt, x0=x0, alpha_t=alpha_t, dalpha_t=dalpha_t, low_var=train_mode and self.loss_type == 'low_var') def _get_score(self, **kwargs): del kwargs raise NotImplementedError def _denoiser_update(self, x, t): raise NotImplementedError def _analytic_update(self, x, t, dt): raise NotImplementedError def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t): """From the clean x0, or denoiser predictions, implement the mathematical expression for q(z_s | z_t, x)""" raise NotImplementedError def _forward_process(self, q_x0, alpha_s): """Apply forward noising: q(x_s | x_0), where x_0 is a K-dimensional vector""" raise NotImplementedError def _get_ancestral_posterior(self, xt, sigma, labels, alpha_s, alpha_t, p_x0): """Top-level call from generate_samples. Compute the standard posterior sampling distribution from the noisy sequence xt""" gamma = self.config.sampling.guid_weight # If class cond but gamma is None or 0, same as sampling # from the class unconditional case. if not self.class_conditional or gamma in (None, 0.0): if labels is not None: labels = torch.full_like(labels, self.num_classes) return self._get_posterior_from_xt(xt, sigma, labels, alpha_s, alpha_t, p_x0) elif gamma == 1.0: # Simply sample from the cond. posterior only. assert labels is not None return self._get_posterior_from_xt(xt, sigma, labels, alpha_s, alpha_t, p_x0) else: # Case gamma not in {0, 1}, and class conditional # -> mix conditional and unconditional predictions. return self._get_guided_posterior_from_xt(self, xt, sigma, labels, gamma, alpha_s, alpha_t, p_x0) def _get_posterior_from_xt(self, xt, sigma, labels, alpha_s, alpha_t, p_x0=None): """From xt, compute a single denoiser predictions, cast to correct dtype, and call _posterior_from_x0.""" if p_x0 is None: log_x0_pred = self.forward(xt, sigma, labels) if self.config.sampling.use_float64: log_x0_pred = log_x0_pred.to(torch.float64) if self.p_nucleus < 1: log_x0_pred = utils.top_k_top_p_filtering( log_x0_pred, top_p=self.p_nucleus) p_x0 = log_x0_pred.exp() return p_x0, self._posterior_from_x0(x0=p_x0,xt=xt, alpha_s=alpha_s, alpha_t=alpha_t) def _get_guided_posterior_from_xt(self, xt, sigma, labels, gamma, alpha_s, alpha_t, p_x0=None): """From xt, combine the class cond / uncond predictions of the denoiser. Call _get_posterior_from_xt.""" # unpack the cache if p_x0 is None: p_x0_cond = p_x0_uncond = None else: p_x0_cond, p_x0_uncond = p_x0 p_x0_cond, cond_posterior = self._get_posterior_from_xt( xt, sigma, labels, alpha_s, alpha_t, p_x0_cond) log_cond_posterior = cond_posterior.log() # NOTE: conditioning on self.num_classes represents the # class-unconditional predictions. p_x0_uncond, uncond_posterior = self._get_posterior_from_xt( xt, sigma, torch.full_like(labels, self.num_classes), alpha_s, alpha_t, p_x0_uncond) log_uncond_posterior = uncond_posterior.log() un_normalized_posterior = gamma * log_cond_posterior \ + (1 - gamma) * log_uncond_posterior # Handle cases where the posterior is zero (eg after # nucleus sampling) is_inf_mask = torch.logical_or( log_cond_posterior.isinf(), log_uncond_posterior.isinf()) un_normalized_posterior[is_inf_mask] = self.neg_infinity return ((p_x0_cond, p_x0_uncond), un_normalized_posterior.softmax(-1)) def _ancestral_update(self, x, t, labels, dt, p_x0=None, noise_removal_step=False): _, alpha_t = self.noise(t) if noise_removal_step: alpha_s = torch.ones_like(alpha_t) else: _, alpha_s = self.noise(t - dt) assert alpha_t.ndim == 2 sigma = self._sigma_from_alphat(alpha_t) assert alpha_t.ndim == 2, f'{alpha_t.ndim=}' p_x0, q_xs = self._get_ancestral_posterior(x, sigma, labels, alpha_s, alpha_t, p_x0) return p_x0, sample_categorical(q_xs) def _psi_update(self, x, t, labels, dt, kappa, p_x0=None, noise_removal_step=False): _, alpha_t = self.noise(t) if noise_removal_step: alpha_s = torch.ones_like(alpha_t) else: _, alpha_s = self.noise(t - dt) alpha_0 = torch.ones_like(alpha_t) sigma = self._sigma_from_alphat(alpha_t) # Standard posterior q(x_s | x_t) p_x0, q_xs = self._get_ancestral_posterior( x, sigma, labels, alpha_s, alpha_t, p_x0) # Posterior targeting t=0, reuse predictions p_x0 _, q_x0 = self._get_ancestral_posterior( x, sigma, labels, alpha_0, alpha_t, p_x0) # PC: forward-noise q_x0 back to time s pc_q_xs = self._forward_process(q_x0, alpha_s) q_sample = kappa * q_xs + (1 - kappa) * pc_q_xs return p_x0, sample_categorical(q_sample) def _get_sampling_time_profile(self, eps, num_steps): profile = self.config.sampling.psi.time_profile num_steps += 1 if profile == 'linear' \ or self.config.sampling.predictor != 'psi': # Default: linearly decrease return torch.linspace(1, eps, num_steps) if not profile.startswith('linear-constant-linear'): raise ValueError(profile) c = float(profile.split('-')[3]) if 'inv' in profile: c = self.noise.get_t_for_alpha(c) psi_cfg = self.config.sampling.psi n_hi = round(psi_cfg.high_frac * num_steps) n_mid = round(psi_cfg.middle_frac * num_steps) return torch.cat([ torch.linspace(1, c, n_hi), torch.full((n_mid,), c), torch.linspace(c, eps, num_steps - n_hi - n_mid)]) def _mode_to_psi_kappas(self, mode, timesteps): n = len(timesteps) if mode == 'pure-posterior': return torch.ones(n) if mode == 'pure-pc': return torch.zeros(n) eta = float(mode.split('-')[-1]) if (mode.startswith('constant-') and not mode.startswith('constant-remdm')): return torch.full((n,), eta) # Noise-schedule-dependent modes (ReMDM-like) _, all_alphas = self.noise(timesteps) alpha_t, alpha_s = all_alphas[:-1], all_alphas[1:] eta_t = torch.tensor([eta]) if mode.startswith('max-capped-'): sigma = torch.minimum( eta_t.expand_as(alpha_t), (1 - alpha_s) / alpha_t) sigma = torch.where(alpha_t == 0, eta, sigma) elif mode.startswith('max-rescale-'): sigma_max = torch.minimum( eta_t.expand_as(alpha_t), (1 - alpha_s) / alpha_t) sigma = torch.where(alpha_t > 0, sigma_max, 1) * eta elif mode.startswith('constant-remdm'): sigma = eta_t else: raise ValueError(mode) kappas = torch.clip(1 - sigma / (1 - alpha_s), 0, 1) if len(kappas) > 0: kappas = torch.cat([kappas, torch.ones(1)]) return kappas def _get_kappas(self, timesteps): cfg = self.config.sampling.psi n = len(timesteps) n_hi = round(cfg.high_frac * n) n_mid = round(cfg.middle_frac * n) kappas = torch.cat([ self._mode_to_psi_kappas(cfg.high_mode, timesteps[:n_hi]), self._mode_to_psi_kappas(cfg.middle_mode, timesteps[n_hi:n_hi + n_mid]), self._mode_to_psi_kappas(cfg.low_mode, timesteps[n_hi + n_mid:])]) assert (kappas >= 0).all() and (kappas <= 1).all() assert len(kappas) == n return kappas @torch.no_grad() def generate_samples(self, num_samples, labels=None, num_steps=None, eps=1e-5): """Generate samples from the model.""" # Lightning auto-casting is not working in this method for some reason if num_steps is None: num_steps = self.config.sampling.steps x = self.prior_sample(num_samples, self.num_tokens) use_psi_sampler = self.config.sampling.predictor == 'psi' timesteps = self._get_sampling_time_profile(eps, num_steps) if use_psi_sampler: kappas = self._get_kappas(timesteps).to(self.device) p_x0_cache = None if labels is not None: labels = labels.to(self.device) for i in range(num_steps): t = timesteps[i] * torch.ones( x.shape[0], 1, device=self.device) dt = timesteps[i] - timesteps[i+1] if self.sampler == 'ancestral': _, x = self._ancestral_update( x=x, t=t, labels=labels, dt=dt, p_x0=None) elif self.sampler == 'ancestral_cache': p_x0_cache, x_next = self._ancestral_update( x=x, t=t, labels=labels, dt=dt, p_x0=p_x0_cache) if (not torch.allclose(x_next, x) or self.time_conditioning): # Disable caching p_x0_cache = None x = x_next elif self.sampler == 'psi': _, x = self._psi_update(x=x, t=t, kappa=kappas[i], labels=labels, dt=dt, p_x0=None) elif self.sampler == 'analytic': assert labels is None, 'class-conditional sampling ' \ 'is not implemented with the analytic sampler' x = self._analytic_update(x=x,t=t, dt=dt) else: raise ValueError(self.sampler) t0 = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device) if self.config.sampling.noise_removal == 'ancestral': if self.sampler == 'analytic': x = self._denoiser_update(x=x, t=t0) else: _, x = self._ancestral_update(x=x, t=t0, labels=labels, dt=None, p_x0=p_x0_cache, noise_removal_step=True) elif self.config.sampling.noise_removal == 'greedy': sigma = self._sigma_from_alphat(self.noise(t0)[1]) x = self.forward(xt=x, sigma=sigma).argmax(dim=-1) return x @torch.no_grad def _semi_ar_sampler( self, n_samples, stride_length, num_strides, dt=0.001): # TODO(subham): Test this method after refactoring. ones = torch.ones(n_samples, dtype=self.dtype, device=self.device) num_steps = int(1 / dt) sampling_steps = 0 intermediate_tokens = [] target = None for _ in range(num_strides + 1): p_x0_cache = None x = self.prior_sample(n_samples, self.num_tokens) if target is not None: x[:, : -stride_length] = target for i in range(num_steps + 1): p_x0_cache, x_next = self._ancestral_update( x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache) if (not torch.allclose(x_next, x) or self.time_conditioning): p_x0_cache = None sampling_steps += 1 x = x_next x = self.forward(x, 0 * ones).argmax(dim=-1) intermediate_tokens.append( x[:, :stride_length].cpu().numpy()) target = x[:, stride_length:] intermediate_tokens.append(target.cpu().numpy()) intermediate_text_samples = [] sequence_lengths = (( np.concatenate(intermediate_tokens, axis=1)[:, 1:] == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1) for i in range(2, len(intermediate_tokens) + 1): intermediate_text_samples.append( self.tokenizer.batch_decode( np.concatenate(intermediate_tokens[:i], axis=1))) return (sampling_steps, intermediate_text_samples, sequence_lengths) def restore_model_and_semi_ar_sample( self, stride_length, num_strides, dt=0.001): """Generate samples from the model.""" # Lightning auto-casting is not working in this method for some reason # TODO(subham): Test this method after refactoring. self._eval_mode() (sampling_steps, samples, sequence_lengths) = self._semi_ar_sampler( n_samples=self.config.loader.eval_batch_size, stride_length=stride_length, num_strides=num_strides, dt=dt) self._train_mode() return sampling_steps, samples, sequence_lengths class AbsorbingState(Diffusion): def __init__(self, config, tokenizer): # NOTE: Ideally, we should do # vocab_size = len(tokenizer), so that we account # for the special tokens added in dataloader.py. # But we use tokenizer.vocab_size so as to to be # consistent with the prior checkpoints. vocab_size = tokenizer.vocab_size if (not hasattr(tokenizer, 'mask_token') or tokenizer.mask_token is None): self.mask_index = vocab_size vocab_size += 1 else: self.mask_index = tokenizer.mask_token_id self.subs_masking = config.algo.subs_masking super().__init__(config, tokenizer, vocab_size=vocab_size) self.save_hyperparameters() def _validate_configuration(self): super()._validate_configuration() if self.parameterization in {'score', 'mean'}: assert self.time_conditioning assert not (self.parameterization == 'mean' and self.T == 0) if self.T > 0: assert self.parameterization in {'mean', 'subs'} if self.subs_masking: assert self.parameterization == 'mean' def q_xt(self, x, alpha_t): """Computes the noisy sample xt. Args: x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. alpha_t: float torch.Tensor with shape (batch_size, 1). """ move_indices = torch.rand( * x.shape, device=x.device) < 1 - alpha_t xt = torch.where(move_indices, self.mask_index, x) if self.ignore_bos: xt[:, 0] = x[:, 0] return xt def prior_sample(self, *batch_dims): return self.mask_index * torch.ones( * batch_dims, dtype=torch.int64, device=self.device) def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t): """From the clean x0, or denoiser predictions, implement the mathematical expression for q(z_s | z_t, x)""" assert x0.dtype == torch.float64, 'Requires float64 prec.' # should be one-hot on clean tokens orig_mask = xt[:, :, None] != self.mask_index orig_mask = orig_mask.expand(-1, -1, x0.shape[-1]) orig_output_on_clean = x0[orig_mask] q_xs = ((alpha_s - alpha_t) / (1 - alpha_t))[..., None] * x0 q_xs[..., self.mask_index] = (1 - alpha_s) / (1 - alpha_t) q_xs[orig_mask] = orig_output_on_clean return q_xs def _forward_process(self, x0, alpha_s): out = alpha_s[..., None] * x0 out[..., self.mask_index] = 1 - alpha_s return out def _staggered_score(self, score, dsigma): score = score.clone() extra_const = (1 - dsigma.exp()) * score.sum(dim=-1) score *= dsigma.exp()[:, None] score[..., self.mask_index] += extra_const return score def _analytic_update(self, x, t, dt): sigma_t = self._sigma_from_alphat(self.noise(t)[1]) sigma_s = self._sigma_from_alphat(self.noise(t - dt)[1]) dsigma = sigma_t - sigma_s score = self._get_score(x, sigma_t) if self.config.sampling.use_float64: score = score.to(torch.float64) stag_score = self._staggered_score(score, dsigma) probs = stag_score * self._transp_transition(x, dsigma) return sample_categorical(probs) def _denoiser_update(self, x, t): sigma = self._sigma_from_alphat(self.noise(t)[1]) score = self._get_score(x, sigma) if self.config.sampling.use_float64: score = score.to(torch.float64) stag_score = self._staggered_score(score, sigma) probs = stag_score * self._transp_transition(x, sigma) probs[..., self.mask_index] = 0 samples = sample_categorical(probs) return samples def _transp_transition(self, i, sigma): sigma = _unsqueeze(sigma, reference=i[..., None]) edge = torch.exp(-sigma) * F.one_hot( i, num_classes=self.vocab_size) edge += torch.where(i == self.mask_index, 1 - torch.exp(-sigma).squeeze(-1), 0)[..., None] return edge class UniformState(Diffusion): def _validate_configuration(self): super()._validate_configuration() assert self.time_conditioning assert self.parameterization == 'mean' if self.config.algo.name != 'distillation': assert self.T == 0 def _forward_process(self, x0, alpha_s): return (alpha_s[..., None] * x0 + (1 - alpha_s[..., None]) / self.vocab_size) def q_xt(self, x, alpha_t): """Computes the noisy sample xt. Args: x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. move_chance: float torch.Tensor with shape (batch_size, 1). """ move_indices = torch.rand( *x.shape, device=x.device) < 1 - alpha_t uniform_tensor = torch.randint( 0, self.vocab_size, x.shape, device=x.device) xt = torch.where(move_indices, uniform_tensor, x) if self.ignore_bos: xt[:, 0] = x[:, 0] return xt def prior_sample(self, *batch_dims): return torch.randint( 0, self.vocab_size, batch_dims, dtype=torch.int64, device=self.device) ================================================ FILE: utils.py ================================================ """Console logger utilities. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging """ import argparse import logging import os import pickle import time from typing import Optional import fsspec import lightning import numpy as np import torch from scipy.integrate import quad from scipy.optimize import curve_fit from scipy.stats import norm from timm.scheduler import CosineLRScheduler def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def fsspec_exists(filename): """Check if a file exists using fsspec.""" fs, _ = fsspec.core.url_to_fs(filename) return fs.exists(filename) def fsspec_listdir(dirname): """Listdir in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) return fs.ls(dirname) def fsspec_mkdirs(dirname, exist_ok=True): """Mkdirs in manner compatible with fsspec.""" fs, _ = fsspec.core.url_to_fs(dirname) fs.makedirs(dirname, exist_ok=exist_ok) def print_nans(tensor, name): if torch.isnan(tensor).any(): print(name, tensor) class LRHalveScheduler: def __init__(self, warmup_steps, n_halve_steps): self.warmup_steps = warmup_steps self.n_halve_steps = n_halve_steps def __call__(self, current_step): if current_step < self.warmup_steps: return current_step / self.warmup_steps return 0.5 ** ((current_step - self.warmup_steps) // self.n_halve_steps) class CosineDecayWarmupLRScheduler( CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): """Wrap timm.scheduler.CosineLRScheduler Enables calling scheduler.step() without passing in epoch. Supports resuming as well. Adapted from: https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._last_epoch = -1 self.step(epoch=0) def step(self, epoch=None): if epoch is None: self._last_epoch += 1 else: self._last_epoch = epoch # We call either step or step_update, depending on # whether we're using the scheduler every epoch or every # step. # Otherwise, lightning will always call step (i.e., # meant for each epoch), and if we set scheduler # interval to "step", then the learning rate update will # be wrong. if self.t_in_epochs: super().step(epoch=self._last_epoch) else: super().step_update(num_updates=self._last_epoch) class LoggingContext: """Context manager for selective logging.""" def __init__(self, logger, level=None, handler=None, close=True): self.logger = logger self.level = level self.handler = handler self.close = close def __enter__(self): if self.level is not None: self.old_level = self.logger.level self.logger.setLevel(self.level) if self.handler: self.logger.addHandler(self.handler) def __exit__(self, et, ev, tb): if self.level is not None: self.logger.setLevel(self.old_level) if self.handler: self.logger.removeHandler(self.handler) if self.handler and self.close: self.handler.close() class GradientInspectionCallback(lightning.Callback): def __init__(self, num_grads_log): self.num_grads_log = 10 def on_before_optimizer_step(self, trainer, pl_module, optimizer): gradients = [] for name, param in pl_module.backbone.blocks.named_parameters(): gradients.append(param.grad.view(-1)) if gradients: grads = torch.cat((gradients)) if not hasattr(pl_module, 'grad_accum_buffer'): pl_module.grad_step = torch.tensor( 0, device=pl_module.device) pl_module.grad_accum_buffer = torch.zeros( self.num_grads_log, grads.shape[0], device=pl_module.device) pl_module.grad_accum_buffer[pl_module.grad_step] = grads pl_module.grad_step += 1 if (hasattr(pl_module, 'grad_accum_buffer') and pl_module.grad_step == self.num_grads_log): grads = pl_module.grad_accum_buffer grad_var = grads.std(0).mean() pl_module.log(name='trainer/grad_var', value=grad_var.item(), on_step=True, on_epoch=False, sync_dist=True) # import ipdb; ipdb.set_trace() # should save the grads tensor as a numpy array # and visualize mean, median, top-k pl_module.grad_accum_buffer.zero_() pl_module.grad_step = 0 def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" logger = logging.getLogger(name) logger.setLevel(level) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ('debug', 'info', 'warning', 'error', 'exception', 'fatal', 'critical'): setattr(logger, level, lightning.pytorch.utilities.rank_zero_only( getattr(logger, level))) return logger # Copied from https://github.com/jdeschena/sdtt/blob/bbc54d5b3c5fcffd79602cff17ed34dde1f3eff6/src/sdtt/core/sampling/utils.py#L10 def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float("Inf"), dim=-1): """Filter a distribution of logits using top-k/top-p (nucleus) filtering. Adapted from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 Args: logits (Tensor): Tensor of logits top_k (int, optional): Number of top values to keep. Deactivated if k is 0. Defaults to 0. top_p (float, optional): Cumulative mass to retain. Deactivated if p = 0. Defaults to 0.0. filter_value (float, optional): Fill value to replace the entries removed by top-k/top-p filtering. Defaults to -float('Inf'). dim (int, optional): Dimension of the filtering. Defaults to -1. Returns: logits: Tensor whose axis `dim` was filtered. """ if dim != -1: logits = torch.transpose(logits, dim, -1) assert top_k < logits.size(dim) if top_k > 0: # Remove all tokens with a probability less than # the last token of the top-k values, _ = torch.topk(logits, k=top_k, dim=-1) to_remove_mask = ( logits < torch.min(values, dim=-1, keepdim=True)[0] ) # min returns a tuple (values, indices) logits[to_remove_mask] = filter_value if top_p > 0.0: sorted_logits, sorted_indices = torch.sort( logits, descending=True, dim=-1) cum_probs = torch.cumsum( torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cum_probs > top_p # Ensures at least one token is kept sorted_indices_to_remove[..., 1:] = \ sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 mask_to_remove = torch.empty_like( sorted_indices_to_remove) mask_to_remove.scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) logits[mask_to_remove] = filter_value if dim != -1: logits = torch.transpose(logits, dim, -1) # Re-normalize as in ReMDM. Alternatively, # could apply a log-softmax. This assumes that the input # tensor `logits` has been pre-processed with log_softmax. probs = logits.exp() Z = probs.sum(-1, keepdim=True) logits = (probs / Z).log() return logits def _discrete_prob_map(gamma_t, N=10): snr_sqrt = np.exp(-gamma_t / 2) def value(x): cdf = norm.cdf(x, scale=1) ** (N - 1) pdf = norm.pdf(x, loc=snr_sqrt, scale=1) return pdf * cdf return value def _discrete_prob_grad(gamma_t, N=10): snr_sqrt = np.exp(-gamma_t / 2) def value(x): coef = -0.5 * snr_sqrt * (x - snr_sqrt) cdf = norm.cdf(x, scale=1) ** (N - 1) pdf = norm.pdf(x, loc=snr_sqrt, scale=1) return coef * pdf * cdf return value def _cache_prob_usdm_in_partition( vocab_size=30522, partition_index=0, num_partitions=1, log10_num_points=5): print(f'Caching partition:{partition_index} / {num_partitions}') path = 'integral' gamma_min = -5 gamma_max = -1 num_points = 10 ** log10_num_points p_cache = [] grad_p_cache = [] start_time = time.time() gammas = np.linspace(gamma_min, gamma_max, num_points) n = num_points // num_partitions for gamma in gammas[partition_index * n: (partition_index + 1) * n]: pt, _ = quad(_discrete_prob_map(gamma, vocab_size), -np.inf, np.inf) p_cache.append(pt) grad_pt, _ = quad(_discrete_prob_grad(gamma, vocab_size), -np.inf, np.inf) grad_p_cache.append(grad_pt) if len(p_cache) % 100 == 0: print('{}% completed. Time elapsed:{:.2f} mins'.format( int(100 * len(p_cache) / num_points), (time.time() - start_time) / 60)) filename = os.path.join( path, '{}_{}_{}-{}.pkl'.format( vocab_size, log10_num_points, partition_index, num_partitions)) with open(filename, 'wb') as f: pickle.dump({ 'vocab_size': vocab_size, 'gamma_min': gamma_min, 'gamma_max': gamma_max, 'num_points': num_points, 'pt': np.asarray(p_cache), 'grad_pt': np.asarray(grad_p_cache)}, f) def test_cache_prob_usdm_in_partition( partition_index=0, num_partitions=1, vocab_size=30522, log10_num_points=5): path = 'integral/{}_{}_{}-{}.pkl'.format( vocab_size, log10_num_points, partition_index, num_partitions) with open(path, 'rb') as f: data = pickle.load(f) num_points = data['num_points'] def _get_index(x): return round((num_points - 1) * (x - data['gamma_min']) / ( data['gamma_max'] - data['gamma_min'])) pt_errors = [] grad_pt_errors = [] gammas = np.linspace(data['gamma_min'], data['gamma_max'], num_points) n = num_points // num_partitions for gamma in gammas[partition_index * n: (partition_index + 1) * n]: pt, _ = quad( _discrete_prob_map(gamma, data['vocab_size']), -np.inf, np.inf) grad_pt, _ = quad( _discrete_prob_grad(gamma, data['vocab_size']), -np.inf, np.inf) idx = _get_index(gamma) print(idx) pt_errors.append((pt - data['pt'][idx]) ** 2) grad_pt_errors.append((grad_pt - data['grad_pt'][idx]) ** 2) print('Integral MSE:{} Integral Squared:{:.4f}'.format( np.mean(pt_errors), np.mean(data['pt'] ** 2))) print('Integral Grad MSE:{} Integral Grad Squared:{:.4f}'.format( np.mean(grad_pt_errors), np.mean(data['grad_pt'] ** 2))) def compute_duo_series_coefficients(num_coefficients, vocab_size): def integrand_m(z, n, K): z = np.float64(z) return z**n * norm.pdf(z) * norm.cdf(z)**(K-1) def integrand_i(z, n, K): z = np.float64(z) return z ** (n+1) * norm.pdf(z) * norm.cdf(z) ** (K-1) arange = np.cumprod(np.arange(1, num_coefficients ).astype(np.float64)) factorials = np.concatenate([[1.0], arange], dtype=np.float64) lo = np.array([-100], dtype=np.float64) hi = np.array([100], dtype=np.float64) coefficients_m = [] coefficients_i = [] for n in range(num_coefficients): f = lambda z: integrand_m(np.float64(z), np.float64(n), vocab_size) g = lambda z: integrand_i(np.float64(z), np.float64(n), vocab_size) m, _ = quad(f, lo, hi) i, _ = quad(g, lo, hi) coefficients_m.append(m / factorials[n]) coefficients_i.append(i / factorials[n]) return (torch.tensor(coefficients_m)[None], torch.tensor(coefficients_i)[None]) def compute_duo_gamma_to_alpha_dalpha_series( gamma_t, coefficients_m, coefficients_i, power_arange, vocab_size, gamma_min, gamma_max): gamma_t = gamma_t.to(torch.float64)[:, None] sigmoid_neg_gamma = torch.sigmoid(-gamma_t) sigmoid_gamma = torch.sigmoid(gamma_t) alpha_t_squared = sigmoid_neg_gamma alpha_t = alpha_t_squared.sqrt() one_minus_alpha_t_squared = 1 - alpha_t_squared mu_t = alpha_t / one_minus_alpha_t_squared.sqrt() arange = power_arange.to(device=gamma_t.device) mu_t_pow = mu_t ** arange exp_term = (-mu_t**2/2).exp().squeeze(-1) vocab_scale = vocab_size / (vocab_size - 1) # Compute alpha sum_term_alpha = (mu_t_pow * coefficients_m).sum(-1) alpha_usdm = (sum_term_alpha * exp_term - 1 \ / vocab_size) * vocab_scale # Compute alpha' sum_term_dalpha = ( mu_t_pow * (coefficients_i - mu_t * coefficients_m)).sum(-1) dalpha_usdm = exp_term * sum_term_dalpha * vocab_scale final_scale = - (sigmoid_gamma.squeeze(-1) * sigmoid_neg_gamma.squeeze(-1) ** 0.5 * 0.5 * (gamma_max - gamma_min)) dalpha_usdm = dalpha_usdm \ / one_minus_alpha_t_squared.squeeze(-1) ** 1.5 \ * final_scale return alpha_usdm.squeeze(-1), dalpha_usdm.squeeze(-1) def duo_t_to_alpha_dalpha_sigm_corrected( t, a: float, b: float, c: float, d: float, e: float, f: float, alpha: float): # Shared quantities sigm_bc = (torch.tanh(b * t + c) + 1) / 2 sigm_ef = (torch.tanh(e * t + f) + 1) / 2 # Compute alpha_t base = a * sigm_bc + d edge_gate = 1 - 4 * sigm_ef * (1 - sigm_ef) edge_correction = alpha * (t - 0.5) * edge_gate alpha_t = base + edge_correction # Compute d_alpha_t dbase = a * b * sigm_bc * (1 - sigm_bc) dgate = -4 * e * sigm_ef * (1 - sigm_ef) * (1 - 2 * sigm_ef) dcorrection = alpha * edge_gate + alpha * (t - 0.5) * dgate dalpha_t = dbase + dcorrection return alpha_t, dalpha_t def duo_to_alpha_dalpha_sigmoid(t: torch.Tensor, a: float, b: float, c: float, d: float): sigm_bc = (torch.tanh(b * t + c) + 1) / 2 alpha_t = a * sigm_bc + d dalpha_t = a * b * sigm_bc * (1 - sigm_bc) return alpha_t, dalpha_t def duo_to_alpha_dalpha_poly(t: torch.Tensor, *coefficients: float): alpha_t = coefficients[0] # a0 term for i, a in enumerate(coefficients[1:], 1): alpha_t = alpha_t + a * t**i dalpha_t = coefficients[1] # a1 term for i, a in enumerate(coefficients[2:], 2): dalpha_t = dalpha_t + i * a * t**(i-1) return alpha_t, dalpha_t def compute_duo_operator_approx(num_coefficients, vocab_size, gamma_min, gamma_max, fct_name='sigmoid'): series_m, series_i = compute_duo_series_coefficients( num_coefficients, vocab_size) ts = torch.linspace(0, 1, steps=100_000) gammas = gamma_min + ts * (gamma_max - gamma_min) power_arange = torch.arange(num_coefficients, dtype=torch.float64)[None] alpha_approx = compute_duo_gamma_to_alpha_dalpha_series( gammas, series_m, series_i, power_arange, vocab_size, gamma_min, gamma_max)[0].float() t_np = ts.numpy() y_np = alpha_approx.numpy() def sigmoid(x): return (np.tanh(x) + 1) / 2 if fct_name == 'sigmoid': def func(t, a, b, c, d): return a * sigmoid(b * t + c) + d p0 = [0.5, 2.0, -1.0, 0.5] elif fct_name == 'sigmoid-edge-corrected': def func(t, a, b, c, d, e, f, alpha): base = a * sigmoid(b * t + c) + d edge_gate = 1 - 4 * sigmoid(e * t + f) \ * sigmoid(-e * t - f) edge_correction = alpha * (t - 0.5) * edge_gate return base + edge_correction p0 = [0.5, 2.0, -1.0, 0.1, 3.0, 0.0, 0.1] elif fct_name == 'poly3': def func(t, a0, a1, a2, a3): return a0 + a1*t + a2*t**2 + a3*t**3 p0 = [0.1] * 4 elif fct_name == 'poly5': def func(t, a0, a1, a2, a3, a4, a5): return a0 + a1*t + a2*t**2 + a3*t**3 + a4*t**4 + a5*t**5 p0 = [0.1] * 6 elif fct_name == 'poly7': def func(t, a0, a1, a2, a3, a4, a5, a6, a7): return (a0 + a1*t + a2*t**2 + a3*t**3 + a4*t**4 + a5*t**5 + a6*t**6 + a7*t**7) p0 = [0.1] * 8 elif fct_name == 'poly9': def func(t, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9): return (a0 + a1*t + a2*t**2 + a3*t**3 + a4*t**4 + a5*t**5 + a6*t**6 + a7*t**7 + a8*t**8 + a9*t**9) p0 = [0.1] * 10 else: raise ValueError(fct_name) popt, _ = curve_fit(func, t_np, y_np, p0=p0, maxfev=10000) preds = func(t_np, *popt) return list(popt), y_np, preds, t_np def _sample_k_int(bs: int, l: int, k: int, max_value: int, device: torch.device): # Robert Floyd's algorithm: # https://www.nowherenearithaca.com/2013/05/robert-floyds-tiny-and-beautiful.html out = torch.empty(size=(bs, l, k), dtype=torch.int64, device=device) for t, i in enumerate(range(max_value - k, max_value)): j = torch.randint(0, i + 1, size=(bs, l), device=device) if t > 0: # Does j already appear in previously chosen slots? dup = (out[..., :t] == j[..., None]).any(dim=-1) # write j where it is new, otherwise write i out[..., t] = torch.where(dup, i, j) else: out[..., 0] = j return out def _sample_topk_gaussian(N: int, sigma: Optional[torch.Tensor]=None, l: int=0, k: int=0, batch: int=None, device: str=None, dtype: torch.dtype=torch.float64): """ Sample from the order statistic of N Gaussians with zero mean (top k). Operate in logspace for stability. """ if sigma is None: assert batch is not None assert device is not None assert dtype is not None else: batch = sigma.shape[0] device = sigma.device dtype = sigma.dtype log_u = torch.log(torch.rand(batch, l, k, device=device, dtype=dtype)) divisors = torch.arange(N, N - k, -1, device=device, dtype=dtype) # (k,) log_rj = log_u / divisors # (batch, l, k) log_prod = torch.cumsum(log_rj, dim=-1) # (batch, l, k) uniforms = torch.exp(log_prod) # (batch, l, k) # convert to Gaussian and rescale topk = torch.special.ndtri(uniforms) if sigma is not None: topk = topk * sigma[:, None, None] return topk def _sample_topk_and_extra(N: int, alpha: torch.Tensor, sigma: torch.Tensor, l: int, k: int): """ Sample the top k order statistics between N - 1 zero mean Gaussians, and a single Gaussian with mean alpha. """ top_k_others = _sample_topk_gaussian(N - 1, sigma, l, k) extra = alpha[:, None] + torch.randn( size=(alpha.shape[0], l), device=alpha.device ) * sigma[:, None] # (bs, l) min_values = top_k_others[:, :, -1] is_extra_in_topk = (extra > min_values) # bs x l top_k_others[:, :, -1][is_extra_in_topk] = extra[is_extra_in_topk] return extra, top_k_others, is_extra_in_topk def _log_mean_exp_trunc_normal(c: torch.Tensor, sigma: torch.Tensor): """ Compute log(E[exp(X) | X < c] for X ~ N(0, sigma^2). Closed-form expression: mu = exp(sigma**2 / 2) * Phi((c - sigma**2) / sigma) / Phi(c / sigma) where Phi is the standard normal CDF. Operate in log space for stability. """ log_num = torch.special.log_ndtr((c - sigma**2) / sigma) log_den = torch.special.log_ndtr(c / sigma) return sigma**2 / 2.0 + log_num - log_den def sample_tempered_softmax_topk( extra_index: torch.Tensor, alpha: torch.Tensor, sigma: torch.Tensor, l: int, k: int, vocab_size: int, # 1 / T. If low temperature, inverse will be large inverse_temperature: float = 1.0): assert alpha.ndim == 1 assert sigma.ndim == 1 # float64 needed for numerical precision alpha = alpha.to(torch.float64) sigma = sigma.to(torch.float64) # Sample the top k between (vocab_size - 1) zero-mean # Gaussians, and a single Gaussian with mean alpha. extra, top_k, is_extra_in_topk = _sample_topk_and_extra( vocab_size, alpha, sigma, l, k) min_rv = torch.where(is_extra_in_topk, top_k[:, :, -2], top_k[:, :, -1]) # (bs, l) scaled_sigma = sigma[:, None] * inverse_temperature # (bs, 1) scaled_c = min_rv * inverse_temperature # (bs, l) log_mu = _log_mean_exp_trunc_normal(scaled_c, scaled_sigma) log_topk = top_k * inverse_temperature # Approximate contribution of the unknown variables count = torch.where(is_extra_in_topk, vocab_size - k, vocab_size - k - 1).to(log_mu.dtype) log_tail = torch.log(count) + log_mu # Contribution of the extra variable, when NOT in the topk log_extra = extra * inverse_temperature extra_not_in_topk = ~is_extra_in_topk log_extra_masked = torch.full_like(log_tail, float('-inf')) log_extra_masked[extra_not_in_topk] = \ log_extra[extra_not_in_topk] log_contribs = torch.cat([ log_topk, log_tail[..., None], log_extra_masked[..., None]], dim=-1) # (bs, l, k+2) log_denom = torch.logsumexp(log_contribs, dim=-1, keepdim=True) # (bs, l, 1) softmax_approx = torch.exp(log_topk - log_denom) # (bs, l, k) # If sum over k categories is zero, just set to one-hot on # the largest. Fix div by zero normalizer = softmax_approx.sum(dim=-1, keepdim=True) zero_sum = normalizer == 0.0 softmax_approx = torch.where(zero_sum, 0.0, softmax_approx) softmax_approx[..., 0][zero_sum[..., 0]] = 1.0 indices = _sample_k_int(alpha.shape[0], l, k, # Note the -1: max_value=vocab_size - 1, device=alpha.device) # Ensure x0 (true token) is not generated indices[indices >= extra_index[..., None]] += 1 indices[..., -1][is_extra_in_topk] = \ extra_index[is_extra_in_topk] xt_usdm = torch.where(is_extra_in_topk, extra_index, indices[..., 0]) return softmax_approx, indices, xt_usdm if __name__ == "__main__": # Usage: python utils.py --vocab_size=N parser = argparse.ArgumentParser( description='Caches the integral appearing in the ' 'Diffusion Transformation operator.') parser.add_argument( '--vocab_size', type=int, default=50257, # For the gpt2 tokenizer help='Vocabulary size (default: 50257)') parser.add_argument( '--partition_index', type=int, default=0, help='Helps parallelize caching') parser.add_argument( '--num_partitions', type=int, default=1, help='Helps parallelize caching') parser.add_argument( '--log10_num_points', type=int, default=5, help=('The integral is function that needs to be ' 'evaluated for inputs with a range [-5, 1]. ' 'This argument represents the logarithm base 10 ' 'of number of bins of discretization.')) args = parser.parse_args() # Computing the integral over [-5, 1] can be slow, # so one might prefer splitting it into `num_partitions` # bins and compute each separately and merge them later. _cache_prob_usdm_in_partition( partition_index=args.partition_index, num_partitions=args.num_partitions, vocab_size=args.vocab_size, log10_num_points=args.log10_num_points) test_cache_prob_usdm_in_partition( partition_index=args.partition_index, num_partitions=args.num_partitions, vocab_size=args.vocab_size, log10_num_points=args.log10_num_points)