[
  {
    "path": ".gitignore",
    "content": ".DS_Store\n\n.hf_cache\ntest/\noutputs/\nwandb/\nwatch_folder/\nnotes.md\ngrid_search.sh\n*.ipynb\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n.idea/\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2025 Subham Sekhar Sahoo\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# The Diffusion Duality Series\n\n</div>\n\n## [Chapter I (ICML 2025)](https://arxiv.org/abs/2506.10892)\n\nBy [Subham Sekhar Sahoo](https://s-sahoo.github.io), [Justin Deschenaux](https://jdeschena.com), [Aaron Gokaslan](https://skylion007.github.io),\n[Guanghan Wang](https://tech.cornell.edu/people/guanghan-wang/), [Justin Chiu](https://justinchiu.netlify.app), [Volodymyr Kuleshov](https://www.cs.cornell.edu/~kuleshov/)\n\n[![GitHub](https://img.shields.io/badge/GitHub-Repo-181717?logo=github&logoColor=white)](https://github.com/s-sahoo/duo/tree/ch-1)\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Sf7R-dqdR6gq-H8nyZ9E3ZkyvqMTqcwq?usp=sharing)\n[![YouTube](https://img.shields.io/badge/YouTube-%23FF0000.svg?logo=YouTube&logoColor=white)](https://youtu.be/FCO-nnqHOqQ?si=4eGnj5zbRgyCYWwI)\n[![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](http://s-sahoo.github.io/duo)\n[![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2506.10892)\n[![deploy](https://img.shields.io/badge/🤗-Huggingface-blue)](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875)\n\n**Unlocks few-step generation in discrete diffusion-LLMs via the underlying Gaussian diffusion.**\n\n<div align=\"center\">\n  <img src=\"https://github.com/s-sahoo/duo/blob/gh-pages/static/images/duo_schematic.png\" width=\"60%\">\n</div>\n\n## [Chapter II: Ψ-Samplers and Efficient Curriculum (ICLR 2026)](https://arxiv.org/abs/2602.21185)\nBy  [Justin Deschenaux](https://jdeschena.com), [Caglar Gulcehre](https://www.caglar.ai),\n[Subham Sekhar Sahoo](https://s-sahoo.github.io)\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1uFSzrfG0KXhGcohRIfWIM2Y7V9Q7cQNA?usp=sharing)\n[![deploy](https://img.shields.io/badge/Blog%20%20-8A2BE2)](http://s-sahoo.github.io/duo-ch2)\n[![arXiv](https://img.shields.io/badge/arXiv-2406.07524-red.svg)](https://arxiv.org/abs/2602.21185)\n<!-- [![deploy](https://img.shields.io/badge/🤗-Huggingface-blue)](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875) -->\n\n**Uniform-state beats Masked diffusion on text and image generation!**\n<div align=\"center\">\n  <img src=\"https://github.com/s-sahoo/duo-ch2/blob/gh-pages/static/images/psi-samplers-figure.png\" width=\"90%\">\n</div>\n\nThis repository contains the code for the two papers in the Diffusion Duality series. It includes:\n- **Duo / $\\text{Duo}^\\text{++}$** sampling (ancestral, ReMDM, $\\Psi$-samplers, greedy-tail) — [Sampling & Eval](#sampling--eval)\n- Original and efficient curriculum training strategies — [Training](#training)\n- Discrete Consistency Distillation (DCD) — [Distillation](#discrete-consistency-distillation)\n- Baselines (AR, MDLM, SEDD, D3PM) — [Baselines](#baselines)\n\n[Getting Started](#getting-started) | [Checkpoints](#checkpoints) | [Citation](#acknowledgements--citation)\n\n# Getting Started\n<a name=\"getting_started\"></a>\n\nTo get started, create a conda environment containing the required dependencies.\n\n```bash\nconda create -n duo python=3.12\nconda activate duo\nconda install nvidia/label/cuda-12.4.0::cuda-toolkit\npip install -r requirements.txt\npip install flash_attn==2.7.4.post1\n```\n\n# Checkpoints\n<a name=\"checkpoints\"></a>\n\n* **Duo** (Language Modeling): Trained on OpenWebText for `1M` training steps (distilled / base):\n  * [Huggingface](https://huggingface.co/collections/s-sahoo/duo-67f9ff8fde919224e5fbd875)🤗.\n  * [Google Drive folder](https://drive.google.com/drive/folders/1JpqFM8XRvifwIkjWPfMyuDvu41r1yk0t?usp=share_link) as the HF checkpoints can't be finetuned.\n* **Duo** (Image Modeling): Trained on CIFAR-10\n  * [Huggingface (contains the raw checkpoints)](https://huggingface.co/jdeschena/duo2-cifar10)\n* **Baselines** (SEDD, MDLM, AR): Trained on OpenWebText\n  * [Google Drive folder](https://drive.google.com/drive/folders/16LuuptK7Xfk-vzhQYZBZ0SA-B-BFluau?usp=sharing) — download `ar.ckpt`, `mdlm.ckpt`, `sedd.ckpt`.\n\n# Training\n<a name=\"training\"></a>\n\nThis repo implements the original Duo curriculum, as well as the fast $\\text{Duo}^\\text{++}$ curriculum. By default, the training scripts use the original curriculum. To enable the efficient curriculum, simply replace `algo.curriculum.mode=simple` by `algo.curriculum.mode=poly9` (see comments in each training script).\n\nTo train $\\text{Duo}^\\text{++}$, use the following scripts:\n* LM1B\n  * w/ sentencepacking (same as in D3PM)\n    * Training script: [`scripts/train_lm1b_duo_sentencepacking.sh`](./scripts/train_lm1b_duo_sentencepacking.sh)\n    * [Wandb run](https://api.wandb.ai/links/kuleshov-group/huwt0ek3) \n  * w/o sentencepacking (same as in MDLM, SEDD)\n    * Training script: [`scripts/train_lm1b_duo.sh`](./scripts/train_lm1b_duo.sh)\n    * [Wandb run](https://api.wandb.ai/links/sahoo-diffusion/lkv5z3tm)\n    \n* OWT: [`scripts/train_owt_duo.sh`](./scripts/train_owt_duo.sh).\n* CIFAR-10:\n  * Duo: [`scripts/train_cifar10_duo_cosine.sh`](./scripts/train_cifar10_duo_cosine.sh)\n  * MDLM: [`scripts/train_cifar10_mdlm_cosine.sh`](./scripts/train_cifar10_mdlm_cosine.sh)\n  * Both scripts default to a cosine noise schedule. To use log-linear instead, set `noise=log-linear`.\n\n**Notes:**\n* Run `mkdir watch_folder` to create a directory to store slurm logs,\n  and then run any script in [`scripts/`](scripts) as a slurm job: `sbatch scripts/ABC_XYZ.sh`\n* Control the batch size per GPU using the argument `loader.batch_size`. If `loader.batch_size * num_gpus < loader.global_batch_size`, PyTorch Lightning resorts to gradient accumulation. \n\n# Discrete Consistency Distillation\n<a name=\"distillation\"></a>\n\nTo distill a model using the Discrete Consistency Distillation (`Alg. 1` in the [Duo paper](https://arxiv.org/abs/2506.10892)), use [`scripts/distil_owt.sh`](scripts/distil_owt.sh).\n\n\n# Sampling & Eval\n<a name=\"sampling\"></a>\n\n## Likelihood\nTo compute test perplexity on the validation set of OWT use [`scripts/eval_owt_duo.sh`](scripts/eval_owt_duo.sh) and for zero shot perplexities use [`scripts/zero_shot_duo.sh`](scripts/zero_shot_duo.sh).\n\n## Sampling\nYou can sample with ancestral sampling using the scripts in [`scripts/gen_ppl_*.sh`](scripts/). To sample with the PC samplers such as ReMDM and our $\\Psi$-samplers, use the scripts in [`scripts/psi_samplers`](scripts/psi_samplers). This directory contains examples for sampling text and images.\n\nTo use the \"Greedy-tail sampler\" (equivalent to nucleus sampling in AR models; see `Sec. 4.2` in the paper), set `sampling.noise_removal=greedy`. Using the default `sampling.noise_removal=ancestral` will produce more diverse samples (higher entropy) but with worse generative perplexity.\n\nTo sample from a HuggingFace checkpoint (text only), run the following command:\n```bash\npython main.py \\\n  mode=sample_eval \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=8 \\\n  data=openwebtext-split \\\n  algo=duo_base \\\n  algo.backbone=hf_dit \\\n  eval.checkpoint_path=s-sahoo/duo-distilled \\\n  sampling.steps=8 \\\n  sampling.num_sample_batches=1 \\\n  sampling.noise_removal=greedy \\\n  +wandb.offline=true \n```\n\nTo use the example scripts with raw checkpoints (see [Checkpoints](#checkpoints)), download them and set the checkpoint path in the script.\n\n# Baselines\n<a name=\"baselines\"></a>\nDownload the baseline checkpoints (see [Checkpoints](#checkpoints)) and specify the paths appropriately in the respective shell scripts:\n* [`scripts/eval_owt_*.sh`](scripts/) for computing validation perplexity on OWT.\n* [`scripts/gen_ppl_*.sh`](scripts/) for generating text samples and evaluating them.\n* [`scripts/zero_shot_*.sh`](scripts/) for computing zero shot perplexities.\n* [`scripts/train_*.sh`](scripts/) for training the models.\n\n# Acknowledgements & Citation\nThis repository was built off of [MDLM's Github repository](https://github.com/kuleshov-group/mdlm). Cite our papers using:\n```\n@inproceedings{\n    sahoo2025the,\n    title={The Diffusion Duality},\n    author={Subham Sekhar Sahoo and Justin Deschenaux and Aaron Gokaslan and Guanghan Wang and Justin T Chiu and Volodymyr Kuleshov},\n    booktitle={Forty-second International Conference on Machine Learning},\n    year={2025},\n    url={https://openreview.net/forum?id=9P9Y8FOSOk}\n}\n\n@inproceedings{\n    deschenaux2026the,\n    title={The Diffusion Duality, Chapter {II}: \\${\\textbackslash}Psi\\$-Samplers and Efficient Curriculum},\n    author={Justin Deschenaux and Caglar Gulcehre and Subham Sekhar Sahoo},\n    booktitle={The Fourteenth International Conference on Learning Representations},\n    year={2026},\n    url={https://openreview.net/forum?id=RSIoYWIzaP}\n}\n```\n"
  },
  {
    "path": "algo.py",
    "content": "import os\nimport collections\nimport copy\nimport pickle\nfrom typing import Optional\n\nimport fsspec\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nimport trainer_base\nimport utils\n\n\nclass AR(trainer_base.TrainerBase):\n  def __init__(self, config, tokenizer):\n    vocab_size = tokenizer.vocab_size\n    if (not hasattr(tokenizer, 'mask_token')\n        or tokenizer.mask_token is None):\n      self.mask_index = vocab_size\n      vocab_size += 1\n    else:\n      self.mask_index = tokenizer.mask_token_id\n    super().__init__(config, tokenizer,\n                     vocab_size=vocab_size)\n    self.save_hyperparameters()\n    self._validate_configuration()\n\n  def _validate_configuration(self):\n    super()._validate_configuration()\n    assert not self.config.algo.time_conditioning\n    assert self.config.prior.type == 'none'\n\n  def _process_model_input(self, x0, valid_tokens):\n    input_tokens = x0[:, :-1]\n    output_tokens = x0[:, 1:]\n    valid_tokens = valid_tokens[:, 1:]\n    return input_tokens, output_tokens, valid_tokens\n\n  def nll(self, input_tokens, output_tokens,\n          current_accumulation_step):\n    del current_accumulation_step\n    output = self.backbone(input_tokens, None)\n    output[:, :, self.mask_index] = self.neg_infinity\n    output = output.log_softmax(-1)\n    return - output.gather(\n      -1, output_tokens[:, :, None])[:, :, 0]\n\n  def generate_samples(self, num_samples, **kwargs):\n    # precompute token buffer\n    num_pred_tokens = self.num_tokens - 1\n    x = torch.zeros(\n      (num_samples, num_pred_tokens + 1),\n      dtype=torch.long,\n      device=self.device)\n    x[:, 0] = self.tokenizer.bos_token_id\n    # precompute noise\n    noise = (torch.distributions.Gumbel(0, 1)\n             .sample((num_samples, num_pred_tokens, self.vocab_size))\n             .to(self.device))\n    if self.config.sampling.use_float64:\n      noise = noise.to(torch.float64)\n    for i in range(num_pred_tokens):\n      output = self.backbone(x[:, :i + 1], None)\n      output[:, :, self.mask_index] = self.neg_infinity\n      output = output.log_softmax(-1)\n      y = (output[:, -1, :] + noise[:, i, :]).argmax(-1)\n      x[:, i + 1] = y\n    return x\n\n  def _process_sigma(self, sigma):\n    del sigma\n    return None\n\n\nclass MDLM(trainer_base.AbsorbingState):\n  def __init__(self, config, tokenizer):\n    super().__init__(config, tokenizer)\n    self._validate_configuration()\n\n  def _validate_configuration(self):\n    assert self.sampler != 'ancestral', \\\n      'sampling.predictor=ancestral is not desirable because ' \\\n      'it is slow. Please set sampling.predictor=ancestral_cache'\n\n  def _process_model_output(self, model_output, xt, sigma):\n    del sigma\n    model_output[:, :, self.mask_index] += self.neg_infinity\n    \n    # Normalize the model_output such that x.exp() is\n    # a probability distribution over vocab_size.\n    model_output = model_output - torch.logsumexp(\n      model_output, dim=-1, keepdim=True)\n    # Apply updates directly in the logits matrix.\n    # For the logits of the unmasked tokens, set all values\n    # to -infinity except for the indices corresponding to\n    # the unmasked tokens.\n    unmasked_indices = (xt != self.mask_index)\n    model_output[unmasked_indices] = self.neg_infinity\n    model_output[unmasked_indices, xt[unmasked_indices]] = 0\n    return model_output\n\n  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,\n                    dalpha_t, low_var=False):\n    del xt\n    log_p_theta = torch.gather(\n      input=log_x_theta,\n      dim=-1,\n      index=x0[:, :, None]).squeeze(-1)\n    return log_p_theta * dalpha_t / (1 - alpha_t)\n\n  def _get_score(self, x, sigma):\n    model_output = self.forward(x, sigma)\n    # score(x, t) = p_t(y) / p_t(x)\n    # => log score(x, t) = log p_t(y) - log p_t(x)\n    \n    # case 1: x = masked\n    #   (i) y = unmasked\n    #     log score(x, t) = log p_\\theta(x)|_y + log k\n    #     where k = exp(- sigma) / (1 - exp(- sigma))\n    #   (ii) y = masked\n    #     log score(x, t) = 0\n\n    # case 2: x = unmasked\n    #   (i) y != masked, y != x\n    #     log score(x_i, t) = - inf\n    #   (ii) y = x \n    #     log score(x_i, t) = 0\n    #   (iii) y = masked token\n    #     log score(x_i, t) = - log k\n    #     where k = exp(- sigma) / (1 - exp(- sigma))\n    \n    log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)\n    assert log_k.ndim == 1\n    \n    masked_score = model_output + log_k[:, None, None]\n    masked_score[:, :, self.mask_index] = 0\n\n    unmasked_score = self.neg_infinity * torch.ones_like(\n      model_output)\n    unmasked_score = torch.scatter(\n      unmasked_score,\n      -1,\n      x[..., None],\n      torch.zeros_like(unmasked_score[..., :1]))\n    unmasked_score[:, :, self.mask_index] = - (\n      log_k[:, None] * torch.ones_like(x))\n    \n    masked_indices = (x == self.mask_index).to(\n      model_output.dtype)[:, :, None]\n    model_output = (\n      masked_score * masked_indices\n      + unmasked_score * (1 - masked_indices))\n    return model_output.exp()\n\n\nclass D3PMAbsorb(trainer_base.AbsorbingState):\n  def __init__(self, config, tokenizer):\n    super().__init__(config, tokenizer)\n    self._validate_configuration()\n\n  def _validate_configuration(self):\n    super()._validate_configuration()\n    assert self.noise.type == 'log-linear'\n    assert self.parameterization == 'mean'\n\n  def _process_model_output(self, model_output, xt, sigma):\n    del xt \n    del sigma\n    if self.subs_masking:\n      model_output[:, :, self.mask_index] += self.neg_infinity\n    return model_output.log_softmax(dim=-1)\n\n  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,\n                    dalpha_t, low_var=False):\n    del dalpha_t\n    assert not low_var\n    dt = 1 / self.T\n    t = 1 - alpha_t  # Only valid for log-linear schedule.\n    t = t.clamp(0., 1.0 - 1e-4)\n    alpha_t = alpha_t + torch.zeros_like(xt)\n    alpha_s = t - dt + torch.zeros_like(xt)\n    assert alpha_s.shape == xt.shape\n    assert alpha_t.shape == xt.shape\n    log_x_theta_at_x0 = torch.gather(\n      log_x_theta, -1, x0[:, :, None]).squeeze(-1)\n    log_x_theta_at_m = log_x_theta[:, :, self.mask_index]\n    x_theta_at_m = log_x_theta_at_m.exp()\n    \n    term_1_coef = dt / t\n    term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)\n    term_1_log_dr = log_x_theta_at_x0\n    \n    term_2_coef = 1 - dt / t\n    term_2_log_nr = term_1_log_nr\n    term_2_log_dr = torch.log(\n      alpha_s * x_theta_at_m / (t - dt) + 1)\n    L_vb_masked = (\n      term_1_coef * (term_1_log_nr - term_1_log_dr)\n      + term_2_coef * (term_2_log_nr - term_2_log_dr))\n\n    diffusion_loss = self.T * L_vb_masked * (xt == self.mask_index)\n    return self._reconstruction_loss(x0) + diffusion_loss\n\n\nclass SEDDAbsorb(trainer_base.AbsorbingState):\n  def __init__(self, config, tokenizer):\n    super().__init__(config, tokenizer)\n    self._validate_configuration()\n\n  def _validate_configuration(self):\n    super()._validate_configuration()\n    assert self.config.sampling.predictor == 'analytic'\n\n  def _get_score(self, x, sigma):\n    return self.forward(x, sigma).exp()\n\n  def _process_model_output(self, model_output, xt, sigma):\n    esigm1_log = torch.where(\n      sigma < 0.5,\n      torch.expm1(sigma),\n      sigma.exp() - 1).log().to(model_output.dtype)\n    # logits shape\n    # (batch_size, context_length, vocab_size)\n    model_output = (model_output\n                    - esigm1_log[:, None, None]\n                    - np.log(model_output.shape[-1] - 1))\n    # The below scatter operation sets the log score\n    # for the input word to 0.\n    model_output = torch.scatter(\n      model_output, -1, xt[..., None],\n      torch.zeros_like(model_output[..., :1]))\n    return model_output\n\n  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,\n                    dalpha_t, low_var=False):\n    \"\"\"Computes the SEDD loss for the Absorbing State Diffusion.\n\n    Args:\n      log_x_theta: float torch.Tensor with shape (batch_size,\n          context_length, vocab_size),\n          log score, output of the denoising network.\n      xt: int torch.Tensor with shape (batch_size,\n          context_length), input.\n      x0: int torch.Tensor with shape (batch_size,\n          context_length), input.\n      alpha_t: float torch.Tensor with shape (batch_size, 1),\n          signal level.\n      alpha_t: float torch.Tensor with shape (batch_size, 1),\n          signal level.\n      dalpha_t: float or float torch.Tensor with shape (batch_size, 1),\n          time derivative of signal level.\n      low_var: bool, low variance loss during training.\n    \n    Returns:\n      loss with shape (batch_size, context_length).\n    \"\"\"\n    assert not low_var\n    masked_indices = xt == self.mask_index\n    sigma = self._sigma_from_alphat(alpha_t)\n    dsigma = - dalpha_t / alpha_t\n\n    expsig_minus_1 = torch.expm1(sigma).expand_as(xt)\n    q_ratio = 1 / expsig_minus_1[masked_indices]\n\n    words_that_were_masked = x0[masked_indices]\n\n    neg_term = q_ratio * torch.gather(\n      log_x_theta[masked_indices],\n      -1,\n      words_that_were_masked[..., None]).squeeze(-1)\n    score = log_x_theta[masked_indices].exp()\n    if self.mask_index == self.vocab_size - 1:\n      pos_term = score[:, :-1].sum(dim=-1)\n    else:\n      pos_term = score[:, : self.mask_index].sum(\n        dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)\n    const = q_ratio * (q_ratio.log() - 1)\n\n    entropy = torch.zeros(* xt.shape, device=xt.device)\n    entropy[masked_indices] += pos_term - neg_term + const\n    return dsigma * entropy\n\n\nclass DUO_BASE(trainer_base.UniformState):\n  def __init__(self, config, tokenizer):\n    super().__init__(config, tokenizer)\n    self._validate_configuration()\n\n  def on_save_checkpoint(self, checkpoint):\n    checkpoint['state_dict'] = collections.OrderedDict(\n      (k, v) for k, v in checkpoint['state_dict'].items()\n      if not k.startswith('teacher'))\n    super().on_save_checkpoint(checkpoint)\n\n  def on_load_checkpoint(self, checkpoint):\n    checkpoint['state_dict'] = collections.OrderedDict(\n      (k, v) for k, v in checkpoint['state_dict'].items()\n      if not k.startswith('teacher'))\n    super().on_load_checkpoint(checkpoint)\n\n  def _process_model_output(self, model_output, xt, sigma):\n    del xt, sigma\n    return model_output.log_softmax(dim=-1)\n\n  def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):\n    \"\"\"Computes the posterior / approximate posterior.\n\n    Args:\n      x: Either clean input `x0` (one-hot),\n        or model's predicted `x_theta` of shape (B, L, V).\n      xt: The noisy latent (as indices) of shape (B, L).\n      alpha_s: Noise level at s of shape (B, [L | 1], 1).\n      alpha_t: Noise level at t of shape (B, [L | 1], 1).\n\n    Returns:\n      Posterior / approximate posterior of shape (B, L, V).\n    \"\"\"\n    if self.config.sampling.use_float64:\n      x0 = x0.to(torch.float64)\n    if alpha_s.ndim == 2:\n      alpha_s = alpha_s.unsqueeze(-1)\n    if alpha_t.ndim == 2:\n      alpha_t = alpha_t.unsqueeze(-1)\n    alpha_ts = alpha_t / alpha_s\n    d_alpha = alpha_s - alpha_t\n    xt_one_hot = F.one_hot(xt, self.vocab_size).to(\n      self.dtype).to(self.device)\n    return (\n      (alpha_t * self.vocab_size * x0 * xt_one_hot + (\n        alpha_ts - alpha_t) * xt_one_hot + d_alpha * x0 + (\n          1 - alpha_ts) * (1 - alpha_s) / self.vocab_size) / (\n            alpha_t * self.vocab_size * torch.gather(\n              x0, -1, xt[..., None]) + (1 - alpha_t)))\n\n  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,\n                    dalpha_t, low_var=False):\n    assert alpha_t.ndim == 2\n    assert x0.ndim == 2\n    assert xt.ndim == 2\n    assert not torch.is_tensor(dalpha_t) or dalpha_t.ndim == 2\n    x_reconst = log_x_theta.exp()\n    x_bar_theta = self.vocab_size * alpha_t[\n        :, :, None] * x_reconst + 1 - alpha_t[:, :, None]\n    coeff = dalpha_t / (self.vocab_size * alpha_t)\n    x_eq_xt = (x0 == xt).float()\n    x_neq_xt = 1 - x_eq_xt\n    xbar_xt = (1 - alpha_t) + self.vocab_size * alpha_t * x_eq_xt\n    xbar_theta_xt = torch.gather(\n      x_bar_theta, -1, xt.unsqueeze(-1)).squeeze(-1)\n    xbar_theta_x = torch.gather(\n      x_bar_theta, -1, x0.unsqueeze(-1)).squeeze(-1)\n    term1 = self.vocab_size * (1 / xbar_xt\n                                - 1 / xbar_theta_xt)\n    \n    const = (1 - alpha_t) / (self.vocab_size * alpha_t\n                             + 1 - alpha_t)\n    term2_coefs = x_eq_xt * const + x_neq_xt\n    term2_offset = ((self.vocab_size - 1) * const * x_eq_xt\n                    - (1 / const) * x_neq_xt) * const.log()\n    term2_theta = - term2_coefs * (\n      x_bar_theta.log().sum(-1)\n      - self.vocab_size * xbar_theta_xt.log())\n    term2_theta = (\n      term2_theta\n      - self.vocab_size * alpha_t / (1 - alpha_t) * (\n        xbar_theta_x.log() - xbar_theta_xt.log()) * x_neq_xt)\n    term2 = term2_theta + term2_offset\n    diffusion_loss = coeff * (term1 - term2)\n    assert diffusion_loss.ndim == 2\n    return diffusion_loss\n\n\nclass Integral(torch.autograd.Function):\n  \"\"\"\n  torch module calculating UDLM's p_t \n  \"\"\"\n\n  @staticmethod\n  def forward(ctx, gamma_t, data):\n    gamma_max = data['gamma_max']\n    gamma_min = data['gamma_min']\n    if (gamma_t.max() > gamma_max) or (\n      gamma_t.min() < gamma_min):\n      print('max:{} {}'.format(gamma_t.max(), gamma_max))\n      print('min:{} {}'.format(gamma_t.min(), gamma_min))\n      gamma_t = torch.clip(gamma_t, gamma_min, gamma_max)\n    indices = torch.round(\n      (data['num_points'] - 1) * (gamma_t - gamma_min) / (\n          gamma_max - gamma_min)).long()\n    grad_pt = data['grad_pt']\n    ctx.grad_pt = grad_pt[indices]\n    \n    pt = data['pt'][indices]\n    assert pt.shape == gamma_t.shape\n    return pt\n\n  @staticmethod\n  def backward(ctx, grad_output):\n    return ctx.grad_pt * grad_output, None\n\n\nclass DUO(DUO_BASE):\n  def __init__(self, config, tokenizer):\n    super().__init__(config, tokenizer)\n    self.gamma_min = self.config.algo.curriculum.gamma_min\n    self.gamma_max = self.config.algo.curriculum.gamma_max\n    self.gumbel_tau_log10_start = \\\n          self.config.algo.curriculum.gumbel_tau_log10_start\n    self.gumbel_tau_log10_end = \\\n            self.config.algo.curriculum.gumbel_tau_log10_end\n    self.curriculum_start = self.config.algo.curriculum.start\n    self.curriculum_end = self.config.algo.curriculum.end\n    self.loss_type = self.config.algo.loss_type\n    self._initialize_curriculum_coefficients()\n    self._validate_configuration()\n    \n  def _initialize_curriculum_coefficients(self):\n    if self.config.algo.curriculum.mode in {'simple', \n      'efficient_cached'}:\n      self._init_curriculum_cached()\n    elif self.config.algo.curriculum.mode == 'series':\n      self._init_curriculum_series()\n    elif self.config.algo.curriculum.mode in {'sigmoid',\n      'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7',\n      'poly9'}:\n      self._init_curriculum_approx()\n    else:\n      raise ValueError(self.config.algo.curriculum.mode)\n\n  def _init_curriculum_cached(self):\n    fpath = self.config.algo.curriculum.integral_cache_path\n    with fsspec.open(fpath, 'rb') as f:\n      self.integral_cache = pickle.load(f)\n    self.integral_cache['pt'] = torch.from_numpy(\n      self.integral_cache['pt'])\n    self.integral_cache['grad_pt'] = torch.from_numpy(\n      self.integral_cache['grad_pt'])\n\n  def _init_curriculum_series(self):\n    m, i = utils.compute_duo_series_coefficients(\n      self.config.algo.curriculum.n_series_terms, \n      self.vocab_size)\n\n    self.register_buffer('coefficients_m', m, \n                         persistent=False)\n    self.register_buffer('coefficients_i', i,\n                         persistent=False)\n    self.register_buffer('power_arange',\n      torch.arange(self.config.algo.curriculum.n_series_terms,\n        dtype=torch.float64)[None], persistent=False)\n\n  def _init_curriculum_approx(self):\n    fname = f'{self.config.algo.curriculum.mode}.npy'\n    fpath = os.path.join(self.config.algo.curriculum.cache_dir, \n                         fname)\n    if not os.path.exists(fpath):\n      # Compute the coefficients on the fly\n      coefficients, _, _, _ = utils.compute_duo_operator_approx(\n        num_coefficients=self.config.algo.curriculum.n_series_terms,\n        vocab_size=self.vocab_size,\n        gamma_min=self.gamma_min,\n        gamma_max=self.gamma_max,\n        fct_name=self.config.algo.curriculum.mode)\n      # Tuples are for torch compile, tuples are immutable\n      coefficients = tuple(coefficients)\n      parent_dir = os.path.dirname(fpath)\n      os.makedirs(parent_dir, exist_ok=True)\n      np.save(fpath, coefficients)\n    else:\n      coefficients = tuple(np.load(fpath).tolist())\n    mode = self.config.algo.curriculum.mode\n    if mode == 'sigmoid':\n      fn = utils.duo_to_alpha_dalpha_sigmoid\n    elif mode == 'sigmoid-edge-corrected':\n      fn = utils.duo_t_to_alpha_dalpha_sigm_corrected\n    elif mode in ('poly3', 'poly5', 'poly7', 'poly9'):\n      fn = utils.duo_to_alpha_dalpha_poly\n    else:\n      raise ValueError(mode)\n    fn = torch.compile(fn)\n    self._t_to_alpha_dalpha_compiled = \\\n      lambda t: fn(t, *coefficients)\n\n  def to(self, *args, **kwargs):\n    self = super().to(*args, **kwargs)\n    self.integral_cache['pt'] = self.integral_cache[\n      'pt'].to(*args, **kwargs)\n    self.integral_cache['grad_pt'] = self.integral_cache[\n      'grad_pt'].to(*args, **kwargs)\n    return self\n\n  def cuda(self, device=None):\n    self = super().cuda(device=device)\n    if hasattr(self, 'integral_cache'):\n      self.integral_cache['pt'] = self.integral_cache[\n        'pt'].cuda(device=device)\n      self.integral_cache['grad_pt'] = self.integral_cache[\n        'grad_pt'].cuda(device=device)\n    return self\n\n  def cpu(self):\n    self = super().cpu()\n    if hasattr(self, 'integral_cache'):\n      self.integral_cache['pt'] = self.integral_cache[\n        'pt'].cpu()\n      self.integral_cache['grad_pt'] = self.integral_cache[\n        'grad_pt'].cpu()\n    return self\n\n  def to(self, *args, **kwargs):\n    self = super().to(*args, **kwargs)\n    if hasattr(self, 'integral_cache'):\n      self.integral_cache['pt'] = self.integral_cache[\n        'pt'].to(*args, **kwargs)\n      self.integral_cache['grad_pt'] = self.integral_cache[\n        'grad_pt'].to(*args, **kwargs)\n    return self\n\n  def _compute_gumbel_tau_inverse(self):\n    start = self.gumbel_tau_log10_start\n    end = self.gumbel_tau_log10_end\n    delta = end - start\n    if self.global_step < self.curriculum_start:\n      tau = start\n    elif self.global_step < self.curriculum_end:\n      frac = (self.global_step - self.curriculum_start) / (\n        self.curriculum_end - self.curriculum_start)\n      tau = start + frac * delta\n    else:\n      tau = -10\n    return 10 ** (-tau)\n\n  def training_step(self, batch, batch_idx):\n    self.log(name='gumbel_tau_log10',\n             value=1 / self._compute_gumbel_tau_inverse(),\n             on_step=True,\n             on_epoch=False,\n             sync_dist=True)\n    return super().training_step(batch, batch_idx)\n\n  def _gamma_to_alpha_dalpha(self, gamma_t, t):\n    if self.config.algo.curriculum.mode in ('simple', \n      'efficient_cached'):\n      return self._gamma_to_alpha_dalpha_cached(gamma_t)\n    elif self.config.algo.curriculum.mode == 'series':\n      return utils.compute_duo_gamma_to_alpha_dalpha_series(\n        gamma_t, self.coefficients_m, self.coefficients_i,\n        self.power_arange, self.vocab_size, self.gamma_min,\n        self.gamma_max)\n    elif self.config.algo.curriculum.mode in ('sigmoid',\n      'sigmoid-edge-corrected', 'poly3', 'poly5', 'poly7',\n      'poly9'):\n      return self._t_to_alpha_dalpha_compiled(t)\n    else:\n      raise ValueError(self.config.algo.curriculum.mode)\n\n  def _gamma_to_alphat_integral(self, gamma_t):\n    integral = Integral.apply(gamma_t, self.integral_cache)\n    return (self.vocab_size * integral - 1) / (\n      self.vocab_size - 1)\n\n  def _gamma_to_alpha_dalpha_cached(self, gamma_t):\n    gamma_t_prime = self.gamma_max - self.gamma_min\n    usdm_alpha_t = DUO._gamma_to_alphat_integral(self, gamma_t)\n    T = 1000\n    usdm_dalpha_t = gamma_t_prime * T * (\n      DUO._gamma_to_alphat_integral(self, gamma_t + 1 / T) \n      - usdm_alpha_t)\n    return usdm_alpha_t, usdm_dalpha_t\n\n  def _prior_loss(self):\n    alpha_1 = self._gamma_to_alphat_integral(\n      torch.tensor(self.gamma_max))\n    loss = ((alpha_1 + (1 - alpha_1) / self.vocab_size) \\\n           * torch.log((self.vocab_size - 1) * alpha_1 + 1) \\\n           + (1 - 1 / self.vocab_size) * (1 - alpha_1) \\\n           * torch.log(1 - alpha_1))\n    return loss.item()\n\n  def _q_xt_gaussian(self, x, gamma_t):\n    \"\"\"Computes the noisy sample xt.\"\"\"\n    assert gamma_t.ndim == 1\n    assert x.ndim == 3\n    gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1)\n    alpha_t = torch.sigmoid(-gamma_t).sqrt()\n    sigma_t = torch.sigmoid(gamma_t).sqrt()\n    epsilon = torch.randn(x.shape, dtype=torch.float32,\n                          device=self.device)\n    return alpha_t * x + sigma_t * epsilon\n\n  def nll(self, x0, output_tokens,\n          current_accumulation_step=None, train_mode=False):\n    use_true_nll = (self.global_step > self.curriculum_end\n                    or not train_mode)\n    if use_true_nll:\n      return super().nll(x0, output_tokens,\n                         current_accumulation_step)\n    del output_tokens\n    t = self._sample_t(x0.shape[0], current_accumulation_step)\n    gamma_t = self.gamma_min + t * (self.gamma_max\n                                    - self.gamma_min)\n    usdm_alpha_t, usdm_dalpha_t = \\\n      self._gamma_to_alpha_dalpha(gamma_t, t)\n\n    usdm_alpha_t = usdm_alpha_t.unsqueeze(-1)\n    assert usdm_alpha_t.ndim == 2\n    usdm_dalpha_t = usdm_dalpha_t.unsqueeze(-1)\n    sigma = self._sigma_from_alphat(usdm_alpha_t)\n    # Default Duo curriculum\n    if self.config.algo.curriculum.mode == 'simple':\n      x0_one_hot = F.one_hot(x0, self.vocab_size)\n      xt = self._q_xt_gaussian(x0_one_hot, gamma_t)\n      xt = xt * self._compute_gumbel_tau_inverse()\n      xt_usdm = xt.argmax(-1)\n      log_x_theta = self.forward(xt, sigma=sigma)\n    else:  # Efficient variant\n      softmax_approx, topk_indices, xt_usdm = \\\n        utils.sample_tempered_softmax_topk(\n        extra_index=x0,\n        alpha=torch.sigmoid(-gamma_t).sqrt(),\n        sigma=torch.sigmoid(gamma_t).sqrt(),\n        l=x0.shape[1],\n        k=self.config.algo.curriculum.top_k,\n        vocab_size=self.vocab_size,\n        inverse_temperature=self._compute_gumbel_tau_inverse())\n      log_x_theta = self.forward(topk_indices, sigma=sigma, \n                                 weights=softmax_approx)\n\n    return self.nll_per_token(log_x_theta=log_x_theta,\n                              xt=xt_usdm,\n                              x0=x0,\n                              alpha_t=usdm_alpha_t,\n                              dalpha_t=usdm_dalpha_t,\n                              low_var=False)\n\n\nclass Distillation(DUO):\n  def __init__(self, config, tokenizer):\n    super().__init__(config, tokenizer)\n    self.update_teacher_every = config.algo.update_teacher_every\n    self.save_hyperparameters()\n    self.teacher = None\n    self.teacher_ema = config.algo.teacher_ema\n    self.linear_growth_dt = config.algo.linear_growth_dt\n    self.linear_growth_min = config.algo.linear_growth_min\n    self.linear_growth_max = config.algo.linear_growth_max\n\n  def _validate_configuration(self):\n    assert os.path.exists(\n      self.config.algo.integral_cache_path), (\n        'The integral cache (Eq. 10 in the paper) for '\n        f'the {self.config.data.tokenizer_name_or_path} '\n        ' tokenizer doesnt exist at '\n        f'{self.config.algo.integral_cache_path}. '\n        'Please generate it by running the utils.py script, '\n        'and ensure the correct path is specified using the '\n        'algo.integral_cache_path flag.')\n    assert self.loss_type in {\n      'kl-fwd', 'kl-bwd', 'posterior', 'kl-posterior'}\n\n  def _maybe_update_teacher_weights(self):\n    if self.global_step % self.update_teacher_every != 0:\n      return\n    if self.teacher_ema:\n      self.ema.copy_to(self.teacher.parameters())\n    else:\n      for better_param, current_param in zip(\n        self.backbone.parameters(), self.teacher.parameters()):\n        if current_param.requires_grad:\n          current_param.data.copy_(better_param.data)\n\n  @torch.no_grad()\n  def _teacher_logits(self, xt, sigma):\n    if self.teacher is None:\n      self.teacher = copy.deepcopy(self.backbone)\n    self._maybe_update_teacher_weights()\n\n    sigma = self._process_sigma(sigma)\n    with torch.amp.autocast('cuda', dtype=torch.float32):\n      model_output = self.teacher(xt, sigma)\n    logits = self._process_model_output(\n      model_output=model_output, xt=xt, sigma=sigma)\n    return logits.detach()\n\n  def _sample_trajectory(self, x0, gamma_t, gamma_s):\n    \"\"\"Computes the noisy sample xt.\"\"\"\n    assert gamma_t.ndim == 1\n    assert gamma_s.ndim == 1\n    assert x0.ndim == 2\n    x0 = F.one_hot(x0, self.vocab_size).to(\n      self.dtype).to(self.device)\n    gamma_t = gamma_t.unsqueeze(-1).unsqueeze(-1)\n    alpha_t = torch.sigmoid(-gamma_t).sqrt()\n    sigma_t = torch.sigmoid(gamma_t).sqrt()\n\n    gamma_s = gamma_s.unsqueeze(-1).unsqueeze(-1)\n    alpha_s = torch.sigmoid(-gamma_s).sqrt()\n    sigma_s = torch.sigmoid(gamma_s).sqrt()\n    \n    epsilon = torch.randn(x0.shape, dtype=torch.float32,\n                          device=self.device)\n    xt = alpha_t * x0 + sigma_t * epsilon\n    xs = alpha_s * x0 + sigma_s * epsilon\n    return xt, xs\n\n  def _compute_dt(self):\n    if self.linear_growth_dt:\n      scale = self.global_step / self.trainer.max_steps\n      return self.linear_growth_min + scale * (\n        self.linear_growth_max -  self.linear_growth_min)\n    n = self.global_step // self.update_teacher_every\n    return 2 ** n / self.T\n\n  def nll(self, x0, output_tokens,\n          current_accumulation_step=None, train_mode=None):\n    del output_tokens, train_mode\n    t = self._sample_t(x0.shape[0], current_accumulation_step)\n    dt = self._compute_dt()\n    t = torch.clip(t + dt, 0, 1)\n\n    gamma_t = self.gamma_min + t * (self.gamma_max\n                                    - self.gamma_min)\n    gamma_s = self.gamma_min + (\n      t - dt) * (self.gamma_max - self.gamma_min)\n\n    usdm_alpha_t = self._gamma_to_alphat_integral(gamma_t)\n    usdm_alpha_t = usdm_alpha_t.unsqueeze(-1)\n    assert usdm_alpha_t.ndim == 2\n    usdm_alpha_s = self._gamma_to_alphat_integral(gamma_s)\n    usdm_alpha_s = usdm_alpha_s.unsqueeze(-1)\n    assert usdm_alpha_s.ndim == 2\n\n    xt, xs = self._sample_trajectory(x0, gamma_t, gamma_s)\n    xt_discrete = xt.argmax(-1)\n    xs_discrete = xs.argmax(-1)\n    log_x_theta_student = self.forward(\n      xt_discrete, sigma=self._sigma_from_alphat(usdm_alpha_t))\n    log_x_theta_teacher = self._teacher_logits(\n      xs_discrete, sigma=self._sigma_from_alphat(usdm_alpha_s))\n    if self.config.training.loss_precision == 'float64':\n      log_x_theta_student = log_x_theta_student.to(torch.float64)\n      log_x_theta_teacher = log_x_theta_teacher.to(torch.float64)\n    if self.loss_type == 'kl-fwd':\n      return (log_x_theta_teacher.exp() * (\n        log_x_theta_teacher - log_x_theta_student)).sum(-1)\n    elif self.loss_type == 'kl-bwd':\n      return (log_x_theta_student.exp() * (\n        log_x_theta_student - log_x_theta_teacher)).sum(-1)\n    \n  def training_step(self, batch, batch_idx):\n    self.log(name='dt',\n             value=self._compute_dt(),\n             on_step=True,\n             on_epoch=False,\n             sync_dist=True)\n    return super().training_step(batch, batch_idx)\n"
  },
  {
    "path": "configs/algo/ar.yaml",
    "content": "name: ar\nbackbone: dit\nparameterization: ar\ntime_conditioning: False\ncausal_attention: True\n# Irrelevant flags\nT: 0\nsubs_masking: False\nignore_bos: False\n"
  },
  {
    "path": "configs/algo/d3pm.yaml",
    "content": "name: d3pm\nbackbone: dit  # dit / dimamba\nparameterization: mean\ntime_conditioning: True\nT: 1000 \nsubs_masking: False  # True / False\ncausal_attention: False\nignore_bos: False\nloss_type: elbo  # elbo, low_var\n"
  },
  {
    "path": "configs/algo/distillation.yaml",
    "content": "name: distillation\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nsubs_masking: False\ncausal_attention: False\ngumbel_tau_log10_start: -1\ngumbel_tau_log10_end: -1\ncurriculum_start: -1\ncurriculum_end: -1\n\nintegral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl\nloss_type: kl-bwd  # kl-fwd, kl-bwd, posterior\nupdate_teacher_every: 10_000\nignore_bos: False\nT: 64\ngamma_min: -4\ngamma_max: -1\nteacher_ema: False\nposterior_loss_weight: 0.0\nlinear_growth_dt: False\nlinear_growth_min: 0.001\nlinear_growth_max: 0.25"
  },
  {
    "path": "configs/algo/duo.yaml",
    "content": "name: duo\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nT: 0  # 0 (continuous time) / 1000 \nsubs_masking: False\ncausal_attention: False\nloss_type: elbo\nignore_bos: False\n\n# Curriculum arguments\ncurriculum:\n  # Simple is the original duo curriculum, poly9 is the one we use in duo2 for speed\n  mode: simple  # simple / efficient_cached / series / sigmoid / sigmoid-edge-corrected / poly3 / poly5 / poly7 / poly9\n  # Arguments for the simple & efficient variant\n  gumbel_tau_log10_start: -1.0\n  gumbel_tau_log10_end: -2.0\n  start: 100_000\n  end: 200_000\n  gamma_min: -3.5\n  gamma_max: -1.75\n  # Argument for the simple and efficient_cached variants\n  integral_cache_path: ${hydra:runtime.cwd}/integral/${data.tokenizer_name_or_path}.pkl\n  # Arguments for the efficient variant only\n  top_k: -1\n  cache_dir: './transform_approx_coefficients'\n  n_series_terms: 150\n"
  },
  {
    "path": "configs/algo/duo_base.yaml",
    "content": "name: duo_base\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nT: 0  # 0 (continuous time) / 1000 \nsubs_masking: False\ncausal_attention: False\nignore_bos: False\nloss_type: elbo  # elbo, low_var\n"
  },
  {
    "path": "configs/algo/mdlm.yaml",
    "content": "name: mdlm\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: subs\ntime_conditioning: False\nT: 0  # 0 (continuous time) / 1000 \nsubs_masking: False\ncausal_attention: False\nignore_bos: False\nloss_type: elbo\n"
  },
  {
    "path": "configs/algo/ot-finetune.yaml",
    "content": "name: ot-finetune\nbackbone: dit  # dit / dimamba / hf_dit\nparameterization: mean\ntime_conditioning: True\nT: 0  # 0 (continuous time) / 1000 \nsubs_masking: False\ncausal_attention: False\nignore_bos: False\ndelta_ts: 0.5"
  },
  {
    "path": "configs/algo/sedd.yaml",
    "content": "name: sedd\nbackbone: dit  # dit / dimamba\nparameterization: score \ntime_conditioning: True\nT: 0  # 0 (continuous time) / 1000 \nsubs_masking: False\ncausal_attention: False\nignore_bos: False\nloss_type: elbo  # elbo, low_var\n"
  },
  {
    "path": "configs/callbacks/checkpoint_every_n_steps.yaml",
    "content": "checkpoint_every_n_steps:\n  _target_: lightning.pytorch.callbacks.ModelCheckpoint\n  save_top_k: -1 # Do not save any \"best\" models; this callback is being used to save every n train steps\n  save_last: True # save model as ${save_dir}/checkpoints/last.ckpt\n  dirpath: ${checkpointing.save_dir}/checkpoints\n  verbose: True\n  auto_insert_metric_name: False\n  every_n_train_steps: 500\n"
  },
  {
    "path": "configs/callbacks/checkpoint_monitor.yaml",
    "content": "checkpoint_monitor:\n  _target_: lightning.pytorch.callbacks.ModelCheckpoint\n  monitor: val/nll # name of the logged metric which determines when model is improving\n  mode: min # can be \"max\" or \"min\"\n  save_top_k: 1 # save k best models (determined by above metric)\n  save_last: False # True = additionally always save model from last epoch\n  dirpath: ${checkpointing.save_dir}/checkpoints\n  filename: best\n  auto_insert_metric_name: False\n  verbose: True\n"
  },
  {
    "path": "configs/callbacks/grad_record.yaml",
    "content": "grad_record:\n  _target_: utils.GradientInspectionCallback\n  num_grads_log: 4"
  },
  {
    "path": "configs/callbacks/learning_rate_monitor.yaml",
    "content": "learning_rate_monitor:\n  _target_: lightning.pytorch.callbacks.LearningRateMonitor\n  logging_interval: step\n"
  },
  {
    "path": "configs/config.yaml",
    "content": "defaults:\n  - _self_\n  - /callbacks: [checkpoint_every_n_steps, checkpoint_monitor, learning_rate_monitor]\n  - /data: openwebtext\n  - /model: small\n  - /strategy: ddp\n  - /noise: log-linear\n  - /lr_scheduler: constant_warmup\n  - /prior: none\n  - /algo: duo_base\n\nmode: train  # train / ppl_eval / sample_eval\n\nseed: 1\n\nloader:\n  global_batch_size: 512\n  eval_global_batch_size: ${.global_batch_size}\n  # Note: batch_size and eval_batch_size are **per machine**\n  batch_size: ${div_up:${.global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}\n  eval_batch_size: ${div_up:${.eval_global_batch_size}, ${eval:${trainer.devices} * ${trainer.num_nodes}}}\n  num_workers: ${eval:\"len(__import__('os').sched_getaffinity(0))\"}\n  pin_memory: True\n\nsampling:\n  predictor: ancestral  # ancestral_cache (only for MDLM), ancestral, analytic, psi (for psi-samplers)\n  steps: 1000\n  noise_removal: ancestral  # 'ancestral', 'greedy', 'none'\n  use_float64: True\n  p_nucleus: 1.0\n  num_sample_batches: 2  # Total samples: `num_gpus` * `loader.eval_batch_size` * num_sample_batches\n  num_sample_log: 2\n  semi_ar: False\n  stride_length: 1\n  num_strides: 1\n  guid_weight: null  # null -> no classifier-free guidance \n  psi:\n    time_profile: linear  # linear / linear-constant-linear-alpha / linear-constant-linear-alpha-inv\n    # Note: kappa = 1 -> pure posterior, kappa = 0 -> pure PC\n    # modes: pure-posterior / pure-pc / constant-eta / constant-remdm-eta / max-capped-eta / max-rescale-eta\n    high_mode: pure-posterior\n    middle_mode: pure-posterior\n    low_mode: pure-posterior\n    # Fraction of [0, 1] spent in the high / middle modes (the rest being in low mode)\n    # Example: high_frac=0.2, middle_frac=0.6. Then,\n    #  - For t in [1.0, 0.8] -> use high_mode\n    #  - For t in [0.8, 0.2] -> use middle_mode\n    #  - For t in [0.2, 0.0] -> use low_mode\n    high_frac: 0.2\n    middle_frac: 0.6\n\ntraining:\n  ema: 0.9999\n  antithetic_sampling: True\n  importance_sampling: False\n  sampling_eps: 1e-3\n  change_of_variables: False\n  loss_precision: 'bf16'  # bf16, float32, float64\n  finetune_path: ''\n  class_dropout_p: 0.1  # only used with class-conditional datasets (eg cifar10)\n\neval:\n  checkpoint_path: ''  # Used to evaluate a checkpoint after training.\n  disable_ema: False\n  compute_generative_perplexity: False\n  perplexity_batch_size: 8\n  compute_perplexity_on_sanity: False\n  gen_ppl_eval_model_name_or_path: gpt2-large  # gpt2-large, meta-llama/Llama-2-7b-hf\n  generate_samples: True\n  generated_samples_path: ${cwd:}/samples.json\n\noptim:\n  weight_decay: 0\n  lr: 3e-4\n  beta1: 0.9\n  beta2: 0.999\n  eps: 1e-8\n\ntrainer:\n  _target_: lightning.Trainer\n  accelerator: cuda\n  num_nodes: 1\n  devices: ${device_count:}\n  accumulate_grad_batches: ${div_up:${loader.global_batch_size}, ${eval:${trainer.devices} * ${loader.batch_size} * ${trainer.num_nodes}}}\n  gradient_clip_val: 1.0\n  precision: 'bf16'\n  num_sanity_val_steps: 2\n  max_steps: 1_000_000\n  log_every_n_steps: 100\n  limit_train_batches: 1.0   # train on full dataset, can be used to toggle quick run\n  limit_val_batches: 1.0     # validate on full dataset, can be used to toggle quick run\n  val_check_interval: 5000\n  check_val_every_n_epoch: null  # Disable end-of-epoch validation. Instead, validate every trainer.val_check_interval steps.\n\nwandb:\n  project: duo\n  group: null\n  job_type: null\n  name: null\n  id: ${.name}_${seed}\n  tags:\n    - ${noise.type}\n    - ${data.train}\n    - ${data.valid}\n    - ${algo.name}\n\nhydra:\n  run:\n    dir: ./outputs/${data.train}/${now:%Y.%m.%d}/${now:%H%M%S}\n  job:\n    chdir: true\n\ncheckpointing:\n  # Use custom `save_dir` if, e.g., saving to S3 bucket, otherwise leave this parameter as is\n  save_dir: ${cwd:}\n  # Note: `checkpoints` path should correspond to `checkpoint_every_n_steps.dirpath`\n  resume_from_ckpt: true\n  resume_ckpt_path: ${.save_dir}/checkpoints/last.ckpt\n"
  },
  {
    "path": "configs/data/ag_news.yaml",
    "content": "train: ag_news\nvalid: ag_news\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/cifar10.yaml",
    "content": "train: cifar10\nvalid: cifar10\nmodality: image\ntokenizer_name_or_path: cifar10\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nnum_classes: 10  # 10 / null\nstreaming: False\nsize: 1024\nlength: 3072\nwrap: False\ninsert_train_eos: False\ninsert_valid_eos: False\ninsert_train_spacial: False"
  },
  {
    "path": "configs/data/fineweb-edu.yaml",
    "content": "train: HuggingFaceFW/fineweb-edu\nvalid: openwebtext-valid  #wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: True\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/lambada.yaml",
    "content": "train: lambada\nvalid: lambada\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/lm1b-gpt2.yaml",
    "content": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/lm1b-streaming.yaml",
    "content": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: bert-base-uncased\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: False\nstreaming: True\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/lm1b-wrap.yaml",
    "content": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: bert-base-uncased\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/lm1b.yaml",
    "content": "train: lm1b\nvalid: lm1b\nmodality: text\ntokenizer_name_or_path: bert-base-uncased\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: False\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/openwebtext-split.yaml",
    "content": "train: openwebtext-train\nvalid: openwebtext-valid\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/openwebtext-streaming.yaml",
    "content": "train: openwebtext\nvalid: wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /tmp/data\nwrap: True\nstreaming: True\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/openwebtext.yaml",
    "content": "train: openwebtext\nvalid: wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/ptb.yaml",
    "content": "train: ptb\nvalid: ptb\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/scientific_papers_arxiv.yaml",
    "content": "train: scientific_papers_arxiv\nvalid: scientific_papers_arxiv\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/scientific_papers_pubmed.yaml",
    "content": "train: scientific_papers_pubmed\nvalid: scientific_papers_pubmed\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/synthetic.yaml",
    "content": "train: synthetic\nvalid: synthetic\nmodality: text\ntokenizer_name_or_path: synthetic\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: True\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/text8-crop.yaml",
    "content": "# TODO: When using this dataset, set model.length = 256 to match D3PM setup\ntrain: text8-crop\nvalid: text8\nmodality: text\ntokenizer_name_or_path: text8\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/text8.yaml",
    "content": "# TODO: When using this dataset, set model.length = 256 to match D3PM setup\ntrain: text8\nvalid: text8\nmodality: text\ntokenizer_name_or_path: text8\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/wikitext103.yaml",
    "content": "train: wikitext103\nvalid: wikitext103\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/data/wikitext2.yaml",
    "content": "train: wikitext2\nvalid: wikitext2\nmodality: text\ntokenizer_name_or_path: gpt2\ncache_dir: /share/kuleshov/ssahoo/textdiffusion/data\nwrap: True\nstreaming: False\ninsert_train_eos: True\ninsert_valid_eos: True"
  },
  {
    "path": "configs/lr_scheduler/constant_warmup.yaml",
    "content": "_target_: transformers.get_constant_schedule_with_warmup\nnum_warmup_steps: 2500"
  },
  {
    "path": "configs/lr_scheduler/cosine_decay_warmup.yaml",
    "content": "_target_: utils.CosineDecayWarmupLRScheduler\nt_in_epochs: False\nt_initial: ${eval:${trainer.max_steps}-${.warmup_t}}\nwarmup_prefix: True\nwarmup_lr_init: 1e-6\nwarmup_t: ${eval:0.1*${trainer.max_steps}}\nlr_min: 1e-6"
  },
  {
    "path": "configs/lr_scheduler/step_scheduler.yaml",
    "content": "_target_: torch.optim.lr_scheduler.LambdaLR\nlr_lambda:\n  _target_: utils.LRHalveScheduler\n  warmup_steps: 500\n  n_halve_steps: 10_000"
  },
  {
    "path": "configs/model/medium.yaml",
    "content": "name: medium\ntype: ddit\nhidden_size: 1024\ncond_dim: 128\nlength: 1024\nn_blocks: 24\nn_heads: 16\nscale_by_sigma: True\ndropout: 0.1\ntie_word_embeddings: False\nvocab_lookup: True"
  },
  {
    "path": "configs/model/small.yaml",
    "content": "name: small\ntype: ddit\nhidden_size: 768\ncond_dim: 128\nlength: 1024\nn_blocks: 12\nn_heads: 12\nscale_by_sigma: True\ndropout: 0.1\ntie_word_embeddings: False\nvocab_lookup: True"
  },
  {
    "path": "configs/model/tiny-dimamba.yaml",
    "content": "name: tiny\ntype: dimamba\nhidden_size: 512\ncond_dim: 128\nlength: 1024\nn_blocks: 14\nn_heads: 8\nscale_by_sigma: True\ndropout: 0.1\ntemb_strategy: adaln \ntie_word_embeddings: False"
  },
  {
    "path": "configs/model/tiny.yaml",
    "content": "name: tiny\ntype: ddit\nhidden_size: 256\ncond_dim: 128\nlength: 1024\nn_blocks: 8\nn_heads: 8\nscale_by_sigma: True\ndropout: 0.1\ntie_word_embeddings: False\nvocab_lookup: True"
  },
  {
    "path": "configs/model/unet.yaml",
    "content": "name: unet\ntype: unet\nch: 128 \nnum_res_blocks: 2\nnum_scales: 4\nch_mult: [1, 2, 2, 2]\ninput_channels: 3\noutput_channels: -1 # determined by vocab_size\nscale_count_to_put_attn: 1 # at 16 res\ndata_min_max: [0, 255] # No need currently\ndropout: 0.1\nskip_rescale: True\ntime_conditioning: True # Whether to add in time embeddings\ntime_scale_factor: 1000 \ntime_embed_dim: ${.ch}\nfix_logistic: False\nsize: ${data.size}\ncond_dim: ${.ch}\nlength: ${data.length}"
  },
  {
    "path": "configs/noise/cosine.yaml",
    "content": "type: cosine\neps: 1e-3"
  },
  {
    "path": "configs/noise/log-linear.yaml",
    "content": "type: log-linear\neps: 1e-3"
  },
  {
    "path": "configs/prior/none.yaml",
    "content": "type: none\nlatent_width: 0\nlatent_height: 0"
  },
  {
    "path": "configs/strategy/ddp.yaml",
    "content": "_target_: lightning.pytorch.strategies.DDPStrategy\nfind_unused_parameters: false\n"
  },
  {
    "path": "configs/strategy/fsdp.yaml",
    "content": "# TODO(yair): Currenly not compatible with grad clipping\n_target_: lightning.pytorch.strategies.FSDPStrategy\nsharding_strategy: SHARD_GRAD_OP\n"
  },
  {
    "path": "dataloader.py",
    "content": "import functools\nimport itertools\nimport json\nimport math\nimport os\nimport re\nimport shutil\nimport typing\nimport urllib\nimport zipfile\nfrom typing import Optional\n\nimport datasets\nimport einops\nimport fsspec\nimport numpy as np\nimport requests\nimport tokenizers\nimport torch\nimport torchvision\nfrom torchvision import transforms as th_transforms\n\nimport transformers\n\nimport utils\n\nLOGGER = utils.get_logger(__name__)\n\n\n\nclass RawPixelsVisionTokenizer:\n  def __init__(self, vocab_size, image_size,\n               add_mask_token=True, add_special_tokens=True):\n    self.pad_token_id = None\n    self.pad_token = None\n    if add_mask_token:\n      self.mask_token = vocab_size\n      self.mask_token_id = vocab_size\n      self.vocab_size = vocab_size + 1  # mask token\n    else:\n      self.vocab_size = vocab_size\n    if add_special_tokens:\n      self.bos_token_id = vocab_size\n      self.bos_token = vocab_size\n      self.eos_token_id = vocab_size + 1\n      self.eos_token = vocab_size + 1\n      # mask token, bos_token, eos_token\n      self.vocab_size = self.vocab_size + 2  \n    else:\n      self.vocab_size = self.vocab_size\n    self.image_size = image_size\n\n  def __call__(self, x):\n    return x\n\n  def batch_decode(self, x):\n    x = einops.rearrange(x, 'b (c h w) -> b c h w', c=3,\n                         h=self.image_size)\n    x = x.to(dtype=torch.uint8)\n    return x\n\n  def decode(self, x):\n    x = einops.rearrange(x, '(c h w) -> h w c', c=3,\n                         h=self.image_size)\n    x = x.to(dtype=torch.uint8)\n    return x\n  \n  def __len__(self):\n    return self.vocab_size\n  \n\nclass DiscreteCIFAR10(torch.utils.data.Dataset):\n  def __init__(self, cache_dir, train):\n    self._dataset = torchvision.datasets.CIFAR10(\n      root=cache_dir, train=train, download=True)\n\n    transforms = []\n    if train:\n      transforms += [th_transforms.RandomHorizontalFlip()]\n      \n    transforms += [th_transforms.Lambda(\n          lambda x: torch.from_numpy(np.array(x))),\n                  th_transforms.Lambda(\n          lambda x: einops.rearrange(x, \"h w c -> (c h w)\")),]\n\n    self.transform = th_transforms.Compose(transforms)\n\n  def __len__(self):\n    return len(self._dataset)\n\n  def __getitem__(self, index):\n    img, labels = self._dataset[index]\n\n    img = self.transform(img)\n    attention_mask = torch.ones_like(img)\n    return {'input_ids': img.to(torch.long), 'labels': labels,\n            'attention_mask': attention_mask}\n\n\ndef wt_detokenizer(string):\n  # contractions\n  string = string.replace(\"s '\", \"s'\")\n  string = re.sub(r\"/' [0-9]/\", r\"/'[0-9]/\", string)\n  # number separators\n  string = string.replace(\" @-@ \", \"-\")\n  string = string.replace(\" @,@ \", \",\")\n  string = string.replace(\" @.@ \", \".\")\n  # punctuation\n  string = string.replace(\" : \", \": \")\n  string = string.replace(\" ; \", \"; \")\n  string = string.replace(\" . \", \". \")\n  string = string.replace(\" ! \", \"! \")\n  string = string.replace(\" ? \", \"? \")\n  string = string.replace(\" , \", \", \")\n  # double brackets\n  string = re.sub(r\"\\(\\s*([^\\)]*?)\\s*\\)\", r\"(\\1)\", string)\n  string = re.sub(r\"\\[\\s*([^\\]]*?)\\s*\\]\", r\"[\\1]\", string)\n  string = re.sub(r\"{\\s*([^}]*?)\\s*}\", r\"{\\1}\", string)\n  string = re.sub(r\"\\\"\\s*([^\\\"]*?)\\s*\\\"\", r'\"\\1\"', string)\n  string = re.sub(r\"'\\s*([^']*?)\\s*'\", r\"'\\1'\", string)\n  # miscellaneous\n  string = string.replace(\"= = = =\", \"====\")\n  string = string.replace(\"= = =\", \"===\")\n  string = string.replace(\"= =\", \"==\")\n  string = string.replace(\" \" + chr(176) + \" \", chr(176))\n  string = string.replace(\" \\n\", \"\\n\")\n  string = string.replace(\"\\n \", \"\\n\")\n  string = string.replace(\" N \", \" 1 \")\n  string = string.replace(\" 's\", \"'s\")\n  return string\n\ndef ptb_detokenizer(x):\n  x = x.replace(\" 's\", \"'s\")\n  x = x.replace(\"s ' \", \"s' \")\n  x = x.replace(\" n't\", \"n't\")\n  x = x.replace(\" \\n \", \"\\n\")\n  x = x.replace(\"\\\\/\", \"/\")\n  for _ in range(10):\n      x = x.replace(\" N \", \" 1 \")\n  x = x.replace(\"$ 1\", \"$1\")\n  x = x.replace(\"# 1\", \"#1\")\n  x = x.replace(\"<unk>\", \"?\")\n  return x\n\n\ndef lm1b_detokenizer(x):\n  x = x.replace('http : / / ', 'http://')\n  x = x.replace('https : / / ', 'https://')\n  x = re.sub(r' \\'(\\w+)', r\"'\\1\", x)\n  x = re.sub(r' (\\w+) \\. ', r' \\1. ', x)\n  x = re.sub(r' (\\w+) \\.$', r' \\1.', x)\n  x = x.replace(' ? ', '? ')\n  x = re.sub(r' \\?$', '?', x)\n  x = x.replace(' ! ', '! ')\n  x = re.sub(r' \\!$', '!', x)\n  x = x.replace(' , ', ', ')\n  x = x.replace(' : ', ': ')\n  x = x.replace(' ; ', '; ')\n  x = x.replace(' / ', '/')\n  x = re.sub(r'\\\" ([^\\\"]+) \\\"', r'\"\\1\"', x)\n  x = re.sub(r'\\' ([^\\']+) \\'', r\"'\\1'\", x)\n  x = re.sub(r'\\( ([^\\(\\)]+) \\)', r\"(\\1)\", x)\n  x = re.sub(r'\\[ ([^\\[\\]]+) \\]', r\"[\\1]\", x)\n  x = x.replace('$ ', '$')\n  x = x.replace('£ ', '£')\n  return x\n\n\ndef lambada_detokenizer(text):\n  text = text.replace(\"“\", '\"')\n  text = text.replace(\"”\", '\"')\n  return '\\n'+text.strip()\n\n\ndef scientific_papers_detokenizer(x):\n  x = wt_detokenizer(x)\n  x = lm1b_detokenizer(x)\n  return x\n\n\nclass SyntheticTokenizer(\n  transformers.PreTrainedTokenizer):\n  \n  def __init__(\n    self,\n    vocab_size,\n    bos_token=\"[BOS]\",\n    eos_token=\"[EOS]\",\n    sep_token=None,\n    cls_token=None,\n    pad_token=None,\n    mask_token=None,\n    unk_token=None,\n    **kwargs):\n    \n    self.tokens = []\n    \n    for i in range (vocab_size - 2):\n      # appending space for readability\n      self.tokens.append(str(i) + \" \")\n    \n    self._vocab_str_to_int = {\n      '[BOS]': vocab_size - 2,\n      '[EOS]': vocab_size - 1,\n      ** {ch: i for i, ch in enumerate(self.tokens)}}\n    \n    self._vocab_int_to_str = {\n      v: k for k, v in self._vocab_str_to_int.items()}\n    \n    super().__init__(\n      bos_token=bos_token,\n      eos_token=eos_token,\n      sep_token=sep_token,\n      cls_token=cls_token,\n      pad_token=pad_token,\n      mask_token=mask_token,\n      unk_token=unk_token,\n      **kwargs)\n\n  @property\n  def vocab_size(self) -> int:\n    return len(self._vocab_str_to_int)\n\n  def _tokenize(self, text: str, **kwargs) -> typing.List[str]:\n    return list(text.lower())\n\n  def _convert_token_to_id(self, token: str) -> int:\n    return self._vocab_str_to_int.get(\n      token, self._vocab_str_to_int['[UNK]'])\n\n  def _convert_id_to_token(self, index: int) -> str:\n    return self._vocab_int_to_str[index]\n\n  def convert_tokens_to_string(self, tokens):\n    return ''.join(tokens)\n\n  def get_vocab(self) -> typing.Dict[str, int]:\n    return self._vocab_str_to_int\n\n\ndef _generate_synthetic_data(dataset_size, \n                             seq_len, vocab_size):\n  dataset = np.zeros((dataset_size, seq_len), dtype=int)\n  # tokens representing sequence boundary\n  dataset[:, 0] = vocab_size - 2  # bos\n  dataset[:, -1] = vocab_size - 1  # eos\n\n  for i in range(dataset_size):\n    # sample from 0, 1, ..., vocab_size - 3\n    temp = np.random.randint(vocab_size - 2)\n    for j in reversed(range(1, seq_len - 1)):\n      dataset[i, j] = temp\n      if temp != 0:\n        temp = temp // 4\n      else:\n        temp = np.random.randint(vocab_size - 2)\n\n  return dataset\n\n\ndef generate_synthetic_dataset(train_dataset_size, \n                               validation_dataset_size, \n                               seq_len, vocab_size):\n  np.random.seed(42)\n  train_data = torch.from_numpy(\n    _generate_synthetic_data(train_dataset_size, \n                             seq_len, vocab_size))\n  train_dataset = datasets.Dataset.from_dict({\n    'input_ids': train_data, \n    'attention_mask': torch.ones_like(train_data),\n  })\n  train_dataset.set_format(type='torch')\n\n  np.random.seed(41)\n  validation_data = torch.from_numpy(\n    _generate_synthetic_data(validation_dataset_size, \n                             seq_len, vocab_size))\n  validation_dataset = datasets.Dataset.from_dict({\n    'input_ids': validation_data, \n    'attention_mask': torch.ones_like(validation_data),\n  })\n  validation_dataset.set_format(type='torch')\n\n  return {\n    'train': train_dataset,\n    'validation': validation_dataset,\n  }\n\n\nclass Text8Tokenizer(transformers.PreTrainedTokenizer):\n  def __init__(\n    self,\n    bos_token='[BOS]',\n    eos_token='[EOS]',\n    sep_token='[SEP]',\n    cls_token='[CLS]',\n    pad_token='[PAD]',\n    mask_token='[MASK]',\n    unk_token='[UNK]',\n    **kwargs):\n    self.characters = list('abcdefghijklmnopqrstuvwxyz ')\n    self._vocab_str_to_int = {\n      '[CLS]': 0,\n      '[SEP]': 1,\n      '[BOS]': 2,\n      '[EOS]': 3,\n      '[MASK]': 4,\n      '[PAD]': 5,\n      '[RESERVED]': 6,\n      '[UNK]': 7,\n      ** {ch: i + 8 for i, ch in enumerate(self.characters)}}\n    self._vocab_int_to_str = {\n      v: k for k, v in self._vocab_str_to_int.items()}\n    super().__init__(\n      bos_token=bos_token,\n      eos_token=eos_token,\n      sep_token=sep_token,\n      cls_token=cls_token,\n      pad_token=pad_token,\n      mask_token=mask_token,\n      unk_token=unk_token,\n      **kwargs)\n\n  @property\n  def vocab_size(self) -> int:\n    return len(self._vocab_str_to_int)\n\n  def _tokenize(self, text: str, **kwargs) -> typing.List[str]:\n    return list(text.lower())\n\n  def _convert_token_to_id(self, token: str) -> int:\n    return self._vocab_str_to_int.get(\n      token, self._vocab_str_to_int['[UNK]'])\n\n  def _convert_id_to_token(self, index: int) -> str:\n    return self._vocab_int_to_str[index]\n\n  def convert_tokens_to_string(self, tokens):\n    return ''.join(tokens)\n\n  def get_vocab(self) -> typing.Dict[str, int]:\n    return self._vocab_str_to_int\n\n\ndef get_lambada_test_dataset():\n    url = \"https://openaipublic.blob.core.windows.net/gpt-2/data/lambada_test.jsonl\"\n\n    def read_jsonl_to_list(url):\n      response = requests.get(url, stream=True)\n      data_list = []\n\n      # Process each line in the response content\n      for line in response.iter_lines(decode_unicode=True):\n        if line:\n          data = json.loads(line)\n          data_list.append(data)\n\n      return data_list\n\n    lambada_data = read_jsonl_to_list(url)\n    dataset = datasets.Dataset.from_list(lambada_data)\n    return dataset\n\n\ndef get_text8_dataset(cache_dir, max_seq_length=256,\n                      drop_last=True, crop_train=False):\n  \"\"\"Adapted from:\n    https://github.com/google-research/google-research/blob/master/d3pm/text/datasets.py#L344\n\n    Args:\n      cache_dir: str, path to cache directory.\n      max_seq_length: int, maximum length of sequences.\n          (default: 256, as in D3PM codebase.)\n      drop_last: bool, whether to drop the last incomplete\n          batch. (default: True, as in D3PM codebase.)\n      crop_train: bool, whether to subsample contiguous\n          subsequences from training example. serves to\n          make sure transformer models with absolute position\n          embeddings do not have incorrect position-wise\n          marginals. (default: False, but necessary to match D3PM AR)\n\n    Returns:\n      dataset: dataset.DatasetDict, with keys 'train',\n          'valid', 'test'.\n  \"\"\"\n  url = 'http://mattmahoney.net/dc/text8.zip'\n  if not crop_train:\n    cache_dir = f'{cache_dir}/text8'\n  else:\n    cache_dir = f'{cache_dir}/text8-crop-train'\n  split_names = ['train', 'validation', 'test']\n  if not all([\n    utils.fsspec_exists(os.path.join(cache_dir, split))\n    for split in split_names\n  ]):\n    # Check if raw data exists\n    raw_cache_dir = os.path.join(cache_dir, 'raw_data')\n    if not all([\n      utils.fsspec_exists(\n        os.path.join(raw_cache_dir, f'text8.{split}.txt'))\n      for split in split_names\n    ]):\n      if not utils.fsspec_exists(\n        os.path.join(raw_cache_dir, 'text8.zip')):\n        utils.fsspec_mkdirs(raw_cache_dir, exist_ok=True)\n        LOGGER.info('Downloading text8 from URL {}.'.format(url))\n        with (urllib.request.urlopen(url) as in_stream,\n              open(os.path.join(raw_cache_dir, 'text8.zip'),\n                   'wb') as out_file):\n          shutil.copyfileobj(in_stream, out_file)\n\n      with fsspec.open(\n        os.path.join(raw_cache_dir, 'text8.zip'),\n        'rb') as f:\n        rawdata = zipfile.ZipFile(f).read(\n          'text8').decode('utf-8')\n\n      # Splits taken from D3PM codebase\n      splits = {\n        'train': rawdata[:90000000],\n        'validation': rawdata[90000000: 95000000],\n        'test': rawdata[95000000:],\n      }\n\n      for split, data in splits.items():\n        _path = os.path.join(raw_cache_dir,\n                             f'text8.{split}.txt')\n        with fsspec.open(_path, 'w') as f:\n          f.write(data)\n    else:\n      splits = {}\n      for split in split_names:\n        _path = os.path.join(raw_cache_dir,\n                             f'text8.{split}.txt')\n        with fsspec.open(_path, 'r') as f:\n          splits[split] = f.read()\n\n    # Chunk and save as datasets.DatasetDict\n    def chunks(lst, n):\n      \"\"\"Yield successive n-sized chunks from lst.\"\"\"\n      for i in range(0, len(lst), n):\n        yield lst[i:i + n]\n\n    dataset_dict = {}\n    for k, v in splits.items():\n      if k == 'train' and crop_train == True:\n        chunk_size = 2 * max_seq_length\n      else:\n        chunk_size = max_seq_length\n      text = list(chunks(v, chunk_size))\n      if drop_last and len(text[-1]) < chunk_size:\n        text = text[:-1]\n      dataset_dict[k] = datasets.Dataset.from_dict({'text': text})\n    dataset = datasets.DatasetDict(dataset_dict)\n    dataset.save_to_disk(cache_dir)\n  else:\n    dataset = datasets.load_from_disk(cache_dir)\n\n  return dataset\n\n\ndef _group_texts(examples, block_size, bos, eos):\n  # Concatenate all texts.\n  concatenated_examples = list(itertools.chain(* examples['input_ids']))\n  total_length = len(concatenated_examples)\n  # TODO(yair): look into not dropping the remainder but rather padding it.\n  # We drop the small remainder, and if the total_length < block_size - 2\n  # we exclude this batch and return an empty dict.\n  # We could add padding if the model supported it instead of\n  # this drop, you can customize this part to your needs.\n  new_block_size = block_size - 2  # [BOS] and [EOS] to be added\n  total_length = (total_length // new_block_size) * new_block_size\n  # Split by chunks of max_len.\n  result = {}\n  _values = []\n  _attn_masks = []\n  for i in range(0, total_length, new_block_size):\n    _values.append(\n      [bos]\n      + concatenated_examples[i : i + new_block_size]\n      + [eos])\n    _attn_masks.append(torch.ones(block_size))\n  result['input_ids'] = _values\n  result['attention_mask'] = _attn_masks\n  return result\n\n\ndef get_dataset(dataset_name,\n                tokenizer,\n                wrap,\n                mode,\n                cache_dir,\n                insert_eos=True,\n                block_size=1024,\n                num_proc=len(os.sched_getaffinity(0)),\n                streaming=False,\n                revision : Optional[str]=None):\n  if dataset_name == 'cifar10':\n    assert mode in ('train', 'validation')\n    return DiscreteCIFAR10(cache_dir=cache_dir, \n                           train=mode=='train')\n  eos_tag = ''\n  if not insert_eos:\n    eos_tag = '_eosFalse'\n  if wrap:\n    filename = f'{dataset_name}_{mode}_bs{block_size}_wrapped{eos_tag}.dat'\n  else:\n    filename = f'{dataset_name}_{mode}_bs{block_size}_unwrapped{eos_tag}.dat'\n  _path = os.path.join(cache_dir, filename)\n  \n  if utils.fsspec_exists(_path):\n    LOGGER.info(f'Loading data from: {_path}')\n    return datasets.load_from_disk(_path).with_format('torch')\n  LOGGER.info(f'Generating new data at: {_path}')\n  LOGGER.info(f'{streaming=}')  \n\n  crop_train = dataset_name == 'text8-crop'\n  if mode == 'train' and crop_train:\n    # double block size for sub-sampling\n    block_size *= 2\n  \n  if dataset_name == 'wikitext103':\n    dataset = datasets.load_dataset(\n      'wikitext',\n      name='wikitext-103-raw-v1',\n      cache_dir=cache_dir,\n      revision=revision)\n  elif dataset_name == 'wikitext2':\n    dataset = datasets.load_dataset(\n      'wikitext',\n      name='wikitext-2-raw-v1',\n      cache_dir=cache_dir,\n      revision=revision)\n  elif dataset_name == 'ptb':\n    dataset = datasets.load_dataset(\n      'ptb_text_only',\n      cache_dir=cache_dir,\n      revision=revision)\n  elif dataset_name == 'lambada':\n    dataset = get_lambada_test_dataset()\n  elif dataset_name == 'text8':\n    assert wrap\n    assert revision is None\n    dataset = get_text8_dataset(\n      cache_dir, max_seq_length=block_size)\n  elif dataset_name == 'text8-crop':\n    assert revision is None\n    dataset = get_text8_dataset(\n      cache_dir, max_seq_length=block_size, crop_train=True)\n  elif dataset_name == 'openwebtext-train':\n    dataset = datasets.load_dataset(\n      'openwebtext',\n      split='train[:-100000]',\n      cache_dir=cache_dir,\n      revision=revision,\n      streaming=False,\n      num_proc=num_proc)\n  elif dataset_name == 'openwebtext-valid':\n    dataset = datasets.load_dataset(\n      'openwebtext',\n      split='train[-100000:]',\n      cache_dir=cache_dir,\n      revision=revision,\n      streaming=False,\n      num_proc=num_proc)\n  elif dataset_name == 'scientific_papers_arxiv':\n    dataset = datasets.load_dataset(\n      'scientific_papers', 'arxiv',\n      cache_dir=cache_dir,\n      streaming=streaming,\n      revision=revision)\n  elif dataset_name == 'scientific_papers_pubmed':\n    dataset = datasets.load_dataset(\n      'scientific_papers', 'pubmed',\n      cache_dir=cache_dir,\n      streaming=streaming,\n      revision=revision)\n  elif dataset_name == 'ag_news':\n    dataset = datasets.load_dataset(\n      'ag_news',\n      cache_dir=cache_dir,\n      streaming=streaming,\n      revision=revision)\n  elif dataset_name == 'synthetic':\n    assert streaming\n    assert wrap  # i.e., no pad tokens\n    dataset = generate_synthetic_dataset(\n      train_dataset_size=100000,\n      validation_dataset_size=1024,\n      seq_len=32,\n      vocab_size=256,\n    )\n  else:\n    dataset = datasets.load_dataset(\n      dataset_name,\n      cache_dir=cache_dir,\n      streaming=streaming,\n      revision=revision)\n\n  if dataset_name in ['lambada', 'openwebtext-train',\n                      'openwebtext-valid']:\n    data = dataset\n  else:\n    data = dataset[mode]\n    if dataset_name == 'synthetic':\n      # already tokenized, no further actions required\n      return data\n\n  if dataset_name.startswith('wikitext'):\n    detokenizer = wt_detokenizer\n  elif dataset_name == 'ptb':\n    detokenizer = ptb_detokenizer\n  elif dataset_name == 'lm1b':\n    detokenizer = lm1b_detokenizer\n  elif dataset_name == 'lambada':\n    detokenizer = lambada_detokenizer\n  elif dataset_name.startswith('scientific_papers'):\n    detokenizer = scientific_papers_detokenizer\n  else:\n    detokenizer = None\n\n  def _apply_detokenizer(detokenizer):\n    def detok(text):\n      for i, t in enumerate(text, 0):\n        text[i] = detokenizer(t)\n      return text\n    return detok\n  \n  EOS = tokenizer.encode(tokenizer.eos_token)[0]\n  BOS = tokenizer.encode(tokenizer.bos_token)[0]\n\n  def preprocess_and_tokenize(example):\n    if dataset_name == 'ptb':\n      text = example['sentence']\n    elif 'scientific_papers' in dataset_name:\n      text = example['article']\n    else:\n      text = example['text']\n    \n    if detokenizer is not None:\n      text = _apply_detokenizer(detokenizer)(text)\n\n    tokenizer.padding_side = 'right'\n    tokenizer.truncation_side = 'right'\n\n    if wrap:\n      tokens = tokenizer(text,\n                         add_special_tokens=False,\n                         return_attention_mask=False,\n                         return_token_type_ids=False)\n      if insert_eos:\n        tokens = {'input_ids':\n                  [t + [EOS] for t in tokens['input_ids']]}\n      # Still missing BOS, but will be added in group_texts\n    else:\n      tokens = tokenizer(text,\n                         max_length=block_size,\n                         padding='max_length',\n                         truncation=True,\n                         add_special_tokens=True,\n                         return_attention_mask=True,\n                         return_token_type_ids=True)\n    return tokens\n\n  if streaming:\n    tokenized_dataset = data.map(\n      preprocess_and_tokenize,\n      batched=True)\n  else:\n    tokenized_dataset = data.map(\n      preprocess_and_tokenize,\n      batched=True,\n      num_proc=num_proc,\n      load_from_cache_file=True,\n      desc='Tokenizing')\n  if dataset_name == 'ptb':\n    tokenized_dataset = tokenized_dataset.remove_columns(\n      'sentence')\n  elif 'scientific_papers' in dataset_name:\n    tokenized_dataset = tokenized_dataset.remove_columns([\n      'article', 'abstract', 'section_names'])\n  elif dataset_name == 'ag_news':\n    tokenized_dataset = tokenized_dataset.remove_columns(\n      ['text', 'label'])\n  else:\n    tokenized_dataset = tokenized_dataset.remove_columns(\n      'text')\n\n  if not wrap:\n    if not streaming:\n      tokenized_dataset.save_to_disk(_path)\n    return tokenized_dataset.with_format('torch')\n\n  group_texts = functools.partial(\n    _group_texts, block_size=block_size, bos=BOS, eos=EOS)\n  if streaming:\n    chunked_dataset = tokenized_dataset.map(\n      group_texts,\n      batched=True)\n  else:\n    chunked_dataset = tokenized_dataset.map(\n      group_texts,\n      batched=True,\n      num_proc=num_proc,\n      load_from_cache_file=True,\n      desc='Grouping')\n    chunked_dataset.save_to_disk(_path)\n  chunked_dataset = chunked_dataset.with_format('torch')\n  return chunked_dataset\n\n\ndef get_tokenizer(config):\n  if config.data.tokenizer_name_or_path == 'text8':\n    tokenizer = Text8Tokenizer()\n  elif config.data.tokenizer_name_or_path == 'bert-base-uncased':\n    tokenizer = transformers.BertTokenizer.\\\n      from_pretrained('bert-base-uncased')\n  elif config.data.tokenizer_name_or_path == 'synthetic':\n    tokenizer = SyntheticTokenizer(vocab_size=256)\n  elif config.data.tokenizer_name_or_path == 'cifar10':\n    return RawPixelsVisionTokenizer(\n      vocab_size=256, image_size=32, add_special_tokens=False, \n      add_mask_token='mdlm' in config.algo.name)\n  else:\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\n      config.data.tokenizer_name_or_path)\n\n  if (isinstance(tokenizer, transformers.GPT2TokenizerFast)\n      or isinstance(tokenizer, transformers.GPT2Tokenizer)):\n    tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(\n      (tokenizer.bos_token, tokenizer.bos_token_id),\n      (tokenizer.eos_token, tokenizer.eos_token_id))\n\n  # For wrapped batches:\n  #  [BOS] sent1 [EOS] sent2-fragment [EOS]\n  #  [BOS] sent2-fragment [EOS] sent3 [EOS]\n  if tokenizer.bos_token is None:\n    if tokenizer.cls_token is None:\n      raise AttributeError(\n        'Tokenizer must have a bos_token or '\n        f'cls_token: {tokenizer}')\n    tokenizer.bos_token = tokenizer.cls_token\n  if tokenizer.eos_token is None:\n    if tokenizer.sep_token is None:\n      raise AttributeError(\n        'Tokenizer must have a eos_token '\n        f'or sep_token: {tokenizer}')\n    tokenizer.eos_token = tokenizer.sep_token\n  if tokenizer.pad_token is None:\n    tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n\n  return tokenizer\n    \n\ndef get_dataloaders(config, tokenizer, skip_train=False,\n                    skip_valid=False, valid_seed=None):\n  num_gpus = torch.cuda.device_count()\n  assert (config.loader.global_batch_size\n          == (config.loader.batch_size\n              * config.trainer.num_nodes\n              * num_gpus\n              * config.trainer.accumulate_grad_batches))\n  if config.loader.global_batch_size % (\n    num_gpus * config.trainer.accumulate_grad_batches) != 0:\n    raise ValueError(\n      f'Train Batch Size {config.training.batch_size}'\n      f'not divisible by {num_gpus} gpus with accumulation '\n      f'{config.trainer.accumulate_grad_batches}.')\n  if config.loader.eval_global_batch_size % num_gpus != 0:\n    raise ValueError(\n      f'Eval Batch Size for {config.eval.batch_size} '\n      f'not divisible by {num_gpus}.')\n  if skip_train:\n    train_set = None\n  else:\n    train_set = get_dataset(\n      config.data.train,\n      tokenizer,\n      mode='train',\n      wrap=config.data.wrap,\n      insert_eos=config.data.insert_train_eos,\n      cache_dir=config.data.cache_dir,\n      block_size=config.model.length,\n      streaming=config.data.streaming,\n      num_proc=config.loader.num_workers,\n      revision=config.data.get(\"train_revision\", None))\n  \n  if config.data.valid in ['text8', 'lm1b', 'ag_news']:\n    validation_split = 'test'\n  else:\n    validation_split = 'validation'\n  if skip_valid:\n    valid_set = None\n  else:\n    valid_set = get_dataset(\n      config.data.valid,\n      tokenizer,\n      wrap=config.data.wrap,\n      mode=validation_split,\n      cache_dir=config.data.cache_dir,\n      insert_eos=config.data.insert_valid_eos,\n      block_size=config.model.length,\n      streaming=config.data.streaming,\n      num_proc=config.loader.num_workers,\n      revision=config.data.get(\"valid_revision\", None))\n\n  if skip_train:\n    train_loader = None\n  else:\n    train_loader = torch.utils.data.DataLoader(\n      train_set,\n      batch_size=config.loader.batch_size,\n      num_workers=config.loader.num_workers,\n      pin_memory=config.loader.pin_memory,\n      shuffle=not config.data.streaming,\n      persistent_workers=True)\n    train_loader.tokenizer = tokenizer\n  if skip_valid:\n    valid_loader = None\n  else:\n    if valid_seed is None:\n      shuffle_valid = False\n      generator = None\n    else:\n      shuffle_valid = True\n      generator = torch.Generator().manual_seed(valid_seed)\n    valid_loader = torch.utils.data.DataLoader(\n      valid_set,\n      batch_size=config.loader.eval_batch_size,\n      num_workers=config.loader.num_workers,\n      pin_memory=config.loader.pin_memory,\n      shuffle=shuffle_valid,\n      generator=generator)\n    # Will be used in generative perplexity calculation\n    valid_loader.tokenizer = tokenizer\n\n  return train_loader, valid_loader\n\n\n# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py\n\n\nclass RandomFaultTolerantSampler(torch.utils.data.RandomSampler):\n\n  def __init__(self, *args, generator=None, **kwargs):\n    # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,\n    # which should be reproducible if pl.seed_everything was called beforehand.\n    # This means that changing the seed of the experiment will also change the\n    # sampling order.\n    if generator is None:\n      seed = int(torch.empty((), dtype=torch.int64).random_().item())\n      generator = torch.Generator().manual_seed(seed)\n    kwargs.pop('shuffle', None)\n    super().__init__(*args, generator=generator, **kwargs)\n    self.counter = 0\n    self.restarting = False\n\n  def state_dict(self):\n    return {'random_state': self.generator.get_state(),\n            'counter': self.counter}\n\n  def load_state_dict(self, state_dict):\n    self.generator.set_state(state_dict.get('random_state'))\n    self.counter = state_dict['counter']\n    # self.start_counter = self.counter\n    self.restarting = True\n\n  # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per\n  # epoch, and subsequent epoch will have very few batches.\n\n  def __iter__(self) -> typing.Iterator[int]:\n    n = len(self.data_source)\n\n    self.state = self.generator.get_state()\n    indices = torch.randperm(n, generator=self.generator).tolist()\n\n    if not self.restarting:\n      self.counter = 0\n    else:\n      indices = indices[self.counter:]\n      self.restarting = False\n\n    for index in indices:\n      self.counter += 1\n      yield index\n\n    self.counter = 0\n\n\nclass FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):\n\n  def __init__(self, *args, **kwargs):\n    super().__init__(*args, **kwargs)\n    self.counter = 0\n    self.restarting = False\n\n  def state_dict(self):\n    return {'epoch': self.epoch, 'counter': self.counter}\n\n  def load_state_dict(self, state_dict):\n    self.epoch = state_dict['epoch']\n    self.counter = state_dict['counter']\n    self.restarting = True\n\n  # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per\n  # epoch, and subsequent epoch will have very few batches.\n  def __iter__(self):\n    if self.shuffle:\n      # deterministically shuffle based on epoch and seed\n      g = torch.Generator()\n      g.manual_seed(self.seed + self.epoch)\n      indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]\n    else:\n      indices = list(range(len(self.dataset)))  # type: ignore[arg-type]\n\n    if not self.drop_last:\n      # add extra samples to make it evenly divisible\n      padding_size = self.total_size - len(indices)\n      if padding_size <= len(indices):\n        indices += indices[:padding_size]\n      else:\n        indices += (indices * math.ceil(\n          padding_size / len(indices)))[:padding_size]\n    else:\n      # remove tail of data to make it evenly divisible.\n      indices = indices[:self.total_size]\n    assert len(indices) == self.total_size\n\n    # subsample\n    indices = indices[self.rank:self.total_size:self.num_replicas]\n    assert len(indices) == self.num_samples\n\n    if not self.restarting:\n      self.counter = 0\n    else:\n      indices = indices[self.counter:]\n      self.restarting = False\n\n    for index in indices:\n      self.counter += 1\n      yield index\n\n    self.counter = 0\n"
  },
  {
    "path": "discrete_diffusion_harness.py",
    "content": "import torch\nfrom omegaconf import OmegaConf\n\nfrom lm_eval.api.model import LM\nfrom lm_eval.api.registry import register_model\nfrom lm_eval.__main__ import cli_evaluate\nfrom lm_eval.api.instance import Instance\nfrom datasets import Dataset\nfrom tqdm import tqdm\nimport numpy as np\nimport algo\nimport dataloader\n\n\"\"\"\n# Instructions on running the eval\n\n## What is the script doing?\nThe script evaluates (or approximates) the log-likelihood of\nprefix + suffix. The script will load a checkpoint, and \ndepending on the config, use the corresponding class \n(e.g. mdlm, duo, ar, etc).\n\n- For MCQ, the log-likelihood of all continuations given the \n  same prefix is evaluated. The most likely continuation is \n  selected as the \"correct\" answer according to the model.\n\n- For lambada_openai, the model is correct if the true \n  continuation is generated as the argmax. For diffusion, we \n  inject noise in the continuation, and check whether the true \n  answer computed in a single forward pass is the most likely. \n  This is naturally favoring AR models, as they run one forward \n  pass per token, while diffusion use a single pass for all \n  tokens for simplicity. Therefore, I usually only compare on \n  MCQ since it is more fair.\n\nTo run the script, you need to install the lm-eval-harness package:\n```\npip install git+https://github.com/EleutherAI/lm-evaluation-harness\n```\n\n## Tasks that we tested with\n**MCQ**: arc_easy, arc_challenge, hellaswag, winogrande, \n         boolq, openbookqa, race, social_iqa, mathqa, piqa\n**Likelihood based**: lambada_openai\n\n## Batch size that fits at the small scale\n    - boolq -> 64\n    - openbookqa -> 256\n    - race -> 32\n    - social_iqa -> 512\n    - winogrande -> 64\n    - mathqa -> 64\n    - lambada_openai -> 64\n    - arc_easy -> 256\n    - arc_challenge -> 256\n    - hellaswag -> 64\n    - piqa -> 64\n\n## Important flags\n  --trust_remote_code -> some datasets execute code when loading \n                         from huggingface. Without this flag, \n                         the script might crash.\n  --batch_size        -> max. num elements to use in parallel \n                         to eval the likelihood (for diffusion). \n                         For simplicity, inputs are NOT padded \n                         and batched for AR, though it should \n                         be fairly easy to add.\n  --tasks             -> one or multiple comma-separated tasks \n                         to evaluate on.\n  --model_args        -> string (without spaces) that contains \n                         the arguments to pass to the evaluator \n                         (stuff like checkpoints path, number \n                         of MC samples to evaluate the \n                         likelihood, path to the sentencepiece \n                         tokenizer, etc).\n  --output_path       -> path to a json file where the evaluation \n                         results will be saved, instead of \n                         only being printed in the terminal \n                         (they will always be printed).\n  --limit             -> debugging flag; limit the number of \n                         examples to a fixed amount, instead \n                         of using the whole dataset\n\n\n## Example commands\n\n\n### Run a single task with an AR model\npython discrete_diffusion_harness.py \\\n    --tasks arc_easy \\\n    --batch_size 256 \\\n    --model dlm \\\n    --model_args checkpoint_path=/home/username/baselines/ar/1M.ckpt \\\n    --output_path ./harness_results/ar/1M/arc_easy.ckpt\n\n### Run a single task with an MDLM model, and 2048 MC samples to approximate the likelihood\npython discrete_diffusion_harness.py \\\n    --tasks arc_easy \\\n    --model dlm \\\n    --batch_size 256 \\\n    --model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \\\n    --output_path ./harness_results/mdlm/1M/arc_easy.ckpt\n\n### Debug with 20 examples only\npython discrete_diffusion_harness.py \\\n    --tasks arc_easy \\\n    --model dlm \\\n    --batch_size 256 \\\n    --model_args checkpoint_path=/home/username/baselines/mdlm/1M.ckpt,num_mc_samples=2048 \\\n    --limit 20 \\\n    --output_path ./harness_results/mdlm/1M/arc_easy.ckpt\n\n\n### Run the benchmarks from \"The Diffusion Duality (Chapter 2)\":\n    (Arc-e, Arc-c, HSwag, WinoG, PIQA, MathQA, OQA)\n-> just change the checkpoint to evaluate mdlm, ar, or duo\n\nfor task_config in \"arc_easy 256\" \"arc_challenge 256\" \"hellaswag 64\" \"winogrande 64\" \"piqa 64\" \"mathqa 64\" \"openbookqa 256\"; do\n  task=$(echo $task_config | cut -d' ' -f1)\n  batch_size=$(echo $task_config | cut -d' ' -f2)\n  python discrete_diffusion_harness.py \\\n    --batch_size $batch_size \\\n    --tasks $task \\\n    --model dlm \\\n    --model_args checkpoint_path=/path/to/checkpoint.ckpt \\\n    --output_path ./harness_results/duo/$task.json\ndone\n\n\"\"\"\n\n\ndef requests_to_dataset(config, requests, tokenizer, num_proc):\n  def _tokenize(e):\n    eos_idx = tokenizer.eos_token_id\n    bos_idx = tokenizer.bos_token_id\n    prefix_tokens = tokenizer(e['prefix'], \n                              return_attention_mask=False, \n                              add_special_tokens=False\n                              )['input_ids']\n    target_tokens = tokenizer(e['target'], \n                              return_attention_mask=False, \n                              add_special_tokens=False\n                              )['input_ids']\n    prefix_tokens = [bos_idx] + prefix_tokens\n    target_tokens = target_tokens + [eos_idx]\n    \n    return {\n        'prefix_text': e['prefix'],\n        'target_text': e['target'],\n        'prefix': prefix_tokens,\n        'target': target_tokens,\n    }\n  ds = []\n  ds = [{'prefix': req.args[0], 'target': req.args[1]} \n        for req in requests]\n  ds = Dataset.from_list(ds)\n  ds = ds.map(_tokenize, num_proc=num_proc)\n  ds = ds.with_format('torch')\n  seq_lenths = [len(x['prefix']) + len(x['target']) \n                for x in ds]\n  \n  num_larger = len([x for x in seq_lenths \n                    if x > config.model.length])\n  if num_larger > 0:\n    print(f'\\033[91mThere are some examples that are longer '\n          f'than the context length, they will be ignored '\n          f'during evaluation. Number of such sequences: '\n          f'{num_larger}\\033[0m')\n\n  return ds\n\n\ndef _eval_suffix_nll_generators(config, module, prefix, \n  suffix, batch_size, num_samples, loss_avg_mode):\n  device = module.device\n  assert num_samples % batch_size == 0\n  full_sentence = torch.cat([prefix, suffix], dim=-1\n                  ).repeat(batch_size, 1).to(module.device)\n  all_ts = module._sample_t(num_samples, accum_step=None)\n  for idx in range(0, num_samples, batch_size):\n    t = all_ts[idx:idx+batch_size].unsqueeze(-1)\n    dalpha_t, alpha_t = module.noise(t)\n    alpha_t = alpha_t.to(device)\n    sigma = module._sigma_from_alphat(alpha_t)\n    x0 = full_sentence.to(device)\n    # Inject noise\n    xt = module.q_xt(full_sentence, alpha_t).to(device)\n    if loss_avg_mode == 'full':\n      pass  # nothing to do\n    elif loss_avg_mode == 'suffix':\n      xt[:, :len(prefix)] = prefix\n      # recompute alpha_t based on number of masked tokens, \n      #   for conditioning of the backbone\n      alpha_t = (xt == x0).float().mean(dim=1)[:, None]\n      t = module.noise.get_t_for_alpha(alpha_t)\n      # We need to get dalpha_t for the loss:\n      alpha_t = alpha_t.to(device)\n      sigma = module._sigma_from_alphat(alpha_t)\n    else:\n      raise ValueError(loss_avg_mode)\n\n    yield xt, x0, t, sigma, alpha_t, dalpha_t\n    \n\ndef eval_suffix_nll(config, module, prefix, suffix, batch_size, \n                    num_samples, loss_avg_mode):\n  if config.algo.name in ('mdlm', 'duo', 'duo_base', \n                          'distillation', 'ot-finetune'):\n    return eval_suffix_nll_diffusion(config, module, prefix, \n      suffix, batch_size, num_samples, loss_avg_mode)\n  elif config.algo.name == 'ar':\n    return eval_suffix_nll_ar(config, module, prefix, \n      suffix, batch_size, num_samples, loss_avg_mode)\n  else:\n    raise ValueError(config.algo.name)\n\n\ndef eval_suffix_nll_ar(config, module, prefix, suffix,\n                    batch_size, num_samples, loss_avg_mode):\n  x_cat = torch.cat([prefix, suffix[:-1]], dim=-1)\n  x_cat = x_cat.reshape(1, -1).to(module.device)\n\n  with torch.amp.autocast('cuda', dtype=torch.float32):\n    out = module.backbone(x_cat, sigma=None)\n\n  out[:, :, module.mask_index] = module.neg_infinity\n  suffix_out = out[:, len(prefix) - 1:, :]\n  suffix_logits = torch.log_softmax(suffix_out, dim=-1)\n  index = suffix[None, :, None].to(module.device)\n  nll = torch.gather(-suffix_logits, dim=-1, \n                     index=index).mean()\n  return float(nll.cpu())\n\ndef eval_suffix_nll_diffusion(config, module, prefix, suffix, \n  batch_size, num_samples, loss_avg_mode):\n  all_losses = []\n  generator =  _eval_suffix_nll_generators(config, module, \n    prefix, suffix, batch_size, num_samples, loss_avg_mode)\n  for xt, x0, t, sigma, alpha_t, dalpha_t in generator:\n    log_x_theta = module(xt, sigma, labels=None)\n    token_nll = module.nll_per_token(log_x_theta, xt, x0, \n                                     alpha_t, dalpha_t)\n    if loss_avg_mode == 'full':\n      loss = float(token_nll.mean())\n    elif loss_avg_mode == 'suffix':\n      loss = float(token_nll[:, len(prefix):].mean())\n    all_losses.append(loss)\n  return float(np.mean(all_losses))\n\n\n@register_model(\"dlm\")\nclass DiscreteDiffusionHarness(LM):\n  def __init__(self, pretrained=\"NONE\", max_length=1024,\n    num_mc_samples=1024, batch_size=64, device=\"cuda\",\n    checkpoint_path=None, num_proc=8, loss_avg_mode='full',\n    *args, **kwargs):\n    super().__init__()\n    # Whether to use the full sequence, or suffix only to \n    #  approximate the NLL. Full should be the correct way.\n    assert loss_avg_mode in ('full', 'suffix')\n    ckpt = torch.load(checkpoint_path, map_location='cpu', \n                      weights_only=False)\n    config = ckpt['hyper_parameters']['config']\n    # Backfill missing keys into legacy checkpoints\n    if not hasattr(config.training, 'class_dropout_p'):\n      OmegaConf.set_struct(config, False)\n      config.training.class_dropout_p = 0.0\n      OmegaConf.set_struct(config, True)\n    self.tokenizer = dataloader.get_tokenizer(config)\n    if config.algo.name == 'mdlm':\n      self.model = algo.MDLM(config, self.tokenizer)\n    elif config.algo.name in ('duo', 'duo_base', \n                              'distillation', 'ot-finetune'):\n      self.model = algo.DUO_BASE(config, self.tokenizer)\n    elif config.algo.name == 'ar':\n      self.model = algo.AR(config, self.tokenizer)\n    else:\n      raise ValueError(f'Implement for {config.algo.name}')\n    self.config = config\n    self.num_proc = num_proc\n    self.num_mc_samples = num_mc_samples\n    self.batch_size = int(batch_size)\n    self.device = device\n    self.loss_avg_mode = loss_avg_mode\n\n    self.model.load_state_dict(ckpt['state_dict'])\n    self.model.to(device)\n    self.model.eval()\n\n  def suffix_greedy_prediction(self, prefix, target):\n    if self.config.algo.name == 'mdlm':\n      return self._suffix_greedy_prediction_mdlm(prefix, \n                                                 target)\n    elif self.config.algo.name in ('duo', 'duo_base', \n                              'distillation', 'ot-finetune'):\n      return self._suffix_greedy_prediction_duo_base(prefix, \n                                                     target)\n    elif self.config.algo.name == 'ar':\n      return self._suffix_greedy_prediction_ar(prefix, target)\n    else:\n      raise ValueError(self.config.algo.name)\n    \n  def _suffix_greedy_prediction_ar(self, prefix, target):\n    x_cat = torch.cat([prefix, target[:-1]], \n      dim=-1).reshape(1, -1).to(self.device)\n\n    # Follows generate_samples in AR (algo.py)\n    out = self.model.backbone(x_cat, sigma=None)\n    out[:, :, self.model.mask_index] = self.model.neg_infinity\n    out = out.log_softmax(-1)\n    preds_suffix = out[:, len(prefix) - 1:, :]\n    greedy_preds = preds_suffix.argmax(-1).flatten()\n    return (greedy_preds.cpu() == target).all().item()\n\n\n  def _suffix_greedy_prediction_mdlm(self, prefix, target):\n    mask_idx = self.model.mask_index\n    eos_idx = self.tokenizer.eos_token_id\n    # Note: because of the preprocessing, we know that the \n    #  last token is an eos token.\n    noisy_target = [mask_idx] * (len(target) - 1) + [eos_idx]\n    noisy_target = torch.tensor(noisy_target, \n                                device=self.device)\n    prefix = prefix.to(self.device)\n    seq = torch.concatenate([prefix, noisy_target], \n                            dim=-1).reshape(1, -1)\n    sigma = torch.zeros(size=(seq.shape[0], 1), \n                        device=self.device)\n    logits = self.model(seq, sigma, labels=None)\n    assert logits.shape[0] == 1\n    suffix_logits = logits[0, len(prefix):]\n    target_preds = suffix_logits.argmax(-1).cpu()\n    correct = target_preds == target\n    correct = correct.all()\n    return bool(correct)\n  \n  def _suffix_greedy_prediction_duo_base(self, prefix, target):\n    noisy_suffix = torch.randint(\n      0, self.model.vocab_size, size=target.shape, \n      dtype=target.dtype, device=self.device)\n\n    prefix = prefix.to(self.device)\n    # shape: (1, len)\n    noisy_seq = torch.concatenate([prefix, noisy_suffix])[None, :]\n    # Set the EOS token at the end\n    noisy_seq[0, -1] = target[-1]\n    clean_seq = torch.concatenate([prefix, target.to(self.device)])[None, :]\n    \n    alpha_t = (noisy_seq == clean_seq).float().mean(dim=1)[:, None]\n    sigma = self.model._sigma_from_alphat(alpha_t)\n    t = self.model.noise.get_t_for_alpha(alpha_t)\n    logits = self.model(noisy_seq, sigma, labels=None)\n    assert logits.shape[0] == 1\n    suffix_logits = logits[0, len(prefix):]\n    target_preds = suffix_logits.argmax(-1).cpu()\n\n    correct = target_preds == target\n    correct = correct.all()\n    return bool(correct)\n  \n  @torch.no_grad()\n  def loglikelihood(self, requests: list[Instance]) \\\n                                -> list[tuple[float, bool]]:\n    dataset = requests_to_dataset(self.config, requests, \n                                  self.tokenizer, self.num_proc)\n    out = []\n    for elem in tqdm(dataset, 'Computing likelihood...'):\n      prefix = elem['prefix']\n      target = elem['target']\n      \n      if len(prefix) + len(target) > self.model.config.model.length:\n        # If the request is too long, skip it.\n        ll = 0.0\n        is_target_greedy_dec = False\n        out.append((ll, is_target_greedy_dec))\n        print(\"SKIPPING\")\n        continue\n\n      ll = -eval_suffix_nll(self.config, self.model, prefix, \n        target, self.batch_size, self.num_mc_samples, \n        self.loss_avg_mode)\n      is_target_greedy_dec = self.suffix_greedy_prediction(\n        prefix, target)\n      out.append((ll, bool(is_target_greedy_dec)))\n    return out\n\n  def loglikelihood_rolling(\n        self, requests: list[Instance]\n    ) -> list[tuple[float, bool]]:\n    raise NotImplementedError\n  \n  def generate_until(self, context, max_length, stop, \n                     **generation_kwargs):\n    raise NotImplementedError\n\n\nif __name__ == \"__main__\":\n    cli_evaluate()"
  },
  {
    "path": "main.py",
    "content": "import json\nimport os\n\nimport fsspec\nimport hydra\nimport lightning as L\nfrom lightning.fabric import Fabric\nimport omegaconf\nimport rich.syntax\nimport rich.tree\nimport torch\nfrom torch.utils.data.distributed import DistributedSampler\nfrom torchmetrics.image.fid import FrechetInceptionDistance\nfrom torchmetrics.image.inception import InceptionScore\nfrom tqdm import tqdm, trange\n\nimport algo\nimport dataloader\nimport utils\n\nomegaconf.OmegaConf.register_new_resolver(\n  'cwd', os.getcwd)\nomegaconf.OmegaConf.register_new_resolver(\n  'device_count', torch.cuda.device_count)\nomegaconf.OmegaConf.register_new_resolver(\n  'eval', eval)\nomegaconf.OmegaConf.register_new_resolver(\n  'div_up', lambda x, y: (x + y - 1) // y)\n\n\ndef _load_from_checkpoint(diffusion_model, config, tokenizer):\n  if 'hf' in config.algo.backbone:\n    return diffusion_model(\n      config, tokenizer=tokenizer).to('cuda')\n  \n  return diffusion_model.load_from_checkpoint(\n    config.eval.checkpoint_path,\n    tokenizer=tokenizer,\n    config=config)\n\n\n@L.pytorch.utilities.rank_zero_only\ndef _print_config(\n  config: omegaconf.DictConfig,\n  resolve: bool = True,\n  save_cfg: bool = True) -> None:\n  \"\"\"Prints content of DictConfig using Rich library and its tree structure.\n  \n  Args:\n    config (DictConfig): Configuration composed by Hydra.\n    resolve (bool): Whether to resolve reference fields of DictConfig.\n    save_cfg (bool): Whether to save the configuration tree to a file.\n  \"\"\"\n\n  style = 'dim'\n  tree = rich.tree.Tree('CONFIG', style=style, guide_style=style)\n\n  fields = config.keys()\n  for field in fields:\n    branch = tree.add(field, style=style, guide_style=style)\n\n    config_section = config.get(field)\n    branch_content = str(config_section)\n    if isinstance(config_section, omegaconf.DictConfig):\n      branch_content = omegaconf.OmegaConf.to_yaml(\n        config_section, resolve=resolve)\n\n    branch.add(rich.syntax.Syntax(branch_content, 'yaml'))\n  rich.print(tree)\n  if save_cfg:\n    with fsspec.open(\n      '{}/config_tree.txt'.format(\n        config.checkpointing.save_dir), 'w') as fp:\n      rich.print(tree, file=fp)\n\n\n@L.pytorch.utilities.rank_zero_only\ndef _print_batch(config, train_ds, valid_ds, tokenizer, k=64):\n  for dl_type, dl in [\n    ('train', train_ds), ('valid', valid_ds)]:\n    print(f'Printing {dl_type} dataloader batch.')\n    batch = next(iter(dl))\n    print('Batch input_ids.shape', batch['input_ids'].shape)\n    if config.data.modality == 'text':\n      first = batch['input_ids'][0, :k]\n      last = batch['input_ids'][0, -k:]\n      print(f'First {k} tokens:', tokenizer.decode(first))\n      print('ids:', first)\n      print(f'Last {k} tokens:', tokenizer.decode(last))\n      print('ids:', last)\n\n\ndef _generate_samples(diffusion_model, config, logger,\n                      tokenizer):\n  logger.info('Starting Sample Eval.')\n  model = _load_from_checkpoint(\n    diffusion_model=diffusion_model,\n    config=config,\n    tokenizer=tokenizer)\n  model.metrics.gen_ppl.reset()\n  model.metrics.sample_entropy.reset()\n  if config.eval.disable_ema:\n    logger.info('Disabling EMA.')\n    model.ema = None\n  stride_length = config.sampling.stride_length\n  num_strides = config.sampling.num_strides\n  all_samples = []\n  for _ in trange(config.sampling.num_sample_batches):\n    if config.sampling.semi_ar:\n      _, intermediate_samples, _ = model.restore_model_and_semi_ar_sample(\n        stride_length=stride_length,\n        num_strides=num_strides,\n        dt=1 / config.sampling.steps)\n      text_samples = intermediate_samples[-1]\n      # Note: Samples generated using semi-ar method\n      # need to to be processed before computing generative perplexity\n      # since these samples contain numerous <|endoftext|> tokens\n      # and diffusion.compute_generative_perplexity() discards\n      # any text after the first EOS token.\n    else:\n      samples = model.restore_model_and_sample(\n        num_steps=config.sampling.steps)\n      model.metrics.record_entropy(samples)\n      text_samples = model.tokenizer.batch_decode(samples)\n      model.metrics.record_generative_perplexity(\n        text_samples, config.model.length, model.device)\n      all_samples.extend(list(text_samples))\n  generative_ppl = 0.\n  entropy = 0.\n  if not config.sampling.semi_ar:\n    generative_ppl = model.metrics.gen_ppl.compute().item()\n    entropy = model.metrics.sample_entropy.compute().item()\n    logger.info(f'Generative perplexity: {generative_ppl}')\n    logger.info(f'Sample entropy: {entropy}')\n  samples_path = config.eval.generated_samples_path\n  with fsspec.open(samples_path, 'w') as f:\n    json.dump({'generative_ppl': generative_ppl,\n               'entropy': entropy,\n               'generated_seqs': all_samples}, f, indent=4)\n  logger.info(f'Samples saved at: {samples_path}',)\n\ndef _eval_ppl(diffusion_model, config, logger, tokenizer):\n  logger.info('Starting Perplexity Eval.')\n\n  model = _load_from_checkpoint(\n    diffusion_model=diffusion_model,\n    config=config,\n    tokenizer=tokenizer)\n  if config.eval.disable_ema:\n    logger.info('Disabling EMA.')\n    model.ema = None\n\n  wandb_logger = None\n  if config.get('wandb', None) is not None:\n    wandb_logger = L.pytorch.loggers.WandbLogger(\n      config=omegaconf.OmegaConf.to_object(config),\n      ** config.wandb)\n  callbacks = []\n  if 'callbacks' in config:\n    for _, callback in config.callbacks.items():\n      callbacks.append(hydra.utils.instantiate(callback))\n  trainer = hydra.utils.instantiate(\n    config.trainer,\n    default_root_dir=os.getcwd(),\n    callbacks=callbacks,\n    strategy=hydra.utils.instantiate(config.strategy),\n    logger=wandb_logger)\n  _, valid_ds = dataloader.get_dataloaders(\n    config, tokenizer, skip_train=True, valid_seed=config.seed)\n  trainer.validate(model, valid_ds)\n\n\ndef _train(diffusion_model, config, logger, tokenizer):\n  logger.info('Starting Training.')\n  wandb_logger = None\n  if config.get('wandb', None) is not None:\n    wandb_logger = L.pytorch.loggers.WandbLogger(\n      config=omegaconf.OmegaConf.to_object(config),\n      **config.wandb)\n\n  if (config.checkpointing.resume_from_ckpt\n      and config.checkpointing.resume_ckpt_path is not None\n      and utils.fsspec_exists(\n        config.checkpointing.resume_ckpt_path)):\n    ckpt_path = config.checkpointing.resume_ckpt_path\n  else:\n    ckpt_path = None\n\n  # Lightning callbacks\n  callbacks = []\n  if 'callbacks' in config:\n    for _, callback in config.callbacks.items():\n      callbacks.append(hydra.utils.instantiate(callback))\n\n  train_ds, valid_ds = dataloader.get_dataloaders(\n    config, tokenizer)\n  _print_batch(config, train_ds, valid_ds, tokenizer)\n\n  if config.training.finetune_path != '':\n    assert utils.fsspec_exists(config.training.finetune_path)\n    model = diffusion_model.load_from_checkpoint(\n      config.training.finetune_path,\n      tokenizer=tokenizer,\n      config=config)\n  else:\n    model = diffusion_model(config, tokenizer=valid_ds.tokenizer)\n\n  trainer = hydra.utils.instantiate(\n    config.trainer,\n    default_root_dir=os.getcwd(),\n    callbacks=callbacks,\n    strategy=hydra.utils.instantiate(config.strategy),\n    logger=wandb_logger)\n  trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)\n\n\ndef _eval_fid(diffusion_model, config, logger, tokenizer):\n  logger.info('Preparing data and model for FID eval.')\n  fabric = Fabric(accelerator=config.trainer.accelerator,\n                  devices=config.trainer.devices,\n                  num_nodes=config.trainer.num_nodes)\n  \n  fabric.launch()\n  seed = config.seed + fabric.global_rank\n  L.seed_everything(seed)\n  print(f'(Rank {fabric.global_rank}): seed: {seed}')\n  model = _load_from_checkpoint(\n    diffusion_model=diffusion_model,\n    config=config,\n    tokenizer=tokenizer)\n  model.to(fabric.device)\n  if config.eval.disable_ema:\n    logger.info('Disabling EMA.')\n    model.ema = None\n  model._eval_mode()\n\n  assert config.data.train == 'cifar10', \\\n                    'FID eval only implemented for CIFAR-10'\n\n  # Like in flow matching papers: FID against train\n  loader, _ = dataloader.get_dataloaders(config, \n    tokenizer=tokenizer, skip_valid=True)\n\n  sampler = DistributedSampler(\n    loader.dataset,\n    num_replicas=fabric.world_size,\n    rank=fabric.global_rank,\n    shuffle=False)\n  \n  loader = torch.utils.data.DataLoader(\n    loader.dataset,\n    batch_size=config.loader.eval_batch_size,\n    sampler=sampler,\n    num_workers=loader.num_workers if hasattr(loader, 'num_workers') else 0,\n    pin_memory=getattr(loader, 'pin_memory', False))\n  \n  # Check each GPU must generate the same number of images\n  assert len(loader) == len(loader.dataset) // loader.batch_size // fabric.world_size, \\\n     f'{len(loader)=}, {len(loader.dataset)=}, {loader.batch_size=}, {fabric.world_size=}'\n\n  fid_calculator = FrechetInceptionDistance(\n    normalize=False).to(fabric.device)\n  is_calculator = InceptionScore(\n    normalize=False).to(fabric.device)\n  \n  desc = f'(Rank {fabric.global_rank}) Sampling...'\n  for batch in tqdm(loader, desc=desc):\n    real_samples = batch['input_ids']\n    # Generate images with labels matching the true data\n    labels = batch['labels']\n\n    gen_samples = model.generate_samples(\n      num_samples=real_samples.shape[0], \n      num_steps=config.sampling.steps,\n      labels=labels)\n    # Reshape 1D seq -> 2D image\n    gen_samples = model.tokenizer.batch_decode(gen_samples)\n    real_samples = model.tokenizer.batch_decode(\n      real_samples).to(fabric.device)\n\n    fid_calculator.update(gen_samples, real=False)\n    fid_calculator.update(real_samples, real=True)\n    is_calculator.update(gen_samples)\n  \n  fabric.barrier()\n  logger.info('Done sampling. Computing FID & IS...')\n  fid = fid_calculator.compute()\n  incep_score = is_calculator.compute()\n\n  if fabric.global_rank == 0:\n    logger.info(f'FID: {fid}')\n    logger.info(f'IS: {incep_score}')\n  fabric.barrier()\n\n\n@hydra.main(version_base=None, config_path='configs',\n            config_name='config')\ndef main(config):\n  \"\"\"Main entry point for training.\"\"\"\n  L.seed_everything(config.seed)\n  _print_config(config, resolve=True, save_cfg=True)\n  \n  logger = utils.get_logger(__name__)\n  tokenizer = dataloader.get_tokenizer(config)\n  if config.algo.name == 'ar':\n    diffusion_model = algo.AR\n  elif config.algo.name == 'mdlm':\n    diffusion_model = algo.MDLM\n  elif config.algo.name == 'duo_base':\n    diffusion_model = algo.DUO_BASE\n  elif config.algo.name == 'd3pm':\n    diffusion_model = algo.D3PMAbsorb\n  elif config.algo.name == 'sedd':\n    diffusion_model = algo.SEDDAbsorb\n  elif config.algo.name == 'duo':\n    diffusion_model = algo.DUO\n  elif config.algo.name == 'distillation':\n    diffusion_model = algo.Distillation\n  elif config.algo.name == 'ot-finetune':\n    diffusion_model = algo.OptimalTransportFinetune\n  else:\n    raise ValueError(\n      f'Invalid algorithm name: {config.algo.name}')\n  kwargs = {'diffusion_model': diffusion_model,\n            'config': config,\n            'tokenizer': tokenizer,\n            'logger': logger}\n  if config.mode == 'sample_eval':\n    _generate_samples(**kwargs)\n  elif config.mode == 'ppl_eval':\n    _eval_ppl(**kwargs)\n  elif config.mode == 'fid_eval':\n    _eval_fid(**kwargs)\n  else:\n    _train(**kwargs)\n\n\nif __name__ == '__main__':\n  main()"
  },
  {
    "path": "metrics.py",
    "content": "import math\nimport os\nimport typing\n\nimport torch\nimport torch.nn.functional as F\nimport torchmetrics\nimport transformers\n\nLOG2 = math.log(2)\n\n\nclass NLL(torchmetrics.aggregation.MeanMetric):\n  def update(self,\n             value:typing.Union[float, torch.Tensor],\n             weight:typing.Union[float, torch.Tensor]=1.0) -> None:\n    \"\"\"Update state with data.\n\n    Args:\n      value: Either a float or tensor containing data.\n        Additional tensor dimensions will be flattened\n      weight: Either a float or tensor containing weights\n        for calculating the average. Shape of weight should\n        be able to broadcast with the shape of `value`.\n        Default to `1.0` corresponding to simple harmonic\n        average.\n    \"\"\"\n    # broadcast weight to value shape\n    if not isinstance(value, torch.Tensor):\n        value = torch.as_tensor(value, dtype=self.dtype,\n                                device=self.device)\n    if (weight is not None\n        and not isinstance(weight, torch.Tensor)):\n      weight = torch.as_tensor(weight,\n                               dtype=self.dtype,\n                               device=self.device)\n    weight = torch.broadcast_to(weight, value.shape)\n    value, weight = self._cast_and_nan_check_input(value,\n                                                   weight)\n\n    if value.numel() == 0:\n      return\n    self.mean_value += value.sum()\n    self.weight += weight.sum()\n\n\nclass BPD(NLL):\n  def compute(self) -> torch.Tensor:\n    \"\"\"Computes the bits per dimension.\n\n    Returns:\n      bpd\n    \"\"\"\n    return self.mean_value / self.weight / LOG2\n\n\nclass Perplexity(NLL):\n  def compute(self) -> torch.Tensor:\n    \"\"\"Computes the Perplexity.\n\n    Returns:\n     Perplexity\n    \"\"\"\n    return torch.exp(self.mean_value / self.weight)\n\n\nclass Metrics:\n  def __init__(self, gen_ppl_eval_model_name_or_path=None,\n               eval_ppl_batch_size=None) -> None:\n    metrics = torchmetrics.MetricCollection({\n        'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity()})\n    metrics.set_dtype(torch.float64)\n    self.train_nlls = metrics.clone(prefix='train/')\n    self.train_aux = BPD()\n    self.valid_nlls = metrics.clone(prefix='val/')\n    self.valid_aux = BPD()\n    self.gen_ppl = Perplexity()\n    self.sample_entropy = torchmetrics.aggregation.MeanMetric()\n    self.eval_ppl_batch_size = eval_ppl_batch_size\n    self.gen_ppl_eval_model_name_or_path = gen_ppl_eval_model_name_or_path\n    self.tokenizer = transformers.AutoTokenizer.\\\n      from_pretrained(gen_ppl_eval_model_name_or_path)\n    if self.tokenizer.pad_token is None:\n      self.tokenizer.pad_token = self.tokenizer.eos_token\n      self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n\n  def to(self, *args, **kwargs):\n    self.gen_ppl = self.gen_ppl.to(*args, **kwargs)\n    self.sample_entropy = self.sample_entropy.to(*args, **kwargs)\n    self.train_nlls = self.train_nlls.to(*args, **kwargs)\n    self.train_aux = self.train_aux.to(*args, **kwargs)\n    self.valid_nlls = self.valid_nlls.to(*args, **kwargs)\n    self.valid_aux = self.valid_aux.to(*args, **kwargs)\n\n  def reset(self):\n    self.gen_ppl.reset()\n    self.sample_entropy.reset()\n    self.train_nlls.reset()\n    self.train_aux.reset()\n    self.valid_nlls.reset()\n    self.valid_aux.reset()\n\n  def update_train(self, nll, aux_loss, num_tokens):\n    self.train_nlls.update(nll, num_tokens)\n    self.train_aux.update(aux_loss, num_tokens)\n\n  def update_valid(self, nll, aux_loss, num_tokens):\n    self.valid_nlls.update(nll, num_tokens)\n    self.valid_aux.update(aux_loss, num_tokens)\n\n\n  @torch.no_grad()\n  def _eval_retokenize(self, text_samples, max_length,\n                       device):\n    \"\"\"Retokenizes samples for the eval model.\n    \n    Args:\n        text_samples: List of sentences generated by the model.\n    Returns:\n        samples: Samples re-tokenized for the eval model\n        attn_mask: Attention mask for the eval model\n        eval_context_size: Size of the context for the eval model\n    \"\"\"\n    if 'llama2' in self.gen_ppl_eval_model_name_or_path:\n      tokenizer_kwargs = {\n        'text_samples': text_samples,\n        'return_tensors': 'pt',\n        'return_token_type_ids': False,\n        'return_attention_mask': True,\n        'truncation': True,\n        'padding': True,\n        'max_length': max_length,\n      }\n      eval_context_size = 4096\n    else:\n      tokenizer_kwargs = {\n        'return_tensors': 'pt',\n        'return_token_type_ids': False,\n        'return_attention_mask': True,\n        'truncation': True,\n        'padding': True,\n        'max_length': max_length,\n      }\n      eval_context_size = 1024\n    samples = self.tokenizer(text_samples,\n                             **tokenizer_kwargs)\n    attn_mask = samples['attention_mask']\n    samples = samples['input_ids']\n    if 'llama2' not in self.gen_ppl_eval_model_name_or_path:\n      attn_mask = attn_mask.to(device)\n      samples = samples.to(device)      \n    return samples, attn_mask, eval_context_size\n\n  @torch.no_grad()\n  def record_entropy(self, tokens):\n    for sample in tokens:\n      _, counts = torch.unique(\n        sample, return_counts=True, sorted=False)\n      entropy = torch.special.entr(\n        counts.float() / counts.sum()).sum().item()\n      self.sample_entropy.update(entropy)\n\n  @torch.no_grad()\n  def record_generative_perplexity(\n    self,\n    text_samples: typing.List[str],\n    max_length: int,\n    retokenize: bool = True,\n    device='cuda') -> None:\n    \n    os.environ['TOKENIZERS_PARALLELISM'] = 'false'\n    eval_model = transformers.AutoModelForCausalLM.from_pretrained(\n      self.gen_ppl_eval_model_name_or_path).eval()\n    if 'llama2' not in self.gen_ppl_eval_model_name_or_path:\n      eval_model = eval_model.to(device)\n    # Re-tokenize using eval model's tokenizer\n    if retokenize:\n      (samples, attn_mask,\n       eval_context_size) = self._eval_retokenize(\n         text_samples, max_length=max_length, device=device)\n    else:\n      samples = text_samples\n      attn_mask = torch.ones(samples.shape).to(device)\n      eval_context_size = samples.shape[-1]\n    batch_size = min(self.eval_ppl_batch_size,\n                     samples.shape[0])\n    num_batches = samples.shape[0] // batch_size\n    for i in range(num_batches):\n      _samples = torch.split(\n        samples[i * batch_size: (i + 1) * batch_size],\n        eval_context_size,\n        dim=-1)\n      _attn_mask = torch.split(\n        attn_mask[i * batch_size: (i + 1) * batch_size],\n        eval_context_size,\n        dim=-1)\n      for (sample_chunk, attn_mask_chunk) in zip(_samples,\n                                                 _attn_mask):\n        logits = eval_model(sample_chunk,\n                            attention_mask=attn_mask_chunk)\n        logits = logits[0].transpose(-1, -2)\n        nlls = F.cross_entropy(logits[..., :-1],\n                               sample_chunk[..., 1:],\n                               reduction='none')\n        first_eos = (\n          sample_chunk\n          == self.tokenizer.eos_token_id).cumsum(-1) == 1\n        token_mask = sample_chunk != self.tokenizer.eos_token_id\n        valid_tokens = first_eos[..., 1:] + token_mask[..., 1:]\n        self.gen_ppl.update(nlls * valid_tokens, valid_tokens)\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from . import dit\nfrom . import ema\nfrom . import unet"
  },
  {
    "path": "models/dit.py",
    "content": "import math\nimport typing\n\nimport einops\nimport flash_attn\nimport flash_attn.layers.rotary\nimport huggingface_hub\nimport omegaconf\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n# Flags required to enable jit fusion kernels\ntorch._C._jit_set_profiling_mode(False)\ntorch._C._jit_set_profiling_executor(False)\ntorch._C._jit_override_can_fuse_on_cpu(True)\ntorch._C._jit_override_can_fuse_on_gpu(True)\n\n\ndef bias_dropout_add_scale(\n    x: torch.Tensor,\n    bias: typing.Optional[torch.Tensor],\n    scale: torch.Tensor,\n    residual: typing.Optional[torch.Tensor],\n    prob: float,\n    training: bool) -> torch.Tensor:\n  if bias is not None:\n    out = scale * F.dropout(x + bias, p=prob, training=training)\n  else:\n    out = scale * F.dropout(x, p=prob, training=training)\n\n  if residual is not None:\n    out = residual + out\n  return out\n\n\ndef get_bias_dropout_add_scale(training):\n  def _bias_dropout_add(x, bias, scale, residual, prob):\n    return bias_dropout_add_scale(\n      x, bias, scale, residual, prob, training)\n\n  return _bias_dropout_add\n\n\n# function overload\ndef modulate(x: torch.Tensor,\n             shift: torch.Tensor,\n             scale: torch.Tensor) -> torch.Tensor:\n  return x * (1 + scale) + shift\n\n\n@torch.jit.script\ndef bias_dropout_add_scale_fused_train(\n    x: torch.Tensor,\n    bias: typing.Optional[torch.Tensor],\n    scale: torch.Tensor,\n    residual: typing.Optional[torch.Tensor],\n    prob: float) -> torch.Tensor:\n  return bias_dropout_add_scale(\n    x, bias, scale, residual, prob, True)\n\n\n@torch.jit.script\ndef bias_dropout_add_scale_fused_inference(\n    x: torch.Tensor,\n    bias: typing.Optional[torch.Tensor],\n    scale: torch.Tensor,\n    residual: typing.Optional[torch.Tensor],\n    prob: float) -> torch.Tensor:\n  return bias_dropout_add_scale(\n    x, bias, scale, residual, prob, False)\n\n\n@torch.jit.script\ndef modulate_fused(x: torch.Tensor,\n                   shift: torch.Tensor,\n                   scale: torch.Tensor) -> torch.Tensor:\n  return modulate(x, shift, scale)\n\n\nclass Rotary(torch.nn.Module):\n  def __init__(self, dim, base=10_000):\n    super().__init__()\n    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))\n    self.register_buffer('inv_freq', inv_freq)\n    self.seq_len_cached = None\n    self.cos_cached = None\n    self.sin_cached = None\n\n  def forward(self, x, seq_dim=1):\n    seq_len = x.shape[seq_dim]\n    if seq_len != self.seq_len_cached:\n      self.seq_len_cached = seq_len\n      t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)\n      freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq.clone())\n      emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n      # dims are: batch, seq_len, qkv, head, dim\n      self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)\n      self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)\n      # This makes the transformation on v an identity.\n      self.cos_cached[:,:,2,:,:].fill_(1.)\n      self.sin_cached[:,:,2,:,:].fill_(0.)\n\n    return self.cos_cached, self.sin_cached\n\n\ndef rotate_half(x):\n  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]\n  return torch.cat((-x2, x1), dim=-1)\n\n\ndef split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin):\n  with torch.amp.autocast('cuda', enabled=False):\n    cos, sin = rotary_cos_sin\n    cos = cos.to(qkv.dtype)\n    sin = sin.to(qkv.dtype)\n    cos = cos[0,:,0,0,:cos.shape[-1]//2]\n    sin = sin[0,:,0,0,:sin.shape[-1]//2]\n    q, k, v = qkv.chunk(3, dim=2)\n    q = flash_attn.layers.rotary.apply_rotary_emb_torch(\n      q.squeeze(dim=2), cos, sin)\n    k = flash_attn.layers.rotary.apply_rotary_emb_torch(\n      k.squeeze(dim=2), cos, sin)\n    v = v.squeeze(dim=2)\n  return q, k, v\n\n\ndef apply_rotary_pos_emb(qkv, cos, sin):\n  cos = cos[0,:,0,0,:cos.shape[-1]//2]\n  sin = sin[0,:,0,0,:sin.shape[-1]//2]\n  return flash_attn.layers.rotary.apply_rotary_emb_qkv_(qkv, cos, sin)\n\n\ndef regular_attention_multi_headed(q, k, v):\n  # Assuming qkv is a tensor with shape [batch, seq_len, 3, num_heads, head_dim]\n  # where the 3 represents Q, K, V packed in that order\n  attention_output = F.scaled_dot_product_attention(\n    query=q.transpose(1, 2),\n    key=k.transpose(1, 2),\n    value=v.transpose(1, 2),\n    attn_mask=None,\n    dropout_p=0.0,\n    is_causal=False)\n  # [batch_size, seq_len, num_heads, head_dim]\n  attention_output = attention_output.transpose(1, 2)\n  return einops.rearrange(attention_output, 'b s h d -> b s (h d)')\n\n\n#################################################################################\n#                                  Layers                                       #\n#################################################################################\nclass LayerNorm(nn.Module):\n  def __init__(self, dim):\n    super().__init__()\n    self.weight = nn.Parameter(torch.ones([dim]))\n    self.dim = dim\n  def forward(self, x):\n    with torch.amp.autocast('cuda', enabled=False):\n      x = F.layer_norm(x.float(), [self.dim])\n    return x * self.weight[None, None, :]\n\n\ndef residual_linear(x, W, x_skip, residual_scale):\n  \"\"\"x_skip + residual_scale * W @ x\"\"\"\n  dim_out, dim_in = W.shape[0], W.shape[1]\n  return torch.addmm(\n    x_skip.view(-1, dim_out),\n    x.view(-1, dim_in),\n    W.T,\n    alpha=residual_scale).view(*x.shape[:-1], dim_out)\n\n\n#################################################################################\n#               Embedding Layers for Timesteps and Class Labels                 #\n#################################################################################\nclass TimestepEmbedder(nn.Module):\n  \"\"\"\n  Embeds scalar timesteps into vector representations.\n  \"\"\"\n  def __init__(self, hidden_size, frequency_embedding_size=256):\n    super().__init__()\n    self.mlp = nn.Sequential(\n      nn.Linear(frequency_embedding_size, hidden_size, bias=True),\n      nn.SiLU(),\n      nn.Linear(hidden_size, hidden_size, bias=True))\n    self.frequency_embedding_size = frequency_embedding_size\n\n  @staticmethod\n  def timestep_embedding(t, dim, max_period=10000):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param t: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an (N, D) Tensor of positional embeddings.\n    \"\"\"\n    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py\n    half = dim // 2\n    freqs = torch.exp(\n      - math.log(max_period)\n      * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)\n      / half)\n    args = t[:, None].float() * freqs[None]\n    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n    if dim % 2:\n      embedding = torch.cat(\n        [embedding,\n         torch.zeros_like(embedding[:, :1])], dim=-1)\n    return embedding\n\n  def forward(self, t):\n    t_freq = self.timestep_embedding(t, self.frequency_embedding_size)\n    t_emb = self.mlp(t_freq)\n    return t_emb\n\n\nclass LabelEmbedder(nn.Module):\n  \"\"\"Embeds class labels into vector representations.\n  \n  Also handles label dropout for classifier-free guidance.\n  \"\"\"\n  def __init__(self, num_classes, cond_size):\n    super().__init__()\n    self.embedding_table = nn.Embedding(num_classes + 1, cond_size)\n    self.num_classes = num_classes\n\n    # TODO think of initializing with 0.02 std deviation like in original DiT paper\n\n  def forward(self, labels):\n    embeddings = self.embedding_table(labels)\n    return embeddings\n    \n\n#################################################################################\n#                                 Core Model                                    #\n#################################################################################\n\nclass DDiTBlockCausal(nn.Module):\n  def __init__(self, dim, n_heads, mlp_ratio=4, dropout=0.1):\n    super().__init__()\n    self.n_heads = n_heads\n\n    self.norm1 = LayerNorm(dim)\n    self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)\n    self.attn_out = nn.Linear(dim, dim, bias=False)\n    self.dropout1 = nn.Dropout(dropout)\n\n    self.norm2 = LayerNorm(dim)\n    self.mlp = nn.Sequential(\n      nn.Linear(dim, mlp_ratio * dim, bias=True),\n      nn.GELU(approximate='tanh'),\n      nn.Linear(mlp_ratio * dim, dim, bias=True))\n    self.dropout2 = nn.Dropout(dropout)\n    self.dropout = dropout\n\n  def _get_bias_dropout_scale(self):\n    if self.training:\n      return bias_dropout_add_scale_fused_train\n    else:\n      return bias_dropout_add_scale_fused_inference\n\n  def forward(self, x, rotary_cos_sin, **kwargs):\n    del kwargs\n    batch_size, seq_len = x.shape[0], x.shape[1]\n\n    bias_dropout_scale_fn = self._get_bias_dropout_scale()\n\n    # attention operation\n    x_skip = x\n    x = self.norm1(x)\n\n    qkv = self.attn_qkv(x)\n    qkv = einops.rearrange(\n      qkv,\n      'b s (three h d) -> b s three h d',\n      three=3,\n      h=self.n_heads)\n    with torch.amp.autocast('cuda', enabled=False):\n      cos, sin = rotary_cos_sin\n      qkv = apply_rotary_pos_emb(\n        qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)\n      )\n    qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...')\n    cu_seqlens = torch.arange(\n      0, (batch_size + 1) * seq_len,\n      step=seq_len, dtype=torch.int32, device=qkv.device)\n    x = flash_attn.flash_attn_interface.flash_attn_varlen_qkvpacked_func(\n      qkv, cu_seqlens, seq_len, 0.0, causal=True)\n\n    x = einops.rearrange(x, '(b s) h d -> b s (h d)',\n                         b=batch_size)\n\n    scale = torch.ones(1, device=x.device, dtype=x.dtype)\n    x = bias_dropout_scale_fn(\n      self.attn_out(x), None, scale, x_skip, self.dropout)\n\n    # mlp operation\n    x = bias_dropout_scale_fn(\n      self.mlp(self.norm2(x)), None, scale, x, self.dropout)\n    return x\n\n\n\nclass DDiTBlock(nn.Module):\n  def __init__(self, dim, n_heads, adaLN,\n               cond_dim=None, mlp_ratio=4,\n               dropout=0.1):\n    super().__init__()\n    self.n_heads = n_heads\n    self.adaLN = adaLN\n\n    self.norm1 = LayerNorm(dim)\n    self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)\n    self.attn_out = nn.Linear(dim, dim, bias=False)\n    self.dropout1 = nn.Dropout(dropout)\n\n    self.norm2 = LayerNorm(dim)\n    self.mlp = nn.Sequential(\n      nn.Linear(dim, mlp_ratio * dim, bias=True),\n      nn.GELU(approximate='tanh'),\n      nn.Linear(mlp_ratio * dim, dim, bias=True))\n    self.dropout2 = nn.Dropout(dropout)\n    self.dropout = dropout\n\n    if self.adaLN:\n      self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim)\n      self.adaLN_modulation.weight.data.zero_()\n      self.adaLN_modulation.bias.data.zero_()\n\n\n  def _get_bias_dropout_scale(self):\n    if self.training:\n      return bias_dropout_add_scale_fused_train\n    else:\n      return bias_dropout_add_scale_fused_inference\n\n\n  def forward(self, x, rotary_cos_sin, c=None):\n\n    bias_dropout_scale_fn = self._get_bias_dropout_scale()\n\n    x_skip = x\n    x = self.norm1(x)\n\n    if self.adaLN:\n      # self.adaLN_modulation(c): (128, 1536)\n      # self.adaLN_modulation(c)[:, None]: (128, 1, 1536)\n      # \"\" .chunk(6, dim=2) returns 6 tuples of shapes (128, 1, 256)\n      (shift_msa, scale_msa, gate_msa, shift_mlp,\n       scale_mlp, gate_mlp) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)\n      x = modulate_fused(x, shift_msa, scale_msa)\n\n    qkv = einops.rearrange(\n      self.attn_qkv(x),\n      'b s (three h d) -> b s three h d',\n      three=3,\n      h=self.n_heads)\n    q, k, v = split_and_apply_rotary_pos_emb(qkv, rotary_cos_sin)\n    \n    x = regular_attention_multi_headed(q, k, v)\n\n    if self.adaLN:\n      x = bias_dropout_scale_fn(self.attn_out(x),\n                                None,\n                                gate_msa,\n                                x_skip,\n                                self.dropout)\n      x = bias_dropout_scale_fn(\n        self.mlp(modulate_fused(\n          self.norm2(x), shift_mlp, scale_mlp)),\n        None, gate_mlp, x, self.dropout)\n    else:\n      scale = torch.ones(1, device=x.device, dtype=x.dtype)\n      x = bias_dropout_scale_fn(\n        self.attn_out(x), None, scale, x_skip, self.dropout)\n      x = bias_dropout_scale_fn(\n        self.mlp(self.norm2(x)), None, scale, x, self.dropout)\n    return x\n\n\nclass EmbeddingLayer(nn.Module):\n  def __init__(self, dim, vocab_dim):\n    super().__init__()\n    self.embedding = nn.Parameter(torch.empty((vocab_dim, dim)))\n    torch.nn.init.kaiming_uniform_(self.embedding, a=math.sqrt(5))\n\n  def forward(self, x, weights=None):\n    if weights is not None:\n      bs, seq_len, k = x.shape\n      flat_x = x.reshape(-1, k)\n      flat_w = weights.reshape(-1, k).float()\n      bag = F.embedding_bag(flat_x, self.embedding.float(),\n                            per_sample_weights=flat_w,\n                            mode='sum')\n      return bag.view(bs, seq_len, -1)\n    elif x.ndim == 2:\n      return self.embedding[x]\n    assert x.ndim == 3\n    return torch.einsum(\n      \"blv,ve->ble\",\n      torch.nn.functional.softmax(x, dim=-1).float(),\n      self.embedding.float()).to(x.dtype)\n\n\nclass DDiTFinalLayer(nn.Module):\n  def __init__(self, hidden_size, out_channels, cond_dim,\n               adaLN):\n    super().__init__()\n    self.norm_final = LayerNorm(hidden_size)\n    self.linear = nn.Linear(hidden_size, out_channels)\n    self.linear.weight.data.zero_()\n    self.linear.bias.data.zero_()\n    self.adaLN = adaLN\n    if self.adaLN:\n      self.adaLN_modulation = nn.Linear(cond_dim,\n                                        2 * hidden_size,\n                                        bias=True)\n      self.adaLN_modulation.weight.data.zero_()\n      self.adaLN_modulation.bias.data.zero_()\n\n  def forward(self, x, c):\n    x = self.norm_final(x)\n    if self.adaLN:\n      shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)\n      x = modulate_fused(x, shift, scale)\n    x = self.linear(x)\n    return x\n\n\nclass DIT(nn.Module, huggingface_hub.PyTorchModelHubMixin):\n  def __init__(self, config, vocab_size: int):\n    super().__init__()\n    if type(config) == dict:\n      config = omegaconf.OmegaConf.create(config)\n    self.causal = config.algo.causal_attention\n    self.adaLN = not self.causal\n    self.config = config\n    self.vocab_size = vocab_size\n    dim = config.model.hidden_size\n    cond_dim = config.model.cond_dim\n    self.vocab_embed = EmbeddingLayer(dim, vocab_size)\n    if not self.causal:\n      self.sigma_map = TimestepEmbedder(cond_dim)\n    self.rotary_emb = Rotary(dim // config.model.n_heads)\n\n    blocks = []\n    for _ in range(config.model.n_blocks):\n      if self.causal:\n        block = DDiTBlockCausal(\n          dim=dim,\n          n_heads=config.model.n_heads,\n          dropout=config.model.dropout)\n      else:\n        block = DDiTBlock(\n          dim=dim,\n          n_heads=config.model.n_heads,\n          cond_dim=cond_dim,\n          adaLN=self.adaLN,\n          dropout=config.model.dropout)\n      blocks.append(block)\n    self.blocks = nn.ModuleList(blocks)\n\n    self.output_layer = DDiTFinalLayer(\n      hidden_size=dim,\n      out_channels=vocab_size,\n      cond_dim=cond_dim,\n      adaLN=self.adaLN)\n    self.scale_by_sigma = config.model.scale_by_sigma\n\n  def _get_bias_dropout_scale(self):\n    if self.training:\n      return bias_dropout_add_scale_fused_train\n    else:\n      return  bias_dropout_add_scale_fused_inference\n\n  def forward(self, x, sigma, class_cond=None, weights=None):\n    assert class_cond is None, 'Not implemented for DiT'\n    x = self.vocab_embed(x, weights)\n    if self.causal:\n      t_cond = None\n    else:\n      t_cond = F.silu(self.sigma_map(sigma))\n\n    rotary_cos_sin = self.rotary_emb(x)\n    with torch.amp.autocast('cuda', dtype=torch.bfloat16):\n      for i in range(len(self.blocks)):\n        x = self.blocks[i](x, rotary_cos_sin, c=t_cond)\n      x = self.output_layer(x, c=t_cond)\n\n    return x\n"
  },
  {
    "path": "models/ema.py",
    "content": "import torch\n\n\nclass ExponentialMovingAverage:\n  \"\"\"\n  Maintains (exponential) moving average of a set of parameters.\n  \"\"\"\n\n  def __init__(self, parameters, decay, use_num_updates=True):\n    \"\"\"\n    Args:\n        parameters: Iterable of `torch.nn.Parameter`; usually the result of\n            `model.parameters()`.\n        decay: The exponential decay.\n        use_num_updates: Whether to use number of updates when computing\n            averages.\n    \"\"\"\n    if decay < 0.0 or decay > 1.0:\n      raise ValueError('Decay must be between 0 and 1')\n    self.decay = decay\n    self.num_updates = 0 if use_num_updates else None\n    self.shadow_params = [p.clone().detach()\n                          for p in parameters if p.requires_grad]\n    self.collected_params = []\n\n  def move_shadow_params_to_device(self, device):\n    self.shadow_params = [i.to(device) for i in self.shadow_params]\n\n  def update(self, parameters):\n    \"\"\"\n    Update currently maintained parameters.\n\n    Call this every time the parameters are updated, such as the result of\n    the `optimizer.step()` call.\n\n    Args:\n        parameters: Iterable of `torch.nn.Parameter`; usually the same set of\n            parameters used to initialize this object.\n    \"\"\"\n    decay = self.decay\n    if self.num_updates is not None:\n      self.num_updates += 1\n      decay = min(decay, (1 + self.num_updates) /\n                  (10 + self.num_updates))\n    one_minus_decay = 1.0 - decay\n    with torch.no_grad():\n      parameters = [p for p in parameters if p.requires_grad]\n      for s_param, param in zip(self.shadow_params, parameters):\n        s_param.sub_(one_minus_decay * (s_param - param))\n\n  def copy_to(self, parameters):\n    \"\"\"\n    Copy current parameters into given collection of parameters.\n\n    Args:\n        parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored moving averages.\n    \"\"\"\n    parameters = [p for p in parameters if p.requires_grad]\n    for s_param, param in zip(self.shadow_params, parameters):\n      if param.requires_grad:\n        param.data.copy_(s_param.data)\n\n  def store(self, parameters):\n    \"\"\"\n    Save the current parameters for restoring later.\n\n    Args:\n        parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            temporarily stored.\n    \"\"\"\n    self.collected_params = [param.clone() for param in parameters]\n\n  def restore(self, parameters):\n    \"\"\"\n    Restore the parameters stored with the `store` method.\n    Useful to validate the model with EMA parameters without affecting the\n    original optimization process. Store the parameters before the\n    `copy_to` method. After validation (or model saving), use this to\n    restore the former parameters.\n\n    Args:\n        parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n    \"\"\"\n    for c_param, param in zip(self.collected_params, parameters):\n      param.data.copy_(c_param.data)\n\n  def state_dict(self):\n    return dict(decay=self.decay,\n                num_updates=self.num_updates,\n                shadow_params=self.shadow_params)\n\n  def load_state_dict(self, state_dict):\n    self.decay = state_dict['decay']\n    self.num_updates = state_dict['num_updates']\n    self.shadow_params = state_dict['shadow_params']\n"
  },
  {
    "path": "models/unet.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\nimport torch.nn.functional as F\nimport numpy as np\nimport omegaconf\n\nimport transformers\nfrom einops import rearrange\nfrom .dit import LabelEmbedder, EmbeddingLayer\n\n\n# From https://github.com/yang-song/score_sde_pytorch/ which is from\n#  https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py\ndef transformer_timestep_embedding(timesteps, embedding_dim, max_positions=10000):\n  assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32\n  half_dim = embedding_dim // 2\n  # magic number 10000 is from transformers\n  emb = math.log(max_positions) / (half_dim - 1)\n  # emb = math.log(2.) / (half_dim - 1)\n  emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)\n  # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]\n  # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]\n  emb = timesteps.float()[:, None] * emb[None, :]\n  emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n  if embedding_dim % 2 == 1:  # zero pad\n    emb = F.pad(emb, (0, 1), mode='constant')\n  assert emb.shape == (timesteps.shape[0], embedding_dim)\n  return emb\n\n\n# Code modified from https://github.com/yang-song/score_sde_pytorch\ndef variance_scaling(scale, mode, distribution,\n                     in_axis=1, out_axis=0,\n                     dtype=torch.float32,\n                     device='cpu'):\n  \"\"\"Ported from JAX. \"\"\"\n\n  def _compute_fans(shape, in_axis=1, out_axis=0):\n    receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]\n    fan_in = shape[in_axis] * receptive_field_size\n    fan_out = shape[out_axis] * receptive_field_size\n    return fan_in, fan_out\n\n  def init(shape, dtype=dtype, device=device):\n    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)\n    if mode == \"fan_in\":\n      denominator = fan_in\n    elif mode == \"fan_out\":\n      denominator = fan_out\n    elif mode == \"fan_avg\":\n      denominator = (fan_in + fan_out) / 2\n    else:\n      raise ValueError(\n        \"invalid mode for variance scaling initializer: {}\".format(mode))\n    variance = scale / denominator\n    if distribution == \"normal\":\n      return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)\n    elif distribution == \"uniform\":\n      return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)\n    else:\n      raise ValueError(\"invalid distribution for variance scaling initializer\")\n\n  return init\n\n\ndef default_init(scale=1.):\n  \"\"\"The same initialization used in DDPM.\"\"\"\n  scale = 1e-10 if scale == 0 else scale\n  return variance_scaling(scale, 'fan_avg', 'uniform')\n\n\nclass NiN(nn.Module):\n  def __init__(self, in_ch, out_ch, init_scale=0.1):\n    super().__init__()\n    self.W = nn.Parameter(default_init(scale=init_scale)((in_ch, out_ch)), requires_grad=True)\n    self.b = nn.Parameter(torch.zeros(out_ch), requires_grad=True)\n\n  def forward(self, x, #  [\"batch\", \"in_ch\", \"H\", \"W\"]\n    ):\n\n    x = x.permute(0, 2, 3, 1)\n    # x (batch, H, W, in_ch)\n    y = torch.einsum('bhwi,ik->bhwk', x, self.W) + self.b\n    # y (batch, H, W, out_ch)\n    return y.permute(0, 3, 1, 2)\n\nclass AttnBlock(nn.Module):\n  \"\"\"Channel-wise self-attention block.\"\"\"\n  def __init__(self, channels, skip_rescale=True):\n    super().__init__()\n    self.skip_rescale = skip_rescale\n    self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels//4, 32),\n        num_channels=channels, eps=1e-6)\n    self.NIN_0 = NiN(channels, channels)\n    self.NIN_1 = NiN(channels, channels)\n    self.NIN_2 = NiN(channels, channels)\n    self.NIN_3 = NiN(channels, channels, init_scale=0.)\n\n  def forward(self, x, # [\"batch\", \"channels\", \"H\", \"W\"]\n    ):\n\n    B, C, H, W = x.shape\n    h = self.GroupNorm_0(x)\n    q = self.NIN_0(h)\n    k = self.NIN_1(h)\n    v = self.NIN_2(h)\n\n    w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))\n    w = torch.reshape(w, (B, H, W, H * W))\n    w = F.softmax(w, dim=-1)\n    w = torch.reshape(w, (B, H, W, H, W))\n    h = torch.einsum('bhwij,bcij->bchw', w, v)\n    h = self.NIN_3(h)\n\n    if self.skip_rescale:\n        return (x + h) / np.sqrt(2.)\n    else:\n        return x + h\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_ch, out_ch, temb_dim=None, dropout=0.1, skip_rescale=True):\n        super().__init__()\n\n        self.in_ch = in_ch\n        self.out_ch = out_ch\n\n        self.skip_rescale = skip_rescale\n\n        self.act = nn.functional.silu\n        self.groupnorm0 = nn.GroupNorm(\n            num_groups=min(in_ch // 4, 32),\n            num_channels=in_ch, eps=1e-6\n        )\n        self.conv0 = nn.Conv2d(\n            in_ch, out_ch, kernel_size=3, padding=1\n        )\n\n        if temb_dim is not None:\n            self.dense0 = nn.Linear(temb_dim, out_ch)\n            nn.init.zeros_(self.dense0.bias)\n\n\n        self.groupnorm1 = nn.GroupNorm(\n            num_groups=min(out_ch // 4, 32),\n            num_channels=out_ch, eps=1e-6\n        )\n        self.dropout0 = nn.Dropout(dropout)\n\n        self.conv1 = nn.Conv2d(\n            out_ch, out_ch, kernel_size=3, padding=1\n        )\n        if out_ch != in_ch:\n            self.nin = NiN(in_ch, out_ch)\n\n    def forward(self, x, # [\"batch\", \"in_ch\", \"H\", \"W\"]\n                temb=None, #  [\"batch\", \"temb_dim\"]\n        ):\n\n        assert x.shape[1] == self.in_ch\n\n        h = self.groupnorm0(x)\n        h = self.act(h)\n        h = self.conv0(h)\n\n        if temb is not None:\n            h += self.dense0(self.act(temb))[:, :, None, None]\n\n        h = self.groupnorm1(h)\n        h = self.act(h)\n        h = self.dropout0(h)\n        h = self.conv1(h)\n        if h.shape[1] != self.in_ch:\n            x = self.nin(x)\n\n        assert x.shape == h.shape\n\n        if self.skip_rescale:\n            return (x + h) / np.sqrt(2.)\n        else:\n            return x + h\n\nclass Downsample(nn.Module):\n    def __init__(self, channels):\n        super().__init__()\n        self.conv = nn.Conv2d(channels, channels, kernel_size=3, \n            stride=2, padding=0)\n\n    def forward(self, x, # [\"batch\", \"ch\", \"inH\", \"inW\"]\n        ):\n        B, C, H, W = x.shape\n        x = nn.functional.pad(x, (0, 1, 0, 1))\n        x= self.conv(x)\n\n        assert x.shape == (B, C, H // 2, W // 2)\n        return x\n\nclass Upsample(nn.Module):\n  def __init__(self, channels):\n    super().__init__()\n    self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)\n\n  def forward(self, x, # [\"batch\", \"ch\", \"inH\", \"inW\"]\n    ):\n    B, C, H, W = x.shape\n    h = F.interpolate(x, (H*2, W*2), mode='nearest')\n    h = self.conv(h)\n\n    assert h.shape == (B, C, H*2, W*2)\n    return h\n\n\nclass UNet(nn.Module):\n    def __init__(self, config, vocab_size=None):\n        super().__init__()\n        if type(config) == dict:\n            config = omegaconf.OmegaConf.create(config)\n        assert config.model.name == 'unet'\n        self.ch = config.model.ch\n        self.num_res_blocks = config.model.num_res_blocks\n        self.num_scales = config.model.num_scales\n        self.ch_mult = config.model.ch_mult\n        assert self.num_scales == len(self.ch_mult)\n        self.input_channels = config.model.input_channels\n        self.output_channels = 2 * config.model.input_channels\n        self.scale_count_to_put_attn = config.model.scale_count_to_put_attn\n        self.data_min_max = [0, vocab_size] # config.model.data_min_max # tuple of min and max value of input so it can be rescaled to [-1, 1]\n        self.dropout = config.model.dropout\n        self.skip_rescale = config.model.skip_rescale\n        self.time_conditioning = config.model.time_conditioning # Whether to add in time embeddings\n        self.time_scale_factor = config.model.time_scale_factor # scale to make the range of times be 0 to 1000\n        self.time_embed_dim = config.model.time_embed_dim\n        self.vocab_size = vocab_size\n\n        self.size = config.model.size\n        self.length = config.model.length\n\n        # truncated logistic \n        self.fix_logistic = config.model.fix_logistic\n\n        self.act = nn.functional.silu\n\n        if self.time_conditioning:\n            self.temb_modules = []\n            self.temb_modules.append(nn.Linear(self.time_embed_dim, self.time_embed_dim*4))\n            nn.init.zeros_(self.temb_modules[-1].bias)\n            self.temb_modules.append(nn.Linear(self.time_embed_dim*4, self.time_embed_dim*4))\n            nn.init.zeros_(self.temb_modules[-1].bias)\n            self.temb_modules = nn.ModuleList(self.temb_modules)\n\n        self.expanded_time_dim = 4 * self.time_embed_dim if self.time_conditioning else None\n\n        self.input_conv = nn.Conv2d(\n            in_channels=self.input_channels, out_channels=self.ch,\n            kernel_size=3, padding=1\n        )\n\n        h_cs = [self.ch]\n        in_ch = self.ch\n\n        # Downsampling\n        self.downsampling_modules = []\n\n        for scale_count in range(self.num_scales):\n            for res_count in range(self.num_res_blocks):\n                out_ch = self.ch * self.ch_mult[scale_count]\n                self.downsampling_modules.append(\n                    ResBlock(in_ch, out_ch, temb_dim=self.expanded_time_dim,\n                        dropout=self.dropout, skip_rescale=self.skip_rescale)\n                )\n                in_ch = out_ch\n                h_cs.append(in_ch)\n                if scale_count == self.scale_count_to_put_attn:\n                    self.downsampling_modules.append(\n                        AttnBlock(in_ch, skip_rescale=self.skip_rescale)\n                    )\n\n            if scale_count != self.num_scales - 1:\n                self.downsampling_modules.append(Downsample(in_ch))\n                h_cs.append(in_ch)\n\n        self.downsampling_modules = nn.ModuleList(self.downsampling_modules)\n\n        # Middle\n        self.middle_modules = []\n\n        self.middle_modules.append(\n            ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim,\n                dropout=self.dropout, skip_rescale=self.skip_rescale)\n        )\n        self.middle_modules.append(\n            AttnBlock(in_ch, skip_rescale=self.skip_rescale)\n        )\n        self.middle_modules.append(\n            ResBlock(in_ch, in_ch, temb_dim=self.expanded_time_dim,\n                dropout=self.dropout, skip_rescale=self.skip_rescale)\n        )\n        self.middle_modules = nn.ModuleList(self.middle_modules)\n\n        # Upsampling\n        self.upsampling_modules = []\n\n        for scale_count in reversed(range(self.num_scales)):\n            for res_count in range(self.num_res_blocks+1):\n                out_ch = self.ch * self.ch_mult[scale_count]\n                self.upsampling_modules.append(\n                    ResBlock(in_ch + h_cs.pop(), \n                        out_ch,\n                        temb_dim=self.expanded_time_dim,\n                        dropout=self.dropout,\n                        skip_rescale=self.skip_rescale\n                    )\n                )\n                in_ch = out_ch\n\n                if scale_count == self.scale_count_to_put_attn:\n                    self.upsampling_modules.append(\n                        AttnBlock(in_ch, skip_rescale=self.skip_rescale)\n                    )\n            if scale_count != 0:\n                self.upsampling_modules.append(Upsample(in_ch))\n\n        self.upsampling_modules = nn.ModuleList(self.upsampling_modules)\n\n        assert len(h_cs) == 0\n\n        # output\n        self.output_modules = []\n        \n        self.output_modules.append(\n            nn.GroupNorm(min(in_ch//4, 32), in_ch, eps=1e-6)\n        )\n\n        self.output_modules.append(\n           nn.Conv2d(in_ch, self.output_channels, kernel_size=3, padding=1)\n        )\n        self.output_modules = nn.ModuleList(self.output_modules)\n\n        self.cond_map = LabelEmbedder(\n            config.data.num_classes,\n            self.time_embed_dim*4)\n\n    def _center_data(self, x):\n        out = (x - self.data_min_max[0]) / (self.data_min_max[1] - self.data_min_max[0]) # [0, 1]\n        return 2 * out - 1 # to put it in [-1, 1]\n\n    def _time_embedding(self, timesteps):\n        if self.time_conditioning:\n            temb = transformer_timestep_embedding(\n                timesteps * self.time_scale_factor, self.time_embed_dim\n            )\n            temb = self.temb_modules[0](temb)\n            temb = self.temb_modules[1](self.act(temb))\n        else:\n            temb = None\n\n        return temb\n\n    def _do_input_conv(self, h):\n        h = self.input_conv(h)\n        hs = [h]\n        return h, hs\n\n    def _do_downsampling(self, h, hs, temb):\n        m_idx = 0\n        for scale_count in range(self.num_scales):\n            for res_count in range(self.num_res_blocks):\n                h = self.downsampling_modules[m_idx](h, temb)\n                m_idx += 1\n                if scale_count == self.scale_count_to_put_attn:\n                    h = self.downsampling_modules[m_idx](h)\n                    m_idx += 1\n                hs.append(h)\n\n            if scale_count != self.num_scales - 1:\n                h = self.downsampling_modules[m_idx](h)\n                hs.append(h)\n                m_idx += 1\n\n        assert m_idx == len(self.downsampling_modules)\n\n        return h, hs\n\n    def _do_middle(self, h, temb):\n        m_idx = 0\n        h = self.middle_modules[m_idx](h, temb)\n        m_idx += 1\n        h = self.middle_modules[m_idx](h)\n        m_idx += 1\n        h = self.middle_modules[m_idx](h, temb)\n        m_idx += 1\n\n        assert m_idx == len(self.middle_modules)\n\n        return h\n\n    def _do_upsampling(self, h, hs, temb):\n        m_idx = 0\n        for scale_count in reversed(range(self.num_scales)):\n            for res_count in range(self.num_res_blocks+1):\n                h = self.upsampling_modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)\n                m_idx += 1\n\n                if scale_count == self.scale_count_to_put_attn:\n                    h = self.upsampling_modules[m_idx](h)\n                    m_idx += 1\n\n            if scale_count != 0:\n                h = self.upsampling_modules[m_idx](h)\n                m_idx += 1\n\n        assert len(hs) == 0\n        assert m_idx == len(self.upsampling_modules)\n\n        return h\n\n    def _do_output(self, h):\n\n        h = self.output_modules[0](h)\n        h = self.act(h)\n        h = self.output_modules[1](h)\n\n        return h\n\n    def _logistic_output_res(self,\n        h, #  [\"B\", \"twoC\", \"H\", \"W\"]\n        centered_x_in, # [\"B\", \"C\", \"H\", \"W\"]\n    ):\n        B, twoC, H, W = h.shape\n        C = twoC//2\n        h[:, 0:C, :, :] = torch.tanh(centered_x_in + h[:, 0:C, :, :])\n        return h\n\n    def _log_minus_exp(self, a, b, eps=1e-6):\n        \"\"\" \n            Compute log (exp(a) - exp(b)) for (b<a)\n            From https://arxiv.org/pdf/2107.03006.pdf\n        \"\"\"\n        return a + torch.log1p(-torch.exp(b-a) + eps)\n\n    \n    def _truncated_logistic_output(self, net_out):\n        B, D = net_out.shape[0], self.length\n        C = 3\n        S = self.vocab_size\n\n        # Truncated logistic output from https://arxiv.org/pdf/2107.03006.pdf\n        mu = net_out[:, 0:C, :, :].unsqueeze(-1)\n        log_scale = net_out[:, C:, :, :].unsqueeze(-1)\n\n        inv_scale = torch.exp(- (log_scale - 2))\n\n        bin_width = 2. / S\n        bin_centers = torch.linspace(start=-1. + bin_width/2,\n            end=1. - bin_width/2,\n            steps=S,\n            device='cuda').view(1, 1, 1, 1, S)\n\n        sig_in_left = (bin_centers - bin_width/2 - mu) * inv_scale\n        bin_left_logcdf = F.logsigmoid(sig_in_left)\n        sig_in_right = (bin_centers + bin_width/2 - mu) * inv_scale\n        bin_right_logcdf = F.logsigmoid(sig_in_right)\n\n        logits_1 = self._log_minus_exp(bin_right_logcdf, bin_left_logcdf)\n        logits_2 = self._log_minus_exp(-sig_in_left + bin_left_logcdf, -sig_in_right + bin_right_logcdf)\n        if self.fix_logistic:\n            logits = torch.min(logits_1, logits_2)\n        else:\n            logits = logits_1\n\n        logits = logits.view(B,D,S)\n\n        return logits\n\n\n    def forward(self,\n        x, # [\"B\", \"C\", \"H\", \"W\"]\n        sigma=None, \n        class_cond=None,\n        weights=None\n    ):\n        assert weights == None\n        img_size = int(np.sqrt(self.size))\n\n        h = rearrange(x, \"b (c h w) -> b c h w\", h=img_size, \n                      w=img_size, c=3)\n        h = self._center_data(h)\n        centered_x_in = h\n\n        temb = self._time_embedding(sigma)\n        if class_cond is not None:\n            if self.cond_map is None:\n                raise ValueError(\"Conditioning variable provided, \"\n                                \"but Model was not initialized \"\n                                \"with condition embedding layer.\")\n            else:\n                assert class_cond.shape == (x.shape[0],)\n                temb = temb + self.cond_map(class_cond)\n\n        h, hs = self._do_input_conv(h)\n\n        h, hs = self._do_downsampling(h, hs, temb)\n\n        h = self._do_middle(h, temb)\n\n        h = self._do_upsampling(h, hs, temb)\n\n        h = self._do_output(h)\n\n        # h (B, 2*C, H, W)\n        h = self._logistic_output_res(h, centered_x_in)\n        h = self._truncated_logistic_output(h) # (B, D, S)\n\n        return h"
  },
  {
    "path": "models/unit_test_attention.py",
    "content": "import unittest\n\nimport torch\n\n# from flash_attn import flash_attention\nimport torch.nn.functional as F\n\n\ndef attention_inner_heads_flash(qkv, num_heads):\n    \"\"\"Computes attention with heads inside of qkv in the channel dimension using FlashAttention.\n\n    Args:\n        qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:\n            H = number of heads,\n            C = number of channels per head.\n        num_heads: number of heads.\n\n    Returns:\n        Attention output of shape (B, H*C, T).\n    \"\"\"\n\n    # bs, width, length = qkv.shape\n    # ch = width // (3 * num_heads)\n\n    # # Split into (q, k, v) of shape (B, H*C, T).\n    # q, k, v = qkv.chunk(3, dim=1)\n\n    # # Rescale q and k. This makes them contiguous in memory.\n    # scale = ch ** (-1 / 4)  # scale with 4th root = scaling output by sqrt\n    # q = q * scale\n    # k = k * scale\n\n    # # Reshape q, k, v to (B, H, T, C) for FlashAttention\n    # q = q.view(bs, num_heads, ch, length).permute(0, 1, 3, 2)  # (B, H, T, C)\n    # k = k.view(bs, num_heads, ch, length).permute(0, 1, 3, 2)  # (B, H, T, C)\n    # v = v.view(bs, num_heads, ch, length).permute(0, 1, 3, 2)  # (B, H, T, C)\n\n    # # Compute attention using FlashAttention\n    # out = flash_attention(q, k, v)  # (B, H, T, C)\n\n    # # Reshape back to (B, H*C, T)\n    # out = out.permute(0, 1, 3, 2).reshape(bs, num_heads * ch, length)\n    # return out\n    bs, width, length = qkv.shape\n    ch = width // (3 * num_heads)\n\n    # Split into (q, k, v) and reshape directly to (B, H, T, C)\n    q, k, v = qkv.chunk(3, dim=1)\n    q = q.view(bs, num_heads, ch, length).transpose(2, 3)  # (B, H, T, C)\n    k = k.view(bs, num_heads, ch, length).transpose(2, 3)  # (B, H, T, C)\n    v = v.view(bs, num_heads, ch, length).transpose(2, 3)  # (B, H, T, C)\n\n    # Compute scaled dot-product attention\n    out = F.scaled_dot_product_attention(q, k, v)  # (B, H, T, C)\n\n    # Reshape back to (B, H*C, T) in one step\n    out = out.transpose(2, 3).reshape(bs, num_heads * ch, length)\n    return out\n\n\n\nclass TestAttentionInnerHeadsFlash(unittest.TestCase):\n    def setUp(self):\n        # Set up common test variables\n        self.batch_size = 2\n        self.num_heads = 4\n        self.seq_len = 8\n        self.channels_per_head = 16\n        self.qkv_dim = 3 * self.num_heads * self.channels_per_head\n\n        # Create a random qkv tensor\n        self.qkv = torch.randn(self.batch_size, self.qkv_dim, self.seq_len, dtype=torch.float32)\n\n    def attention_inner_heads_old(self, qkv, num_heads):\n        \"\"\"Original implementation of attention for reference.\"\"\"\n        bs, width, length = qkv.shape\n        ch = width // (3 * num_heads)\n        q, k, v = qkv.chunk(3, dim=1)\n        scale = ch ** (-1 / 4)\n        q = q * scale\n        k = k * scale\n        q = q.view(bs * num_heads, ch, length)\n        k = k.view(bs * num_heads, ch, length)\n        v = v.reshape(bs * num_heads, ch, length)\n        weight = torch.einsum(\"bct,bcs->bts\", q, k)\n        weight = torch.softmax(weight.float(), dim=-1).to(weight.dtype)\n        out = torch.einsum(\"bts,bcs->bct\", weight, v)\n        return out.reshape(bs, num_heads * ch, length)\n\n    def test_attention_inner_heads_flash(self):\n        # Compute the output using the old and flash attention implementations\n        output_old = self.attention_inner_heads_old(self.qkv, self.num_heads)\n        output_flash = attention_inner_heads_flash(self.qkv, self.num_heads)\n\n        # Verify the shapes are the same\n        self.assertEqual(output_old.shape, output_flash.shape)\n\n        # Verify that the outputs are close (numerically similar)\n        self.assertTrue(torch.allclose(output_old, output_flash, atol=1e-5))\n\nif __name__ == \"__main__\":\n    unittest.main()"
  },
  {
    "path": "requirements.txt",
    "content": "# conda install nvidia/label/cuda-12.4.0::cuda-toolkit\ndatasets==2.15.0\neinops==0.7.0\nfsspec\ngit-lfs==1.6\nh5py==3.10.0\nhydra-core==1.3.2\nipdb==0.13.13\nlightning==2.2.1\nnotebook==7.1.1\nnvitop==1.3.2\nomegaconf==2.3.0\npackaging==23.2\npandas==2.2.1\nrich==13.7.1\nseaborn==0.13.2\nscikit-learn==1.4.0\ntransformers==4.38.2\ntriton==2.2.0\ntorch==2.3.1\ntorchaudio==2.3.1\ntorchmetrics==1.6.1\ntorchvision==0.18.1\nwandb\ntimm\nocifs\nhf_transfer\nhuggingface-hub\n# Install flash attention only after installing the above modules via pip install -r requirements.txt\n# flash_attn==2.7.4.post1\n"
  },
  {
    "path": "scripts/distil_owt.sh",
    "content": "#!/bin/bash\n#SBATCH -J posterior                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=64000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                 # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\nexport HYDRA_FULL_ERROR=1\nfinetune_path=/path/to/duo.ckpt\n\nsrun python -u -m main \\\n  mode=train \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=2 \\\n  data=openwebtext-split \\\n  model=small \\\n  algo=distillation \\\n  training.finetune_path=$finetune_path \\\n  sampling.num_sample_batches=10 \\\n  sampling.steps=32 \\\n  eval.compute_generative_perplexity=True \\\n  algo.T=512 \\\n  lr_scheduler.num_warmup_steps=500 \\\n  trainer.val_check_interval=1000 \\\n  trainer.max_steps=50000 \\\n  loader.global_batch_size=128 \\\n  training.ema=0.999 \\\n  algo.update_teacher_every=10000 \\\n  optim.lr=6e-5 \\\n  trainer.limit_val_batches=8 \\\n  algo.teacher_ema=False \\\n  algo.linear_growth_dt=false \\\n  +wandb.offline=true\n"
  },
  {
    "path": "scripts/eval_lm1b_duo.sh",
    "content": "#!/bin/bash\n#SBATCH -J eval_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=100000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=gpu               # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=4\n#SBATCH --gres=gpu:4                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-W2ZcFy-small-conjugate-lm1b-wrap-1M-gmin-350-gmax-175-3/checkpoints/7-100000.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=ppl_eval \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  data=lm1b-wrap \\\n  model=small \\\n  model.length=128 \\\n  algo=duo_base \\\n  eval.checkpoint_path=$checkpoint_path \\\n  sampling.num_sample_batches=0 \\\n  +wandb.offline=true"
  },
  {
    "path": "scripts/eval_owt_ar.sh",
    "content": "#!/bin/bash\n#SBATCH -J eval_ar                # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=100000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=4\n#SBATCH --gres=gpu:4                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=ppl_eval \\\n  loader.batch_size=16 \\\n  loader.eval_batch_size=16 \\\n  data=openwebtext-split \\\n  algo=ar \\\n  model.length=1024 \\\n  eval.checkpoint_path=$checkpoint_path \\\n  +wandb.offline=true"
  },
  {
    "path": "scripts/eval_owt_duo.sh",
    "content": "#!/bin/bash\n#SBATCH -J owt_duo_anneal                    # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=100000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-VlCQLK-small-conjugate-OWT-anneal/checkpoints/last.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=ppl_eval \\\n  loader.batch_size=8 \\\n  loader.eval_batch_size=8 \\\n  data=openwebtext-split \\\n  model=small \\\n  algo=duo_base \\\n  eval.checkpoint_path=$checkpoint_path \\\n  sampling.num_sample_batches=0 \\\n  +wandb.offline=true"
  },
  {
    "path": "scripts/eval_owt_mdlm.sh",
    "content": "#!/bin/bash\n#SBATCH -J eval_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=100000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=gpu               # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=4\n#SBATCH --gres=gpu:4                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=ppl_eval \\\n  loader.batch_size=16 \\\n  loader.eval_batch_size=16 \\\n  data=openwebtext-split \\\n  model=small \\\n  algo=mdlm \\\n  eval.checkpoint_path=$checkpoint_path \\\n  sampling.num_sample_batches=0 \\\n  +wandb.offline=true"
  },
  {
    "path": "scripts/eval_owt_sedd.sh",
    "content": "#!/bin/bash\n#SBATCH -J eval_sedd              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=100000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=4\n#SBATCH --gres=gpu:4                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=ppl_eval \\\n  loader.batch_size=16 \\\n  loader.eval_batch_size=16 \\\n  data=openwebtext-split \\\n  model=small \\\n  algo=sedd \\\n  eval.checkpoint_path=$checkpoint_path \\\n  sampling.num_sample_batches=0 \\\n  +wandb.offline=true"
  },
  {
    "path": "scripts/fid_cifar10_duo_ancestral_cosine.sh",
    "content": "export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1\npython -u -m main \\\n    mode=fid_eval \\\n    sampling.steps=64 \\\n    sampling.guid_weight=1.0 \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    noise=cosine \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    trainer.num_nodes=1 \\\n    loader.eval_batch_size=500 \\\n    eval.checkpoint_path=<PATH-TO-THE-MDLM-CHECKPOINT>\n"
  },
  {
    "path": "scripts/fid_cifar10_duo_base_ancestral_cosine.sh",
    "content": "export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1\npython -u -m main \\\n    mode=fid_eval \\\n    sampling.steps=64 \\\n    sampling.guid_weight=1.0 \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    noise=cosine \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    trainer.num_nodes=1 \\\n    loader.eval_batch_size=500 \\\n    eval.checkpoint_path=<PATH-TO-THE-MDLM-CHECKPOINT>\n"
  },
  {
    "path": "scripts/fid_cifar10_mdlm_ancestral_cosine.sh",
    "content": "python -u -m main \\\n    mode=fid_eval \\\n    sampling.steps=64 \\\n    sampling.guid_weight=1.0 \\\n    sampling.predictor=ancestral_cache \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    noise=cosine \\\n    algo=mdlm \\\n    algo.backbone=unet \\\n    loader.eval_batch_size=500 \\\n    eval.checkpoint_path=<PATH-TO-THE-MDLM-CHECKPOINT>\n"
  },
  {
    "path": "scripts/gen_ppl_lm1b_ar.sh",
    "content": "#!/bin/bash\n#SBATCH -J sample_ar                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=gpu               # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/lm1b-ar/last_copy.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=sample_eval \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=64 \\\n  data=lm1b-wrap \\\n  algo=ar \\\n  model=small \\\n  model.length=128 \\\n  eval.checkpoint_path=$checkpoint_path \\\n  sampling.num_sample_batches=15 \\\n  +wandb.offline=true\n"
  },
  {
    "path": "scripts/gen_ppl_lm1b_duo.sh",
    "content": "#!/bin/bash\n#SBATCH -J sample_ar                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=gpu               # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil-kl-bwd-32/checkpoints/0-1000.ckpt\nsteps=32\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=sample_eval \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=64 \\\n  data=lm1b-wrap \\\n  algo=duo_base \\\n  model=small \\\n  model.length=128 \\\n  eval.checkpoint_path=/share/kuleshov/ssahoo/flow-ode/6eTwW0-distil7-kl-bwd/checkpoints/last.ckpt \\\n  sampling.num_sample_batches=15 \\\n  sampling.steps=$steps \\\n  +wandb.offline=true \\\n  sampling.noise_removal=greedy\n"
  },
  {
    "path": "scripts/gen_ppl_owt_ar.sh",
    "content": "#!/bin/bash\n#SBATCH -J sample_ar                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=16000                   # server memory requested (per node)\n#SBATCH -t 24:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov,gpu          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints\nseed=1\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=sample_eval \\\n  seed=$seed \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=8 \\\n  data=openwebtext-split \\\n  algo=ar \\\n  model=small \\\n  eval.checkpoint_path=$checkpoint_path/last.ckpt \\\n  sampling.num_sample_batches=100 \\\n  +wandb.offline=true \\\n  eval.generated_samples_path=$checkpoint_path/$seed-ckpt-last.json"
  },
  {
    "path": "scripts/gen_ppl_owt_duo.sh",
    "content": "#!/bin/bash\n#SBATCH -J an_owt_duo                    # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=16000                   # server memory requested (per node)\n#SBATCH -t 24:00:00                   # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov,gpu      # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\nexport HYDRA_FULL_ERROR=1\n\nwhile [[ \"$#\" -gt 0 ]]; do\n    case $1 in\n        --steps) steps=\"$2\"; shift ;;\n        --seed) seed=\"$2\"; shift ;;\n        --ckpt) ckpt=\"$2\"; shift ;;\n        *) echo \"Unknown parameter: $1\"; exit 1 ;;\n    esac\n    shift\ndone\n\ncheckpoint_path=/share/kuleshov/ssahoo/flow-ode/distil-distil-vjrpZb-distillation-OWT/checkpoints\nckpt=0-50000\n\nsteps=${steps:-32}\nseed=${seed:-1}\necho \"  Steps: $steps\"\necho \"  Seed: $seed\"\necho \"  ckpt: $ckpt\"\n\nsrun python -u -m main \\\n  mode=sample_eval \\\n  seed=$seed \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=8 \\\n  data=openwebtext-split \\\n  algo=duo_base \\\n  model=small \\\n  eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \\\n  sampling.num_sample_batches=100 \\\n  sampling.steps=$steps \\\n  +wandb.offline=true \\\n  eval.generated_samples_path=$checkpoint_path/samples_ancestral/$seed-$steps-ckpt-$ckpt.json"
  },
  {
    "path": "scripts/gen_ppl_owt_mdlm.sh",
    "content": "#!/bin/bash\n#SBATCH -J sample_mdlm                # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=64000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\nsrun python -u -m main \\\n  mode=sample_eval \\\n  loader.batch_size=16 \\\n  loader.eval_batch_size=16 \\\n  data=openwebtext-split \\\n  model=small \\\n  algo=mdlm \\\n  eval.checkpoint_path=/share/kuleshov/ssahoo/textdiffusion/mdlm.ckpt \\\n  sampling.num_sample_batches=4 \\\n  +wandb.offline=true"
  },
  {
    "path": "scripts/gen_ppl_owt_sedd.sh",
    "content": "#!/bin/bash\n#SBATCH -J sedd_samples               # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=16000                   # server memory requested (per node)\n#SBATCH -t 24:00:00                   # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov,gpu      # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\nwhile [[ \"$#\" -gt 0 ]]; do\n    case $1 in\n        --steps) steps=\"$2\"; shift ;;\n        --seed) seed=\"$2\"; shift ;;\n        *) echo \"Unknown parameter: $1\"; exit 1 ;;\n    esac\n    shift\ndone\n\nsteps=${steps:-32}\nseed=${seed:-1}\n\necho \"  Steps: $steps\"\necho \"  Seed: $seed\"\n\nexport HYDRA_FULL_ERROR=1\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints\nckpt=last\n\nsrun python -u -m main \\\n  mode=sample_eval \\\n  seed=$seed \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=8 \\\n  data=openwebtext-split \\\n  algo=sedd \\\n  model=small \\\n  eval.checkpoint_path=$checkpoint_path/$ckpt.ckpt \\\n  sampling.num_sample_batches=0 \\\n  sampling.num_sample_batches=100 \\\n  sampling.steps=$steps \\\n  sampling.predictor=analytic \\\n  eval.generated_samples_path=$checkpoint_path/$seed-$steps-ckpt-$ckpt.json \\\n  +wandb.offline=true"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_constant_remdm.sh",
    "content": "# DUO psi-sampler with constant-remdm-eta mode (ReMDM loop)\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear-constant-linear-0.9-inv \\\n    sampling.psi.high_mode=pure-posterior \\\n    sampling.psi.middle_mode=constant-remdm-$ETA \\\n    sampling.psi.low_mode=pure-posterior \\\n    sampling.psi.high_frac=0.45 \\\n    sampling.psi.middle_frac=0.5\n"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_max_capped_remdm.sh",
    "content": "# DUO psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.005\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-capped-$ETA \\\n    sampling.psi.middle_mode=max-capped-$ETA \\\n    sampling.psi.low_mode=max-capped-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_max_rescale_eta.sh",
    "content": "# DUO psi-sampler with max-rescale-eta mode — CIFAR-10 FID eval\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-rescale-$ETA \\\n    sampling.psi.middle_mode=max-rescale-$ETA \\\n    sampling.psi.low_mode=max-rescale-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/psi_samplers/cifar10/duo_psi_pc.sh",
    "content": "# DUO psi-sampler with constant kappa in pc phase\n# Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC\n#   t in [1.0, 0.5] -> pure-posterior\n#   t in [0.5, 0.1] -> constant-kappa\n#   t in [0.1, 0.0] -> pure-posterior\n\nNUM_STEPS=256\nKAPPA=0.95\nNOISE=cosine\nHIGH_FRAC=0.5\nMIDDLE_FRAC=0.4\nCHECKPOINT_PATH=<PATH-TO-DUO-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=pure-posterior \\\n    sampling.psi.middle_mode=constant-$KAPPA \\\n    sampling.psi.low_mode=pure-posterior \\\n    sampling.psi.high_frac=$HIGH_FRAC \\\n    sampling.psi.middle_frac=$MIDDLE_FRAC\n"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_constant_remdm.sh",
    "content": "# MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop) — CIFAR-10 FID eval\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-MDLM-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=mdlm \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear-constant-linear-0.9-inv \\\n    sampling.psi.high_mode=pure-posterior \\\n    sampling.psi.middle_mode=constant-remdm-$ETA \\\n    sampling.psi.low_mode=pure-posterior \\\n    sampling.psi.high_frac=0.45 \\\n    sampling.psi.middle_frac=0.5\n"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_max_capped_remdm.sh",
    "content": "# MDLM psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.005\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-MDLM-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=mdlm \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-capped-$ETA \\\n    sampling.psi.middle_mode=max-capped-$ETA \\\n    sampling.psi.low_mode=max-capped-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_max_rescale_eta.sh",
    "content": "# MDLM psi-sampler with max-rescale-eta mode\n\nNUM_STEPS=256\nETA=0.01\nNOISE=cosine\nCHECKPOINT_PATH=<PATH-TO-MDLM-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=mdlm \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-rescale-$ETA \\\n    sampling.psi.middle_mode=max-rescale-$ETA \\\n    sampling.psi.low_mode=max-rescale-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/psi_samplers/cifar10/mdlm_psi_pc.sh",
    "content": "# MDLM psi-sampler with constant kappa during pc phase\n# Kappa controls the posterior/PC mix: 1 = pure posterior, 0 = pure PC\n#   t in [1.0, 0.1] -> constant kappa\n#   t in [0.1, 0.0] -> pure posterior\n\nNUM_STEPS=256\nKAPPA=0.99\nNOISE=cosine\nHIGH_FRAC=0.0\nMIDDLE_FRAC=0.9\nCHECKPOINT_PATH=<PATH-TO-MDLM-CHECKPOINT>\nDATA_CACHE_DIR=<YOUR-CACHE-PATH>\nEVAL_BATCH_SIZE=500\n\npython -u -m main \\\n    mode=fid_eval \\\n    data=cifar10 \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=unet \\\n    algo=mdlm \\\n    algo.backbone=unet \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.guid_weight=1.0 \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=pure-posterior \\\n    sampling.psi.middle_mode=constant-$KAPPA \\\n    sampling.psi.low_mode=pure-posterior \\\n    sampling.psi.high_frac=$HIGH_FRAC \\\n    sampling.psi.middle_frac=$MIDDLE_FRAC\n"
  },
  {
    "path": "scripts/psi_samplers/owt/duo_loop_remdm.sh",
    "content": "# DUO psi-sampler with constant-remdm-eta mode (ReMDM loop)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.95\nNOISE=log-linear\nCHECKPOINT_PATH=???\nDATA_CACHE_DIR=???\nEVAL_BATCH_SIZE=16\nNUM_SAMPLE_BATCHES=32\n\npython -u -m main \\\n    mode=sample_eval \\\n    data=openwebtext-split \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=small \\\n    algo=duo_base \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.p_nucleus=$NUCLEUS_P \\\n    sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear-constant-linear-0.9-inv \\\n    sampling.psi.high_mode=pure-posterior \\\n    sampling.psi.middle_mode=constant-remdm-$ETA \\\n    sampling.psi.low_mode=pure-posterior \\\n    sampling.psi.high_frac=0.45 \\\n    sampling.psi.middle_frac=0.5\n"
  },
  {
    "path": "scripts/psi_samplers/owt/duo_max_capped_remdm.sh",
    "content": "# DUO psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHECKPOINT_PATH=???\nDATA_CACHE_DIR=???\nEVAL_BATCH_SIZE=16\nNUM_SAMPLE_BATCHES=32\n\npython -u -m main \\\n    mode=sample_eval \\\n    data=openwebtext-split \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=small \\\n    algo=duo_base \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.p_nucleus=$NUCLEUS_P \\\n    sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-capped-$ETA \\\n    sampling.psi.middle_mode=max-capped-$ETA \\\n    sampling.psi.low_mode=max-capped-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/psi_samplers/owt/duo_max_rescale_eta.sh",
    "content": "# DUO psi-sampler with max-rescale-eta mode (ReMDM rescale)\n\nNUM_STEPS=256\nETA=0.05\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHECKPOINT_PATH=???\nDATA_CACHE_DIR=???\nEVAL_BATCH_SIZE=16\nNUM_SAMPLE_BATCHES=32\n\npython -u -m main \\\n    mode=sample_eval \\\n    data=openwebtext-split \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=small \\\n    algo=duo_base \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.p_nucleus=$NUCLEUS_P \\\n    sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-rescale-$ETA \\\n    sampling.psi.middle_mode=max-rescale-$ETA \\\n    sampling.psi.low_mode=max-rescale-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/psi_samplers/owt/mdlm_loop_remdm.sh",
    "content": "# MDLM psi-sampler with constant-remdm-eta mode (ReMDM loop)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.95\nNOISE=log-linear\nCHECKPOINT_PATH=???\nDATA_CACHE_DIR=???\nEVAL_BATCH_SIZE=16\nNUM_SAMPLE_BATCHES=32\n\npython -u -m main \\\n    mode=sample_eval \\\n    data=openwebtext-split \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=small \\\n    algo=mdlm \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.p_nucleus=$NUCLEUS_P \\\n    sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear-constant-linear-0.9-inv \\\n    sampling.psi.high_mode=pure-posterior \\\n    sampling.psi.middle_mode=constant-remdm-$ETA \\\n    sampling.psi.low_mode=pure-posterior \\\n    sampling.psi.high_frac=0.45 \\\n    sampling.psi.middle_frac=0.5\n"
  },
  {
    "path": "scripts/psi_samplers/owt/mdlm_max_capped_remdm.sh",
    "content": "# MDLM psi-sampler with max-capped-eta mode (ReMDM cap)\n\nNUM_STEPS=256\nETA=0.01\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHECKPOINT_PATH=???\nDATA_CACHE_DIR=???\nEVAL_BATCH_SIZE=16\nNUM_SAMPLE_BATCHES=32\n\npython -u -m main \\\n    mode=sample_eval \\\n    data=openwebtext-split \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=small \\\n    algo=mdlm \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.p_nucleus=$NUCLEUS_P \\\n    sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-capped-$ETA \\\n    sampling.psi.middle_mode=max-capped-$ETA \\\n    sampling.psi.low_mode=max-capped-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/psi_samplers/owt/mdlm_max_rescale_eta.sh",
    "content": "# MDLM psi-sampler with max-rescale-eta mode (ReMDM rescale)\n\nNUM_STEPS=256\nETA=0.05\nNUCLEUS_P=0.9\nNOISE=log-linear\nCHECKPOINT_PATH=???\nDATA_CACHE_DIR=???\nEVAL_BATCH_SIZE=16\nNUM_SAMPLE_BATCHES=32\n\npython -u -m main \\\n    mode=sample_eval \\\n    data=openwebtext-split \\\n    data.cache_dir=$DATA_CACHE_DIR \\\n    model=small \\\n    algo=mdlm \\\n    noise=$NOISE \\\n    sampling.predictor=psi \\\n    sampling.steps=$NUM_STEPS \\\n    sampling.p_nucleus=$NUCLEUS_P \\\n    sampling.num_sample_batches=$NUM_SAMPLE_BATCHES \\\n    eval.checkpoint_path=$CHECKPOINT_PATH \\\n    loader.eval_batch_size=$EVAL_BATCH_SIZE \\\n    sampling.psi.time_profile=linear \\\n    sampling.psi.high_mode=max-rescale-$ETA \\\n    sampling.psi.middle_mode=max-rescale-$ETA \\\n    sampling.psi.low_mode=max-rescale-$ETA \\\n    sampling.psi.high_frac=0.0 \\\n    sampling.psi.middle_frac=0.0\n"
  },
  {
    "path": "scripts/train_cifar10_duo_base_cosine.sh",
    "content": "python -u -m main \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    noise=cosine \\\n    loader.global_batch_size=128 \\\n    loader.batch_size=32 \\\n    loader.eval_batch_size=32 \\\n    loader.num_workers=8 \\\n    trainer.val_check_interval=2500 \\\n    trainer.max_steps=1_500_000 \\\n    lr_scheduler.num_warmup_steps=5000 \\\n    eval.generate_samples=False \\\n    optim.lr=2e-4 \\\n    callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \\\n    wandb.name=duo_base_1_5M_d3pm_like_cosine \\\n    hydra.run.dir=./outputs/cifar10/duo_base_1_5M_d3pm_like_cosine \\\n\n"
  },
  {
    "path": "scripts/train_cifar10_duo_cosine.sh",
    "content": "python -u -m main \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    algo=duo_base \\\n    algo.backbone=unet \\\n    noise=cosine \\\n    loader.global_batch_size=128 \\\n    loader.batch_size=32 \\\n    loader.eval_batch_size=32 \\\n    loader.num_workers=8 \\\n    trainer.val_check_interval=2500 \\\n    trainer.max_steps=1_500_000 \\\n    lr_scheduler.num_warmup_steps=5000 \\\n    eval.generate_samples=False \\\n    optim.lr=2e-4 \\\n    callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \\\n    wandb.name=duo_base_1_5M_d3pm_like_cosine \\\n    hydra.run.dir=./outputs/cifar10/duo_base_1_5M_d3pm_like_cosine \\\n\n"
  },
  {
    "path": "scripts/train_cifar10_mdlm_cosine.sh",
    "content": "python -u -m main \\\n    data=cifar10 \\\n    data.cache_dir=<YOUR-CACHE-PATH> \\\n    model=unet \\\n    algo=mdlm \\\n    algo.backbone=unet \\\n    noise=cosine \\\n    loader.global_batch_size=128 \\\n    loader.batch_size=32 \\\n    loader.eval_batch_size=32 \\\n    loader.num_workers=8 \\\n    trainer.val_check_interval=2500 \\\n    trainer.max_steps=1_500_000 \\\n    lr_scheduler.num_warmup_steps=5000 \\\n    eval.generate_samples=False \\\n    optim.lr=2e-4 \\\n    callbacks.checkpoint_every_n_steps.every_n_train_steps=5_000 \\\n    wandb.name=mdlm_1_5M_d3pm_like_cosine \\\n    hydra.run.dir=./outputs/cifar10/mdlm_1_5M_d3pm_like_cosine \\\n\n"
  },
  {
    "path": "scripts/train_lm1b_ar.sh",
    "content": "#!/bin/bash\n#SBATCH -J train_ar_lm1b              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=8\n#SBATCH --gres=gpu:8                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nsrun python -u -m main \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  algo=ar \\\n  data=lm1b \\\n  wandb.name=ar-lm1b-small \\\n  model=small \\\n  model.length=128 "
  },
  {
    "path": "scripts/train_lm1b_ar_sentencepacking.sh",
    "content": "#!/bin/bash\n#SBATCH -J train_ar_lm1b              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=2\n#SBATCH --gres=gpu:2                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nsrun python -u -m main \\\n  loader.batch_size=256 \\\n  loader.eval_batch_size=256 \\\n  algo=ar \\\n  data=lm1b-wrap \\\n  wandb.name=ar-lm1b-wrap-small \\\n  model=small \\\n  model.length=128 "
  },
  {
    "path": "scripts/train_lm1b_d3pm.sh",
    "content": "#!/bin/bash\n#SBATCH -J train_d3pm                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=8\n#SBATCH --gres=gpu:8                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nsrun python -u -m main \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  model=small \\\n  data=lm1b \\\n  wandb.name=d3pm-lm1b \\\n  algo=d3pm \\\n  model.length=128 \\\n  eval.compute_generative_perplexity=False \n"
  },
  {
    "path": "scripts/train_lm1b_duo.sh",
    "content": "#!/bin/bash\n#SBATCH -J duo-lm1b                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=64000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=8\n#SBATCH --gres=gpu:8                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\n# Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed\nsrun python -u -m main \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  data=lm1b \\\n  wandb.name=duo-lm1b \\\n  model=small \\\n  algo=duo \\\n  model.length=128 \\\n  algo.curriculum.mode=simple \\\n  algo.curriculum.gumbel_tau_log10_start=-3.0 \\\n  algo.curriculum.gumbel_tau_log10_end=-3.0 \\\n  algo.curriculum.gamma_min=-3.5 \\\n  algo.curriculum.gamma_max=-1.75 \\\n  algo.curriculum.start=0 \\\n  algo.curriculum.end=500000\n"
  },
  {
    "path": "scripts/train_lm1b_duo_sentencepacking.sh",
    "content": "#!/bin/bash\n#SBATCH -J duo-lm1b                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=64000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=8\n#SBATCH --gres=gpu:8                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\n# Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed\nsrun python -u -m main \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  data=lm1b-wrap \\\n  wandb.name=duo-lm1b \\\n  model=small \\\n  algo=duo \\\n  model.length=128 \\\n  algo.curriculum.mode=simple \\\n  algo.curriculum.gumbel_tau_log10_start=-3.0 \\\n  algo.curriculum.gumbel_tau_log10_end=-3.0 \\\n  algo.curriculum.gamma_min=-3.5 \\\n  algo.curriculum.gamma_max=-1.75 \\\n  algo.curriculum.start=0 \\\n  algo.curriculum.end=500000\n"
  },
  {
    "path": "scripts/train_lm1b_mdlm.sh",
    "content": "#!/bin/bash\n#SBATCH -J lm1b_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=2\n#SBATCH --gres=gpu:2                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nsrun python -u -m main \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  data=lm1b \\\n  wandb.name=mdlm-lm1b \\\n  model=small \\\n  algo=mdlm \\\n  model.length=128 \n"
  },
  {
    "path": "scripts/train_lm1b_mdlm_sentencepacking.sh",
    "content": "#!/bin/bash\n#SBATCH -J lm1b_mdlm                  # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=2\n#SBATCH --gres=gpu:2                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nsrun python -u -m main \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  data=lm1b-wrap \\\n  wandb.name=mdlm-lm1b-wrap-small \\\n  model=small \\\n  algo=mdlm \\\n  model.length=128 \n"
  },
  {
    "path": "scripts/train_owt_duo.sh",
    "content": "#!/bin/bash\n#SBATCH -J duo-lm1b                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=64000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=8\n#SBATCH --gres=gpu:8                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\n# Note: we use algo.curriculum.mode=poly9 in the duo2 paper for speed\nsrun python -u -m main \\\n  loader.batch_size=32 \\\n  loader.eval_batch_size=32 \\\n  data=openwebtext-split \\\n  wandb.name=duo-owt \\\n  model=small \\\n  algo=duo \\\n  model.length=1024 \\\n  algo.curriculum.mode=simple \\\n  algo.curriculum.gumbel_tau_log10_start=-3.0 \\\n  algo.curriculum.gumbel_tau_log10_end=-3.0 \\\n  algo.curriculum.gamma_min=-3.55 \\\n  algo.curriculum.gamma_max=-1.85 \\\n  algo.curriculum.start=0 \\\n  algo.curriculum.end=500000\n"
  },
  {
    "path": "scripts/train_owt_duo_finetune.sh",
    "content": "#!/bin/bash\n#SBATCH -J duo-base                   # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=64000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=8\n#SBATCH --gres=gpu:8                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nfinetune_path=/path/to/intermediate_duo_500k.ckpt\n\n# Assuming the finetune_path corresponds to the DUO model\n# trained for 500K steps with curriculum learning, we train the\n# model for 500K more steps.\nsrun python -u -m main \\\n  loader.batch_size=64 \\\n  loader.eval_batch_size=64 \\\n  data=openwebtext-split \\\n  wandb.name=duo-owt-finetune \\\n  model=small \\\n  algo=duo_base \\\n  model.length=1024 \\\n  wandb.name=duo-base \\\n  training.finetune_path=$finetune_path \\\n  sampling.num_sample_batches=0 \\\n  trainer.max_steps=500000 "
  },
  {
    "path": "scripts/train_owt_mdlm.sh",
    "content": "#!/bin/bash\n#SBATCH -J train_mdlm                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nsrun python -u -m main \\\n  loader.batch_size=2 \\\n  loader.eval_batch_size=2 \\\n  model=small \\\n  data=openwebtext-split \\\n  wandb.name=mdlm-owt \\\n  algo=mdlm \\\n  model.length=1024 \\\n  eval.compute_generative_perplexity=False \\\n  +wandb.offline=True"
  },
  {
    "path": "scripts/train_owt_sedd.sh",
    "content": "#!/bin/bash\n#SBATCH -J train_sedd                 # Job name\n#SBATCH -o watch_folder/%x_%j.out     # output file (%j expands to jobID)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --constraint=\"gpu-mid|gpu-high\"\n#SBATCH --ntasks-per-node=4\n#SBATCH --gres=gpu:4                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon pre-emption\n\n# To enable preemption re-loading, set `hydra.run.dir` or \n# `checkpointing.save_dir` explicitly.\nsrun python -u -m main \\\n  loader.batch_size=16 \\\n  loader.eval_batch_size=16 \\\n  model=small \\\n  data=openwebtext-split \\\n  wandb.name=sedd-owt \\\n  algo=sedd \\\n  model.length=1024 \\\n  eval.compute_generative_perplexity=True \\\n  sampling.predictor=analytic"
  },
  {
    "path": "scripts/zero_shot_ar.sh",
    "content": "#!/bin/bash\n#SBATCH -J zeroshot_ar                # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=gpu               # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-AgBZrc-small-ar-param-ar_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\ndatasets=(\"ag_news\"\n          \"scientific_papers_pubmed\"\n          \"scientific_papers_arxiv\"\n          \"lambada\"\n          \"wikitext2\"\n          \"wikitext103\"\n          \"ptb\"\n          \"lm1b-gpt2\")\nfor data in \"${datasets[@]}\"; do\n  echo \"$data\"\n  srun python -u -m main \\\n    mode=ppl_eval \\\n    loader.batch_size=16 \\\n    loader.eval_batch_size=16 \\\n    data=\"$data\" \\\n    data.insert_valid_eos=False \\\n    algo=ar \\\n    model.length=1024 \\\n    eval.checkpoint_path=$checkpoint_path \\\n    +wandb.offline=true\ndone"
  },
  {
    "path": "scripts/zero_shot_duo.sh",
    "content": "#!/bin/bash\n#SBATCH -J zeroshot_duo               # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                   # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/flow-ode/flow-ode-VlCQLK-small-conjugate-OWT-anneal/checkpoints/last.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\ndatasets=(\"ag_news\"\n          \"scientific_papers_pubmed\"\n          \"scientific_papers_arxiv\"\n          \"lambada\"\n          \"wikitext2\"\n          \"wikitext103\"\n          \"ptb\"\n          \"lm1b-gpt2\")\nfor data in \"${datasets[@]}\"; do\n  echo \"$data\"\n  srun python -u -m main \\\n    mode=ppl_eval \\\n    loader.batch_size=16 \\\n    loader.eval_batch_size=16 \\\n    loader.eval_global_batch_size=128 \\\n    data=\"$data\" \\\n    data.insert_valid_eos=False \\\n    model=small \\\n    algo=duo_base \\\n    model.length=1024 \\\n    eval.checkpoint_path=$checkpoint_path \\\n    sampling.num_sample_batches=0 \\\n    +wandb.offline=true\ndone"
  },
  {
    "path": "scripts/zero_shot_mdlm.sh",
    "content": "#!/bin/bash\n#SBATCH -J zeroshot_mdlm_noeos              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diff-clean-s-owt-no-t-mQ4fQG-param-subs_data-openwebtext-split/checkpoints/61-1000000.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\ndatasets=(\"ag_news\"\n          \"scientific_papers_pubmed\"\n          \"scientific_papers_arxiv\"\n          \"lambada\"\n          \"wikitext2\"\n          \"wikitext103\"\n          \"ptb\"\n          \"lm1b-gpt2\")\nfor data in \"${datasets[@]}\"; do\n  echo \"$data\"\n  srun python -u -m main \\\n    mode=ppl_eval \\\n    loader.batch_size=16 \\\n    loader.eval_batch_size=16 \\\n    loader.eval_global_batch_size=128 \\\n    data=\"$data\" \\\n    data.insert_valid_eos=False \\\n    model=small \\\n    algo=mdlm \\\n    model.length=1024 \\\n    eval.checkpoint_path=$checkpoint_path \\\n    +wandb.offline=true\ndone"
  },
  {
    "path": "scripts/zero_shot_sedd.sh",
    "content": "#!/bin/bash\n#SBATCH -J zeroshot_sedd              # Job name\n#SBATCH -o watch_folder/%x_%j.out     # log file (out & err)\n#SBATCH -N 1                          # Total number of nodes requested\n#SBATCH --get-user-env                # retrieve the users login environment\n#SBATCH --mem=32000                  # server memory requested (per node)\n#SBATCH -t 960:00:00                  # Time limit (hh:mm:ss)\n#SBATCH --partition=kuleshov          # Request partition\n#SBATCH --constraint=\"[a5000|a6000|a100|3090]\"\n#SBATCH --ntasks-per-node=1\n#SBATCH --gres=gpu:1                  # Type/number of GPUs needed\n#SBATCH --open-mode=append            # Do not overwrite logs\n#SBATCH --requeue                     # Requeue upon preemption\n\ncheckpoint_path=/share/kuleshov/ssahoo/textdiffusion/text-diffusion-exp-v4-nBm2gE-small-param-sedd_data-openwebtext-split_seqlen-1024_maxs-1300001_bs-512/checkpoints/last.ckpt\n\nexport HYDRA_FULL_ERROR=1\n\ndatasets=(\"ag_news\"\n          \"scientific_papers_pubmed\"\n          \"scientific_papers_arxiv\"\n          \"lambada\"\n          \"wikitext2\"\n          \"wikitext103\"\n          \"ptb\"\n          \"lm1b-gpt2\")\nfor data in \"${datasets[@]}\"; do\n  echo \"$data\"\n  srun python -u -m main \\\n    mode=ppl_eval \\\n    loader.batch_size=16 \\\n    loader.eval_batch_size=16 \\\n    loader.eval_global_batch_size=128 \\\n    data=\"$data\" \\\n    data.insert_valid_eos=False \\\n    model=small \\\n    algo=sedd \\\n    model.length=1024 \\\n    eval.checkpoint_path=$checkpoint_path \\\n    +wandb.offline=true\ndone"
  },
  {
    "path": "trainer_base.py",
    "content": "import itertools\nfrom dataclasses import dataclass\n\nimport hydra.utils\nimport lightning as L\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport transformers\n\nimport dataloader\nimport metrics\nimport models\nimport utils\n\n\n@dataclass\nclass Loss:\n  loss: torch.FloatTensor\n  nlls: torch.FloatTensor\n  prior_loss: torch.FloatTensor\n  num_tokens: torch.FloatTensor\n\n\nclass LogLinear(torch.nn.Module):\n  def __init__(self, eps):\n    super().__init__()\n    self.eps = eps # 1e-3 by default, to be consistent with SEDD: https://github.com/louaaron/Score-Entropy-Discrete-Diffusion/blob/0605786da5ccb5747545e26d66fdf477187598b6/noise_lib.py#L56\n\n  def forward(self, t):\n    t = (1 - self.eps) * t\n    alpha_t = 1 - t \n    dalpha_t = - torch.ones_like(alpha_t) * (1 - self.eps)\n    return dalpha_t, alpha_t\n\n  def get_t_for_alpha(self, alpha_t):\n    return 1 - alpha_t\n\n\nclass Cosine(torch.nn.Module):\n  def __init__(self, eps):\n    super().__init__()\n    self.eps = eps\n    self.half_pi = torch.pi / 2\n\n  def forward(self, t):\n    t = (1 - self.eps) * t\n    alpha_t = 1 - torch.cos(self.half_pi * (1 - t))\n    dalpha_t = - torch.sin(self.half_pi * (1 - t)) * self.half_pi\n    return dalpha_t, alpha_t\n\n  def get_t_for_alpha(self, alpha_t):\n    is_tensor = torch.is_tensor(alpha_t)\n    if not is_tensor:\n      alpha_t = torch.tensor([alpha_t])\n    t = 1 - 2 / torch.pi * torch.acos(1 - alpha_t)\n    if not is_tensor:\n      t = t.cpu().item()\n    return t\n\n\ndef sample_categorical(categorical_probs):\n  gumbel_norm = (\n    1e-10\n    - (torch.rand_like(categorical_probs) + 1e-10).log())\n  return (categorical_probs / gumbel_norm).argmax(dim=-1)\n\n\ndef _unsqueeze(x, reference):\n  return x.view(\n    * x.shape,\n    * ((1,) * (len(reference.shape) - len(x.shape))))\n\n\nclass TrainerBase(L.LightningModule):\n  def __init__(\n    self,\n    config,\n    tokenizer: transformers.PreTrainedTokenizer,\n    vocab_size=None):\n    super().__init__()\n    self.save_hyperparameters()\n    self.config = config\n    if hasattr(self.config.algo, 'ignore_bos'):\n      self.ignore_bos = config.algo.ignore_bos\n    else:\n      self.ignore_bos = False\n    if hasattr(self.config.algo, 'loss_type'):\n      self.loss_type = config.algo.loss_type\n    self.tokenizer = tokenizer\n    if vocab_size is None:\n      self.vocab_size = len(self.tokenizer)\n    else:\n      self.vocab_size = vocab_size\n    self.sampler = self.config.sampling.predictor\n    self.antithetic_sampling = self.config.training.antithetic_sampling\n    self.parameterization = self.config.algo.parameterization\n    if self.config.algo.backbone == 'dit':\n      self.backbone = models.dit.DIT(\n        self.config, vocab_size=self.vocab_size)\n    elif self.config.algo.backbone == 'unet':\n      self.backbone = models.unet.UNet(self.config, self.vocab_size)\n    elif self.config.algo.backbone == 'dimamba':\n      self.backbone = models.dimamba.DiMamba(\n        self.config,\n        vocab_size=self.vocab_size,\n        pad_token_id=self.tokenizer.pad_token_id)\n    elif self.config.algo.backbone == 'hf_dit':\n      self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(\n        config.eval.checkpoint_path, trust_remote_code=True)\n\n    self.T = self.config.algo.T\n    self.num_tokens = self.config.model.length\n    self.p_nucleus = self.config.sampling.p_nucleus\n    # Noise Schedule\n    if config.noise.type == 'log-linear':\n      self.noise = LogLinear(config.noise.eps)\n    elif config.noise.type == 'cosine':\n      self.noise = Cosine(config.noise.eps)\n    else:\n      raise ValueError(config.noise.type)\n    # Class-conditional training arguments\n    self.num_classes = config.data.get('num_classes', None)\n    self.class_conditional = self.num_classes is not None\n    self.class_cond_dropout = config.training.class_dropout_p \n\n    self.metrics = metrics.Metrics(\n      gen_ppl_eval_model_name_or_path=\\\n        self.config.eval.gen_ppl_eval_model_name_or_path,\n      eval_ppl_batch_size=\\\n        self.config.eval.perplexity_batch_size)\n\n    if self.config.training.ema > 0:\n      self.ema = models.ema.ExponentialMovingAverage(\n        self._get_parameters(),\n        decay=self.config.training.ema)\n    else:\n      self.ema = None\n    \n    self.lr = self.config.optim.lr\n    self.sampling_eps = self.config.training.sampling_eps\n    self.time_conditioning = self.config.algo.time_conditioning\n    self.neg_infinity = -1000000.0\n    self.fast_forward_epochs = None\n    self.fast_forward_batches = None\n\n  def _validate_configuration(self):\n    assert self.config.algo.backbone in {'dit', 'hf_dit', \n                                         'unet'}\n    if self.config.algo.parameterization == 'ar':\n      assert not self.config.algo.time_conditioning\n      assert self.config.prior.type == 'none'\n    if self.time_conditioning:\n      assert self.config.sampling.predictor != 'ancestral_cache'\n\n    if self.parameterization in {'score', 'mean'}:\n      assert self.time_conditioning\n    if self.T > 0:\n      assert self.parameterization != 'score'\n\n  def to(self, *args, **kwargs):\n    self = super().to(*args, **kwargs) \n    self.metrics.to(*args, **kwargs)\n    return self\n\n  def q_xt(self, x, alpha_t):\n    raise NotImplementedError\n\n  def _get_parameters(self):\n    return itertools.chain(self.backbone.parameters(),\n                           self.noise.parameters())\n\n  def _eval_mode(self):\n    if self.ema:\n      self.ema.store(self._get_parameters())\n      self.ema.copy_to(self._get_parameters())\n    self.backbone.eval()\n    self.noise.eval()\n\n  def _train_mode(self):\n    if self.ema:\n      self.ema.restore(self._get_parameters())\n    self.backbone.train()\n    self.noise.train()\n\n  def on_load_checkpoint(self, checkpoint):\n    if self.ema:\n      self.ema.load_state_dict(checkpoint['ema'])\n    # Copied from:\n    # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41\n    self.fast_forward_epochs = checkpoint['loops'][\n      'fit_loop']['epoch_progress']['current']['completed']\n    self.fast_forward_batches = checkpoint['loops'][\n      'fit_loop']['epoch_loop.batch_progress'][\n        'current']['completed']\n\n  def on_save_checkpoint(self, checkpoint):\n    if self.ema:\n      checkpoint['ema'] = self.ema.state_dict()\n    # Copied from:\n    # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py\n    # ['epoch_loop.batch_progress']['total']['completed']\n    # is 1 iteration behind, so we're using the optimizer's progress.\n    checkpoint['loops']['fit_loop'][\n      'epoch_loop.batch_progress']['total'][\n        'completed'] = checkpoint['loops']['fit_loop'][\n          'epoch_loop.automatic_optimization.optim_progress'][\n            'optimizer']['step']['total'][\n              'completed'] * self.trainer.accumulate_grad_batches\n    checkpoint['loops']['fit_loop'][\n      'epoch_loop.batch_progress']['current'][\n        'completed'] = checkpoint['loops']['fit_loop'][\n          'epoch_loop.automatic_optimization.optim_progress'][\n            'optimizer']['step']['current'][\n              'completed'] * self.trainer.accumulate_grad_batches\n    # _batches_that_stepped tracks the number of global steps,\n    # not the number of local steps, so we don't multiply with\n    # self.trainer.accumulate_grad_batches here.\n    checkpoint['loops']['fit_loop'][\n      'epoch_loop.state_dict'][\n        '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][\n          'epoch_loop.automatic_optimization.optim_progress'][\n            'optimizer']['step']['total']['completed']\n    if 'sampler' not in checkpoint.keys():\n      checkpoint['sampler'] = {}\n    if hasattr(self.trainer.train_dataloader.sampler,\n               'state_dict'):\n      sampler_state_dict = self.trainer.\\\n        train_dataloader.sampler.state_dict()\n      checkpoint['sampler'][\n        'random_state'] = sampler_state_dict.get(\n          'random_state', None)\n    else:\n      checkpoint['sampler']['random_state'] = None\n\n  def on_train_start(self):\n    if self.ema:\n      self.ema.move_shadow_params_to_device(self.device)\n    # Adapted from:\n    # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py\n    distributed = (\n      self.trainer._accelerator_connector.use_distributed_sampler\n      and self.trainer._accelerator_connector.is_distributed)\n    if distributed:\n      sampler_cls = dataloader.FaultTolerantDistributedSampler\n    else:\n      sampler_cls = dataloader.RandomFaultTolerantSampler\n    updated_dls = []\n    for dl in self.trainer.fit_loop._combined_loader.flattened:\n      if hasattr(dl.sampler, 'shuffle'):\n        dl_sampler = sampler_cls(\n          dl.dataset, shuffle=dl.sampler.shuffle)\n      else:\n        dl_sampler = sampler_cls(dl.dataset)\n      if (distributed\n          and self.fast_forward_epochs is not None\n          and self.fast_forward_batches is not None):\n        dl_sampler.load_state_dict({\n          'epoch': self.fast_forward_epochs,\n          'counter': (self.fast_forward_batches\n                      * self.config.loader.batch_size)})\n      updated_dls.append(\n        torch.utils.data.DataLoader(\n          dl.dataset,\n          batch_size=self.config.loader.batch_size,\n          num_workers=self.config.loader.num_workers,\n          pin_memory=self.config.loader.pin_memory,\n          sampler=dl_sampler,\n          shuffle=False,\n          persistent_workers=True))\n    self.trainer.fit_loop._combined_loader.flattened = updated_dls\n\n  def optimizer_step(self, *args, **kwargs):\n    super().optimizer_step(*args, **kwargs)\n    if self.ema:\n      self.ema.update(self._get_parameters())\n\n  def _process_sigma(self, sigma):\n    raise NotImplementedError\n\n  def _process_model_output(self, model_output, xt, sigma):\n    raise NotImplementedError\n\n  def forward(self, xt, sigma, labels=None, weights=None,\n              nn_input_idxs=None):\n    if nn_input_idxs is None:\n      nn_input_idxs = xt\n\n    sigma = self._process_sigma(sigma)\n    with torch.amp.autocast('cuda', dtype=torch.float32):\n      model_output = self.backbone(\n        x=nn_input_idxs, sigma=sigma, class_cond=labels, \n        weights=weights)\n    return self._process_model_output(\n      model_output=model_output, xt=xt, sigma=sigma)\n\n  def on_train_epoch_start(self):\n    self.metrics.reset()\n    assert self.metrics.train_nlls.nll.mean_value == 0\n    assert self.metrics.train_nlls.nll.weight == 0\n\n  def training_step(self, batch, batch_idx):\n    current_accumulation_step = (\n      batch_idx % self.trainer.accumulate_grad_batches)\n    losses = self._loss(batch['input_ids'],\n                        batch.get('labels', None),\n                        batch['attention_mask'],\n                        current_accumulation_step,\n                        train_mode=True)\n    self.metrics.update_train(losses.nlls, losses.prior_loss,\n                              losses.num_tokens)\n    self.log(name='trainer/loss',\n             value=losses.loss.item(),\n             on_step=True,\n             on_epoch=False,\n             sync_dist=True)\n    return losses.loss\n\n  def on_train_epoch_end(self):\n    for k, v in self.metrics.valid_nlls.items():\n      self.log(name=k, value=v.compute(), on_step=False,\n               on_epoch=True, sync_dist=True)\n\n  def on_validation_epoch_start(self):\n    self.metrics.reset()\n    self._eval_mode()\n    assert self.metrics.valid_nlls.nll.mean_value == 0\n    assert self.metrics.valid_nlls.nll.weight == 0\n\n  def validation_step(self, batch, batch_idx):\n    del batch_idx\n    losses = self._loss(batch['input_ids'],\n                        batch.get('labels', None),\n                        batch['attention_mask'])\n    self.metrics.update_valid(losses.nlls, losses.prior_loss,\n                              losses.num_tokens)\n    return losses.loss\n\n  def on_validation_epoch_end(self):\n    for k, v in self.metrics.valid_nlls.items():\n      self.log(name=k,  value=v.compute(), on_step=False,\n               on_epoch=True, sync_dist=True)\n    if ((self.config.eval.compute_perplexity_on_sanity\n         or not self.trainer.sanity_checking)\n         and self.config.eval.generate_samples):\n      samples, text_samples = None, None\n      for _ in range(\n        self.config.sampling.num_sample_batches):\n        samples = self.generate_samples(\n          num_samples=self.config.loader.eval_batch_size)\n        \n        self.metrics.record_entropy(samples)\n        # Decode the samples to be re-tokenized by eval model\n        text_samples = self.tokenizer.batch_decode(samples)\n        if self.config.eval.compute_generative_perplexity:\n          self.metrics.record_generative_perplexity(\n            text_samples, self.num_tokens, self.device)\n      if text_samples is not None:\n        if self.trainer.global_rank == 0 and hasattr(\n          self.trainer.logger, 'log_table'):\n          # Log the last generated samples\n          text_samples = text_samples[\n            : self.config.sampling.num_sample_log]\n          self.trainer.logger.log_table(\n            key=f'samples@global_step{self.global_step}',\n            columns=['Generated Samples'],\n            data=[[s] for s in text_samples])\n        if self.config.eval.compute_generative_perplexity:\n          self.log('val/gen_ppl',\n                   self.metrics.gen_ppl.compute(),\n                   on_epoch=True,\n                   on_step=False,\n                   sync_dist=True)\n          self.log('val/sample_entropy',\n                   self.metrics.sample_entropy.compute(),\n                   on_epoch=True,\n                   on_step=False,\n                   sync_dist=True)\n    self._train_mode()\n\n  def configure_optimizers(self):\n    optimizer = torch.optim.AdamW(\n      self._get_parameters(),\n      lr=self.config.optim.lr,\n      betas=(self.config.optim.beta1,\n             self.config.optim.beta2),\n      eps=self.config.optim.eps,\n      weight_decay=self.config.optim.weight_decay)\n\n    scheduler = hydra.utils.instantiate(\n      self.config.lr_scheduler, optimizer=optimizer)\n    scheduler_dict = {'scheduler': scheduler,\n                      'interval': 'step',\n                      'monitor': 'val/loss',\n                      'name': 'trainer/lr'}\n    return [optimizer], [scheduler_dict]\n\n  def generate_samples(self, num_samples, num_steps, eps):\n    raise NotImplementedError\n\n  def restore_model_and_sample(self, num_steps, eps=1e-5):\n    \"\"\"Generate samples from the model.\"\"\"\n    # Lightning auto-casting is not working in this method for some reason\n    self._eval_mode()\n    samples = self.generate_samples(\n      num_samples=self.config.loader.eval_batch_size,\n      num_steps=num_steps,\n      eps=eps)\n    self._train_mode()\n    return samples\n\n  def _process_model_input(self, x0, valid_tokens):\n    raise NotImplementedError\n\n  def nll(self, input_tokens, labels, output_tokens,\n          current_accumulation_step=None, train_mode=False):\n    raise NotImplementedError\n\n  def _loss(self, x0, labels, valid_tokens,\n            current_accumulation_step=None,\n            train_mode=False):\n    (input_tokens, output_tokens,\n     valid_tokens) = self._process_model_input(\n       x0, valid_tokens)\n    loss = self.nll(input_tokens, labels, output_tokens,\n                    current_accumulation_step, train_mode)\n    assert loss.ndim == 2\n    if self.ignore_bos:\n      loss[:, 1:] = loss[:, 1:]\n      valid_tokens[:, 1:] = valid_tokens[:, 1:]\n\n    nlls = (loss * valid_tokens).sum()\n    num_tokens = valid_tokens.sum()\n    token_nll = nlls / num_tokens\n\n    return Loss(loss=token_nll,\n                nlls=nlls,\n                prior_loss=0.0,\n                num_tokens=num_tokens)\n\n\nclass Diffusion(TrainerBase):\n  def _validate_configuration(self):\n    super()._validate_configuration()\n    assert self.config.sampling.noise_removal in {\n      'none', 'ancestral', 'greedy'}\n    assert self.loss_type in {'elbo', 'low_var'}\n    if self.config.sampling.noise_removal == 'greedy':\n      assert self.sampler != 'analytic'\n      assert self.parameterization in {'mean', 'subs'}\n\n  def _process_model_input(self, x0, valid_tokens):\n    return x0, None, valid_tokens\n\n  def _process_sigma(self, sigma):\n    assert sigma.ndim == 2\n    sigma = sigma.mean(-1).squeeze()\n    if sigma.ndim == 0:\n      sigma = sigma.unsqueeze(0)\n    if not self.time_conditioning:\n      sigma = torch.zeros_like(sigma)\n    assert sigma.ndim == 1, sigma.shape\n    return sigma\n\n  def _sample_t(self, n, accum_step):\n    if accum_step is not None:\n      # During training\n      batch_dim = n\n      n = self.config.loader.global_batch_size\n    _eps_t = torch.rand(n, device=self.device)\n    if self.antithetic_sampling:\n      offset = torch.arange(n, device=self.device) / n\n      _eps_t = (_eps_t / n + offset) % 1\n    t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps\n    if accum_step is not None:\n      t = t.chunk(self.trainer.num_nodes)[self.trainer.node_rank]\n      t = t.chunk(self.trainer.num_devices)[self.trainer.local_rank]\n      t = t.chunk(self.trainer.accumulate_grad_batches)[\n        accum_step]\n      # corner case for the last datapoint\n      t = t[:batch_dim]\n    return t\n\n  def _sigma_from_alphat(self, alpha_t):\n    return -torch.log(alpha_t)\n\n  def _reconstruction_loss(self, x0):\n    t0 = torch.zeros(1, x0.shape[0], dtype=self.dtype,\n                     device=self.device)\n    sigma_t0 = self._sigma_from_alphat(self.noise(t0)[1])\n    model_output_t0 = self.forward(x0, sigma_t0)\n    return - torch.gather(input=model_output_t0,\n                          dim=-1,\n                          index=x0[:, :, None]).squeeze(-1)\n\n  def nll_per_token(self, model_output, xt, x0, alpha_t,\n                    dalpha_t, low_var):\n    raise NotImplementedError\n\n  def nll(self, x0, labels, output_tokens,\n          current_accumulation_step=None, train_mode=False):\n    del output_tokens\n    t = self._sample_t(x0.shape[0],\n                       current_accumulation_step)\n    assert t.shape[0] == x0.shape[0]\n    if self.T > 0:\n      t = (t * self.T).to(torch.int)\n      t = t / self.T\n      # t \\in {1/T, 2/T, ..., 1}\n      t += (1 / self.T)\n    \n    dalpha_t, alpha_t = self.noise(t)\n    alpha_t = alpha_t.unsqueeze(-1)\n    dalpha_t = dalpha_t.unsqueeze(-1)\n    assert alpha_t.ndim == 2\n    assert dalpha_t.ndim == 2\n    sigma = self._sigma_from_alphat(alpha_t)\n\n    xt = self.q_xt(x0, alpha_t)\n    # Handle class-conditional training, with class dropout\n    if self.class_conditional:\n      assert labels is not None\n      rand = torch.rand(size=labels.shape, dtype=torch.float32, \n                      device=self.device)\n      # num_classes represent the absence of class-conditioning\n      labels = torch.where(rand < self.class_cond_dropout, \n                           self.num_classes, labels)\n    else:\n      assert labels is None\n    log_x_theta = self.forward(xt, sigma=sigma, labels=labels)\n    utils.print_nans(log_x_theta, 'model_output')\n    return self.nll_per_token(\n      log_x_theta=log_x_theta,\n      xt=xt,\n      x0=x0,\n      alpha_t=alpha_t,\n      dalpha_t=dalpha_t,\n      low_var=train_mode and self.loss_type == 'low_var')\n\n  def _get_score(self, **kwargs):\n    del kwargs\n    raise NotImplementedError\n\n  def _denoiser_update(self, x, t):\n    raise NotImplementedError\n\n  def _analytic_update(self, x, t, dt):\n    raise NotImplementedError\n\n  def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):\n    \"\"\"From the clean x0, or denoiser predictions, implement \n    the mathematical expression for q(z_s | z_t, x)\"\"\"\n    raise NotImplementedError\n\n  def _forward_process(self, q_x0, alpha_s):\n    \"\"\"Apply forward noising: q(x_s | x_0), where x_0 is a \n       K-dimensional vector\"\"\"\n    raise NotImplementedError\n\n  def _get_ancestral_posterior(self, xt, sigma, labels, \n                              alpha_s, alpha_t, p_x0):\n    \"\"\"Top-level call from generate_samples. Compute the \n       standard posterior sampling distribution from the \n       noisy sequence xt\"\"\"\n    gamma = self.config.sampling.guid_weight\n    # If class cond but gamma is None or 0, same as sampling \n    #  from the class unconditional case.\n    if not self.class_conditional or gamma in (None, 0.0):\n      if labels is not None:\n        labels = torch.full_like(labels, self.num_classes)\n      return self._get_posterior_from_xt(xt, sigma, labels, \n                                         alpha_s, alpha_t, p_x0)\n    elif gamma == 1.0:  \n      # Simply sample from the cond. posterior only.\n      assert labels is not None\n      return self._get_posterior_from_xt(xt, sigma, labels, \n                                         alpha_s, alpha_t, p_x0)\n    else:\n      # Case gamma not in {0, 1}, and class conditional \n      #  -> mix conditional and unconditional predictions.\n      return self._get_guided_posterior_from_xt(self, xt,\n        sigma, labels, gamma, alpha_s, alpha_t, p_x0)\n\n  def _get_posterior_from_xt(self, xt, sigma, labels, alpha_s, \n                             alpha_t, p_x0=None):\n    \"\"\"From xt, compute a single denoiser predictions, cast \n       to correct dtype, and call _posterior_from_x0.\"\"\"\n    if p_x0 is None:\n      log_x0_pred = self.forward(xt, sigma, labels)\n\n      if self.config.sampling.use_float64:\n        log_x0_pred = log_x0_pred.to(torch.float64)\n\n      if self.p_nucleus < 1:\n        log_x0_pred = utils.top_k_top_p_filtering(\n          log_x0_pred, top_p=self.p_nucleus)\n      p_x0 = log_x0_pred.exp()\n\n    return p_x0, self._posterior_from_x0(x0=p_x0,xt=xt,\n      alpha_s=alpha_s, alpha_t=alpha_t)\n\n  def _get_guided_posterior_from_xt(self, xt, sigma, labels, \n    gamma, alpha_s, alpha_t, p_x0=None):\n    \"\"\"From xt, combine the class cond / uncond predictions \n       of the denoiser. Call _get_posterior_from_xt.\"\"\"\n    # unpack the cache\n    if p_x0 is None:\n      p_x0_cond = p_x0_uncond = None\n    else:\n      p_x0_cond, p_x0_uncond = p_x0\n    \n    p_x0_cond, cond_posterior = self._get_posterior_from_xt(\n      xt, sigma, labels, alpha_s, alpha_t, p_x0_cond)\n    log_cond_posterior = cond_posterior.log()\n    # NOTE: conditioning on self.num_classes represents the\n    #  class-unconditional predictions.\n    p_x0_uncond, uncond_posterior = self._get_posterior_from_xt(\n      xt, sigma, torch.full_like(labels, self.num_classes), \n      alpha_s, alpha_t, p_x0_uncond)\n    log_uncond_posterior = uncond_posterior.log()\n\n    un_normalized_posterior = gamma * log_cond_posterior \\\n                       + (1 - gamma) * log_uncond_posterior\n    # Handle cases where the posterior is zero (eg after \n    #  nucleus sampling)\n    is_inf_mask = torch.logical_or(\n      log_cond_posterior.isinf(), log_uncond_posterior.isinf())\n    un_normalized_posterior[is_inf_mask] = self.neg_infinity\n    \n    return ((p_x0_cond, p_x0_uncond), \n            un_normalized_posterior.softmax(-1))\n\n  def _ancestral_update(self, x, t, labels, dt, p_x0=None, \n                        noise_removal_step=False):\n    _, alpha_t = self.noise(t)\n    if noise_removal_step:\n      alpha_s = torch.ones_like(alpha_t)\n    else:\n      _, alpha_s = self.noise(t - dt)\n    assert alpha_t.ndim == 2\n\n    sigma = self._sigma_from_alphat(alpha_t)\n    assert alpha_t.ndim == 2, f'{alpha_t.ndim=}'\n\n    p_x0, q_xs = self._get_ancestral_posterior(x, sigma, \n      labels, alpha_s, alpha_t, p_x0)\n\n    return p_x0, sample_categorical(q_xs)\n\n  def _psi_update(self, x, t, labels, dt, kappa, p_x0=None,\n                  noise_removal_step=False):\n    _, alpha_t = self.noise(t)\n    if noise_removal_step:\n      alpha_s = torch.ones_like(alpha_t)\n    else:\n      _, alpha_s = self.noise(t - dt)\n    alpha_0 = torch.ones_like(alpha_t)\n    sigma = self._sigma_from_alphat(alpha_t)\n\n    # Standard posterior q(x_s | x_t)\n    p_x0, q_xs = self._get_ancestral_posterior(\n      x, sigma, labels, alpha_s, alpha_t, p_x0)\n\n    # Posterior targeting t=0, reuse predictions p_x0\n    _, q_x0 = self._get_ancestral_posterior(\n      x, sigma, labels, alpha_0, alpha_t, p_x0)\n\n    # PC: forward-noise q_x0 back to time s\n    pc_q_xs = self._forward_process(q_x0, alpha_s)\n\n    q_sample = kappa * q_xs + (1 - kappa) * pc_q_xs\n    return p_x0, sample_categorical(q_sample)\n\n  def _get_sampling_time_profile(self, eps, num_steps):\n    profile = self.config.sampling.psi.time_profile\n    num_steps += 1\n    if profile == 'linear' \\\n      or self.config.sampling.predictor != 'psi':\n      # Default: linearly decrease\n      return torch.linspace(1, eps, num_steps)\n    if not profile.startswith('linear-constant-linear'):\n      raise ValueError(profile)\n    c = float(profile.split('-')[3])\n    if 'inv' in profile:\n      c = self.noise.get_t_for_alpha(c)\n    psi_cfg = self.config.sampling.psi\n    n_hi = round(psi_cfg.high_frac * num_steps)\n    n_mid = round(psi_cfg.middle_frac * num_steps)\n    return torch.cat([\n      torch.linspace(1, c, n_hi),\n      torch.full((n_mid,), c),\n      torch.linspace(c, eps, num_steps - n_hi - n_mid)])\n\n  def _mode_to_psi_kappas(self, mode, timesteps):\n    n = len(timesteps)\n    if mode == 'pure-posterior':\n      return torch.ones(n)\n    if mode == 'pure-pc':\n      return torch.zeros(n)\n\n    eta = float(mode.split('-')[-1])\n    if (mode.startswith('constant-')\n        and not mode.startswith('constant-remdm')):\n      return torch.full((n,), eta)\n\n    # Noise-schedule-dependent modes (ReMDM-like)\n    _, all_alphas = self.noise(timesteps)\n    alpha_t, alpha_s = all_alphas[:-1], all_alphas[1:]\n    eta_t = torch.tensor([eta])\n    if mode.startswith('max-capped-'):\n      sigma = torch.minimum(\n        eta_t.expand_as(alpha_t), (1 - alpha_s) / alpha_t)\n      sigma = torch.where(alpha_t == 0, eta, sigma)\n    elif mode.startswith('max-rescale-'):\n      sigma_max = torch.minimum(\n        eta_t.expand_as(alpha_t), (1 - alpha_s) / alpha_t)\n      sigma = torch.where(alpha_t > 0, sigma_max, 1) * eta\n    elif mode.startswith('constant-remdm'):\n      sigma = eta_t\n    else:\n      raise ValueError(mode)\n    kappas = torch.clip(1 - sigma / (1 - alpha_s), 0, 1)\n    if len(kappas) > 0:\n      kappas = torch.cat([kappas, torch.ones(1)])\n    return kappas\n\n  def _get_kappas(self, timesteps):\n    cfg = self.config.sampling.psi\n    n = len(timesteps)\n    n_hi = round(cfg.high_frac * n)\n    n_mid = round(cfg.middle_frac * n)\n    kappas = torch.cat([\n      self._mode_to_psi_kappas(cfg.high_mode, timesteps[:n_hi]),\n      self._mode_to_psi_kappas(cfg.middle_mode,\n        timesteps[n_hi:n_hi + n_mid]),\n      self._mode_to_psi_kappas(cfg.low_mode,\n        timesteps[n_hi + n_mid:])])\n    assert (kappas >= 0).all() and (kappas <= 1).all()\n    assert len(kappas) == n\n    return kappas\n\n  @torch.no_grad()\n  def generate_samples(self, num_samples, labels=None,\n                       num_steps=None, eps=1e-5):\n    \"\"\"Generate samples from the model.\"\"\"\n    # Lightning auto-casting is not working in this method for some reason\n    if num_steps is None:\n      num_steps = self.config.sampling.steps\n    x = self.prior_sample(num_samples, self.num_tokens)\n    use_psi_sampler = self.config.sampling.predictor == 'psi'\n    timesteps = self._get_sampling_time_profile(eps, \n                                                num_steps)\n    if use_psi_sampler:\n      kappas = self._get_kappas(timesteps).to(self.device)\n\n    p_x0_cache = None\n    if labels is not None:\n      labels = labels.to(self.device)\n\n    for i in range(num_steps):\n      t = timesteps[i] * torch.ones(\n        x.shape[0], 1, device=self.device)\n      dt = timesteps[i] - timesteps[i+1]\n      if self.sampler == 'ancestral':\n        _, x = self._ancestral_update(\n          x=x, t=t, labels=labels, dt=dt, p_x0=None)\n      elif self.sampler == 'ancestral_cache':\n        p_x0_cache, x_next = self._ancestral_update(\n          x=x, t=t, labels=labels, dt=dt, p_x0=p_x0_cache)\n        if (not torch.allclose(x_next, x)\n            or self.time_conditioning):\n          # Disable caching\n          p_x0_cache = None\n        x = x_next\n      elif self.sampler == 'psi':\n        _, x = self._psi_update(x=x, t=t, kappa=kappas[i],\n                                labels=labels, dt=dt, \n                                p_x0=None)\n      elif self.sampler == 'analytic':\n        assert labels is None, 'class-conditional sampling ' \\\n          'is not implemented with the analytic sampler'\n        x = self._analytic_update(x=x,t=t, dt=dt)\n      else:\n        raise ValueError(self.sampler)\n\n    t0 = timesteps[-1] * torch.ones(x.shape[0], 1,\n                                    device=self.device)\n    if self.config.sampling.noise_removal == 'ancestral':\n      if self.sampler == 'analytic':\n        x = self._denoiser_update(x=x, t=t0)\n      else:\n        _, x = self._ancestral_update(x=x, t=t0, labels=labels,\n          dt=None, p_x0=p_x0_cache, noise_removal_step=True)\n    elif self.config.sampling.noise_removal == 'greedy':\n      sigma = self._sigma_from_alphat(self.noise(t0)[1])\n      x = self.forward(xt=x, sigma=sigma).argmax(dim=-1)\n    return x\n\n  @torch.no_grad\n  def _semi_ar_sampler(\n    self, n_samples, stride_length, num_strides, dt=0.001):\n    # TODO(subham): Test this method after refactoring.\n    ones = torch.ones(n_samples, dtype=self.dtype,\n                      device=self.device)\n\n    num_steps = int(1 / dt)\n    sampling_steps = 0\n    intermediate_tokens = []\n    target = None\n    for _ in range(num_strides + 1):\n      p_x0_cache = None\n      x = self.prior_sample(n_samples, self.num_tokens)\n      if target is not None:\n        x[:, : -stride_length] = target\n      for i in range(num_steps + 1):\n        p_x0_cache, x_next = self._ancestral_update(\n          x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache)\n        if (not torch.allclose(x_next, x)\n            or self.time_conditioning):\n          p_x0_cache = None\n          sampling_steps += 1\n        x = x_next\n      x = self.forward(x, 0 * ones).argmax(dim=-1)\n      intermediate_tokens.append(\n        x[:, :stride_length].cpu().numpy())\n      target = x[:, stride_length:]\n    \n    intermediate_tokens.append(target.cpu().numpy())\n    intermediate_text_samples = []\n    sequence_lengths = ((\n      np.concatenate(intermediate_tokens, axis=1)[:, 1:]\n      == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1)\n    for i in range(2, len(intermediate_tokens) + 1):\n      intermediate_text_samples.append(\n        self.tokenizer.batch_decode(\n          np.concatenate(intermediate_tokens[:i], axis=1)))\n    return (sampling_steps, intermediate_text_samples,\n            sequence_lengths)\n\n  def restore_model_and_semi_ar_sample(\n      self, stride_length, num_strides, dt=0.001):\n    \"\"\"Generate samples from the model.\"\"\"\n    # Lightning auto-casting is not working in this method for some reason\n    # TODO(subham): Test this method after refactoring.\n    self._eval_mode()\n    (sampling_steps, samples,\n     sequence_lengths) = self._semi_ar_sampler(\n      n_samples=self.config.loader.eval_batch_size,\n      stride_length=stride_length,\n      num_strides=num_strides, \n      dt=dt)\n    self._train_mode()\n    return sampling_steps, samples, sequence_lengths\n\n\nclass AbsorbingState(Diffusion):\n  def __init__(self, config, tokenizer):\n    # NOTE: Ideally, we should do \n    # vocab_size = len(tokenizer), so that we account\n    # for the special tokens added in dataloader.py.\n    # But we use tokenizer.vocab_size so as to to be\n    # consistent with the prior checkpoints.\n    vocab_size = tokenizer.vocab_size\n    if (not hasattr(tokenizer, 'mask_token')\n        or tokenizer.mask_token is None):\n      self.mask_index = vocab_size\n      vocab_size += 1\n    else:\n      self.mask_index = tokenizer.mask_token_id\n    self.subs_masking = config.algo.subs_masking\n    super().__init__(config, tokenizer,\n                     vocab_size=vocab_size)\n    self.save_hyperparameters()\n\n  def _validate_configuration(self):\n    super()._validate_configuration()\n    if self.parameterization in {'score', 'mean'}:\n      assert self.time_conditioning\n    assert not (self.parameterization == 'mean'\n                and self.T == 0)\n    if self.T > 0:\n      assert self.parameterization in {'mean', 'subs'}\n    if self.subs_masking:\n      assert self.parameterization == 'mean'\n\n  def q_xt(self, x, alpha_t):\n    \"\"\"Computes the noisy sample xt.\n\n    Args:\n      x: int torch.Tensor with shape (batch_size,\n          diffusion_model_input_length), input. \n      alpha_t: float torch.Tensor with shape (batch_size, 1).\n    \"\"\"\n    move_indices = torch.rand(\n      * x.shape, device=x.device) < 1 - alpha_t\n    xt = torch.where(move_indices, self.mask_index, x)\n    if self.ignore_bos:\n      xt[:, 0] = x[:, 0]\n    return xt\n\n  def prior_sample(self, *batch_dims):\n    return self.mask_index * torch.ones(\n      * batch_dims, dtype=torch.int64, device=self.device)\n\n  def _posterior_from_x0(self, x0, xt, alpha_s, alpha_t):\n    \"\"\"From the clean x0, or denoiser predictions, implement \n    the mathematical expression for q(z_s | z_t, x)\"\"\"\n    assert x0.dtype == torch.float64, 'Requires float64 prec.'\n    # should be one-hot on clean tokens\n    orig_mask = xt[:, :, None] != self.mask_index\n    orig_mask = orig_mask.expand(-1, -1, x0.shape[-1])\n    orig_output_on_clean = x0[orig_mask]\n\n    q_xs = ((alpha_s - alpha_t) / (1 - alpha_t))[..., None] * x0\n    q_xs[..., self.mask_index] = (1 - alpha_s) / (1 - alpha_t)\n    q_xs[orig_mask] = orig_output_on_clean\n    return q_xs\n\n  def _forward_process(self, x0, alpha_s):\n    out = alpha_s[..., None] * x0\n    out[..., self.mask_index] = 1 - alpha_s\n    return out\n\n  def _staggered_score(self, score, dsigma):\n    score = score.clone()\n    extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)\n    score *= dsigma.exp()[:, None]\n    score[..., self.mask_index] += extra_const\n    return score\n\n  def _analytic_update(self, x, t, dt):\n    sigma_t = self._sigma_from_alphat(self.noise(t)[1])\n    sigma_s = self._sigma_from_alphat(self.noise(t - dt)[1])\n    dsigma = sigma_t - sigma_s\n    score = self._get_score(x, sigma_t)\n    if self.config.sampling.use_float64:\n      score = score.to(torch.float64)\n    stag_score = self._staggered_score(score, dsigma)\n    probs = stag_score * self._transp_transition(x, dsigma)\n    return sample_categorical(probs)\n\n  def _denoiser_update(self, x, t):\n    sigma = self._sigma_from_alphat(self.noise(t)[1])\n    score = self._get_score(x, sigma)\n    if self.config.sampling.use_float64:\n      score = score.to(torch.float64)\n    stag_score = self._staggered_score(score, sigma)\n    probs = stag_score * self._transp_transition(x, sigma)\n    probs[..., self.mask_index] = 0\n    samples = sample_categorical(probs)\n    return samples\n\n  def _transp_transition(self, i, sigma):\n    sigma = _unsqueeze(sigma, reference=i[..., None])\n    edge = torch.exp(-sigma) * F.one_hot(\n      i, num_classes=self.vocab_size)\n    edge += torch.where(i == self.mask_index,\n                        1 - torch.exp(-sigma).squeeze(-1),\n                        0)[..., None]\n    return edge\n\n\nclass UniformState(Diffusion):\n  def _validate_configuration(self):\n    super()._validate_configuration()\n    assert self.time_conditioning\n    assert self.parameterization == 'mean'\n    if self.config.algo.name != 'distillation':\n      assert self.T == 0\n\n  def _forward_process(self, x0, alpha_s):\n    return (alpha_s[..., None] * x0\n            + (1 - alpha_s[..., None]) / self.vocab_size)\n\n  def q_xt(self, x, alpha_t):\n    \"\"\"Computes the noisy sample xt.\n\n    Args:\n      x: int torch.Tensor with shape (batch_size,\n          diffusion_model_input_length), input.\n      move_chance: float torch.Tensor with shape\n        (batch_size, 1).\n    \"\"\"\n    move_indices = torch.rand(\n      *x.shape, device=x.device) < 1 - alpha_t\n    uniform_tensor = torch.randint(\n      0, self.vocab_size, x.shape, device=x.device)\n    xt = torch.where(move_indices, uniform_tensor, x)\n    if self.ignore_bos:\n      xt[:, 0] = x[:, 0]\n    return xt\n\n  def prior_sample(self, *batch_dims):\n    return torch.randint(\n      0, self.vocab_size, batch_dims, dtype=torch.int64,\n      device=self.device)"
  },
  {
    "path": "utils.py",
    "content": "\"\"\"Console logger utilities.\n\nCopied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py\nCopied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging\n\"\"\"\n\nimport argparse\nimport logging\nimport os\nimport pickle\nimport time\nfrom typing import Optional\n\nimport fsspec\nimport lightning\nimport numpy as np\nimport torch\nfrom scipy.integrate import quad\nfrom scipy.optimize import curve_fit\nfrom scipy.stats import norm\nfrom timm.scheduler import CosineLRScheduler\n\n\ndef count_parameters(model):\n  return sum(p.numel()\n             for p in model.parameters()\n             if p.requires_grad)\n\ndef fsspec_exists(filename):\n  \"\"\"Check if a file exists using fsspec.\"\"\"\n  fs, _ = fsspec.core.url_to_fs(filename)\n  return fs.exists(filename)\n\n\ndef fsspec_listdir(dirname):\n  \"\"\"Listdir in manner compatible with fsspec.\"\"\"\n  fs, _ = fsspec.core.url_to_fs(dirname)\n  return fs.ls(dirname)\n\n\ndef fsspec_mkdirs(dirname, exist_ok=True):\n  \"\"\"Mkdirs in manner compatible with fsspec.\"\"\"\n  fs, _ = fsspec.core.url_to_fs(dirname)\n  fs.makedirs(dirname, exist_ok=exist_ok)\n\n\ndef print_nans(tensor, name):\n  if torch.isnan(tensor).any():\n    print(name, tensor)\n\n\nclass LRHalveScheduler:\n  def __init__(self, warmup_steps, n_halve_steps):\n    self.warmup_steps = warmup_steps\n    self.n_halve_steps = n_halve_steps\n  \n  def __call__(self, current_step):\n    if current_step < self.warmup_steps:\n      return current_step / self.warmup_steps\n    return 0.5 ** ((current_step - self.warmup_steps)\n                   // self.n_halve_steps)\n\n\nclass CosineDecayWarmupLRScheduler(\n  CosineLRScheduler,\n  torch.optim.lr_scheduler._LRScheduler):\n  \"\"\"Wrap timm.scheduler.CosineLRScheduler\n  Enables calling scheduler.step() without passing in epoch.\n  Supports resuming as well.\n  Adapted from:\n    https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py\n  \"\"\"\n\n  def __init__(self, *args, **kwargs):\n    super().__init__(*args, **kwargs)\n    self._last_epoch = -1\n    self.step(epoch=0)\n\n  def step(self, epoch=None):\n    if epoch is None:\n      self._last_epoch += 1\n    else:\n      self._last_epoch = epoch\n    # We call either step or step_update, depending on\n    # whether we're using the scheduler every epoch or every\n    # step.\n    # Otherwise, lightning will always call step (i.e.,\n    # meant for each epoch), and if we set scheduler\n    # interval to \"step\", then the learning rate update will\n    # be wrong.\n    if self.t_in_epochs:\n      super().step(epoch=self._last_epoch)\n    else:\n      super().step_update(num_updates=self._last_epoch)\n\n\nclass LoggingContext:\n  \"\"\"Context manager for selective logging.\"\"\"\n  def __init__(self, logger, level=None, handler=None, close=True):\n    self.logger = logger\n    self.level = level\n    self.handler = handler\n    self.close = close\n\n  def __enter__(self):\n    if self.level is not None:\n      self.old_level = self.logger.level\n      self.logger.setLevel(self.level)\n    if self.handler:\n      self.logger.addHandler(self.handler)\n\n  def __exit__(self, et, ev, tb):\n    if self.level is not None:\n      self.logger.setLevel(self.old_level)\n    if self.handler:\n      self.logger.removeHandler(self.handler)\n    if self.handler and self.close:\n      self.handler.close()\n\n\nclass GradientInspectionCallback(lightning.Callback):\n    def __init__(self, num_grads_log):\n        self.num_grads_log = 10\n\n    def on_before_optimizer_step(self, trainer, pl_module, optimizer):\n      gradients = []\n      for name, param in pl_module.backbone.blocks.named_parameters():\n          gradients.append(param.grad.view(-1))\n\n      if gradients:\n        grads = torch.cat((gradients))\n        if not hasattr(pl_module, 'grad_accum_buffer'):\n          pl_module.grad_step = torch.tensor(\n            0, device=pl_module.device)\n          pl_module.grad_accum_buffer = torch.zeros(\n            self.num_grads_log,\n            grads.shape[0],\n            device=pl_module.device)\n        pl_module.grad_accum_buffer[pl_module.grad_step] = grads\n        pl_module.grad_step += 1\n\n      if (hasattr(pl_module, 'grad_accum_buffer') \n          and pl_module.grad_step == self.num_grads_log):\n        grads = pl_module.grad_accum_buffer\n        grad_var = grads.std(0).mean()\n        pl_module.log(name='trainer/grad_var',\n                      value=grad_var.item(),\n                      on_step=True,\n                      on_epoch=False,\n                      sync_dist=True)\n        # import ipdb; ipdb.set_trace()\n        # should save the grads tensor as a numpy array\n        # and visualize mean, median, top-k\n        pl_module.grad_accum_buffer.zero_()\n        pl_module.grad_step = 0\n\n\ndef get_logger(name=__name__, level=logging.INFO) -> logging.Logger:\n  \"\"\"Initializes multi-GPU-friendly python logger.\"\"\"\n\n  logger = logging.getLogger(name)\n  logger.setLevel(level)\n\n  # this ensures all logging levels get marked with the rank zero decorator\n  # otherwise logs would get multiplied for each GPU process in multi-GPU setup\n  for level in ('debug', 'info', 'warning', 'error',\n                'exception', 'fatal', 'critical'):\n    setattr(logger,\n            level,\n            lightning.pytorch.utilities.rank_zero_only(\n              getattr(logger, level)))\n\n  return logger\n\n\n# Copied from https://github.com/jdeschena/sdtt/blob/bbc54d5b3c5fcffd79602cff17ed34dde1f3eff6/src/sdtt/core/sampling/utils.py#L10\ndef top_k_top_p_filtering(\n    logits,\n    top_k=0,\n    top_p=0.0,\n    filter_value=-float(\"Inf\"),\n    dim=-1):\n    \"\"\"Filter a distribution of logits using top-k/top-p (nucleus) filtering.\n    Adapted from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317\n\n    Args:\n      logits (Tensor): Tensor of logits\n      top_k (int, optional): Number of top values to keep.\n          Deactivated if k is 0. Defaults to 0.\n      top_p (float, optional): Cumulative mass to retain.\n          Deactivated if p = 0. Defaults to 0.0.\n      filter_value (float, optional): Fill value to replace\n          the entries removed by top-k/top-p filtering.\n          Defaults to -float('Inf').\n      dim (int, optional): Dimension of the filtering. Defaults to -1.\n\n    Returns:\n        logits: Tensor whose axis `dim` was filtered.\n    \"\"\"\n    if dim != -1:\n      logits = torch.transpose(logits, dim, -1)\n\n    assert top_k < logits.size(dim)\n    if top_k > 0:\n      # Remove all tokens with a probability less than\n      # the last token of the top-k\n      values, _ = torch.topk(logits, k=top_k, dim=-1)\n      to_remove_mask = (\n          logits < torch.min(values, dim=-1, keepdim=True)[0]\n      )  # min returns a tuple (values, indices)\n      logits[to_remove_mask] = filter_value\n\n    if top_p > 0.0:\n      sorted_logits, sorted_indices = torch.sort(\n        logits, descending=True, dim=-1)\n      cum_probs = torch.cumsum(\n        torch.softmax(sorted_logits, dim=-1), dim=-1)\n\n      sorted_indices_to_remove = cum_probs > top_p\n      # Ensures at least one token is kept\n      sorted_indices_to_remove[..., 1:] = \\\n        sorted_indices_to_remove[..., :-1].clone()\n      sorted_indices_to_remove[..., 0] = 0\n\n      mask_to_remove = torch.empty_like(\n        sorted_indices_to_remove)\n      mask_to_remove.scatter_(dim=-1,\n                              index=sorted_indices,\n                              src=sorted_indices_to_remove)\n      logits[mask_to_remove] = filter_value\n\n    if dim != -1:\n      logits = torch.transpose(logits, dim, -1)\n    \n    # Re-normalize as in ReMDM. Alternatively, \n    #  could apply a log-softmax. This assumes that the input\n    #  tensor `logits` has been pre-processed with log_softmax.\n    probs = logits.exp()\n    Z = probs.sum(-1, keepdim=True)\n    logits = (probs / Z).log()\n    return logits\n\n\ndef _discrete_prob_map(gamma_t, N=10):\n  snr_sqrt = np.exp(-gamma_t / 2)\n  def value(x):\n    cdf = norm.cdf(x, scale=1) ** (N - 1)\n    pdf = norm.pdf(x, loc=snr_sqrt, scale=1)\n    return pdf * cdf\n  return value\n\n\ndef _discrete_prob_grad(gamma_t, N=10):\n  snr_sqrt = np.exp(-gamma_t / 2)\n  def value(x):\n    coef = -0.5 * snr_sqrt * (x - snr_sqrt)\n    cdf = norm.cdf(x, scale=1) ** (N - 1)\n    pdf = norm.pdf(x, loc=snr_sqrt, scale=1)\n    return coef * pdf * cdf\n  return value\n\n\ndef _cache_prob_usdm_in_partition(\n  vocab_size=30522, partition_index=0, num_partitions=1,\n  log10_num_points=5):\n  print(f'Caching partition:{partition_index} / {num_partitions}')\n  path = 'integral'\n  gamma_min = -5\n  gamma_max = -1\n  num_points = 10 ** log10_num_points\n  p_cache = []\n  grad_p_cache = []\n  start_time = time.time()\n  gammas = np.linspace(gamma_min, gamma_max, num_points)\n  n = num_points // num_partitions\n  for gamma in gammas[partition_index * n:\n                      (partition_index + 1) * n]:\n    pt, _ = quad(_discrete_prob_map(gamma, vocab_size),\n                 -np.inf, np.inf)\n    p_cache.append(pt)\n    grad_pt, _ = quad(_discrete_prob_grad(gamma, vocab_size),\n                      -np.inf, np.inf)\n    grad_p_cache.append(grad_pt)\n    if len(p_cache) % 100 == 0:\n      print('{}% completed. Time elapsed:{:.2f} mins'.format(\n        int(100 * len(p_cache) / num_points),\n        (time.time() - start_time) / 60))\n\n  filename = os.path.join(\n    path, '{}_{}_{}-{}.pkl'.format(\n      vocab_size, log10_num_points, partition_index,\n      num_partitions))\n  with open(filename, 'wb') as f:\n    pickle.dump({\n      'vocab_size': vocab_size,\n      'gamma_min': gamma_min,\n      'gamma_max': gamma_max,\n      'num_points': num_points,\n      'pt': np.asarray(p_cache),\n      'grad_pt': np.asarray(grad_p_cache)}, f)\n\n\ndef test_cache_prob_usdm_in_partition(\n  partition_index=0, num_partitions=1, vocab_size=30522,\n  log10_num_points=5):\n  path = 'integral/{}_{}_{}-{}.pkl'.format(\n    vocab_size, log10_num_points, partition_index,\n    num_partitions)\n  with open(path, 'rb') as f:\n    data = pickle.load(f)\n  num_points = data['num_points']\n  def _get_index(x):\n    return round((num_points - 1) * (x - data['gamma_min']) / (\n      data['gamma_max'] - data['gamma_min']))\n\n  pt_errors = []\n  grad_pt_errors = []\n  gammas = np.linspace(data['gamma_min'],\n                       data['gamma_max'],\n                       num_points)\n  n = num_points // num_partitions\n  for gamma in gammas[partition_index * n:\n                      (partition_index + 1) * n]:\n    pt, _ = quad(\n      _discrete_prob_map(gamma, data['vocab_size']),\n      -np.inf, np.inf)\n    grad_pt, _ = quad(\n      _discrete_prob_grad(gamma, data['vocab_size']),\n      -np.inf, np.inf)\n    idx = _get_index(gamma)\n    print(idx)\n    pt_errors.append((pt - data['pt'][idx]) ** 2)\n    grad_pt_errors.append((grad_pt - data['grad_pt'][idx]) ** 2)\n  print('Integral MSE:{} Integral Squared:{:.4f}'.format(\n    np.mean(pt_errors), np.mean(data['pt'] ** 2)))\n  print('Integral Grad MSE:{} Integral Grad Squared:{:.4f}'.format(\n    np.mean(grad_pt_errors), np.mean(data['grad_pt'] ** 2)))\n\n\ndef compute_duo_series_coefficients(num_coefficients, \n                                    vocab_size):\n  def integrand_m(z, n, K):\n      z = np.float64(z)\n      return z**n * norm.pdf(z) * norm.cdf(z)**(K-1)\n    \n  def integrand_i(z, n, K):\n    z = np.float64(z)\n    return z ** (n+1) * norm.pdf(z) * norm.cdf(z) ** (K-1)\n  \n  arange = np.cumprod(np.arange(1, num_coefficients\n                                ).astype(np.float64))\n  factorials = np.concatenate([[1.0], arange], \n                              dtype=np.float64)\n  lo = np.array([-100], dtype=np.float64)\n  hi = np.array([100], dtype=np.float64)\n  coefficients_m = []\n  coefficients_i = []\n\n  for n in range(num_coefficients):\n    f = lambda z: integrand_m(np.float64(z), np.float64(n), \n                              vocab_size)\n    g = lambda z: integrand_i(np.float64(z), np.float64(n), \n                              vocab_size)\n    m, _ = quad(f, lo, hi)\n    i, _ = quad(g, lo, hi)\n    coefficients_m.append(m / factorials[n])\n    coefficients_i.append(i / factorials[n])\n\n  return (torch.tensor(coefficients_m)[None], \n          torch.tensor(coefficients_i)[None])\n\n\ndef compute_duo_gamma_to_alpha_dalpha_series(\n    gamma_t, coefficients_m, coefficients_i, power_arange, \n    vocab_size, gamma_min, gamma_max):\n  gamma_t = gamma_t.to(torch.float64)[:, None]\n    \n  sigmoid_neg_gamma = torch.sigmoid(-gamma_t)\n  sigmoid_gamma = torch.sigmoid(gamma_t)\n  alpha_t_squared = sigmoid_neg_gamma\n  alpha_t = alpha_t_squared.sqrt()\n\n  one_minus_alpha_t_squared = 1 - alpha_t_squared\n  mu_t = alpha_t / one_minus_alpha_t_squared.sqrt()\n  \n  arange = power_arange.to(device=gamma_t.device)\n  mu_t_pow = mu_t ** arange\n  \n  exp_term = (-mu_t**2/2).exp().squeeze(-1)\n  vocab_scale = vocab_size / (vocab_size - 1)\n  \n  # Compute alpha\n  sum_term_alpha = (mu_t_pow * coefficients_m).sum(-1)\n  alpha_usdm = (sum_term_alpha * exp_term - 1 \\\n              / vocab_size) * vocab_scale\n\n  # Compute alpha'\n  sum_term_dalpha = (\n    mu_t_pow * (coefficients_i - mu_t * coefficients_m)).sum(-1)\n  dalpha_usdm = exp_term * sum_term_dalpha * vocab_scale\n  \n  final_scale = - (sigmoid_gamma.squeeze(-1) \n                  * sigmoid_neg_gamma.squeeze(-1) ** 0.5 * \n                  0.5 * (gamma_max - gamma_min))\n  dalpha_usdm = dalpha_usdm \\\n              / one_minus_alpha_t_squared.squeeze(-1) ** 1.5 \\\n              * final_scale\n  return alpha_usdm.squeeze(-1), dalpha_usdm.squeeze(-1)\n\n\ndef duo_t_to_alpha_dalpha_sigm_corrected(\n  t, a: float, b: float, c: float, d: float, e: float, \n  f: float, alpha: float):\n  # Shared quantities\n  sigm_bc = (torch.tanh(b * t + c) + 1) / 2\n  sigm_ef = (torch.tanh(e * t + f) + 1) / 2\n\n  # Compute alpha_t\n  base = a * sigm_bc + d\n  edge_gate = 1 - 4 * sigm_ef * (1 - sigm_ef)\n  edge_correction = alpha * (t - 0.5) * edge_gate\n  alpha_t = base + edge_correction\n\n  # Compute d_alpha_t\n  dbase = a * b * sigm_bc * (1 - sigm_bc)\n  dgate = -4 * e * sigm_ef * (1 - sigm_ef) * (1 - 2 * sigm_ef)\n  dcorrection = alpha * edge_gate + alpha * (t - 0.5) * dgate\n  dalpha_t = dbase + dcorrection\n  return alpha_t, dalpha_t\n\n\ndef duo_to_alpha_dalpha_sigmoid(t: torch.Tensor, a: float, \n                                b: float, c: float, d: float):\n  sigm_bc = (torch.tanh(b * t + c) + 1) / 2\n  alpha_t = a * sigm_bc + d\n  dalpha_t = a * b * sigm_bc * (1 - sigm_bc)\n  return alpha_t, dalpha_t\n\n\ndef duo_to_alpha_dalpha_poly(t: torch.Tensor, \n                              *coefficients: float):\n  alpha_t = coefficients[0]  # a0 term\n  for i, a in enumerate(coefficients[1:], 1):\n    alpha_t = alpha_t + a * t**i\n    \n  dalpha_t = coefficients[1]  # a1 term\n  for i, a in enumerate(coefficients[2:], 2):\n    dalpha_t = dalpha_t + i * a * t**(i-1)\n    \n  return alpha_t, dalpha_t\n\n\ndef compute_duo_operator_approx(num_coefficients, vocab_size, \n                                gamma_min, gamma_max, \n                                fct_name='sigmoid'):\n  series_m, series_i = compute_duo_series_coefficients(\n    num_coefficients, vocab_size)\n  ts = torch.linspace(0, 1, steps=100_000)\n  gammas = gamma_min + ts * (gamma_max - gamma_min)\n  power_arange = torch.arange(num_coefficients, \n                              dtype=torch.float64)[None]\n  alpha_approx = compute_duo_gamma_to_alpha_dalpha_series(\n    gammas, series_m, series_i, power_arange, vocab_size, \n    gamma_min, gamma_max)[0].float()\n  \n  t_np = ts.numpy()\n  y_np = alpha_approx.numpy()\n\n  def sigmoid(x):\n    return (np.tanh(x) + 1) / 2\n  \n  if fct_name == 'sigmoid':\n    def func(t, a, b, c, d):\n      return a * sigmoid(b * t + c) + d\n    p0 = [0.5, 2.0, -1.0, 0.5]\n\n  elif fct_name == 'sigmoid-edge-corrected':\n    def func(t, a, b, c, d, e, f, alpha):\n      base = a * sigmoid(b * t + c) + d\n      edge_gate = 1 - 4 * sigmoid(e * t + f) \\\n                  * sigmoid(-e * t - f)\n      edge_correction = alpha * (t - 0.5) * edge_gate\n      return base + edge_correction\n    p0 = [0.5, 2.0, -1.0, 0.1, 3.0, 0.0, 0.1]\n  elif fct_name == 'poly3':\n    def func(t, a0, a1, a2, a3):\n      return a0 + a1*t + a2*t**2 + a3*t**3\n    p0 = [0.1] * 4\n  elif fct_name == 'poly5':\n    def func(t, a0, a1, a2, a3, a4, a5):\n      return a0 + a1*t + a2*t**2 + a3*t**3 + a4*t**4 + a5*t**5\n    p0 = [0.1] * 6\n  elif fct_name == 'poly7':\n    def func(t, a0, a1, a2, a3, a4, a5, a6, a7):\n      return (a0 + a1*t + a2*t**2 + a3*t**3 + \n              a4*t**4 + a5*t**5 + a6*t**6 + a7*t**7)\n    p0 = [0.1] * 8\n  elif fct_name == 'poly9':\n    def func(t, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):\n      return (a0 + a1*t + a2*t**2 + a3*t**3 + a4*t**4 + \n              a5*t**5 + a6*t**6 + a7*t**7 + a8*t**8 + a9*t**9)\n    p0 = [0.1] * 10\n  else:\n    raise ValueError(fct_name)\n  \n  popt, _ = curve_fit(func, t_np, y_np, p0=p0, maxfev=10000)\n  preds = func(t_np, *popt)\n  return list(popt), y_np, preds, t_np\n\n\ndef _sample_k_int(bs: int, l: int, k: int, max_value: int, \n                 device: torch.device):\n  # Robert Floyd's algorithm: \n  #  https://www.nowherenearithaca.com/2013/05/robert-floyds-tiny-and-beautiful.html\n  out = torch.empty(size=(bs, l, k), dtype=torch.int64, \n                    device=device)\n  for t, i in enumerate(range(max_value - k, max_value)):\n    j = torch.randint(0, i + 1, size=(bs, l), device=device)\n    if t > 0:\n      # Does j already appear in previously chosen slots?\n      dup = (out[..., :t] == j[..., None]).any(dim=-1)\n      # write j where it is new, otherwise write i\n      out[..., t] = torch.where(dup, i, j)\n    else:\n      out[..., 0] = j\n  return out\n\n\ndef _sample_topk_gaussian(N: int, \n  sigma: Optional[torch.Tensor]=None, l: int=0, k: int=0, \n  batch: int=None, device: str=None, \n  dtype: torch.dtype=torch.float64):\n  \"\"\"\n  Sample from the order statistic of N Gaussians with zero\n  mean (top k). Operate in logspace for stability.\n  \"\"\"\n  if sigma is None:\n    assert batch is not None\n    assert device is not None\n    assert dtype is not None\n  else:\n    batch = sigma.shape[0]\n    device = sigma.device\n    dtype = sigma.dtype\n  log_u = torch.log(torch.rand(batch, l, k, device=device, \n                               dtype=dtype))\n  divisors = torch.arange(N, N - k, -1, device=device, \n                          dtype=dtype)  # (k,)\n  log_rj = log_u / divisors  # (batch, l, k)\n  log_prod = torch.cumsum(log_rj, dim=-1)  # (batch, l, k)\n  uniforms = torch.exp(log_prod)  # (batch, l, k)\n  # convert to Gaussian and rescale\n  topk = torch.special.ndtri(uniforms)\n  if sigma is not None:\n    topk = topk * sigma[:, None, None]\n  return topk\n\n\ndef _sample_topk_and_extra(N: int, alpha: torch.Tensor, \n  sigma: torch.Tensor, l: int, k: int):\n  \"\"\"\n  Sample the top k order statistics between N - 1 zero mean \n  Gaussians, and a single Gaussian with mean alpha.\n  \"\"\"\n  top_k_others = _sample_topk_gaussian(N - 1, sigma, l, k)\n  extra = alpha[:, None] + torch.randn(\n    size=(alpha.shape[0], l), device=alpha.device\n    ) * sigma[:, None]  # (bs, l)\n  min_values = top_k_others[:, :, -1]\n  is_extra_in_topk = (extra > min_values)  # bs x l\n  top_k_others[:, :, -1][is_extra_in_topk] = extra[is_extra_in_topk]\n  return extra, top_k_others, is_extra_in_topk\n\n\ndef _log_mean_exp_trunc_normal(c: torch.Tensor, \n                              sigma: torch.Tensor):\n  \"\"\"\n  Compute log(E[exp(X) | X < c] for X ~ N(0, sigma^2).\n  Closed-form expression:\n    mu = exp(sigma**2 / 2) \n         * Phi((c - sigma**2) / sigma) \n         / Phi(c / sigma)\n  where Phi is the standard normal CDF. Operate in log space\n  for stability.\n  \"\"\"\n  log_num = torch.special.log_ndtr((c - sigma**2) / sigma)\n  log_den = torch.special.log_ndtr(c / sigma)\n  return sigma**2 / 2.0 + log_num - log_den\n\n\ndef sample_tempered_softmax_topk(\n  extra_index: torch.Tensor, \n  alpha: torch.Tensor, \n  sigma: torch.Tensor, \n  l: int, \n  k: int, \n  vocab_size: int,\n  # 1 / T. If low temperature, inverse will be large\n  inverse_temperature: float = 1.0):\n  assert alpha.ndim == 1\n  assert sigma.ndim == 1\n  # float64 needed for numerical precision\n  alpha = alpha.to(torch.float64)\n  sigma = sigma.to(torch.float64)\n  # Sample the top k between (vocab_size - 1) zero-mean\n  #  Gaussians, and a single Gaussian with mean alpha.\n  extra, top_k, is_extra_in_topk = _sample_topk_and_extra(\n    vocab_size, alpha, sigma, l, k)\n  min_rv = torch.where(is_extra_in_topk, top_k[:, :, -2],\n                         top_k[:, :, -1])  # (bs, l)\n  \n  scaled_sigma = sigma[:, None] * inverse_temperature  # (bs, 1)\n  scaled_c = min_rv * inverse_temperature  # (bs, l)\n\n  log_mu = _log_mean_exp_trunc_normal(scaled_c, scaled_sigma)\n  log_topk = top_k * inverse_temperature\n\n  # Approximate contribution of the unknown variables\n  count = torch.where(is_extra_in_topk, vocab_size - k,\n            vocab_size - k - 1).to(log_mu.dtype)\n  log_tail = torch.log(count) + log_mu\n\n  # Contribution of the extra variable, when NOT in the topk\n  log_extra = extra * inverse_temperature\n  extra_not_in_topk = ~is_extra_in_topk\n  log_extra_masked = torch.full_like(log_tail, float('-inf'))\n  log_extra_masked[extra_not_in_topk] = \\\n                                log_extra[extra_not_in_topk]\n\n  log_contribs = torch.cat([\n      log_topk, \n      log_tail[..., None], \n      log_extra_masked[..., None]], \n    dim=-1)  # (bs, l, k+2)\n  log_denom = torch.logsumexp(log_contribs, dim=-1, \n                              keepdim=True)  # (bs, l, 1)\n  softmax_approx = torch.exp(log_topk - log_denom)  # (bs, l, k)\n  # If sum over k categories is zero, just set to one-hot on \n  #  the largest. Fix div by zero\n  normalizer = softmax_approx.sum(dim=-1, keepdim=True)\n  zero_sum = normalizer == 0.0\n  softmax_approx = torch.where(zero_sum, 0.0, softmax_approx)\n  softmax_approx[..., 0][zero_sum[..., 0]] = 1.0\n  \n  indices = _sample_k_int(alpha.shape[0], l, k, \n                         # Note the -1:\n                         max_value=vocab_size - 1,\n                         device=alpha.device)\n  # Ensure x0 (true token) is not generated\n  indices[indices >= extra_index[..., None]] += 1\n  indices[..., -1][is_extra_in_topk] = \\\n                              extra_index[is_extra_in_topk]\n  xt_usdm = torch.where(is_extra_in_topk, extra_index, \n                        indices[..., 0])\n  return softmax_approx, indices, xt_usdm\n\n\nif __name__ == \"__main__\":\n  # Usage: python utils.py --vocab_size=N\n  parser = argparse.ArgumentParser(\n    description='Caches the integral appearing in the '\n                'Diffusion Transformation operator.')\n  parser.add_argument(\n    '--vocab_size',\n    type=int,\n    default=50257,  # For the gpt2 tokenizer\n    help='Vocabulary size (default: 50257)')\n  parser.add_argument(\n    '--partition_index',\n    type=int,\n    default=0,\n    help='Helps parallelize caching')\n  parser.add_argument(\n    '--num_partitions',\n    type=int,\n    default=1,\n    help='Helps parallelize caching')\n  parser.add_argument(\n    '--log10_num_points',\n    type=int,\n    default=5,\n    help=('The integral is function that needs to be '\n          'evaluated for inputs with a range [-5, 1]. '\n          'This argument represents the logarithm base 10 '\n          'of number of bins of discretization.'))\n  args = parser.parse_args()\n\n  # Computing the integral over [-5, 1] can be slow,\n  # so one might prefer splitting it into `num_partitions`\n  # bins and compute each separately and merge them later.\n  _cache_prob_usdm_in_partition(\n    partition_index=args.partition_index,\n    num_partitions=args.num_partitions,\n    vocab_size=args.vocab_size,\n    log10_num_points=args.log10_num_points)\n  \n  test_cache_prob_usdm_in_partition(\n    partition_index=args.partition_index,\n    num_partitions=args.num_partitions,\n    vocab_size=args.vocab_size,\n    log10_num_points=args.log10_num_points)"
  }
]