[
  {
    "path": ".gitignore",
    "content": "# WandB\n/wandb/\n# checkpoints\n/checkpoints/\n# cache\n/cache/\n# data\n/data/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"dataset/raw-data/ConceptARC\"]\n\tpath = dataset/raw-data/ConceptARC\n\turl = git@github.com:victorvikram/ConceptARC.git\n[submodule \"dataset/raw-data/ARC-AGI\"]\n\tpath = dataset/raw-data/ARC-AGI\n\turl = git@github.com:fchollet/ARC-AGI.git\n[submodule \"dataset/raw-data/ARC-AGI-2\"]\n\tpath = dataset/raw-data/ARC-AGI-2\n\turl = git@github.com:arcprize/ARC-AGI-2.git\n"
  },
  {
    "path": ".vscode/launch.json",
    "content": "{\n    // Use IntelliSense to learn about possible attributes.\n    // Hover to view descriptions of existing attributes.\n    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387\n    \"version\": \"0.2.0\",\n    \"configurations\": [\n        {\n            \"name\": \"Python Debugger: Current File\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"${file}\",\n            \"console\": \"integratedTerminal\"\n        },\n        {\n            \"name\": \"Debug: Single GPU\",\n            \"type\": \"debugpy\",\n            \"request\": \"launch\",\n            \"program\": \"pretrain.py\",\n            \"args\": [],\n            \"env\": {\n                \"OMP_NUM_THREADS\": \"1\",\n                \"DISABLE_COMPILE\": \"true\"\n            }\n        }\n    ]\n}"
  },
  {
    "path": ".vscode/settings.json",
    "content": "{\n    \"python.analysis.typeCheckingMode\": \"standard\"\n}"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "# Hierarchical Reasoning Model\n\n![](./assets/hrm.png)\n\nReasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.\nCurrent large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.\nHRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.\nFurthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.\nThese results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.\n\n**Join our Discord Community: [https://discord.gg/sapient](https://discord.gg/sapient)**\n\n\n## Quick Start Guide 🚀\n\n### Prerequisites ⚙️\n\nEnsure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:\n\n```bash\n# Install CUDA 12.6\nCUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run\n\nwget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL\nsudo sh cuda_installer.run --silent --toolkit --override\n\nexport CUDA_HOME=/usr/local/cuda-12.6\n\n# Install PyTorch with CUDA 12.6\nPYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126\n\npip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL\n\n# Additional packages for building extensions\npip3 install packaging ninja wheel setuptools setuptools-scm\n```\n\nThen install FlashAttention. For Hopper GPUs, install FlashAttention 3\n\n```bash\ngit clone git@github.com:Dao-AILab/flash-attention.git\ncd flash-attention/hopper\npython setup.py install\n```\n\nFor Ampere or earlier GPUs, install FlashAttention 2\n\n```bash\npip3 install flash-attn\n```\n\n## Install Python Dependencies 🐍\n\n```bash\npip install -r requirements.txt\n```\n\n## W&B Integration 📈\n\nThis project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:\n\n```bash\nwandb login\n```\n\n## Run Experiments\n\n### Quick Demo: Sudoku Solver 💻🗲\n\nTrain a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩\n\n```bash\n# Download and build Sudoku dataset\npython dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000  --subsample-size 1000 --num-aug 1000\n\n# Start training (single GPU, smaller batch size)\nOMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0\n```\n\nRuntime: ~10 hours on a RTX 4070 laptop GPU\n\n## Trained Checkpoints 🚧\n\n - [ARC-AGI-2](https://huggingface.co/sapientinc/HRM-checkpoint-ARC-2)\n - [Sudoku 9x9 Extreme (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-sudoku-extreme)\n - [Maze 30x30 Hard (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-maze-30x30-hard)\n\nTo use the checkpoints, see Evaluation section below.\n\n## Full-scale Experiments 🔵\n\nExperiments below assume an 8-GPU setup.\n\n### Dataset Preparation\n\n```bash\n# Initialize submodules\ngit submodule update --init --recursive\n\n# ARC-1\npython dataset/build_arc_dataset.py  # ARC offical + ConceptARC, 960 examples\n# ARC-2\npython dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000  # ARC-2 official, 1120 examples\n\n# Sudoku-Extreme\npython dataset/build_sudoku_dataset.py  # Full version\npython dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000  --subsample-size 1000 --num-aug 1000  # 1000 examples\n\n# Maze\npython dataset/build_maze_dataset.py  # 1000 examples\n```\n\n### Dataset Visualization\n\nExplore the puzzles visually:\n\n* Open `puzzle_visualizer.html` in your browser.\n* Upload the generated dataset folder located in `data/...`.\n\n## Launch experiments\n\n### Small-sample (1K)\n\nARC-1:\n\n```bash\nOMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py \n```\n\n*Runtime:* ~24 hours\n\nARC-2:\n\n```bash\nOMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000\n```\n\n*Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)\n\nSudoku Extreme (1k):\n\n```bash\nOMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0\n```\n\n*Runtime:* ~10 minutes\n\nMaze 30x30 Hard (1k):\n\n```bash\nOMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0\n```\n\n*Runtime:* ~1 hour\n\n### Full Sudoku-Hard\n\n```bash\nOMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned\n```\n\n*Runtime:* ~2 hours\n\n## Evaluation\n\nEvaluate your trained models:\n\n* Check `eval/exact_accuracy` in W&B.\n* For ARC-AGI, follow these additional steps:\n\n```bash\nOMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>\n```\n\n* Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.\n\n## Notes\n\n - Small-sample learning typically exhibits accuracy variance of around ±2 points.\n - For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.\n\n## Citation 📜\n\n```bibtex\n@misc{wang2025hierarchicalreasoningmodel,\n      title={Hierarchical Reasoning Model}, \n      author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},\n      year={2025},\n      eprint={2506.21734},\n      archivePrefix={arXiv},\n      primaryClass={cs.AI},\n      url={https://arxiv.org/abs/2506.21734}, \n}\n```\n"
  },
  {
    "path": "arc_eval.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import json\\n\",\n    \"from glob import glob\\n\",\n    \"import hashlib\\n\",\n    \"import matplotlib.pyplot as plt\\n\",\n    \"import matplotlib.colors as mcolors\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"import torch.nn.functional as F\\n\",\n    \"import numpy as np\\n\",\n    \"from numba import njit\\n\",\n    \"\\n\",\n    \"from dataset.common import inverse_dihedral_transform\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"DATASET_PATH = \\\"data/arc-aug-1000\\\"  # ARC-1\\n\",\n    \"# DATASET_PATH = \\\"data/arc-2-aug-1000\\\"  # ARC-2\\n\",\n    \"\\n\",\n    \"CHECKPOINT_PATH = \\\"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\\\"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"PAD_PUZZLE_IDENTIFIER = 0\\n\",\n    \"\\n\",\n    \"# Visualization\\n\",\n    \"ARC_COLOR_MAP = mcolors.ListedColormap([\\n\",\n    \"    \\\"#000000\\\",  # symbol_0: black\\n\",\n    \"    \\\"#0074D9\\\",  # symbol_1: blue\\n\",\n    \"    \\\"#FF4136\\\",  # symbol_2: red\\n\",\n    \"    \\\"#2ECC40\\\",  # symbol_3: green\\n\",\n    \"    \\\"#FFDC00\\\",  # symbol_4: yellow\\n\",\n    \"    \\\"#AAAAAA\\\",  # symbol_5: grey\\n\",\n    \"    \\\"#F012BE\\\",  # symbol_6: fuschia\\n\",\n    \"    \\\"#FF851B\\\",  # symbol_7: orange\\n\",\n    \"    \\\"#7FDBFF\\\",  # symbol_8: teal\\n\",\n    \"    \\\"#870C25\\\"   # symbol_9: brown\\n\",\n    \"])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\\n\",\n    \"    # Load puzzle identifiers\\n\",\n    \"    with open(os.path.join(dataset_path, \\\"identifiers.json\\\"), \\\"r\\\") as f:\\n\",\n    \"        identifier_map = json.load(f)\\n\",\n    \"        \\n\",\n    \"    # Load preds\\n\",\n    \"    all_preds = {}\\n\",\n    \"    for filename in glob(f\\\"{checkpoint_path}_all_preds.*\\\"):\\n\",\n    \"        preds = torch.load(filename)\\n\",\n    \"        for k, v in preds.items():\\n\",\n    \"            all_preds.setdefault(k, [])\\n\",\n    \"            all_preds[k].append(v)\\n\",\n    \"            \\n\",\n    \"        del preds\\n\",\n    \"\\n\",\n    \"    all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\\n\",\n    \"    \\n\",\n    \"    # Remove paddings\\n\",\n    \"    mask = all_preds[\\\"puzzle_identifiers\\\"] != PAD_PUZZLE_IDENTIFIER\\n\",\n    \"    all_preds = {k: v[mask] for k, v in all_preds.items()}\\n\",\n    \"\\n\",\n    \"    return identifier_map, all_preds\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def inverse_aug(name: str, grid: np.ndarray):\\n\",\n    \"    if \\\"_\\\" not in name:\\n\",\n    \"        return grid\\n\",\n    \"\\n\",\n    \"    trans_id, perm = name.split(\\\"_\\\")[-2:]\\n\",\n    \"    trans_id = int(trans_id[1:])  # Remove \\\"t\\\" letter\\n\",\n    \"    inv_perm = np.argsort(list(perm))\\n\",\n    \"    \\n\",\n    \"    return inv_perm[inverse_dihedral_transform(grid, trans_id)]\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def grid_hash(grid: np.ndarray):\\n\",\n    \"    return hash((grid.tobytes(), grid.shape))\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"@njit\\n\",\n    \"def crop(grid: np.ndarray):\\n\",\n    \"    # Find maximum-sized rectangle without any EOS token inside.\\n\",\n    \"    grid = grid.reshape(30, 30)\\n\",\n    \"\\n\",\n    \"    max_area = 0\\n\",\n    \"    max_size = (0, 0)\\n\",\n    \"    nr, nc = grid.shape\\n\",\n    \"    \\n\",\n    \"    num_c = nc\\n\",\n    \"    for num_r in range(1, nr + 1):\\n\",\n    \"        # Scan for maximum c\\n\",\n    \"        for c in range(1, num_c + 1):\\n\",\n    \"            x = grid[num_r - 1, c - 1]\\n\",\n    \"            if (x < 2) | (x > 11):\\n\",\n    \"                num_c = c - 1\\n\",\n    \"                break\\n\",\n    \"        \\n\",\n    \"        area = num_r * num_c\\n\",\n    \"        if area > max_area:\\n\",\n    \"            max_area = area\\n\",\n    \"            max_size = (num_r, num_c)\\n\",\n    \"\\n\",\n    \"    return grid[:max_size[0], :max_size[1]] - 2\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def test(visualize, Ks=[1, 2, 10, 100, 1000]):\\n\",\n    \"    identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\\n\",\n    \"    \\n\",\n    \"    global_hmap = {}\\n\",\n    \"    \\n\",\n    \"    # Get puzzles and corresponding answers\\n\",\n    \"    puzzle_labels = {}\\n\",\n    \"    for identifier, input, label in zip(all_preds[\\\"puzzle_identifiers\\\"], all_preds[\\\"inputs\\\"], all_preds[\\\"labels\\\"]):\\n\",\n    \"        name = identifier_map[identifier]\\n\",\n    \"        if \\\"_\\\" not in name:   # Not-augmented\\n\",\n    \"            puzzle_labels.setdefault(name, {})\\n\",\n    \"            \\n\",\n    \"            input = crop(input.numpy())\\n\",\n    \"            label = crop(label.numpy())\\n\",\n    \"\\n\",\n    \"            input_hash = grid_hash(input)\\n\",\n    \"            label_hash = grid_hash(label)\\n\",\n    \"\\n\",\n    \"            global_hmap[input_hash] = input\\n\",\n    \"            global_hmap[label_hash] = label\\n\",\n    \"\\n\",\n    \"            assert input_hash not in puzzle_labels[name]\\n\",\n    \"            puzzle_labels[name][input_hash] = label_hash\\n\",\n    \"            \\n\",\n    \"    print (\\\"Number of puzzles\\\", len(puzzle_labels))\\n\",\n    \"    \\n\",\n    \"    # Argmax prediction\\n\",\n    \"    preds = all_preds[\\\"logits\\\"].argmax(-1)\\n\",\n    \"\\n\",\n    \"    # Collate\\n\",\n    \"    pred_answers = {}\\n\",\n    \"    for identifier, input, pred, q in zip(all_preds[\\\"puzzle_identifiers\\\"], all_preds[\\\"inputs\\\"], preds, all_preds[\\\"q_halt_logits\\\"].sigmoid()):\\n\",\n    \"        name = identifier_map[identifier]\\n\",\n    \"        orig_name = name.split(\\\"_\\\")[0]\\n\",\n    \"        \\n\",\n    \"        input = input.numpy()\\n\",\n    \"        input_hash = grid_hash(inverse_aug(name, crop(input)))\\n\",\n    \"        assert input_hash in puzzle_labels[orig_name]\\n\",\n    \"        \\n\",\n    \"        pred = inverse_aug(name, crop(pred.numpy()))\\n\",\n    \"        pred_hash = grid_hash(pred)\\n\",\n    \"        global_hmap[pred_hash] = pred\\n\",\n    \"        \\n\",\n    \"        pred_answers.setdefault(orig_name, {})\\n\",\n    \"        pred_answers[orig_name].setdefault(input_hash, [])\\n\",\n    \"        pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\\n\",\n    \"\\n\",\n    \"    # test-1\\n\",\n    \"    if visualize:\\n\",\n    \"        num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\\n\",\n    \"        fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\\n\",\n    \"        \\n\",\n    \"        fig_id = 0\\n\",\n    \"    \\n\",\n    \"    correct = [0 for _ in range(len(Ks))]\\n\",\n    \"    for name, tests in puzzle_labels.items():\\n\",\n    \"        num_test_correct = [0 for _ in range(len(Ks))]\\n\",\n    \"        for input_hash, label_hash in tests.items():\\n\",\n    \"            p = pred_answers[name][input_hash]\\n\",\n    \"            p_map = {}\\n\",\n    \"            \\n\",\n    \"            for h, q in p:\\n\",\n    \"                p_map.setdefault(h, [0, 0])\\n\",\n    \"                p_map[h][0] += 1\\n\",\n    \"                p_map[h][1] += q\\n\",\n    \"                \\n\",\n    \"            for h, stats in p_map.items():\\n\",\n    \"                stats[1] /= stats[0]\\n\",\n    \"                \\n\",\n    \"            p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\\n\",\n    \"\\n\",\n    \"            # 2-vote\\n\",\n    \"            for i, k in enumerate(Ks):\\n\",\n    \"                ok = False\\n\",\n    \"                for h, stats in p_map[:k]:\\n\",\n    \"                    ok |= h == label_hash\\n\",\n    \"                    \\n\",\n    \"                num_test_correct[i] += ok\\n\",\n    \"\\n\",\n    \"            if visualize:\\n\",\n    \"                # Show input and ground truth\\n\",\n    \"                axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\\n\",\n    \"                axes[fig_id, 0].set_title(f\\\"{name}\\\\nInput\\\")\\n\",\n    \"                axes[fig_id, 0].axis('off')\\n\",\n    \"                \\n\",\n    \"                axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\\n\",\n    \"                axes[fig_id, 1].set_title(f\\\"{name}\\\\nAnswer\\\")\\n\",\n    \"                axes[fig_id, 1].axis('off')\\n\",\n    \"                \\n\",\n    \"                trial_id = 2\\n\",\n    \"                for h, stats in p_map[:2]:\\n\",\n    \"                    ans = global_hmap[h]\\n\",\n    \"                    \\n\",\n    \"                    axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\\n\",\n    \"                    axes[fig_id, trial_id].set_title(f\\\"{name}\\\\nTrial {trial_id}\\\")\\n\",\n    \"                    axes[fig_id, trial_id].axis('off')\\n\",\n    \"                    \\n\",\n    \"                    trial_id += 1\\n\",\n    \"                \\n\",\n    \"                fig_id += 1\\n\",\n    \"            \\n\",\n    \"        # Total correctness\\n\",\n    \"        for i in range(len(Ks)):\\n\",\n    \"            correct[i] += num_test_correct[i] == len(tests)\\n\",\n    \"\\n\",\n    \"    for i, k in enumerate(Ks):\\n\",\n    \"        print (f\\\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"test(visualize=False)\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.10\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "assets/npyjs.js",
    "content": "class npyjs {\n\n    constructor(opts) {\n        if (opts && !('convertFloat16' in opts)) {\n            console.warn([\n                \"npyjs constructor now accepts {convertFloat16?: boolean}.\",\n                \"For usage, go to https://github.com/jhuapl-boss/npyjs.\"\n            ].join(\" \"));\n        }\n\n        this.convertFloat16 = opts?.convertFloat16 ?? true;\n\n        this.dtypes = {\n            \"<u1\": {\n                name: \"uint8\",\n                size: 8,\n                arrayConstructor: Uint8Array,\n            },\n            \"|u1\": {\n                name: \"uint8\",\n                size: 8,\n                arrayConstructor: Uint8Array,\n            },\n            \"<u2\": {\n                name: \"uint16\",\n                size: 16,\n                arrayConstructor: Uint16Array,\n            },\n            \"|i1\": {\n                name: \"int8\",\n                size: 8,\n                arrayConstructor: Int8Array,\n            },\n            \"<i2\": {\n                name: \"int16\",\n                size: 16,\n                arrayConstructor: Int16Array,\n            },\n            \"<u4\": {\n                name: \"uint32\",\n                size: 32,\n                arrayConstructor: Uint32Array,\n            },\n            \"<i4\": {\n                name: \"int32\",\n                size: 32,\n                arrayConstructor: Int32Array,\n            },\n            \"<u8\": {\n                name: \"uint64\",\n                size: 64,\n                arrayConstructor: BigUint64Array,\n            },\n            \"<i8\": {\n                name: \"int64\",\n                size: 64,\n                arrayConstructor: BigInt64Array,\n            },\n            \"<f4\": {\n                name: \"float32\",\n                size: 32,\n                arrayConstructor: Float32Array\n            },\n            \"<f8\": {\n                name: \"float64\",\n                size: 64,\n                arrayConstructor: Float64Array\n            },\n            \"<f2\": {\n                name: \"float16\",\n                size: 16,\n                arrayConstructor: Uint16Array,\n                converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined\n            },\n        };\n    }\n\n    float16ToFloat32Array(float16Array) {\n        const length = float16Array.length;\n        const float32Array = new Float32Array(length);\n        \n        for (let i = 0; i < length; i++) {\n            float32Array[i] = npyjs.float16ToFloat32(float16Array[i]);\n        }\n        \n        return float32Array;\n    }\n\n    static float16ToFloat32(float16) {\n        // Extract the parts of the float16\n        const sign = (float16 >> 15) & 0x1;\n        const exponent = (float16 >> 10) & 0x1f;\n        const fraction = float16 & 0x3ff;\n\n        // Handle special cases\n        if (exponent === 0) {\n            if (fraction === 0) {\n                // Zero\n                return sign ? -0 : 0;\n            }\n            // Denormalized number\n            return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400);\n        } else if (exponent === 0x1f) {\n            if (fraction === 0) {\n                // Infinity\n                return sign ? -Infinity : Infinity;\n            }\n            // NaN\n            return NaN;\n        }\n\n        // Normalized number\n        return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400);\n    }\n\n    parse(arrayBufferContents) {\n        // const version = arrayBufferContents.slice(6, 8); // Uint8-encoded\n        const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);\n        const offsetBytes = 10 + headerLength;\n\n        const hcontents = new TextDecoder(\"utf-8\").decode(\n            new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))\n        );\n        const header = JSON.parse(\n            hcontents\n                .toLowerCase() // True -> true\n                .replace(/'/g, '\"')\n                .replace(\"(\", \"[\")\n                .replace(/,*\\),*/g, \"]\")\n        );\n        const shape = header.shape;\n        const dtype = this.dtypes[header.descr];\n\n        if (!dtype) {\n            console.error(`Unsupported dtype: ${header.descr}`);\n            return null;\n        }\n\n        const nums = new dtype.arrayConstructor(\n            arrayBufferContents,\n            offsetBytes\n        );\n\n        // Convert float16 to float32 if converter exists\n        const data = dtype.converter ? dtype.converter.call(this, nums) : nums;\n\n        return {\n            dtype: dtype.name,\n            data: data,\n            shape,\n            fortranOrder: header.fortran_order\n        };\n    }\n\n    async load(filename, callback, fetchArgs) {\n        /*\n        Loads an array from a stream of bytes.\n        */\n        fetchArgs = fetchArgs || {};\n        let arrayBuf;\n        // If filename is ArrayBuffer\n        if (filename instanceof ArrayBuffer) {\n            arrayBuf = filename;\n        }\n        // If filename is a file path\n        else {\n            const resp = await fetch(filename, { ...fetchArgs });\n            arrayBuf = await resp.arrayBuffer();\n        }\n        const result = this.parse(arrayBuf);\n        if (callback) {\n            return callback(result);\n        }\n        return result;\n    }\n}\n"
  },
  {
    "path": "config/arch/hrm_v1.yaml",
    "content": "name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1\nloss:\n  name: losses@ACTLossHead\n  loss_type: stablemax_cross_entropy\n\nhalt_exploration_prob: 0.1\nhalt_max_steps: 16\n\nH_cycles: 2\nL_cycles: 2\n\nH_layers: 4\nL_layers: 4\n\nhidden_size: 512\nnum_heads: 8  # min(2, hidden_size // 64)\nexpansion: 4\n\npuzzle_emb_ndim: ${.hidden_size}\n\npos_encodings: rope\n"
  },
  {
    "path": "config/cfg_pretrain.yaml",
    "content": "# ARC training config\n\ndefaults:\n  - arch: hrm_v1\n  - _self_\n\nhydra:\n  output_subdir: null\n\n# Data path\ndata_path: data/arc-aug-1000\n\n# Hyperparams - Training\nglobal_batch_size: 768\n\nepochs: 100000\neval_interval: 10000\ncheckpoint_every_eval: True\n\nlr: 1e-4\nlr_min_ratio: 1.0\nlr_warmup_steps: 2000\n\n# Standard hyperparameter settings for LM, as used in Llama\nbeta1: 0.9\nbeta2: 0.95\nweight_decay: 0.1\npuzzle_emb_weight_decay: 0.1\n\n# Hyperparams - Puzzle embeddings training\npuzzle_emb_lr: 1e-2\n"
  },
  {
    "path": "dataset/build_arc_dataset.py",
    "content": "from typing import List, Optional, Tuple, Dict\nfrom dataclasses import dataclass\nfrom pathlib import Path\nimport os\nimport json\nimport hashlib\nimport numpy as np\nfrom glob import glob\n\nfrom argdantic import ArgParser\nfrom pydantic import BaseModel\n\nfrom common import PuzzleDatasetMetadata, dihedral_transform\n\n\ncli = ArgParser()\n\n\nclass DataProcessConfig(BaseModel):\n    # ARC-1\n    dataset_dirs: List[str] = [\"dataset/raw-data/ARC-AGI/data\", \"dataset/raw-data/ConceptARC/corpus\"]\n    output_dir: str = \"data/arc-aug-1000\"\n    \n    # ARC-2\n    # dataset_dirs: List[str] = [\"dataset/raw-data/ARC-AGI-2/data\"]\n    # output_dir: str = \"data/arc-2-aug-1000\"\n\n    seed: int = 42\n    num_aug: int = 1000\n    \n    \nARCMaxGridSize = 30\nARCAugmentRetriesFactor = 5\n    \n\n@dataclass\nclass ARCPuzzle:\n    id: str\n\n    examples: List[Tuple[np.ndarray, np.ndarray]]\n\n    \ndef arc_grid_to_np(grid: List[List[int]]):\n    arr = np.array(grid)\n\n    # Shape check\n    assert arr.ndim == 2\n    assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize\n    # Element check\n    assert np.all((arr >= 0) & (arr <= 9))\n    return arr.astype(np.uint8)\n\n\ndef np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):\n    # PAD: 0, <eos>: 1, digits: 2 ... 11\n    # Compute random top-left pad\n    if do_translation:\n        pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)\n        pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)\n    else:\n        pad_r = pad_c = 0\n\n    # Pad grid\n    result = []\n    for grid in [inp, out]:\n        nrow, ncol = grid.shape\n        grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)\n\n        # Add <eos>\n        eos_row, eos_col = pad_r + nrow, pad_c + ncol\n        if eos_row < ARCMaxGridSize:\n            grid[eos_row, pad_c:eos_col] = 1\n        if eos_col < ARCMaxGridSize:\n            grid[pad_r:eos_row, eos_col] = 1\n\n        result.append(grid.flatten())\n\n    return result\n\n\ndef puzzle_hash(puzzle: dict):\n    # Hash the puzzle for checking equivalence\n    def _grid_hash(grid: np.ndarray):\n        buffer = [x.to_bytes(1) for x in grid.shape]\n        buffer.append(grid.tobytes())\n        \n        return hashlib.sha256(b\"\".join(buffer)).hexdigest()\n    \n    hashes = []\n    for example_type, example in puzzle.items():\n        for input, label in example.examples:\n            hashes.append(f\"{_grid_hash(input)}|{_grid_hash(label)}\")\n            \n    hashes.sort()\n    return hashlib.sha256(\"|\".join(hashes).encode()).hexdigest()\n\n\ndef convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):\n    # Remove \"name\"\n    name = puzzle.pop(\"name\", default_name)\n    \n    # Convert\n    dests = set(dest_mapping.values())\n    converted = {dest: ARCPuzzle(name, []) for dest in dests}\n    for example_type, examples in puzzle.items():\n        dest = dest_mapping[example_type]\n        converted[dest].examples.extend([(arc_grid_to_np(example[\"input\"]), arc_grid_to_np(example[\"output\"])) for example in examples])\n\n    group = [converted]\n    \n    # Augment\n    if aug_count > 0:\n        hashes = {puzzle_hash(converted)}\n\n        for _trial in range(ARCAugmentRetriesFactor * aug_count):\n            # Augment plan\n            trans_id = np.random.randint(0, 8)\n            mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))])  # Permute colors, Excluding \"0\" (black)\n            \n            aug_repr = f\"t{trans_id}_{''.join(str(x) for x in mapping)}\"\n\n            def _map_grid(grid: np.ndarray):\n                return dihedral_transform(mapping[grid], trans_id)\n            \n            # Check duplicate\n            augmented = {dest: ARCPuzzle(f\"{puzzle.id}_{aug_repr}\", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}\n            h = puzzle_hash(augmented)\n            if h not in hashes:\n                hashes.add(h)\n                group.append(augmented)\n                \n            if len(group) >= aug_count + 1:\n                break\n            \n        if len(group) < aug_count + 1:\n            print (f\"[Puzzle {name}] augmentation not full, only {len(group)}\")\n\n    # Append\n    for dest in dests:\n        # Convert the examples\n        dest_split, dest_set = dest\n\n        results.setdefault(dest_split, {})\n        results[dest_split].setdefault(dest_set, [])\n        results[dest_split][dest_set].append([converted[dest] for converted in group])\n\n\ndef load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):\n    train_examples_dest = (\"train\", \"all\")\n    test_examples_map = {\n        \"evaluation\": [(1.0, (\"test\", \"all\"))],\n        \"_default\": [(1.0, (\"train\", \"all\"))]\n    }\n    \n    total_puzzles = 0\n    for subdir in os.scandir(dataset_path):\n        if subdir.is_dir():\n            # Load all puzzles in this directory\n            puzzles = []\n            for filename in glob(os.path.join(subdir.path, \"*.json\")):\n                with open(filename, \"r\") as f:\n                    puzzles.append((Path(filename).stem, json.load(f)))\n                    \n            # Shuffle puzzles\n            np.random.shuffle(puzzles)\n            \n            # Assign by fraction\n            for idx, (default_name, puzzle) in enumerate(puzzles):\n                fraction = idx / len(puzzles)\n                test_examples_dest = None\n                for f, dest in test_examples_map.get(subdir.name, test_examples_map[\"_default\"]):\n                    if fraction < f:\n                        test_examples_dest = dest\n                        break\n                        \n                assert test_examples_dest is not None\n                \n                convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {\"train\": train_examples_dest, \"test\": test_examples_dest})\n                total_puzzles += 1\n\n    print (f\"[{dataset_path}] total puzzles: {total_puzzles}\")\n\n\ndef convert_dataset(config: DataProcessConfig):\n    np.random.seed(config.seed)\n    \n    # Read dataset\n    data = {}\n    for dataset_dir in config.dataset_dirs:\n        load_puzzles_arcagi(data, dataset_dir, config)\n    \n    # Map global puzzle identifiers\n    num_identifiers = 1  # 0 is blank\n    identifier_map = {}\n    for split_name, split in data.items():\n        for subset_name, subset in split.items():\n            for group in subset:\n                for puzzle in group:\n                    if puzzle.id not in identifier_map:\n                        identifier_map[puzzle.id] = num_identifiers\n                        num_identifiers += 1\n\n    print (f\"Total puzzle IDs (including <blank>): {num_identifiers}\")\n\n    # Save\n    for split_name, split in data.items():\n        os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)\n        \n        # Translational augmentations\n        enable_translational_augment = split_name == \"train\"\n\n        # Statistics\n        total_examples = 0\n        total_puzzles = 0\n        total_groups = 0\n        \n        for subset_name, subset in split.items():\n            # Construct subset\n            results = {k: [] for k in [\"inputs\", \"labels\", \"puzzle_identifiers\", \"puzzle_indices\", \"group_indices\"]}\n            results[\"puzzle_indices\"].append(0)\n            results[\"group_indices\"].append(0)\n            \n            example_id = 0\n            puzzle_id = 0\n            \n            for group in subset:\n                for puzzle in group:\n                    # Push puzzle\n                    no_aug_id = np.random.randint(0, len(puzzle.examples))\n                    for _idx_ex, (inp, out) in enumerate(puzzle.examples):\n                        inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)\n                            \n                        results[\"inputs\"].append(inp)\n                        results[\"labels\"].append(out)\n                        example_id += 1\n                        \n                        total_examples += 1\n\n                    results[\"puzzle_indices\"].append(example_id)\n                    results[\"puzzle_identifiers\"].append(identifier_map[puzzle.id])\n                    \n                    puzzle_id += 1\n                    \n                    total_puzzles += 1\n                    \n                # Push group\n                results[\"group_indices\"].append(puzzle_id)\n                total_groups += 1\n            \n            for k, v in results.items():\n                if k in {\"inputs\", \"labels\"}:\n                    v = np.stack(v, 0)\n                else:\n                    v = np.array(v, dtype=np.int32)\n                \n                np.save(os.path.join(config.output_dir, split_name, f\"{subset_name}__{k}.npy\"), v)\n        \n        # Metadata\n        metadata = PuzzleDatasetMetadata(\n            seq_len=ARCMaxGridSize * ARCMaxGridSize,\n            vocab_size=10 + 2,  # PAD + EOS + \"0\" ... \"9\"\n            \n            pad_id=0,\n            ignore_label_id=0,\n            \n            blank_identifier_id=0,\n            num_puzzle_identifiers=num_identifiers,\n            \n            total_groups=total_groups,\n            mean_puzzle_examples=total_examples / total_puzzles,\n            sets=list(split.keys())\n        )\n\n        # Save metadata as JSON.\n        with open(os.path.join(config.output_dir, split_name, \"dataset.json\"), \"w\") as f:\n            json.dump(metadata.model_dump(), f)\n            \n    # Save IDs mapping\n    with open(os.path.join(config.output_dir, \"identifiers.json\"), \"w\") as f:\n        ids_mapping = {v: k for k, v in identifier_map.items()}\n        \n        json.dump([ids_mapping.get(i, \"<blank>\") for i in range(num_identifiers)], f)\n\n\n@cli.command(singleton=True)\ndef main(config: DataProcessConfig):\n    convert_dataset(config)\n\n\nif __name__ == \"__main__\":\n    cli()\n"
  },
  {
    "path": "dataset/build_maze_dataset.py",
    "content": "from typing import Optional\nimport math\nimport os\nimport csv\nimport json\nimport numpy as np\n\nfrom argdantic import ArgParser\nfrom pydantic import BaseModel\nfrom tqdm import tqdm\nfrom huggingface_hub import hf_hub_download\n\nfrom common import PuzzleDatasetMetadata, dihedral_transform\n\n\nCHARSET = \"# SGo\"\n\n\ncli = ArgParser()\n\n\nclass DataProcessConfig(BaseModel):\n    source_repo: str = \"sapientinc/maze-30x30-hard-1k\"\n    output_dir: str = \"data/maze-30x30-hard-1k\"\n\n    subsample_size: Optional[int] = None\n    aug: bool = False\n\n\ndef convert_subset(set_name: str, config: DataProcessConfig):\n    # Read CSV\n    all_chars = set()\n    grid_size = None\n    inputs = []\n    labels = []\n    \n    with open(hf_hub_download(config.source_repo, f\"{set_name}.csv\", repo_type=\"dataset\"), newline=\"\") as csvfile:  # type: ignore\n        reader = csv.reader(csvfile)\n        next(reader)  # Skip header\n        for source, q, a, rating in reader:\n            all_chars.update(q)\n            all_chars.update(a)\n\n            if grid_size is None:\n                n = int(len(q) ** 0.5)\n                grid_size = (n, n)\n                \n            inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))\n            labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))\n\n    # If subsample_size is specified for the training set,\n    # randomly sample the desired number of examples.\n    if set_name == \"train\" and config.subsample_size is not None:\n        total_samples = len(inputs)\n        if config.subsample_size < total_samples:\n            indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)\n            inputs = [inputs[i] for i in indices]\n            labels = [labels[i] for i in indices]\n\n    # Generate dataset\n    results = {k: [] for k in [\"inputs\", \"labels\", \"puzzle_identifiers\", \"puzzle_indices\", \"group_indices\"]}\n    puzzle_id = 0\n    example_id = 0\n    \n    results[\"puzzle_indices\"].append(0)\n    results[\"group_indices\"].append(0)\n    \n    for inp, out in zip(tqdm(inputs), labels):\n        # Dihedral transformations for augmentation\n        for aug_idx in range(8 if (set_name == \"train\" and config.aug) else 1):\n            results[\"inputs\"].append(dihedral_transform(inp, aug_idx))\n            results[\"labels\"].append(dihedral_transform(out, aug_idx))\n            example_id += 1\n            puzzle_id += 1\n            \n            results[\"puzzle_indices\"].append(example_id)\n            results[\"puzzle_identifiers\"].append(0)\n            \n        # Push group\n        results[\"group_indices\"].append(puzzle_id)\n            \n    # Char mappings\n    assert len(all_chars - set(CHARSET)) == 0\n    \n    char2id = np.zeros(256, np.uint8)\n    char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1\n\n    # To Numpy\n    def _seq_to_numpy(seq):\n        arr = np.vstack([char2id[s.reshape(-1)] for s in seq])\n        \n        return arr\n    \n    results = {\n        \"inputs\": _seq_to_numpy(results[\"inputs\"]),\n        \"labels\": _seq_to_numpy(results[\"labels\"]),\n        \n        \"group_indices\": np.array(results[\"group_indices\"], dtype=np.int32),\n        \"puzzle_indices\": np.array(results[\"puzzle_indices\"], dtype=np.int32),\n        \"puzzle_identifiers\": np.array(results[\"puzzle_identifiers\"], dtype=np.int32),\n    }\n\n    # Metadata\n    metadata = PuzzleDatasetMetadata(\n        seq_len=int(math.prod(grid_size)),  # type: ignore\n        vocab_size=len(CHARSET) + 1,  # PAD + Charset\n        \n        pad_id=0,\n        ignore_label_id=0,\n        \n        blank_identifier_id=0,\n        num_puzzle_identifiers=1,\n        \n        total_groups=len(results[\"group_indices\"]) - 1,\n        mean_puzzle_examples=1,\n        sets=[\"all\"]\n    )\n\n    # Save metadata as JSON.\n    save_dir = os.path.join(config.output_dir, set_name)\n    os.makedirs(save_dir, exist_ok=True)\n    \n    with open(os.path.join(save_dir, \"dataset.json\"), \"w\") as f:\n        json.dump(metadata.model_dump(), f)\n        \n    # Save data\n    for k, v in results.items():\n        np.save(os.path.join(save_dir, f\"all__{k}.npy\"), v)\n        \n    # Save IDs mapping (for visualization only)\n    with open(os.path.join(config.output_dir, \"identifiers.json\"), \"w\") as f:\n        json.dump([\"<blank>\"], f)\n\n\n@cli.command(singleton=True)\ndef preprocess_data(config: DataProcessConfig):\n    convert_subset(\"train\", config)\n    convert_subset(\"test\", config)\n\n\nif __name__ == \"__main__\":\n    cli()\n"
  },
  {
    "path": "dataset/build_sudoku_dataset.py",
    "content": "from typing import Optional\nimport os\nimport csv\nimport json\nimport numpy as np\n\nfrom argdantic import ArgParser\nfrom pydantic import BaseModel\nfrom tqdm import tqdm\nfrom huggingface_hub import hf_hub_download\n\nfrom common import PuzzleDatasetMetadata\n\n\ncli = ArgParser()\n\n\nclass DataProcessConfig(BaseModel):\n    source_repo: str = \"sapientinc/sudoku-extreme\"\n    output_dir: str = \"data/sudoku-extreme-full\"\n\n    subsample_size: Optional[int] = None\n    min_difficulty: Optional[int] = None\n    num_aug: int = 0\n\n\ndef shuffle_sudoku(board: np.ndarray, solution: np.ndarray):\n    # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged\n    digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))\n    \n    # Randomly decide whether to transpose.\n    transpose_flag = np.random.rand() < 0.5\n\n    # Generate a valid row permutation:\n    # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.\n    bands = np.random.permutation(3)\n    row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])\n\n    # Similarly for columns (stacks).\n    stacks = np.random.permutation(3)\n    col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])\n\n    # Build an 81->81 mapping. For each new cell at (i, j)\n    # (row index = i // 9, col index = i % 9),\n    # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].\n    mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])\n\n    def apply_transformation(x: np.ndarray) -> np.ndarray:\n        # Apply transpose flag\n        if transpose_flag:\n            x = x.T\n        # Apply the position mapping.\n        new_board = x.flatten()[mapping].reshape(9, 9).copy()\n        # Apply digit mapping\n        return digit_map[new_board]\n\n    return apply_transformation(board), apply_transformation(solution)\n\n\ndef convert_subset(set_name: str, config: DataProcessConfig):\n    # Read CSV\n    inputs = []\n    labels = []\n    \n    with open(hf_hub_download(config.source_repo, f\"{set_name}.csv\", repo_type=\"dataset\"), newline=\"\") as csvfile:\n        reader = csv.reader(csvfile)\n        next(reader)  # Skip header\n        for source, q, a, rating in reader:\n            if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):\n                assert len(q) == 81 and len(a) == 81\n                \n                inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))\n                labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))\n\n    # If subsample_size is specified for the training set,\n    # randomly sample the desired number of examples.\n    if set_name == \"train\" and config.subsample_size is not None:\n        total_samples = len(inputs)\n        if config.subsample_size < total_samples:\n            indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)\n            inputs = [inputs[i] for i in indices]\n            labels = [labels[i] for i in indices]\n\n    # Generate dataset\n    num_augments = config.num_aug if set_name == \"train\" else 0\n\n    results = {k: [] for k in [\"inputs\", \"labels\", \"puzzle_identifiers\", \"puzzle_indices\", \"group_indices\"]}\n    puzzle_id = 0\n    example_id = 0\n    \n    results[\"puzzle_indices\"].append(0)\n    results[\"group_indices\"].append(0)\n    \n    for orig_inp, orig_out in zip(tqdm(inputs), labels):\n        for aug_idx in range(1 + num_augments):\n            # First index is not augmented\n            if aug_idx == 0:\n                inp, out = orig_inp, orig_out\n            else:\n                inp, out = shuffle_sudoku(orig_inp, orig_out)\n\n            # Push puzzle (only single example)\n            results[\"inputs\"].append(inp)\n            results[\"labels\"].append(out)\n            example_id += 1\n            puzzle_id += 1\n            \n            results[\"puzzle_indices\"].append(example_id)\n            results[\"puzzle_identifiers\"].append(0)\n            \n        # Push group\n        results[\"group_indices\"].append(puzzle_id)\n        \n    # To Numpy\n    def _seq_to_numpy(seq):\n        arr = np.concatenate(seq).reshape(len(seq), -1)\n        \n        assert np.all((arr >= 0) & (arr <= 9))\n        return arr + 1\n    \n    results = {\n        \"inputs\": _seq_to_numpy(results[\"inputs\"]),\n        \"labels\": _seq_to_numpy(results[\"labels\"]),\n        \n        \"group_indices\": np.array(results[\"group_indices\"], dtype=np.int32),\n        \"puzzle_indices\": np.array(results[\"puzzle_indices\"], dtype=np.int32),\n        \"puzzle_identifiers\": np.array(results[\"puzzle_identifiers\"], dtype=np.int32),\n    }\n\n    # Metadata\n    metadata = PuzzleDatasetMetadata(\n        seq_len=81,\n        vocab_size=10 + 1,  # PAD + \"0\" ... \"9\"\n        \n        pad_id=0,\n        ignore_label_id=0,\n        \n        blank_identifier_id=0,\n        num_puzzle_identifiers=1,\n        \n        total_groups=len(results[\"group_indices\"]) - 1,\n        mean_puzzle_examples=1,\n        sets=[\"all\"]\n    )\n\n    # Save metadata as JSON.\n    save_dir = os.path.join(config.output_dir, set_name)\n    os.makedirs(save_dir, exist_ok=True)\n    \n    with open(os.path.join(save_dir, \"dataset.json\"), \"w\") as f:\n        json.dump(metadata.model_dump(), f)\n        \n    # Save data\n    for k, v in results.items():\n        np.save(os.path.join(save_dir, f\"all__{k}.npy\"), v)\n        \n    # Save IDs mapping (for visualization only)\n    with open(os.path.join(config.output_dir, \"identifiers.json\"), \"w\") as f:\n        json.dump([\"<blank>\"], f)\n\n\n@cli.command(singleton=True)\ndef preprocess_data(config: DataProcessConfig):\n    convert_subset(\"train\", config)\n    convert_subset(\"test\", config)\n\n\nif __name__ == \"__main__\":\n    cli()\n"
  },
  {
    "path": "dataset/common.py",
    "content": "from typing import List, Optional\n\nimport pydantic\nimport numpy as np\n\n\n# Global list mapping each dihedral transform id to its inverse.\n# Index corresponds to the original tid, and the value is its inverse.\nDIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]\n\n\nclass PuzzleDatasetMetadata(pydantic.BaseModel):\n    pad_id: int\n    ignore_label_id: Optional[int]\n    blank_identifier_id: int\n    \n    vocab_size: int\n    seq_len: int\n    num_puzzle_identifiers: int\n    \n    total_groups: int\n    mean_puzzle_examples: float\n\n    sets: List[str]\n\n\ndef dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:\n    \"\"\"8 dihedral symmetries by rotate, flip and mirror\"\"\"\n    \n    if tid == 0:\n        return arr  # identity\n    elif tid == 1:\n        return np.rot90(arr, k=1)\n    elif tid == 2:\n        return np.rot90(arr, k=2)\n    elif tid == 3:\n        return np.rot90(arr, k=3)\n    elif tid == 4:\n        return np.fliplr(arr)       # horizontal flip\n    elif tid == 5:\n        return np.flipud(arr)       # vertical flip\n    elif tid == 6:\n        return arr.T                # transpose (reflection along main diagonal)\n    elif tid == 7:\n        return np.fliplr(np.rot90(arr, k=1))  # anti-diagonal reflection\n    else:\n        return arr\n    \n    \ndef inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:\n    return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])\n"
  },
  {
    "path": "evaluate.py",
    "content": "from typing import List\nimport yaml\nimport os\n\nimport torch\nimport torch.distributed as dist\n\nimport pydantic\nfrom omegaconf import OmegaConf\nfrom pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader\n\n\nclass EvalConfig(pydantic.BaseModel):\n    checkpoint: str\n    \n    save_outputs: List[str] = [\"inputs\", \"labels\", \"puzzle_identifiers\", \"logits\", \"q_halt_logits\", \"q_continue_logits\"]\n\n\ndef launch():\n    eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli()))  # type: ignore\n    \n    RANK = 0\n    WORLD_SIZE = 1\n    # Initialize distributed training if in distributed environment (e.g. torchrun)\n    if \"LOCAL_RANK\" in os.environ:\n        # Initialize distributed, default device and dtype\n        dist.init_process_group(backend=\"nccl\")\n\n        RANK = dist.get_rank()\n        WORLD_SIZE = dist.get_world_size()\n\n        torch.cuda.set_device(int(os.environ[\"LOCAL_RANK\"]))\n\n    with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), \"all_config.yaml\"), \"r\") as f:\n        config = PretrainConfig(**yaml.safe_load(f))\n\n        config.eval_save_outputs = eval_cfg.save_outputs\n        config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)\n\n    # Dataloader\n    train_loader, train_metadata = create_dataloader(config, \"train\", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)\n    eval_loader,  eval_metadata  = create_dataloader(config, \"test\", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)\n\n    # Models\n    train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)\n    # Try unwrap torch.compile\n    try:\n        train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location=\"cuda\"), assign=True)\n    except:\n        train_state.model.load_state_dict({k.removeprefix(\"_orig_mod.\"): v for k, v in torch.load(eval_cfg.checkpoint, map_location=\"cuda\").items()}, assign=True)\n    \n    train_state.step = 0\n    ckpt_filename = os.path.basename(eval_cfg.checkpoint)\n    if ckpt_filename.startswith(\"step_\"):\n        train_state.step = int(ckpt_filename.removeprefix(\"step_\"))\n\n    # Evaluate\n    print (\"Starting evaluation\")\n    \n    train_state.model.eval()\n    metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)\n\n    if metrics is not None:\n        print (metrics)\n\n\nif __name__ == \"__main__\":\n    launch()\n"
  },
  {
    "path": "models/common.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\n\n\ndef trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):\n    # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor\n    # This function is a PyTorch version of jax truncated normal init (default init method in flax)\n    # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848\n    # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199\n\n    with torch.no_grad():\n        if std == 0:\n            tensor.zero_()\n        else:\n            sqrt2 = math.sqrt(2)\n            a = math.erf(lower / sqrt2)\n            b = math.erf(upper / sqrt2)\n            z = (b - a) / 2\n\n            c = (2 * math.pi) ** -0.5\n            pdf_u = c * math.exp(-0.5 * lower ** 2)\n            pdf_l = c * math.exp(-0.5 * upper ** 2)\n            comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)\n\n            tensor.uniform_(a, b)\n            tensor.erfinv_()\n            tensor.mul_(sqrt2 * comp_std)\n            tensor.clip_(lower * comp_std, upper * comp_std)\n\n    return tensor\n"
  },
  {
    "path": "models/hrm/hrm_act_v1.py",
    "content": "from typing import Tuple, List, Dict, Optional\nfrom dataclasses import dataclass\nimport math\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom pydantic import BaseModel\n\nfrom models.common import trunc_normal_init_\nfrom models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear\nfrom models.sparse_embedding import CastedSparseEmbedding\n\n\n@dataclass\nclass HierarchicalReasoningModel_ACTV1InnerCarry:\n    z_H: torch.Tensor\n    z_L: torch.Tensor\n\n\n@dataclass\nclass HierarchicalReasoningModel_ACTV1Carry:\n    inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry\n    \n    steps: torch.Tensor\n    halted: torch.Tensor\n    \n    current_data: Dict[str, torch.Tensor]\n\n\nclass HierarchicalReasoningModel_ACTV1Config(BaseModel):\n    batch_size: int\n    seq_len: int\n    puzzle_emb_ndim: int = 0\n    num_puzzle_identifiers: int\n    vocab_size: int\n\n    H_cycles: int\n    L_cycles: int\n\n    H_layers: int\n    L_layers: int\n\n    # Transformer config\n    hidden_size: int\n    expansion: float\n    num_heads: int\n    pos_encodings: str\n\n    rms_norm_eps: float = 1e-5\n    rope_theta: float = 10000.0\n    \n    # Halting Q-learning config\n    halt_max_steps: int\n    halt_exploration_prob: float\n\n    forward_dtype: str = \"bfloat16\"\n\n\nclass HierarchicalReasoningModel_ACTV1Block(nn.Module):\n    def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:\n        super().__init__()\n\n        self.self_attn = Attention(\n            hidden_size=config.hidden_size,\n            head_dim=config.hidden_size // config.num_heads,\n            num_heads=config.num_heads,\n            num_key_value_heads=config.num_heads,\n            causal=False\n        )\n        self.mlp = SwiGLU(\n            hidden_size=config.hidden_size,\n            expansion=config.expansion,\n        )\n        self.norm_eps = config.rms_norm_eps\n\n    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:\n        # Post Norm\n        # Self Attention\n        hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)\n        # Fully Connected\n        hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)\n        return hidden_states\n\n\nclass HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):\n    def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):\n        super().__init__()\n\n        self.layers = torch.nn.ModuleList(layers)\n\n    def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:\n        # Input injection (add)\n        hidden_states = hidden_states + input_injection\n        # Layers\n        for layer in self.layers:\n            hidden_states = layer(hidden_states=hidden_states, **kwargs)\n\n        return hidden_states\n\n\nclass HierarchicalReasoningModel_ACTV1_Inner(nn.Module):\n    def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:\n        super().__init__()\n        self.config = config\n        self.forward_dtype = getattr(torch, self.config.forward_dtype)\n\n        # I/O\n        self.embed_scale  = math.sqrt(self.config.hidden_size)\n        embed_init_std = 1.0 / self.embed_scale\n\n        self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)\n        self.lm_head      = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)\n        self.q_head       = CastedLinear(self.config.hidden_size, 2, bias=True)\n\n        self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)  # ceil div\n        if self.config.puzzle_emb_ndim > 0:\n            # Zero init puzzle embeddings\n            self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,\n                                                    batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)\n\n        # LM Blocks\n        if self.config.pos_encodings == \"rope\":\n            self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,\n                                              max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,\n                                              base=self.config.rope_theta)\n        elif self.config.pos_encodings == \"learned\":\n            self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)\n        else:\n            raise NotImplementedError()\n\n        # Reasoning Layers\n        self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])\n        self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])\n        \n        # Initial states\n        self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)\n        self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)\n\n        # Q head special init\n        # Init Q to (almost) zero for faster learning during bootstrapping\n        with torch.no_grad():\n            self.q_head.weight.zero_()\n            self.q_head.bias.fill_(-5)  # type: ignore\n\n    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):\n        # Token embedding\n        embedding = self.embed_tokens(input.to(torch.int32))\n\n        # Puzzle embeddings\n        if self.config.puzzle_emb_ndim > 0:\n            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)\n            \n            pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]\n            if pad_count > 0:\n                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))\n\n            embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)\n\n        # Position embeddings\n        if self.config.pos_encodings == \"learned\":\n            # scale by 1/sqrt(2) to maintain forward variance\n            embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))\n\n        # Scale\n        return self.embed_scale * embedding\n\n    def empty_carry(self, batch_size: int):\n        return HierarchicalReasoningModel_ACTV1InnerCarry(\n            z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),\n            z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),\n        )\n        \n    def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):\n        return HierarchicalReasoningModel_ACTV1InnerCarry(\n            z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),\n            z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),\n        )\n\n    def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        seq_info = dict(\n            cos_sin=self.rotary_emb() if hasattr(self, \"rotary_emb\") else None,\n        )\n\n        # Input encoding\n        input_embeddings = self._input_embeddings(batch[\"inputs\"], batch[\"puzzle_identifiers\"])\n\n        # Forward iterations\n        with torch.no_grad():\n            z_H, z_L = carry.z_H, carry.z_L\n\n            for _H_step in range(self.config.H_cycles):\n                for _L_step in range(self.config.L_cycles):\n                    if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):\n                        z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)\n\n                if not (_H_step == self.config.H_cycles - 1):\n                    z_H = self.H_level(z_H, z_L, **seq_info)\n\n        assert not z_H.requires_grad and not z_L.requires_grad\n\n        # 1-step grad\n        z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)\n        z_H = self.H_level(z_H, z_L, **seq_info)\n\n        # LM Outputs\n        new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())  # New carry no grad\n        output = self.lm_head(z_H)[:, self.puzzle_emb_len:]\n\n        # Q head\n        q_logits = self.q_head(z_H[:, 0]).to(torch.float32)\n        \n        return new_carry, output, (q_logits[..., 0], q_logits[..., 1])\n\n\nclass HierarchicalReasoningModel_ACTV1(nn.Module):\n    \"\"\"ACT wrapper.\"\"\"\n\n    def __init__(self, config_dict: dict):\n        super().__init__()\n        self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)\n        self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)\n\n    @property\n    def puzzle_emb(self):\n        return self.inner.puzzle_emb\n\n    def initial_carry(self, batch: Dict[str, torch.Tensor]):\n        batch_size = batch[\"inputs\"].shape[0]\n\n        return HierarchicalReasoningModel_ACTV1Carry(\n            inner_carry=self.inner.empty_carry(batch_size),  # Empty is expected, it will be reseted in first pass as all sequences are halted.\n            \n            steps=torch.zeros((batch_size, ), dtype=torch.int32),\n            halted=torch.ones((batch_size, ), dtype=torch.bool),  # Default to halted\n            \n            current_data={k: torch.empty_like(v) for k, v in batch.items()}\n        )\n        \n    def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:\n        # Update data, carry (removing halted sequences)\n        new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)\n        \n        new_steps = torch.where(carry.halted, 0, carry.steps)\n\n        new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}\n\n        # Forward inner model\n        new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)\n\n        outputs = {\n            \"logits\": logits,\n            \"q_halt_logits\": q_halt_logits,\n            \"q_continue_logits\": q_continue_logits\n        }\n        \n        with torch.no_grad():\n            # Step\n            new_steps = new_steps + 1\n            is_last_step = new_steps >= self.config.halt_max_steps\n            \n            halted = is_last_step\n\n            # if training, and ACT is enabled\n            if self.training and (self.config.halt_max_steps > 1):\n                # Halt signal\n                # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes\n                halted = halted | (q_halt_logits > q_continue_logits)\n\n                # Exploration\n                min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)\n\n                halted = halted & (new_steps >= min_halt_steps)\n\n                # Compute target Q\n                # NOTE: No replay buffer and target networks for computing target Q-value.\n                # As batch_size is large, there're many parallel envs.\n                # Similar concept as PQN https://arxiv.org/abs/2407.04811\n                next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]\n                \n                outputs[\"target_q_continue\"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))\n\n        return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs\n"
  },
  {
    "path": "models/layers.py",
    "content": "from typing import Tuple\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\ntry:\n    from flash_attn_interface import flash_attn_func  # type: ignore[import]\nexcept ImportError:\n    # Fallback to FlashAttention 2\n    from flash_attn import flash_attn_func  # type: ignore[import]\n\nfrom models.common import trunc_normal_init_\n\n\nCosSin = Tuple[torch.Tensor, torch.Tensor]\n\n\ndef _find_multiple(a, b):\n    return (-(a // -b)) * b\n\n\ndef rotate_half(x: torch.Tensor):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):\n    # q, k: [bs, seq_len, num_heads, head_dim]\n    # cos, sin: [seq_len, head_dim]\n    orig_dtype = q.dtype\n    q = q.to(cos.dtype)\n    k = k.to(cos.dtype)\n\n    q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))\n    k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))\n\n    return q_embed.to(orig_dtype), k_embed.to(orig_dtype)\n\n\nclass CastedLinear(nn.Module):\n    def __init__(self,\n                 in_features: int,\n                 out_features: int,\n                 bias: bool):\n        super().__init__()\n        # Truncated LeCun normal init\n        self.weight = nn.Parameter(\n            trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))\n        )\n        self.bias = None\n        if bias:\n            # Zero init bias\n            self.bias = nn.Parameter(torch.zeros((out_features, )))\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)\n\n\nclass CastedEmbedding(nn.Module):\n    def __init__(self,\n                 num_embeddings: int,\n                 embedding_dim: int,\n                 init_std: float,\n                 cast_to: torch.dtype):\n        super().__init__()\n        self.cast_to = cast_to\n\n        # Truncated LeCun normal init\n        self.embedding_weight = nn.Parameter(\n            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)\n        )\n        \n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return F.embedding(input, self.embedding_weight.to(self.cast_to))\n\n\nclass RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings, base, device=None):\n        super().__init__()\n\n        # RoPE\n        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))\n        t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)\n        freqs = torch.outer(t, inv_freq)\n\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.cos_cached = nn.Buffer(emb.cos(), persistent=False)\n        self.sin_cached = nn.Buffer(emb.sin(), persistent=False)\n\n    def forward(self):\n        return self.cos_cached, self.sin_cached\n\n\nclass Attention(nn.Module):\n    def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):\n        super().__init__()\n\n        self.hidden_size = hidden_size\n        self.head_dim = head_dim\n        self.output_size = head_dim * num_heads\n        self.num_heads = num_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.causal = causal\n\n        self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)\n        self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)\n\n    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:\n        batch_size, seq_len, _ = hidden_states.shape\n\n        # hidden_states: [bs, seq_len, num_heads, head_dim]\n        qkv = self.qkv_proj(hidden_states)\n\n        # Split head\n        qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)\n        query = qkv[:, :, :self.num_heads]\n        key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]\n        value = qkv[:, :, self.num_heads + self.num_key_value_heads:]\n\n        # RoPE\n        if cos_sin is not None:\n            cos, sin = cos_sin\n            query, key = apply_rotary_pos_emb(query, key, cos, sin)\n\n        # flash attn\n        attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)\n        if isinstance(attn_output, tuple):  # fa2 and fa3 compatibility\n            attn_output = attn_output[0]\n\n        attn_output = attn_output.view(batch_size, seq_len, self.output_size)  # type: ignore\n        return self.o_proj(attn_output)\n\n\nclass SwiGLU(nn.Module):\n    def __init__(self, hidden_size: int, expansion: float):\n        super().__init__()\n        inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)\n\n        self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)\n        self.down_proj    = CastedLinear(inter, hidden_size, bias=False)\n\n    def forward(self, x):\n        gate, up = self.gate_up_proj(x).chunk(2, dim=-1)\n        return self.down_proj(F.silu(gate) * up)\n\n\ndef rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:\n    input_dtype = hidden_states.dtype\n    hidden_states = hidden_states.to(torch.float32)\n\n    variance = hidden_states.square().mean(-1, keepdim=True)\n    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)\n    return hidden_states.to(input_dtype)\n"
  },
  {
    "path": "models/losses.py",
    "content": "from typing import Any, Tuple, Dict, Sequence, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nIGNORE_LABEL_ID = -100\n\n\ndef s(x, epsilon=1e-30):\n    return torch.where(\n        x<0,\n        1/(1-x+ epsilon),\n        x + 1\n    )\n\n\ndef log_stablemax(x, dim=-1):\n    s_x = s(x)\n    return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))\n\n\ndef stablemax_cross_entropy(logits, labels, ignore_index: int = -100):\n    logprobs = log_stablemax(logits.to(torch.float64), dim=-1)\n\n    valid_mask = labels != ignore_index\n    transformed_labels = torch.where(valid_mask, labels, 0)\n    prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)\n\n    return -torch.where(valid_mask, prediction_logprobs, 0)\n\n\ndef softmax_cross_entropy(logits, labels, ignore_index: int = -100):\n    # Cast logits to f32\n    # Flatten logits\n    return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction=\"none\").view(labels.shape)\n\n\nclass ACTLossHead(nn.Module):\n    def __init__(self, model: nn.Module, loss_type: str):\n        super().__init__()\n        self.model = model\n        self.loss_fn = globals()[loss_type]\n        \n    def initial_carry(self, *args, **kwargs):\n        return self.model.initial_carry(*args, **kwargs)  # type: ignore\n\n    def forward(\n        self,\n        return_keys: Sequence[str],\n        # Model args\n        **model_kwargs,\n    ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:\n        # Model logits\n        # B x SeqLen x D\n        new_carry, outputs = self.model(**model_kwargs)\n        labels = new_carry.current_data[\"labels\"]\n\n        # Correctness\n        with torch.no_grad():\n            mask = labels != IGNORE_LABEL_ID\n            loss_counts = mask.sum(-1)\n            loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1)  # Avoid NaNs in division\n\n            is_correct = mask & (torch.argmax(outputs[\"logits\"], dim=-1) == labels)\n            seq_is_correct = is_correct.sum(-1) == loss_counts\n            \n            # Metrics (halted)\n            valid_metrics = new_carry.halted & (loss_counts > 0)\n            metrics = {\n                \"count\": valid_metrics.sum(),\n                \n                \"accuracy\":       torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),\n                \"exact_accuracy\": (valid_metrics & seq_is_correct).sum(),\n\n                \"q_halt_accuracy\": (valid_metrics & ((outputs[\"q_halt_logits\"] >= 0) == seq_is_correct)).sum(),\n                \"steps\":          torch.where(valid_metrics, new_carry.steps, 0).sum(),\n            }\n\n        # Losses\n        # FIXME: Assuming the batch is always full\n        lm_loss = (self.loss_fn(outputs[\"logits\"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()\n        q_halt_loss = F.binary_cross_entropy_with_logits(outputs[\"q_halt_logits\"], seq_is_correct.to(outputs[\"q_halt_logits\"].dtype), reduction=\"sum\")\n\n        metrics.update({\n            \"lm_loss\": lm_loss.detach(),\n            \"q_halt_loss\": q_halt_loss.detach(),\n        })\n\n        # Q continue (bootstrapping target loss)\n        q_continue_loss = 0\n        if \"target_q_continue\" in outputs:\n            q_continue_loss = F.binary_cross_entropy_with_logits(outputs[\"q_continue_logits\"], outputs[\"target_q_continue\"], reduction=\"sum\")\n\n            metrics[\"q_continue_loss\"] = q_continue_loss.detach()\n\n        # Filter outputs for return\n        detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}\n\n        return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()\n"
  },
  {
    "path": "models/sparse_embedding.py",
    "content": "from typing import Union\n\nimport torch\nfrom torch import nn\nimport torch.distributed as dist\nfrom torch.optim.optimizer import Optimizer, ParamsT\n\nfrom models.common import trunc_normal_init_\n\n\nclass CastedSparseEmbedding(nn.Module):\n    def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):\n        super().__init__()\n        self.cast_to = cast_to\n\n        # Real Weights\n        # Truncated LeCun normal init\n        self.weights = nn.Buffer(\n            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True\n        )\n\n        # Local weights and IDs\n        # Local embeddings, with gradient, not persistent\n        self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)\n        # Local embedding IDs, not persistent\n        self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)\n\n    def forward(self, inputs: torch.Tensor) -> torch.Tensor:\n        if not self.training:\n            # Test mode, no gradient\n            return self.weights[inputs].to(self.cast_to)\n            \n        # Training mode, fill puzzle embedding from weights\n        with torch.no_grad():\n            self.local_weights.copy_(self.weights[inputs])\n            self.local_ids.copy_(inputs)\n\n        return self.local_weights.to(self.cast_to)\n\n\nclass CastedSparseEmbeddingSignSGD_Distributed(Optimizer):\n    def __init__(\n        self,\n        params: ParamsT,\n\n        world_size: int,\n        lr: Union[float, torch.Tensor] = 1e-3,\n        weight_decay: float = 1e-2,\n    ):\n        if not 0.0 <= lr:\n            raise ValueError(f\"Invalid learning rate: {lr}\")\n        if not 0.0 <= weight_decay:\n            raise ValueError(f\"Invalid weight_decay value: {weight_decay}\")\n\n        defaults = dict(\n            lr=lr,\n            weight_decay=weight_decay,\n            world_size=world_size\n        )\n        super().__init__(params, defaults)\n\n    @torch.no_grad\n    def step(self, closure=None):  # type: ignore\n        for group in self.param_groups:\n            # Find the sparse embedding weights\n            local_weights_grad = None\n            local_ids = None\n            weights = None\n            \n            assert len(group[\"params\"]) == 3\n            for p in group[\"params\"]:\n                if p.requires_grad:\n                    local_weights_grad = p.grad\n                elif p.ndim == 1:\n                    local_ids = p\n                elif p.ndim == 2:\n                    weights = p\n                else:\n                    assert False\n                \n            assert local_weights_grad is not None\n            assert local_ids is not None\n            assert weights is not None\n        \n            # Apply SignSGD\n            # Adam ≈ SignSGD if gradient is very sparse\n            _sparse_emb_signsgd_dist(\n                local_weights_grad,\n                local_ids,\n                weights,\n                \n                lr=group[\"lr\"],\n                weight_decay=group[\"weight_decay\"],\n                world_size=group[\"world_size\"]\n            )\n\n\ndef _sparse_emb_signsgd_dist(\n    local_weights_grad: torch.Tensor,\n    local_ids: torch.Tensor,\n    weights: torch.Tensor,\n    \n    lr: float,\n    weight_decay: float,\n    world_size: int\n) -> None:\n    N, D = local_weights_grad.shape\n    \n    # All-gather\n    all_weights_grad = local_weights_grad\n    all_ids = local_ids\n\n    if world_size > 1:\n        all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)\n        all_ids = torch.empty(world_size * N,               dtype=local_ids.dtype,          device=local_ids.device)\n    \n        dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)\n        dist.all_gather_into_tensor(all_ids,          local_ids)\n\n    # Unique\n    grad_ids, inv = all_ids.unique(return_inverse=True)\n\n    grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)\n    grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)\n\n    # SignSGD with decoupled weight decay\n    p = weights[grad_ids]\n\n    p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)\n\n    # Write updated slices back\n    weights[grad_ids] = p\n"
  },
  {
    "path": "pretrain.py",
    "content": "from typing import Optional, Any, Sequence, List\nfrom dataclasses import dataclass\nimport os\nimport math\nimport yaml\nimport shutil\n\nimport torch\nimport torch.distributed as dist\nfrom torch import nn\nfrom torch.utils.data import DataLoader\n\nimport tqdm\nimport wandb\nimport coolname\nimport hydra\nimport pydantic\nfrom omegaconf import DictConfig\nfrom adam_atan2 import AdamATan2\n\nfrom puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata\nfrom utils.functions import load_model_class, get_model_source_path\nfrom models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed\n\n\nclass LossConfig(pydantic.BaseModel):\n    model_config = pydantic.ConfigDict(extra='allow')\n    \n    name: str\n\n\nclass ArchConfig(pydantic.BaseModel):\n    model_config = pydantic.ConfigDict(extra='allow')\n\n    name: str\n    loss: LossConfig\n\n\nclass PretrainConfig(pydantic.BaseModel):\n    # Config\n    arch: ArchConfig\n    # Data\n    data_path: str\n\n    # Hyperparams\n    global_batch_size: int\n    epochs: int\n\n    lr: float\n    lr_min_ratio: float\n    lr_warmup_steps: int\n\n    weight_decay: float\n    beta1: float\n    beta2: float\n\n    # Puzzle embedding\n    puzzle_emb_lr: float\n    puzzle_emb_weight_decay: float\n\n    # Names\n    project_name: Optional[str] = None\n    run_name: Optional[str] = None\n    checkpoint_path: Optional[str] = None\n\n    # Extras\n    seed: int = 0\n    checkpoint_every_eval: bool = False\n    eval_interval: Optional[int] = None\n    eval_save_outputs: List[str] = []\n\n\n@dataclass\nclass TrainState:\n    model: nn.Module\n    optimizers: Sequence[torch.optim.Optimizer]\n    optimizer_lrs: Sequence[float]\n    carry: Any\n\n    step: int\n    total_steps: int\n\n\ndef create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):\n    dataset = PuzzleDataset(PuzzleDatasetConfig(\n        seed=config.seed,\n\n        dataset_path=config.data_path,\n\n        rank=rank,\n        num_replicas=world_size,\n        \n        **kwargs\n    ), split=split)\n    dataloader = DataLoader(\n        dataset,\n        batch_size=None,\n\n        num_workers=1,\n        prefetch_factor=8,\n\n        pin_memory=True,\n        persistent_workers=True\n    )\n    return dataloader, dataset.metadata\n\n\ndef create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):\n    model_cfg = dict(\n        **config.arch.__pydantic_extra__,  # type: ignore\n\n        batch_size=config.global_batch_size // world_size,\n\n        vocab_size=train_metadata.vocab_size,\n        seq_len=train_metadata.seq_len,\n        num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,\n        causal=False  # Non-autoregressive\n    )\n\n    # Instantiate model with loss head\n    model_cls = load_model_class(config.arch.name)\n    loss_head_cls = load_model_class(config.arch.loss.name)\n\n    with torch.device(\"cuda\"):\n        model: nn.Module = model_cls(model_cfg)\n        model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__)  # type: ignore\n        if \"DISABLE_COMPILE\" not in os.environ:\n            model = torch.compile(model, dynamic=False)  # type: ignore\n\n        # Broadcast parameters from rank 0\n        if world_size > 1:\n            with torch.no_grad():\n                for param in list(model.parameters()) + list(model.buffers()):\n                    dist.broadcast(param, src=0)\n\n    # Optimizers and lr\n    optimizers = [\n        CastedSparseEmbeddingSignSGD_Distributed(\n            model.model.puzzle_emb.buffers(),  # type: ignore\n            \n            lr=0,  # Needs to be set by scheduler\n            weight_decay=config.puzzle_emb_weight_decay,\n\n            world_size=world_size\n        ),\n        AdamATan2(\n            model.parameters(),\n\n            lr=0,  # Needs to be set by scheduler\n            weight_decay=config.weight_decay,\n            betas=(config.beta1, config.beta2)\n        )\n    ]\n    optimizer_lrs = [\n        config.puzzle_emb_lr,\n        config.lr\n    ]\n\n    return model, optimizers, optimizer_lrs\n\n\ndef cosine_schedule_with_warmup_lr_lambda(\n    current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5\n):\n    if current_step < num_warmup_steps:\n        return base_lr * float(current_step) / float(max(1, num_warmup_steps))\n\n    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n    return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))\n\n\ndef init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):\n    # Estimated total training steps\n    total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)\n\n    # Model\n    model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)\n\n    return TrainState(\n        step=0,\n        total_steps=total_steps,\n\n        model=model,\n        optimizers=optimizers,\n        optimizer_lrs=optimizer_lrs,\n        carry=None\n    )\n\n\ndef save_train_state(config: PretrainConfig, train_state: TrainState):\n    # FIXME: Only saved model.\n    if config.checkpoint_path is None:\n        return\n\n    os.makedirs(config.checkpoint_path, exist_ok=True)\n    torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f\"step_{train_state.step}\"))\n\n\ndef compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):\n    return cosine_schedule_with_warmup_lr_lambda(\n        current_step=train_state.step,\n        base_lr=base_lr,\n        num_warmup_steps=round(config.lr_warmup_steps),\n        num_training_steps=train_state.total_steps,\n        min_ratio=config.lr_min_ratio\n    )\n\n\ndef train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):\n    train_state.step += 1\n    if train_state.step > train_state.total_steps:  # At most train_total_steps\n        return\n\n    # To device\n    batch = {k: v.cuda() for k, v in batch.items()}\n\n    # Init carry if it is None\n    if train_state.carry is None:\n        with torch.device(\"cuda\"):\n            train_state.carry = train_state.model.initial_carry(batch)  # type: ignore\n\n    # Forward\n    train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])\n\n    ((1 / global_batch_size) * loss).backward()\n\n    # Allreduce\n    if world_size > 1:\n        for param in train_state.model.parameters():\n            if param.grad is not None:\n                dist.all_reduce(param.grad)\n            \n    # Apply optimizer\n    lr_this_step = None    \n    for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):\n        lr_this_step = compute_lr(base_lr, config, train_state)\n\n        for param_group in optim.param_groups:\n            param_group['lr'] = lr_this_step\n            \n        optim.step()\n        optim.zero_grad()\n\n    # Reduce metrics\n    if len(metrics):\n        assert not any(v.requires_grad for v in metrics.values())\n\n        metric_keys = list(sorted(metrics.keys()))  # Sort keys to guarantee all processes use the same order.\n        # Reduce and reconstruct\n        metric_values = torch.stack([metrics[k] for k in metric_keys])\n        if world_size > 1:\n            dist.reduce(metric_values, dst=0)\n\n        if rank == 0:\n            metric_values = metric_values.cpu().numpy()\n            reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}\n            \n            # Postprocess\n            count = max(reduced_metrics[\"count\"], 1)  # Avoid NaNs\n            reduced_metrics = {f\"train/{k}\": v / (global_batch_size if k.endswith(\"loss\") else count) for k, v in reduced_metrics.items()}\n\n            reduced_metrics[\"train/lr\"] = lr_this_step\n            return reduced_metrics\n\n\ndef evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):\n    with torch.inference_mode():\n        set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}\n        \n        all_preds = {}\n\n        metric_keys = []\n        metric_values = None\n        metric_global_batch_size = [0 for _ in range(len(set_ids))]\n        \n        carry = None\n        for set_name, batch, global_batch_size in eval_loader:\n            # To device\n            batch = {k: v.cuda() for k, v in batch.items()}\n            with torch.device(\"cuda\"):\n                carry = train_state.model.initial_carry(batch)  # type: ignore\n\n            # Forward\n            while True:\n                carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)\n                \n                if all_finish:\n                    break\n\n            for collection in (batch, preds):\n                for k, v in collection.items():\n                    if k in config.eval_save_outputs:\n                        all_preds.setdefault(k, [])\n                        all_preds[k].append(v.cpu())  # Move to CPU for saving GPU memory\n                        \n            del carry, preds, batch, all_finish\n\n            # Aggregate\n            set_id = set_ids[set_name]\n            \n            if metric_values is None:\n                metric_keys = list(sorted(metrics.keys()))  # Sort keys to guarantee all processes use the same order.\n                metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device=\"cuda\")\n                \n            metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])\n            metric_global_batch_size[set_id] += global_batch_size\n\n        if len(all_preds) and config.checkpoint_path is not None:\n            all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n\n            os.makedirs(config.checkpoint_path, exist_ok=True)\n            torch.save(all_preds, os.path.join(config.checkpoint_path, f\"step_{train_state.step}_all_preds.{rank}\"))\n\n        # Logging\n        # Reduce to rank 0\n        if metric_values is not None:\n            if world_size > 1:\n                dist.reduce(metric_values, dst=0)\n            \n            if rank == 0:\n                reduced_metrics = metric_values.cpu().numpy()\n                reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}\n                                   for set_id, set_name in enumerate(set_ids)}\n                \n                # Postprocess\n                for set_name, metrics in reduced_metrics.items():\n                    count = metrics.pop(\"count\")\n                    reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}\n\n                return reduced_metrics\n\n\ndef save_code_and_config(config: PretrainConfig):\n    if config.checkpoint_path is None or wandb.run is None:\n        return\n\n    os.makedirs(config.checkpoint_path, exist_ok=True)\n\n    # Copy code\n    code_list = [\n        get_model_source_path(config.arch.name),\n        get_model_source_path(config.arch.loss.name)\n    ]\n    for code_file in code_list:\n        if code_file is not None:\n            code_name = os.path.basename(code_file)\n\n            shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))\n\n    # Dump config as yaml\n    config_file = os.path.join(config.checkpoint_path, \"all_config.yaml\")\n    with open(config_file, \"wt\") as f:\n        yaml.dump(config.model_dump(), f)\n\n    # Log code\n    wandb.run.log_code(config.checkpoint_path)\n\n\ndef load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:\n    objects = [None]\n    if rank == 0:\n        config = PretrainConfig(**hydra_config)  # type: ignore\n\n        # Naming\n        if config.project_name is None:\n            config.project_name = f\"{os.path.basename(config.data_path).capitalize()} ACT-torch\"\n        if config.run_name is None:\n            config.run_name = f\"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}\"\n        if config.checkpoint_path is None:\n            config.checkpoint_path = os.path.join(\"checkpoints\", config.project_name, config.run_name)\n\n        objects = [config]\n\n    if world_size > 1:\n        dist.broadcast_object_list(objects, src=0)\n\n    return objects[0]  # type: ignore\n\n\n@hydra.main(config_path=\"config\", config_name=\"cfg_pretrain\", version_base=None)\ndef launch(hydra_config: DictConfig):\n    RANK = 0\n    WORLD_SIZE = 1\n\n    # Initialize distributed training if in distributed environment (e.g. torchrun)\n    if \"LOCAL_RANK\" in os.environ:\n        # Initialize distributed, default device and dtype\n        dist.init_process_group(backend=\"nccl\")\n\n        RANK = dist.get_rank()\n        WORLD_SIZE = dist.get_world_size()\n\n        torch.cuda.set_device(int(os.environ[\"LOCAL_RANK\"]))\n        \n    # Load sync'ed config\n    config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)\n\n    # Seed RNGs to ensure consistency\n    torch.random.manual_seed(config.seed + RANK)\n\n    # Dataset\n    train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs\n    total_iters = config.epochs // train_epochs_per_iter\n\n    assert config.epochs % train_epochs_per_iter == 0, \"Eval interval must be a divisor of total epochs.\"\n\n    train_loader, train_metadata = create_dataloader(config, \"train\", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)\n    eval_loader,  eval_metadata  = create_dataloader(config, \"test\", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)\n\n    # Train state\n    train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)\n\n    # Progress bar and logger\n    progress_bar = None\n    if RANK == 0:\n        progress_bar = tqdm.tqdm(total=train_state.total_steps)\n\n        wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True))  # type: ignore\n        wandb.log({\"num_params\": sum(x.numel() for x in train_state.model.parameters())}, step=0)\n        save_code_and_config(config)\n\n    # Training Loop\n    for _iter_id in range(total_iters):\n        print (f\"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}\")\n\n        ############ Train Iter\n        train_state.model.train()\n        for set_name, batch, global_batch_size in train_loader:\n            metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)\n\n            if RANK == 0 and metrics is not None:\n                wandb.log(metrics, step=train_state.step)\n                progress_bar.update(train_state.step - progress_bar.n)  # type: ignore\n\n        ############ Evaluation\n        train_state.model.eval()\n        metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)\n\n        if RANK == 0 and metrics is not None:\n            wandb.log(metrics, step=train_state.step)\n            \n        ############ Checkpointing\n        if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):\n            save_train_state(config, train_state)\n\n    # finalize\n    if dist.is_initialized():\n        dist.destroy_process_group()\n    wandb.finish()\n\n\nif __name__ == \"__main__\":\n    launch()\n"
  },
  {
    "path": "puzzle_dataset.py",
    "content": "import os\nimport json\n\nimport numpy as np\nimport pydantic\n\nimport torch\nfrom torch.utils.data import IterableDataset, get_worker_info\n\nfrom models.losses import IGNORE_LABEL_ID\nfrom dataset.common import PuzzleDatasetMetadata\n\n\ndef _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):\n    # Pack examples into a full batch\n    batch = []\n    batch_puzzle_indices = []\n    current_size = 0\n\n    while (start_index < group_order.size) and (current_size < global_batch_size):\n        # Pick a group and a puzzle from that group\n        group_id = group_order[start_index]\n        puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])\n        start_index += 1\n\n        # Get range of the puzzle\n        puzzle_start = puzzle_indices[puzzle_id]\n        puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)\n\n        append_size = min(puzzle_size, global_batch_size - current_size)\n\n        # Put into batch\n        batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))\n        batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))\n\n        current_size += append_size\n\n    return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)\n\n\nclass PuzzleDatasetConfig(pydantic.BaseModel):\n    seed: int\n    dataset_path: str\n    global_batch_size: int\n    test_set_mode: bool\n\n    epochs_per_iter: int  # Batch X epochs in an iteration to reduce overhead.\n\n    rank: int\n    num_replicas: int\n\n\nclass PuzzleDataset(IterableDataset):\n    def __init__(self, config: PuzzleDatasetConfig, split: str = \"train\"):\n        super().__init__()\n        self.config = config\n        self.split = split\n        self.metadata = self._load_metadata()\n        \n        # Checks\n        assert self.config.global_batch_size % self.config.num_replicas == 0, f\"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}.\"\n        self.local_batch_size = self.config.global_batch_size // self.config.num_replicas\n\n        # State\n        self._data = None\n        self._iters = 0\n\n    def _load_metadata(self) -> PuzzleDatasetMetadata:\n        with open(os.path.join(self.config.dataset_path, self.split, \"dataset.json\"), \"r\") as f:\n            return PuzzleDatasetMetadata(**json.load(f))\n\n    def _lazy_load_dataset(self):\n        if self._data is not None:\n            return\n\n        field_mmap_modes = {\n            \"inputs\": \"r\",\n            \"labels\": \"r\",\n\n            # Keep indices in memory\n            \"puzzle_identifiers\": None,\n            \"puzzle_indices\": None,\n            \"group_indices\": None\n        }\n\n        # Load data\n        self._data = {}\n        for set_name in self.metadata.sets:\n            # Load subset\n            self._data[set_name] = {\n                field_name: np.load(os.path.join(self.config.dataset_path, self.split, f\"{set_name}__{field_name}.npy\"), mmap_mode=mmap_mode)\n                for field_name, mmap_mode in field_mmap_modes.items()\n            }\n\n    def _collate_batch(self, batch):\n        # Convert dtype\n        batch = {k: v.astype(np.int32) for k, v in batch.items()}\n\n        # Convert ignore label IDs\n        if self.metadata.ignore_label_id is not None:\n            batch[\"labels\"][batch[\"labels\"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID\n\n        # Pad\n        if batch[\"puzzle_identifiers\"].size < self.local_batch_size:\n            pad_size = self.local_batch_size - batch[\"puzzle_identifiers\"].size\n\n            pad_values = {\n                \"inputs\": self.metadata.pad_id,\n                \"labels\": IGNORE_LABEL_ID,\n\n                \"puzzle_identifiers\": self.metadata.blank_identifier_id\n            }\n            batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}\n\n        # To tensor\n        return {k: torch.from_numpy(v) for k, v in batch.items()}\n    \n    def _iter_test(self):\n        for set_name, dataset in self._data.items():  # type: ignore\n            total_examples = len(dataset[\"inputs\"])\n\n            # Load examples one by one\n            start_index = 0\n            while start_index < total_examples:\n                # Compute indices\n                end_index = min(total_examples, start_index + self.config.global_batch_size)\n                \n                local_start = start_index + self.config.rank * self.local_batch_size\n                local_end   = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)\n                \n                # Get batch of examples, and also puzzle IDs\n                puzzle_indices = []\n                puzzle_index = np.searchsorted(dataset[\"puzzle_indices\"], local_start, side=\"right\") - 1\n                for i in range(local_start, local_end):\n                    while puzzle_index + 1 < len(dataset[\"puzzle_indices\"]) and i >= dataset[\"puzzle_indices\"][puzzle_index + 1]:\n                        puzzle_index += 1\n\n                    puzzle_indices.append(puzzle_index)\n                \n                batch = self._collate_batch({\n                    \"inputs\": dataset[\"inputs\"][local_start: local_end],\n                    \"labels\": dataset[\"labels\"][local_start: local_end],\n                    \"puzzle_identifiers\": dataset[\"puzzle_identifiers\"][puzzle_indices]\n                })\n\n                yield set_name, batch, end_index - start_index\n                \n                # Advance to next batch\n                start_index += self.config.global_batch_size\n\n    def _iter_train(self):\n        for set_name, dataset in self._data.items():  # type: ignore\n            # Increase epoch count\n            self._iters += 1\n\n            # Randomly shuffle groups\n            rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))\n\n            group_order = np.concatenate([rng.permutation(dataset[\"group_indices\"].size - 1) for _i in range(self.config.epochs_per_iter)])\n            start_index = 0\n            \n            while start_index < group_order.size:\n                start_index, batch_indices, batch_puzzle_indices = _sample_batch(\n                    rng,\n                    group_order=group_order,\n                    puzzle_indices=dataset[\"puzzle_indices\"],\n                    group_indices=dataset[\"group_indices\"],\n                    start_index=start_index,\n                    global_batch_size=self.config.global_batch_size,\n                )\n\n                # Select current rank and collate\n                global_effective_batch_size = batch_puzzle_indices.size  # Global effective batch size, excluding pads\n\n                # Drop last batch\n                if global_effective_batch_size < self.config.global_batch_size:\n                    break\n\n                batch_indices        = batch_indices       [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]\n                batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]\n                batch = self._collate_batch({\n                    \"inputs\": dataset[\"inputs\"][batch_indices],\n                    \"labels\": dataset[\"labels\"][batch_indices],\n                    \"puzzle_identifiers\": dataset[\"puzzle_identifiers\"][batch_puzzle_indices]\n                })\n\n                yield set_name, batch, global_effective_batch_size\n                \n    def __iter__(self):\n        worker_info = get_worker_info()\n        assert worker_info is None or worker_info.num_workers == 1, \"Multithreaded data loading is not currently supported.\"\n        \n        self._lazy_load_dataset()\n        \n        # Iterate using specified mode\n        if self.config.test_set_mode:\n            yield from self._iter_test()\n        else:\n            yield from self._iter_train()\n"
  },
  {
    "path": "puzzle_visualizer.html",
    "content": "<!DOCTYPE html>\n<html>\n<head>\n  <meta charset=\"UTF-8\" />\n  <title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title>\n  <style>\n    body {\n      font-family: sans-serif;\n      margin: 16px;\n    }\n    .selector-area {\n      margin-bottom: 1rem;\n    }\n    .grid-canvas {\n      margin: 4px;\n      border: 1px solid #ccc;\n    }\n    .example-container {\n      display: inline-block;\n      margin: 0 16px 16px 0;\n      vertical-align: top;\n    }\n    .puzzle-display {\n      margin-top: 1rem;\n    }\n    .puzzle-id {\n      font-weight: bold;\n      margin-bottom: 0.5rem;\n    }\n    #groupList, #puzzleList {\n      margin: 1rem 0;\n    }\n    .group-item, .puzzle-item {\n      cursor: pointer;\n      margin: 4px 8px 4px 0;\n      padding: 2px 6px;\n      border: 1px solid #aaa;\n      display: inline-block;\n    }\n    .group-item:hover, .puzzle-item:hover {\n      background: #eef;\n    }\n  </style>\n</head>\n<body>\n<h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1>\n\n<div class=\"selector-area\">\n  <!-- 1) Directory input with webkitdirectory, mozdirectory -->\n  <label>Upload ARC Folder:</label>\n  <input type=\"file\" id=\"folderInput\"\n         webkitdirectory mozdirectory multiple\n         onchange=\"onFolderSelected(event)\" />\n  <br><br>\n\n  <!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated -->\n  <label>Set:</label>\n  <select id=\"setSelect\" disabled>\n    <option value=\"train\">train</option>\n    <option value=\"test\">test</option>\n  </select>\n\n  <label> Subset:</label>\n  <select id=\"subsetSelect\" disabled>\n    <option value=\"all\">all</option>\n  </select>\n\n  <button id=\"loadBtn\" disabled>Load</button>\n</div>\n\n<div>\n  <div id=\"groupList\"></div>\n  <div id=\"puzzleList\"></div>\n  <div class=\"puzzle-display\" id=\"puzzleView\"></div>\n</div>\n\n<!-- \n   3) Use local 'assets/npyjs.js' from your project folder instead of a CDN.\n   Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't\n   contain \"import\" statements. \n-->\n<script src=\"assets/npyjs.js\"></script>\n\n<script>\n/***************************************************************************\n * Global Maps & Variables\n ***************************************************************************/\n\n// Map from \"train/all__inputs.npy\" => File, etc.\nlet filesByPath = {};\n\n// Once loaded, we store typed arrays for the chosen set/subset\nlet inputsArr, labelsArr;\nlet puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr;\nlet identifiersJson;\n\n// The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize\nlet seqLen = 0;\nlet gridSize = 0;\n\n\n/***************************************************************************\n * 1) Handle folder selection: read all files, find identifiers.json,\n *    remove topmost folder from each file path, validate.\n ***************************************************************************/\nfunction onFolderSelected(event) {\n  filesByPath = {};\n  const fileList = event.target.files;\n  if (!fileList || fileList.length === 0) {\n    alert(\"No files selected!\");\n    return;\n  }\n\n  // We'll gather all webkitRelativePaths\n  const paths = [];\n  for (let i = 0; i < fileList.length; i++) {\n    // Typically \"arc-aug-10/train/all__inputs.npy\", etc.\n    const file = fileList[i];\n    const relPath = file.webkitRelativePath || file.mozRelativePath || file.name;\n    paths.push(relPath);\n  }\n\n  // 1. Check if we have \"identifiers.json\" somewhere.\n  const idPath = paths.find(p => p.endsWith(\"identifiers.json\"));\n  if (!idPath) {\n    alert(\"Error: No 'identifiers.json' found in the uploaded folder.\");\n    return;\n  }\n\n  // 2. Derive the top-level directory from that file's path\n  //    e.g. if idPath = \"arc-aug-10/identifiers.json\", topDir = \"arc-aug-10\"\n  //    If there's no slash, topDir = \"\" => do nothing\n  let topDir = \"\";\n  const lastSlash = idPath.lastIndexOf(\"/\");\n  if (lastSlash >= 0) {\n    topDir = idPath.substring(0, lastSlash);\n  }\n\n  // 3. Rebuild filesByPath with the top folder removed.\n  //    For example, if topDir = \"arc-aug-10\", then \"arc-aug-10/train/all__inputs.npy\"\n  //    becomes \"train/all__inputs.npy\"\n  for (let i = 0; i < fileList.length; i++) {\n    const file = fileList[i];\n    let relPath = file.webkitRelativePath || file.mozRelativePath || file.name;\n    // If relPath starts with \"arc-aug-10/\", remove that prefix\n    if (topDir && relPath.startsWith(topDir + \"/\")) {\n      relPath = relPath.substring(topDir.length + 1);\n    }\n    filesByPath[relPath] = file;\n  }\n\n  // Enable set/subset selection and \"Load\"\n  document.getElementById(\"setSelect\").disabled = false;\n  document.getElementById(\"subsetSelect\").disabled = false;\n  document.getElementById(\"loadBtn\").disabled = false;\n}\n\n// When user clicks \"Load,\" parse the .npy for the chosen set/subset\ndocument.getElementById(\"loadBtn\").addEventListener(\"click\", async () => {\n  document.getElementById(\"groupList\").innerHTML = \"\";\n  document.getElementById(\"puzzleList\").innerHTML = \"\";\n  document.getElementById(\"puzzleView\").innerHTML = \"\";\n\n  const setName = document.getElementById(\"setSelect\").value;        // e.g. \"train\"\n  const subsetName = document.getElementById(\"subsetSelect\").value;  // e.g. \"all\"\n\n  try {\n    await loadDataset(setName, subsetName);\n    buildGroupList(); // show groups\n  } catch (err) {\n    console.error(err);\n    alert(\"Error while loading dataset: \" + err);\n  }\n});\n\n\n/***************************************************************************\n * 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer)\n ***************************************************************************/\nasync function loadDataset(setName, subsetName) {\n  const prefix = `${setName}/${subsetName}__`;\n  // e.g. \"train/all__inputs.npy\"\n  const inputsPath  = prefix + \"inputs.npy\";\n  const labelsPath  = prefix + \"labels.npy\";\n  const pIdxPath    = prefix + \"puzzle_indices.npy\";\n  const gIdxPath    = prefix + \"group_indices.npy\";\n  const pIdsPath    = prefix + \"puzzle_identifiers.npy\";\n  const identifiersPath = \"identifiers.json\";\n\n  // Check existence\n  const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath];\n  for (const f of needed) {\n    if (!filesByPath[f]) {\n      throw new Error(`Missing file: ${f}`);\n    }\n  }\n\n  // parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array\n  const inputsNpy       = await parseNpy(filesByPath[inputsPath]);\n  const labelsNpy       = await parseNpy(filesByPath[labelsPath]);\n  const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]);\n  const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]);\n  const puzzleIdsNpy    = await parseNpy(filesByPath[pIdsPath]);\n\n  inputsArr            = inputsNpy.data;\n  labelsArr            = labelsNpy.data;\n  puzzleIndicesArr     = puzzleIndicesNpy.data;\n  groupIndicesArr      = groupIndicesNpy.data;\n  puzzleIdentifiersArr = puzzleIdsNpy.data;\n\n  // shape e.g. [N_examples, seqLen]\n  seqLen   = inputsNpy.shape[1];\n  gridSize = Math.sqrt(seqLen);\n\n  // read JSON\n  identifiersJson = await readJsonFile(filesByPath[identifiersPath]);\n}\n\n/***************************************************************************\n * parseNpy => read a File as ArrayBuffer, parse with npyjs\n ***************************************************************************/\nfunction parseNpy(file) {\n  return new Promise((resolve, reject) => {\n    const reader = new FileReader();\n    reader.onload = async () => {\n      try {\n        const arrayBuffer = reader.result;\n        const npy = new npyjs();\n        resolve(await npy.parse(arrayBuffer));\n      } catch (err) {\n        reject(err);\n      }\n    };\n    reader.onerror = err => reject(err);\n    reader.readAsArrayBuffer(file);\n  });\n}\n\n/***************************************************************************\n * readJsonFile => read a local JSON file into object\n ***************************************************************************/\nfunction readJsonFile(file) {\n  return new Promise((resolve, reject) => {\n    const reader = new FileReader();\n    reader.onload = () => {\n      try {\n        const obj = JSON.parse(reader.result);\n        resolve(obj);\n      } catch (err) {\n        reject(err);\n      }\n    };\n    reader.onerror = (err) => reject(err);\n    reader.readAsText(file);\n  });\n}\n\n/***************************************************************************\n * 3) Build group list in UI\n ***************************************************************************/\nfunction buildGroupList() {\n  document.getElementById(\"groupList\").innerHTML = \"<h3>Groups</h3>\";\n  const groupListDiv = document.getElementById(\"groupList\");\n\n  const nGroups = groupIndicesArr.length - 1;\n  for (let g = 0; g < nGroups; g++) {\n    const div = document.createElement(\"span\");\n    div.className = \"group-item\";\n    div.textContent = `Group ${g}`;\n    div.onclick = () => onSelectGroup(g);\n    groupListDiv.appendChild(div);\n  }\n}\n\n/***************************************************************************\n * onSelectGroup => show puzzles in that group\n ***************************************************************************/\nfunction onSelectGroup(groupIndex) {\n  document.getElementById(\"puzzleList\").innerHTML = \"\";\n  document.getElementById(\"puzzleView\").innerHTML = \"\";\n\n  const puzzleListDiv = document.getElementById(\"puzzleList\");\n  puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`;\n\n  const firstPuzzle = groupIndicesArr[groupIndex];\n  const lastPuzzle  = groupIndicesArr[groupIndex + 1];\n\n  for (let p = firstPuzzle; p < lastPuzzle; p++) {\n    const puzzleIntId = puzzleIdentifiersArr[p];\n    const puzzleStrId = (puzzleIntId < identifiersJson.length)\n                        ? identifiersJson[puzzleIntId]\n                        : \"<unknown>\";\n\n    const div = document.createElement(\"span\");\n    div.className = \"puzzle-item\";\n    div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`;\n    div.onclick = () => onSelectPuzzle(p);\n    puzzleListDiv.appendChild(div);\n  }\n}\n\n/***************************************************************************\n * onSelectPuzzle => show each example\n ***************************************************************************/\nfunction onSelectPuzzle(puzzleIndex) {\n  const puzzleView = document.getElementById(\"puzzleView\");\n  puzzleView.innerHTML = \"\";\n\n  // puzzle ID\n  const puzzleIntId = puzzleIdentifiersArr[puzzleIndex];\n  const puzzleStrId = (puzzleIntId < identifiersJson.length)\n                      ? identifiersJson[puzzleIntId]\n                      : \"<unknown>\";\n\n  const titleDiv = document.createElement(\"div\");\n  titleDiv.className = \"puzzle-id\";\n  titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`;\n  puzzleView.appendChild(titleDiv);\n\n  // Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1])\n  const firstExample = puzzleIndicesArr[puzzleIndex];\n  const lastExample  = puzzleIndicesArr[puzzleIndex + 1];\n\n  for (let e = firstExample; e < lastExample; e++) {\n    const inputSeq  = slice1D(inputsArr,  e*seqLen, (e+1)*seqLen);\n    const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen);\n\n    const inputGrid  = decodeGrid(inputSeq);\n    const outputGrid = decodeGrid(outputSeq);\n\n    const exDiv = document.createElement(\"div\");\n    exDiv.className = \"example-container\";\n    exDiv.appendChild(document.createTextNode(`Example ${e}`));\n    exDiv.appendChild(document.createElement(\"br\"));\n\n    exDiv.appendChild(renderGrid(inputGrid));\n    exDiv.appendChild(renderGrid(outputGrid));\n\n    puzzleView.appendChild(exDiv);\n  }\n}\n\n/***************************************************************************\n * slice1D => typed array slicing\n ***************************************************************************/\nfunction slice1D(arr, start, end) {\n  const result = new Uint32Array(end - start);\n  for (let i = start; i < end; i++) {\n    result[i - start] = Number(arr[i]);\n  }\n  return result;\n}\n\n/***************************************************************************\n * decodeGrid => turn the flattened seq of length=gridSize^2 into 2D\n ***************************************************************************/\nfunction decodeGrid(seq) {\n  const grid = [];\n  let idx = 0;\n  for (let r = 0; r < gridSize; r++) {\n    const row = [];\n    for (let c = 0; c < gridSize; c++) {\n      row.push(seq[idx]);\n      idx++;\n    }\n    grid.push(row);\n  }\n  return grid;\n}\n\n/***************************************************************************\n * renderGrid => draws a 2D grid to <canvas>\n ***************************************************************************/\nfunction renderGrid(grid2d) {\n  const rows = grid2d.length;\n  const cols = grid2d[0].length;\n  const scale = 10;\n\n  const canvas = document.createElement(\"canvas\");\n  canvas.width  = cols * scale;\n  canvas.height = rows * scale;\n  canvas.className = \"grid-canvas\";\n  const ctx = canvas.getContext(\"2d\");\n\n  for (let r = 0; r < rows; r++) {\n    for (let c = 0; c < cols; c++) {\n      const val = grid2d[r][c];\n      ctx.fillStyle = indexToColor(val);\n      ctx.fillRect(c * scale, r * scale, scale, scale);\n    }\n  }\n  return canvas;\n}\n\n/***************************************************************************\n * indexToColor => color palette: \n *   0 => pad => white\n *   1 => eos => light gray\n *   2..11 => original color(0..9)\n ***************************************************************************/\nfunction indexToColor(value) {\n  if (value === 0) return \"#FFFFFF\"; // pad => white\n  if (value === 1) return \"#DDDDDD\"; // eos => light gray\n\n  // shift by 2 => original color in [0..9]\n  const colorIdx = value - 2;\n  const palette = [\n    \"#000000\", // color0 => black\n    \"#FF0000\", // color1 => red\n    \"#00FF00\", // color2 => green\n    \"#0000FF\", // color3 => blue\n    \"#FFFF00\", // color4 => yellow\n    \"#FFA500\", // color5 => orange\n    \"#800080\", // color6 => purple\n    \"#00FFFF\", // color7 => cyan\n    \"#FFC0CB\", // color8 => pink\n    \"#808080\"  // color9 => gray\n  ];\n  if (colorIdx >= 0 && colorIdx < palette.length) {\n    return palette[colorIdx];\n  }\n  return \"#FFFFFF\"; // fallback\n}\n</script>\n</body>\n</html>\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch\nadam-atan2\neinops\ntqdm\ncoolname\npydantic\nargdantic\nwandb\nomegaconf\nhydra-core\nhuggingface_hub\n"
  },
  {
    "path": "utils/functions.py",
    "content": "import importlib\nimport inspect\n\n\ndef load_model_class(identifier: str, prefix: str = \"models.\"):\n    module_path, class_name = identifier.split('@')\n\n    # Import the module\n    module = importlib.import_module(prefix + module_path)\n    cls = getattr(module, class_name)\n    \n    return cls\n\n\ndef get_model_source_path(identifier: str, prefix: str = \"models.\"):\n    module_path, class_name = identifier.split('@')\n\n    module = importlib.import_module(prefix + module_path)\n    return inspect.getsourcefile(module)\n"
  }
]