Repository: sapientinc/HRM Branch: main Commit: 42410daaaf6a Files: 25 Total size: 124.9 KB Directory structure: gitextract_7bpy0bpn/ ├── .gitignore ├── .gitmodules ├── .vscode/ │ ├── launch.json │ └── settings.json ├── LICENSE ├── README.md ├── arc_eval.ipynb ├── assets/ │ └── npyjs.js ├── config/ │ ├── arch/ │ │ └── hrm_v1.yaml │ └── cfg_pretrain.yaml ├── dataset/ │ ├── build_arc_dataset.py │ ├── build_maze_dataset.py │ ├── build_sudoku_dataset.py │ └── common.py ├── evaluate.py ├── models/ │ ├── common.py │ ├── hrm/ │ │ └── hrm_act_v1.py │ ├── layers.py │ ├── losses.py │ └── sparse_embedding.py ├── pretrain.py ├── puzzle_dataset.py ├── puzzle_visualizer.html ├── requirements.txt └── utils/ └── functions.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # WandB /wandb/ # checkpoints /checkpoints/ # cache /cache/ # data /data/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ ================================================ FILE: .gitmodules ================================================ [submodule "dataset/raw-data/ConceptARC"] path = dataset/raw-data/ConceptARC url = git@github.com:victorvikram/ConceptARC.git [submodule "dataset/raw-data/ARC-AGI"] path = dataset/raw-data/ARC-AGI url = git@github.com:fchollet/ARC-AGI.git [submodule "dataset/raw-data/ARC-AGI-2"] path = dataset/raw-data/ARC-AGI-2 url = git@github.com:arcprize/ARC-AGI-2.git ================================================ FILE: .vscode/launch.json ================================================ { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ { "name": "Python Debugger: Current File", "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal" }, { "name": "Debug: Single GPU", "type": "debugpy", "request": "launch", "program": "pretrain.py", "args": [], "env": { "OMP_NUM_THREADS": "1", "DISABLE_COMPILE": "true" } } ] } ================================================ FILE: .vscode/settings.json ================================================ { "python.analysis.typeCheckingMode": "standard" } ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Hierarchical Reasoning Model ![](./assets/hrm.png) Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI. Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency. HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes. Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities. These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems. **Join our Discord Community: [https://discord.gg/sapient](https://discord.gg/sapient)** ## Quick Start Guide 🚀 ### Prerequisites ⚙️ Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands: ```bash # Install CUDA 12.6 CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL sudo sh cuda_installer.run --silent --toolkit --override export CUDA_HOME=/usr/local/cuda-12.6 # Install PyTorch with CUDA 12.6 PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126 pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL # Additional packages for building extensions pip3 install packaging ninja wheel setuptools setuptools-scm ``` Then install FlashAttention. For Hopper GPUs, install FlashAttention 3 ```bash git clone git@github.com:Dao-AILab/flash-attention.git cd flash-attention/hopper python setup.py install ``` For Ampere or earlier GPUs, install FlashAttention 2 ```bash pip3 install flash-attn ``` ## Install Python Dependencies 🐍 ```bash pip install -r requirements.txt ``` ## W&B Integration 📈 This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in: ```bash wandb login ``` ## Run Experiments ### Quick Demo: Sudoku Solver 💻🗲 Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩 ```bash # Download and build Sudoku dataset python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # Start training (single GPU, smaller batch size) OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0 ``` Runtime: ~10 hours on a RTX 4070 laptop GPU ## Trained Checkpoints 🚧 - [ARC-AGI-2](https://huggingface.co/sapientinc/HRM-checkpoint-ARC-2) - [Sudoku 9x9 Extreme (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-sudoku-extreme) - [Maze 30x30 Hard (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-maze-30x30-hard) To use the checkpoints, see Evaluation section below. ## Full-scale Experiments 🔵 Experiments below assume an 8-GPU setup. ### Dataset Preparation ```bash # Initialize submodules git submodule update --init --recursive # ARC-1 python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples # ARC-2 python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples # Sudoku-Extreme python dataset/build_sudoku_dataset.py # Full version python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples # Maze python dataset/build_maze_dataset.py # 1000 examples ``` ### Dataset Visualization Explore the puzzles visually: * Open `puzzle_visualizer.html` in your browser. * Upload the generated dataset folder located in `data/...`. ## Launch experiments ### Small-sample (1K) ARC-1: ```bash OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py ``` *Runtime:* ~24 hours ARC-2: ```bash OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000 ``` *Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient) Sudoku Extreme (1k): ```bash OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 ``` *Runtime:* ~10 minutes Maze 30x30 Hard (1k): ```bash OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 ``` *Runtime:* ~1 hour ### Full Sudoku-Hard ```bash OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned ``` *Runtime:* ~2 hours ## Evaluation Evaluate your trained models: * Check `eval/exact_accuracy` in W&B. * For ARC-AGI, follow these additional steps: ```bash OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint= ``` * Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results. ## Notes - Small-sample learning typically exhibits accuracy variance of around ±2 points. - For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%. ## Citation 📜 ```bibtex @misc{wang2025hierarchicalreasoningmodel, title={Hierarchical Reasoning Model}, author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori}, year={2025}, eprint={2506.21734}, archivePrefix={arXiv}, primaryClass={cs.AI}, url={https://arxiv.org/abs/2506.21734}, } ``` ================================================ FILE: arc_eval.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "from glob import glob\n", "import hashlib\n", "import matplotlib.pyplot as plt\n", "import matplotlib.colors as mcolors\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "from numba import njit\n", "\n", "from dataset.common import inverse_dihedral_transform\n", "\n", "\n", "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n", "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n", "\n", "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n", "\n", "\n", "PAD_PUZZLE_IDENTIFIER = 0\n", "\n", "# Visualization\n", "ARC_COLOR_MAP = mcolors.ListedColormap([\n", " \"#000000\", # symbol_0: black\n", " \"#0074D9\", # symbol_1: blue\n", " \"#FF4136\", # symbol_2: red\n", " \"#2ECC40\", # symbol_3: green\n", " \"#FFDC00\", # symbol_4: yellow\n", " \"#AAAAAA\", # symbol_5: grey\n", " \"#F012BE\", # symbol_6: fuschia\n", " \"#FF851B\", # symbol_7: orange\n", " \"#7FDBFF\", # symbol_8: teal\n", " \"#870C25\" # symbol_9: brown\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n", " # Load puzzle identifiers\n", " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n", " identifier_map = json.load(f)\n", " \n", " # Load preds\n", " all_preds = {}\n", " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n", " preds = torch.load(filename)\n", " for k, v in preds.items():\n", " all_preds.setdefault(k, [])\n", " all_preds[k].append(v)\n", " \n", " del preds\n", "\n", " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n", " \n", " # Remove paddings\n", " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n", " all_preds = {k: v[mask] for k, v in all_preds.items()}\n", "\n", " return identifier_map, all_preds\n", "\n", "\n", "def inverse_aug(name: str, grid: np.ndarray):\n", " if \"_\" not in name:\n", " return grid\n", "\n", " trans_id, perm = name.split(\"_\")[-2:]\n", " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n", " inv_perm = np.argsort(list(perm))\n", " \n", " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n", "\n", "\n", "def grid_hash(grid: np.ndarray):\n", " return hash((grid.tobytes(), grid.shape))\n", "\n", "\n", "@njit\n", "def crop(grid: np.ndarray):\n", " # Find maximum-sized rectangle without any EOS token inside.\n", " grid = grid.reshape(30, 30)\n", "\n", " max_area = 0\n", " max_size = (0, 0)\n", " nr, nc = grid.shape\n", " \n", " num_c = nc\n", " for num_r in range(1, nr + 1):\n", " # Scan for maximum c\n", " for c in range(1, num_c + 1):\n", " x = grid[num_r - 1, c - 1]\n", " if (x < 2) | (x > 11):\n", " num_c = c - 1\n", " break\n", " \n", " area = num_r * num_c\n", " if area > max_area:\n", " max_area = area\n", " max_size = (num_r, num_c)\n", "\n", " return grid[:max_size[0], :max_size[1]] - 2\n", "\n", "\n", "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n", " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n", " \n", " global_hmap = {}\n", " \n", " # Get puzzles and corresponding answers\n", " puzzle_labels = {}\n", " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n", " name = identifier_map[identifier]\n", " if \"_\" not in name: # Not-augmented\n", " puzzle_labels.setdefault(name, {})\n", " \n", " input = crop(input.numpy())\n", " label = crop(label.numpy())\n", "\n", " input_hash = grid_hash(input)\n", " label_hash = grid_hash(label)\n", "\n", " global_hmap[input_hash] = input\n", " global_hmap[label_hash] = label\n", "\n", " assert input_hash not in puzzle_labels[name]\n", " puzzle_labels[name][input_hash] = label_hash\n", " \n", " print (\"Number of puzzles\", len(puzzle_labels))\n", " \n", " # Argmax prediction\n", " preds = all_preds[\"logits\"].argmax(-1)\n", "\n", " # Collate\n", " pred_answers = {}\n", " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n", " name = identifier_map[identifier]\n", " orig_name = name.split(\"_\")[0]\n", " \n", " input = input.numpy()\n", " input_hash = grid_hash(inverse_aug(name, crop(input)))\n", " assert input_hash in puzzle_labels[orig_name]\n", " \n", " pred = inverse_aug(name, crop(pred.numpy()))\n", " pred_hash = grid_hash(pred)\n", " global_hmap[pred_hash] = pred\n", " \n", " pred_answers.setdefault(orig_name, {})\n", " pred_answers[orig_name].setdefault(input_hash, [])\n", " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n", "\n", " # test-1\n", " if visualize:\n", " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n", " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n", " \n", " fig_id = 0\n", " \n", " correct = [0 for _ in range(len(Ks))]\n", " for name, tests in puzzle_labels.items():\n", " num_test_correct = [0 for _ in range(len(Ks))]\n", " for input_hash, label_hash in tests.items():\n", " p = pred_answers[name][input_hash]\n", " p_map = {}\n", " \n", " for h, q in p:\n", " p_map.setdefault(h, [0, 0])\n", " p_map[h][0] += 1\n", " p_map[h][1] += q\n", " \n", " for h, stats in p_map.items():\n", " stats[1] /= stats[0]\n", " \n", " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n", "\n", " # 2-vote\n", " for i, k in enumerate(Ks):\n", " ok = False\n", " for h, stats in p_map[:k]:\n", " ok |= h == label_hash\n", " \n", " num_test_correct[i] += ok\n", "\n", " if visualize:\n", " # Show input and ground truth\n", " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n", " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n", " axes[fig_id, 0].axis('off')\n", " \n", " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n", " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n", " axes[fig_id, 1].axis('off')\n", " \n", " trial_id = 2\n", " for h, stats in p_map[:2]:\n", " ans = global_hmap[h]\n", " \n", " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n", " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n", " axes[fig_id, trial_id].axis('off')\n", " \n", " trial_id += 1\n", " \n", " fig_id += 1\n", " \n", " # Total correctness\n", " for i in range(len(Ks)):\n", " correct[i] += num_test_correct[i] == len(tests)\n", "\n", " for i, k in enumerate(Ks):\n", " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n", "\n", "\n", "test(visualize=False)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: assets/npyjs.js ================================================ class npyjs { constructor(opts) { if (opts && !('convertFloat16' in opts)) { console.warn([ "npyjs constructor now accepts {convertFloat16?: boolean}.", "For usage, go to https://github.com/jhuapl-boss/npyjs." ].join(" ")); } this.convertFloat16 = opts?.convertFloat16 ?? true; this.dtypes = { "> 15) & 0x1; const exponent = (float16 >> 10) & 0x1f; const fraction = float16 & 0x3ff; // Handle special cases if (exponent === 0) { if (fraction === 0) { // Zero return sign ? -0 : 0; } // Denormalized number return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400); } else if (exponent === 0x1f) { if (fraction === 0) { // Infinity return sign ? -Infinity : Infinity; } // NaN return NaN; } // Normalized number return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400); } parse(arrayBufferContents) { // const version = arrayBufferContents.slice(6, 8); // Uint8-encoded const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0); const offsetBytes = 10 + headerLength; const hcontents = new TextDecoder("utf-8").decode( new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength)) ); const header = JSON.parse( hcontents .toLowerCase() // True -> true .replace(/'/g, '"') .replace("(", "[") .replace(/,*\),*/g, "]") ); const shape = header.shape; const dtype = this.dtypes[header.descr]; if (!dtype) { console.error(`Unsupported dtype: ${header.descr}`); return null; } const nums = new dtype.arrayConstructor( arrayBufferContents, offsetBytes ); // Convert float16 to float32 if converter exists const data = dtype.converter ? dtype.converter.call(this, nums) : nums; return { dtype: dtype.name, data: data, shape, fortranOrder: header.fortran_order }; } async load(filename, callback, fetchArgs) { /* Loads an array from a stream of bytes. */ fetchArgs = fetchArgs || {}; let arrayBuf; // If filename is ArrayBuffer if (filename instanceof ArrayBuffer) { arrayBuf = filename; } // If filename is a file path else { const resp = await fetch(filename, { ...fetchArgs }); arrayBuf = await resp.arrayBuffer(); } const result = this.parse(arrayBuf); if (callback) { return callback(result); } return result; } } ================================================ FILE: config/arch/hrm_v1.yaml ================================================ name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1 loss: name: losses@ACTLossHead loss_type: stablemax_cross_entropy halt_exploration_prob: 0.1 halt_max_steps: 16 H_cycles: 2 L_cycles: 2 H_layers: 4 L_layers: 4 hidden_size: 512 num_heads: 8 # min(2, hidden_size // 64) expansion: 4 puzzle_emb_ndim: ${.hidden_size} pos_encodings: rope ================================================ FILE: config/cfg_pretrain.yaml ================================================ # ARC training config defaults: - arch: hrm_v1 - _self_ hydra: output_subdir: null # Data path data_path: data/arc-aug-1000 # Hyperparams - Training global_batch_size: 768 epochs: 100000 eval_interval: 10000 checkpoint_every_eval: True lr: 1e-4 lr_min_ratio: 1.0 lr_warmup_steps: 2000 # Standard hyperparameter settings for LM, as used in Llama beta1: 0.9 beta2: 0.95 weight_decay: 0.1 puzzle_emb_weight_decay: 0.1 # Hyperparams - Puzzle embeddings training puzzle_emb_lr: 1e-2 ================================================ FILE: dataset/build_arc_dataset.py ================================================ from typing import List, Optional, Tuple, Dict from dataclasses import dataclass from pathlib import Path import os import json import hashlib import numpy as np from glob import glob from argdantic import ArgParser from pydantic import BaseModel from common import PuzzleDatasetMetadata, dihedral_transform cli = ArgParser() class DataProcessConfig(BaseModel): # ARC-1 dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"] output_dir: str = "data/arc-aug-1000" # ARC-2 # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"] # output_dir: str = "data/arc-2-aug-1000" seed: int = 42 num_aug: int = 1000 ARCMaxGridSize = 30 ARCAugmentRetriesFactor = 5 @dataclass class ARCPuzzle: id: str examples: List[Tuple[np.ndarray, np.ndarray]] def arc_grid_to_np(grid: List[List[int]]): arr = np.array(grid) # Shape check assert arr.ndim == 2 assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize # Element check assert np.all((arr >= 0) & (arr <= 9)) return arr.astype(np.uint8) def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool): # PAD: 0, : 1, digits: 2 ... 11 # Compute random top-left pad if do_translation: pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1) pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1) else: pad_r = pad_c = 0 # Pad grid result = [] for grid in [inp, out]: nrow, ncol = grid.shape grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0) # Add eos_row, eos_col = pad_r + nrow, pad_c + ncol if eos_row < ARCMaxGridSize: grid[eos_row, pad_c:eos_col] = 1 if eos_col < ARCMaxGridSize: grid[pad_r:eos_row, eos_col] = 1 result.append(grid.flatten()) return result def puzzle_hash(puzzle: dict): # Hash the puzzle for checking equivalence def _grid_hash(grid: np.ndarray): buffer = [x.to_bytes(1) for x in grid.shape] buffer.append(grid.tobytes()) return hashlib.sha256(b"".join(buffer)).hexdigest() hashes = [] for example_type, example in puzzle.items(): for input, label in example.examples: hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}") hashes.sort() return hashlib.sha256("|".join(hashes).encode()).hexdigest() def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]): # Remove "name" name = puzzle.pop("name", default_name) # Convert dests = set(dest_mapping.values()) converted = {dest: ARCPuzzle(name, []) for dest in dests} for example_type, examples in puzzle.items(): dest = dest_mapping[example_type] converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples]) group = [converted] # Augment if aug_count > 0: hashes = {puzzle_hash(converted)} for _trial in range(ARCAugmentRetriesFactor * aug_count): # Augment plan trans_id = np.random.randint(0, 8) mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black) aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}" def _map_grid(grid: np.ndarray): return dihedral_transform(mapping[grid], trans_id) # Check duplicate augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()} h = puzzle_hash(augmented) if h not in hashes: hashes.add(h) group.append(augmented) if len(group) >= aug_count + 1: break if len(group) < aug_count + 1: print (f"[Puzzle {name}] augmentation not full, only {len(group)}") # Append for dest in dests: # Convert the examples dest_split, dest_set = dest results.setdefault(dest_split, {}) results[dest_split].setdefault(dest_set, []) results[dest_split][dest_set].append([converted[dest] for converted in group]) def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig): train_examples_dest = ("train", "all") test_examples_map = { "evaluation": [(1.0, ("test", "all"))], "_default": [(1.0, ("train", "all"))] } total_puzzles = 0 for subdir in os.scandir(dataset_path): if subdir.is_dir(): # Load all puzzles in this directory puzzles = [] for filename in glob(os.path.join(subdir.path, "*.json")): with open(filename, "r") as f: puzzles.append((Path(filename).stem, json.load(f))) # Shuffle puzzles np.random.shuffle(puzzles) # Assign by fraction for idx, (default_name, puzzle) in enumerate(puzzles): fraction = idx / len(puzzles) test_examples_dest = None for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]): if fraction < f: test_examples_dest = dest break assert test_examples_dest is not None convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest}) total_puzzles += 1 print (f"[{dataset_path}] total puzzles: {total_puzzles}") def convert_dataset(config: DataProcessConfig): np.random.seed(config.seed) # Read dataset data = {} for dataset_dir in config.dataset_dirs: load_puzzles_arcagi(data, dataset_dir, config) # Map global puzzle identifiers num_identifiers = 1 # 0 is blank identifier_map = {} for split_name, split in data.items(): for subset_name, subset in split.items(): for group in subset: for puzzle in group: if puzzle.id not in identifier_map: identifier_map[puzzle.id] = num_identifiers num_identifiers += 1 print (f"Total puzzle IDs (including ): {num_identifiers}") # Save for split_name, split in data.items(): os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True) # Translational augmentations enable_translational_augment = split_name == "train" # Statistics total_examples = 0 total_puzzles = 0 total_groups = 0 for subset_name, subset in split.items(): # Construct subset results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} results["puzzle_indices"].append(0) results["group_indices"].append(0) example_id = 0 puzzle_id = 0 for group in subset: for puzzle in group: # Push puzzle no_aug_id = np.random.randint(0, len(puzzle.examples)) for _idx_ex, (inp, out) in enumerate(puzzle.examples): inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id) results["inputs"].append(inp) results["labels"].append(out) example_id += 1 total_examples += 1 results["puzzle_indices"].append(example_id) results["puzzle_identifiers"].append(identifier_map[puzzle.id]) puzzle_id += 1 total_puzzles += 1 # Push group results["group_indices"].append(puzzle_id) total_groups += 1 for k, v in results.items(): if k in {"inputs", "labels"}: v = np.stack(v, 0) else: v = np.array(v, dtype=np.int32) np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v) # Metadata metadata = PuzzleDatasetMetadata( seq_len=ARCMaxGridSize * ARCMaxGridSize, vocab_size=10 + 2, # PAD + EOS + "0" ... "9" pad_id=0, ignore_label_id=0, blank_identifier_id=0, num_puzzle_identifiers=num_identifiers, total_groups=total_groups, mean_puzzle_examples=total_examples / total_puzzles, sets=list(split.keys()) ) # Save metadata as JSON. with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f: json.dump(metadata.model_dump(), f) # Save IDs mapping with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: ids_mapping = {v: k for k, v in identifier_map.items()} json.dump([ids_mapping.get(i, "") for i in range(num_identifiers)], f) @cli.command(singleton=True) def main(config: DataProcessConfig): convert_dataset(config) if __name__ == "__main__": cli() ================================================ FILE: dataset/build_maze_dataset.py ================================================ from typing import Optional import math import os import csv import json import numpy as np from argdantic import ArgParser from pydantic import BaseModel from tqdm import tqdm from huggingface_hub import hf_hub_download from common import PuzzleDatasetMetadata, dihedral_transform CHARSET = "# SGo" cli = ArgParser() class DataProcessConfig(BaseModel): source_repo: str = "sapientinc/maze-30x30-hard-1k" output_dir: str = "data/maze-30x30-hard-1k" subsample_size: Optional[int] = None aug: bool = False def convert_subset(set_name: str, config: DataProcessConfig): # Read CSV all_chars = set() grid_size = None inputs = [] labels = [] with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore reader = csv.reader(csvfile) next(reader) # Skip header for source, q, a, rating in reader: all_chars.update(q) all_chars.update(a) if grid_size is None: n = int(len(q) ** 0.5) grid_size = (n, n) inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size)) labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size)) # If subsample_size is specified for the training set, # randomly sample the desired number of examples. if set_name == "train" and config.subsample_size is not None: total_samples = len(inputs) if config.subsample_size < total_samples: indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) inputs = [inputs[i] for i in indices] labels = [labels[i] for i in indices] # Generate dataset results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} puzzle_id = 0 example_id = 0 results["puzzle_indices"].append(0) results["group_indices"].append(0) for inp, out in zip(tqdm(inputs), labels): # Dihedral transformations for augmentation for aug_idx in range(8 if (set_name == "train" and config.aug) else 1): results["inputs"].append(dihedral_transform(inp, aug_idx)) results["labels"].append(dihedral_transform(out, aug_idx)) example_id += 1 puzzle_id += 1 results["puzzle_indices"].append(example_id) results["puzzle_identifiers"].append(0) # Push group results["group_indices"].append(puzzle_id) # Char mappings assert len(all_chars - set(CHARSET)) == 0 char2id = np.zeros(256, np.uint8) char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1 # To Numpy def _seq_to_numpy(seq): arr = np.vstack([char2id[s.reshape(-1)] for s in seq]) return arr results = { "inputs": _seq_to_numpy(results["inputs"]), "labels": _seq_to_numpy(results["labels"]), "group_indices": np.array(results["group_indices"], dtype=np.int32), "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), } # Metadata metadata = PuzzleDatasetMetadata( seq_len=int(math.prod(grid_size)), # type: ignore vocab_size=len(CHARSET) + 1, # PAD + Charset pad_id=0, ignore_label_id=0, blank_identifier_id=0, num_puzzle_identifiers=1, total_groups=len(results["group_indices"]) - 1, mean_puzzle_examples=1, sets=["all"] ) # Save metadata as JSON. save_dir = os.path.join(config.output_dir, set_name) os.makedirs(save_dir, exist_ok=True) with open(os.path.join(save_dir, "dataset.json"), "w") as f: json.dump(metadata.model_dump(), f) # Save data for k, v in results.items(): np.save(os.path.join(save_dir, f"all__{k}.npy"), v) # Save IDs mapping (for visualization only) with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: json.dump([""], f) @cli.command(singleton=True) def preprocess_data(config: DataProcessConfig): convert_subset("train", config) convert_subset("test", config) if __name__ == "__main__": cli() ================================================ FILE: dataset/build_sudoku_dataset.py ================================================ from typing import Optional import os import csv import json import numpy as np from argdantic import ArgParser from pydantic import BaseModel from tqdm import tqdm from huggingface_hub import hf_hub_download from common import PuzzleDatasetMetadata cli = ArgParser() class DataProcessConfig(BaseModel): source_repo: str = "sapientinc/sudoku-extreme" output_dir: str = "data/sudoku-extreme-full" subsample_size: Optional[int] = None min_difficulty: Optional[int] = None num_aug: int = 0 def shuffle_sudoku(board: np.ndarray, solution: np.ndarray): # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0)) # Randomly decide whether to transpose. transpose_flag = np.random.rand() < 0.5 # Generate a valid row permutation: # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows. bands = np.random.permutation(3) row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands]) # Similarly for columns (stacks). stacks = np.random.permutation(3) col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks]) # Build an 81->81 mapping. For each new cell at (i, j) # (row index = i // 9, col index = i % 9), # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9]. mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)]) def apply_transformation(x: np.ndarray) -> np.ndarray: # Apply transpose flag if transpose_flag: x = x.T # Apply the position mapping. new_board = x.flatten()[mapping].reshape(9, 9).copy() # Apply digit mapping return digit_map[new_board] return apply_transformation(board), apply_transformation(solution) def convert_subset(set_name: str, config: DataProcessConfig): # Read CSV inputs = [] labels = [] with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: reader = csv.reader(csvfile) next(reader) # Skip header for source, q, a, rating in reader: if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty): assert len(q) == 81 and len(a) == 81 inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0')) # If subsample_size is specified for the training set, # randomly sample the desired number of examples. if set_name == "train" and config.subsample_size is not None: total_samples = len(inputs) if config.subsample_size < total_samples: indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) inputs = [inputs[i] for i in indices] labels = [labels[i] for i in indices] # Generate dataset num_augments = config.num_aug if set_name == "train" else 0 results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} puzzle_id = 0 example_id = 0 results["puzzle_indices"].append(0) results["group_indices"].append(0) for orig_inp, orig_out in zip(tqdm(inputs), labels): for aug_idx in range(1 + num_augments): # First index is not augmented if aug_idx == 0: inp, out = orig_inp, orig_out else: inp, out = shuffle_sudoku(orig_inp, orig_out) # Push puzzle (only single example) results["inputs"].append(inp) results["labels"].append(out) example_id += 1 puzzle_id += 1 results["puzzle_indices"].append(example_id) results["puzzle_identifiers"].append(0) # Push group results["group_indices"].append(puzzle_id) # To Numpy def _seq_to_numpy(seq): arr = np.concatenate(seq).reshape(len(seq), -1) assert np.all((arr >= 0) & (arr <= 9)) return arr + 1 results = { "inputs": _seq_to_numpy(results["inputs"]), "labels": _seq_to_numpy(results["labels"]), "group_indices": np.array(results["group_indices"], dtype=np.int32), "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), } # Metadata metadata = PuzzleDatasetMetadata( seq_len=81, vocab_size=10 + 1, # PAD + "0" ... "9" pad_id=0, ignore_label_id=0, blank_identifier_id=0, num_puzzle_identifiers=1, total_groups=len(results["group_indices"]) - 1, mean_puzzle_examples=1, sets=["all"] ) # Save metadata as JSON. save_dir = os.path.join(config.output_dir, set_name) os.makedirs(save_dir, exist_ok=True) with open(os.path.join(save_dir, "dataset.json"), "w") as f: json.dump(metadata.model_dump(), f) # Save data for k, v in results.items(): np.save(os.path.join(save_dir, f"all__{k}.npy"), v) # Save IDs mapping (for visualization only) with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: json.dump([""], f) @cli.command(singleton=True) def preprocess_data(config: DataProcessConfig): convert_subset("train", config) convert_subset("test", config) if __name__ == "__main__": cli() ================================================ FILE: dataset/common.py ================================================ from typing import List, Optional import pydantic import numpy as np # Global list mapping each dihedral transform id to its inverse. # Index corresponds to the original tid, and the value is its inverse. DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7] class PuzzleDatasetMetadata(pydantic.BaseModel): pad_id: int ignore_label_id: Optional[int] blank_identifier_id: int vocab_size: int seq_len: int num_puzzle_identifiers: int total_groups: int mean_puzzle_examples: float sets: List[str] def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: """8 dihedral symmetries by rotate, flip and mirror""" if tid == 0: return arr # identity elif tid == 1: return np.rot90(arr, k=1) elif tid == 2: return np.rot90(arr, k=2) elif tid == 3: return np.rot90(arr, k=3) elif tid == 4: return np.fliplr(arr) # horizontal flip elif tid == 5: return np.flipud(arr) # vertical flip elif tid == 6: return arr.T # transpose (reflection along main diagonal) elif tid == 7: return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection else: return arr def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray: return dihedral_transform(arr, DIHEDRAL_INVERSE[tid]) ================================================ FILE: evaluate.py ================================================ from typing import List import yaml import os import torch import torch.distributed as dist import pydantic from omegaconf import OmegaConf from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader class EvalConfig(pydantic.BaseModel): checkpoint: str save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"] def launch(): eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore RANK = 0 WORLD_SIZE = 1 # Initialize distributed training if in distributed environment (e.g. torchrun) if "LOCAL_RANK" in os.environ: # Initialize distributed, default device and dtype dist.init_process_group(backend="nccl") RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f: config = PretrainConfig(**yaml.safe_load(f)) config.eval_save_outputs = eval_cfg.save_outputs config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint) # Dataloader train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) # Models train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) # Try unwrap torch.compile try: train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True) except: train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True) train_state.step = 0 ckpt_filename = os.path.basename(eval_cfg.checkpoint) if ckpt_filename.startswith("step_"): train_state.step = int(ckpt_filename.removeprefix("step_")) # Evaluate print ("Starting evaluation") train_state.model.eval() metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) if metrics is not None: print (metrics) if __name__ == "__main__": launch() ================================================ FILE: models/common.py ================================================ import math import torch from torch import nn def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0): # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor # This function is a PyTorch version of jax truncated normal init (default init method in flax) # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848 # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199 with torch.no_grad(): if std == 0: tensor.zero_() else: sqrt2 = math.sqrt(2) a = math.erf(lower / sqrt2) b = math.erf(upper / sqrt2) z = (b - a) / 2 c = (2 * math.pi) ** -0.5 pdf_u = c * math.exp(-0.5 * lower ** 2) pdf_l = c * math.exp(-0.5 * upper ** 2) comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2) tensor.uniform_(a, b) tensor.erfinv_() tensor.mul_(sqrt2 * comp_std) tensor.clip_(lower * comp_std, upper * comp_std) return tensor ================================================ FILE: models/hrm/hrm_act_v1.py ================================================ from typing import Tuple, List, Dict, Optional from dataclasses import dataclass import math import torch import torch.nn.functional as F from torch import nn from pydantic import BaseModel from models.common import trunc_normal_init_ from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear from models.sparse_embedding import CastedSparseEmbedding @dataclass class HierarchicalReasoningModel_ACTV1InnerCarry: z_H: torch.Tensor z_L: torch.Tensor @dataclass class HierarchicalReasoningModel_ACTV1Carry: inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry steps: torch.Tensor halted: torch.Tensor current_data: Dict[str, torch.Tensor] class HierarchicalReasoningModel_ACTV1Config(BaseModel): batch_size: int seq_len: int puzzle_emb_ndim: int = 0 num_puzzle_identifiers: int vocab_size: int H_cycles: int L_cycles: int H_layers: int L_layers: int # Transformer config hidden_size: int expansion: float num_heads: int pos_encodings: str rms_norm_eps: float = 1e-5 rope_theta: float = 10000.0 # Halting Q-learning config halt_max_steps: int halt_exploration_prob: float forward_dtype: str = "bfloat16" class HierarchicalReasoningModel_ACTV1Block(nn.Module): def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: super().__init__() self.self_attn = Attention( hidden_size=config.hidden_size, head_dim=config.hidden_size // config.num_heads, num_heads=config.num_heads, num_key_value_heads=config.num_heads, causal=False ) self.mlp = SwiGLU( hidden_size=config.hidden_size, expansion=config.expansion, ) self.norm_eps = config.rms_norm_eps def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: # Post Norm # Self Attention hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) # Fully Connected hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps) return hidden_states class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module): def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]): super().__init__() self.layers = torch.nn.ModuleList(layers) def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: # Input injection (add) hidden_states = hidden_states + input_injection # Layers for layer in self.layers: hidden_states = layer(hidden_states=hidden_states, **kwargs) return hidden_states class HierarchicalReasoningModel_ACTV1_Inner(nn.Module): def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: super().__init__() self.config = config self.forward_dtype = getattr(torch, self.config.forward_dtype) # I/O self.embed_scale = math.sqrt(self.config.hidden_size) embed_init_std = 1.0 / self.embed_scale self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div if self.config.puzzle_emb_ndim > 0: # Zero init puzzle embeddings self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) # LM Blocks if self.config.pos_encodings == "rope": self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, base=self.config.rope_theta) elif self.config.pos_encodings == "learned": self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) else: raise NotImplementedError() # Reasoning Layers self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) # Initial states self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) # Q head special init # Init Q to (almost) zero for faster learning during bootstrapping with torch.no_grad(): self.q_head.weight.zero_() self.q_head.bias.fill_(-5) # type: ignore def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): # Token embedding embedding = self.embed_tokens(input.to(torch.int32)) # Puzzle embeddings if self.config.puzzle_emb_ndim > 0: puzzle_embedding = self.puzzle_emb(puzzle_identifiers) pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] if pad_count > 0: puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) # Position embeddings if self.config.pos_encodings == "learned": # scale by 1/sqrt(2) to maintain forward variance embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) # Scale return self.embed_scale * embedding def empty_carry(self, batch_size: int): return HierarchicalReasoningModel_ACTV1InnerCarry( z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), ) def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry): return HierarchicalReasoningModel_ACTV1InnerCarry( z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), ) def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: seq_info = dict( cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, ) # Input encoding input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) # Forward iterations with torch.no_grad(): z_H, z_L = carry.z_H, carry.z_L for _H_step in range(self.config.H_cycles): for _L_step in range(self.config.L_cycles): if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)): z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) if not (_H_step == self.config.H_cycles - 1): z_H = self.H_level(z_H, z_L, **seq_info) assert not z_H.requires_grad and not z_L.requires_grad # 1-step grad z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) z_H = self.H_level(z_H, z_L, **seq_info) # LM Outputs new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad output = self.lm_head(z_H)[:, self.puzzle_emb_len:] # Q head q_logits = self.q_head(z_H[:, 0]).to(torch.float32) return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) class HierarchicalReasoningModel_ACTV1(nn.Module): """ACT wrapper.""" def __init__(self, config_dict: dict): super().__init__() self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict) self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config) @property def puzzle_emb(self): return self.inner.puzzle_emb def initial_carry(self, batch: Dict[str, torch.Tensor]): batch_size = batch["inputs"].shape[0] return HierarchicalReasoningModel_ACTV1Carry( inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. steps=torch.zeros((batch_size, ), dtype=torch.int32), halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted current_data={k: torch.empty_like(v) for k, v in batch.items()} ) def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: # Update data, carry (removing halted sequences) new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) new_steps = torch.where(carry.halted, 0, carry.steps) new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} # Forward inner model new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) outputs = { "logits": logits, "q_halt_logits": q_halt_logits, "q_continue_logits": q_continue_logits } with torch.no_grad(): # Step new_steps = new_steps + 1 is_last_step = new_steps >= self.config.halt_max_steps halted = is_last_step # if training, and ACT is enabled if self.training and (self.config.halt_max_steps > 1): # Halt signal # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes halted = halted | (q_halt_logits > q_continue_logits) # Exploration min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) halted = halted & (new_steps >= min_halt_steps) # Compute target Q # NOTE: No replay buffer and target networks for computing target Q-value. # As batch_size is large, there're many parallel envs. # Similar concept as PQN https://arxiv.org/abs/2407.04811 next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1] outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs ================================================ FILE: models/layers.py ================================================ from typing import Tuple import torch from torch import nn import torch.nn.functional as F try: from flash_attn_interface import flash_attn_func # type: ignore[import] except ImportError: # Fallback to FlashAttention 2 from flash_attn import flash_attn_func # type: ignore[import] from models.common import trunc_normal_init_ CosSin = Tuple[torch.Tensor, torch.Tensor] def _find_multiple(a, b): return (-(a // -b)) * b def rotate_half(x: torch.Tensor): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): # q, k: [bs, seq_len, num_heads, head_dim] # cos, sin: [seq_len, head_dim] orig_dtype = q.dtype q = q.to(cos.dtype) k = k.to(cos.dtype) q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2)) k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2)) return q_embed.to(orig_dtype), k_embed.to(orig_dtype) class CastedLinear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool): super().__init__() # Truncated LeCun normal init self.weight = nn.Parameter( trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5)) ) self.bias = None if bias: # Zero init bias self.bias = nn.Parameter(torch.zeros((out_features, ))) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None) class CastedEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, init_std: float, cast_to: torch.dtype): super().__init__() self.cast_to = cast_to # Truncated LeCun normal init self.embedding_weight = nn.Parameter( trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std) ) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.embedding(input, self.embedding_weight.to(self.cast_to)) class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings, base, device=None): super().__init__() # RoPE inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device) freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.cos_cached = nn.Buffer(emb.cos(), persistent=False) self.sin_cached = nn.Buffer(emb.sin(), persistent=False) def forward(self): return self.cos_cached, self.sin_cached class Attention(nn.Module): def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False): super().__init__() self.hidden_size = hidden_size self.head_dim = head_dim self.output_size = head_dim * num_heads self.num_heads = num_heads self.num_key_value_heads = num_key_value_heads self.causal = causal self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False) self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False) def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, seq_len, _ = hidden_states.shape # hidden_states: [bs, seq_len, num_heads, head_dim] qkv = self.qkv_proj(hidden_states) # Split head qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) query = qkv[:, :, :self.num_heads] key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads] value = qkv[:, :, self.num_heads + self.num_key_value_heads:] # RoPE if cos_sin is not None: cos, sin = cos_sin query, key = apply_rotary_pos_emb(query, key, cos, sin) # flash attn attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal) if isinstance(attn_output, tuple): # fa2 and fa3 compatibility attn_output = attn_output[0] attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore return self.o_proj(attn_output) class SwiGLU(nn.Module): def __init__(self, hidden_size: int, expansion: float): super().__init__() inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256) self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False) self.down_proj = CastedLinear(inter, hidden_size, bias=False) def forward(self, x): gate, up = self.gate_up_proj(x).chunk(2, dim=-1) return self.down_proj(F.silu(gate) * up) def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.square().mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) return hidden_states.to(input_dtype) ================================================ FILE: models/losses.py ================================================ from typing import Any, Tuple, Dict, Sequence, Optional import torch import torch.nn.functional as F from torch import nn IGNORE_LABEL_ID = -100 def s(x, epsilon=1e-30): return torch.where( x<0, 1/(1-x+ epsilon), x + 1 ) def log_stablemax(x, dim=-1): s_x = s(x) return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True)) def stablemax_cross_entropy(logits, labels, ignore_index: int = -100): logprobs = log_stablemax(logits.to(torch.float64), dim=-1) valid_mask = labels != ignore_index transformed_labels = torch.where(valid_mask, labels, 0) prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1) return -torch.where(valid_mask, prediction_logprobs, 0) def softmax_cross_entropy(logits, labels, ignore_index: int = -100): # Cast logits to f32 # Flatten logits return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape) class ACTLossHead(nn.Module): def __init__(self, model: nn.Module, loss_type: str): super().__init__() self.model = model self.loss_fn = globals()[loss_type] def initial_carry(self, *args, **kwargs): return self.model.initial_carry(*args, **kwargs) # type: ignore def forward( self, return_keys: Sequence[str], # Model args **model_kwargs, ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]: # Model logits # B x SeqLen x D new_carry, outputs = self.model(**model_kwargs) labels = new_carry.current_data["labels"] # Correctness with torch.no_grad(): mask = labels != IGNORE_LABEL_ID loss_counts = mask.sum(-1) loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels) seq_is_correct = is_correct.sum(-1) == loss_counts # Metrics (halted) valid_metrics = new_carry.halted & (loss_counts > 0) metrics = { "count": valid_metrics.sum(), "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(), "exact_accuracy": (valid_metrics & seq_is_correct).sum(), "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(), "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(), } # Losses # FIXME: Assuming the batch is always full lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum() q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum") metrics.update({ "lm_loss": lm_loss.detach(), "q_halt_loss": q_halt_loss.detach(), }) # Q continue (bootstrapping target loss) q_continue_loss = 0 if "target_q_continue" in outputs: q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum") metrics["q_continue_loss"] = q_continue_loss.detach() # Filter outputs for return detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs} return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all() ================================================ FILE: models/sparse_embedding.py ================================================ from typing import Union import torch from torch import nn import torch.distributed as dist from torch.optim.optimizer import Optimizer, ParamsT from models.common import trunc_normal_init_ class CastedSparseEmbedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype): super().__init__() self.cast_to = cast_to # Real Weights # Truncated LeCun normal init self.weights = nn.Buffer( trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True ) # Local weights and IDs # Local embeddings, with gradient, not persistent self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) # Local embedding IDs, not persistent self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) def forward(self, inputs: torch.Tensor) -> torch.Tensor: if not self.training: # Test mode, no gradient return self.weights[inputs].to(self.cast_to) # Training mode, fill puzzle embedding from weights with torch.no_grad(): self.local_weights.copy_(self.weights[inputs]) self.local_ids.copy_(inputs) return self.local_weights.to(self.cast_to) class CastedSparseEmbeddingSignSGD_Distributed(Optimizer): def __init__( self, params: ParamsT, world_size: int, lr: Union[float, torch.Tensor] = 1e-3, weight_decay: float = 1e-2, ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, weight_decay=weight_decay, world_size=world_size ) super().__init__(params, defaults) @torch.no_grad def step(self, closure=None): # type: ignore for group in self.param_groups: # Find the sparse embedding weights local_weights_grad = None local_ids = None weights = None assert len(group["params"]) == 3 for p in group["params"]: if p.requires_grad: local_weights_grad = p.grad elif p.ndim == 1: local_ids = p elif p.ndim == 2: weights = p else: assert False assert local_weights_grad is not None assert local_ids is not None assert weights is not None # Apply SignSGD # Adam ≈ SignSGD if gradient is very sparse _sparse_emb_signsgd_dist( local_weights_grad, local_ids, weights, lr=group["lr"], weight_decay=group["weight_decay"], world_size=group["world_size"] ) def _sparse_emb_signsgd_dist( local_weights_grad: torch.Tensor, local_ids: torch.Tensor, weights: torch.Tensor, lr: float, weight_decay: float, world_size: int ) -> None: N, D = local_weights_grad.shape # All-gather all_weights_grad = local_weights_grad all_ids = local_ids if world_size > 1: all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) dist.all_gather_into_tensor(all_ids, local_ids) # Unique grad_ids, inv = all_ids.unique(return_inverse=True) grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device) grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad) # SignSGD with decoupled weight decay p = weights[grad_ids] p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr) # Write updated slices back weights[grad_ids] = p ================================================ FILE: pretrain.py ================================================ from typing import Optional, Any, Sequence, List from dataclasses import dataclass import os import math import yaml import shutil import torch import torch.distributed as dist from torch import nn from torch.utils.data import DataLoader import tqdm import wandb import coolname import hydra import pydantic from omegaconf import DictConfig from adam_atan2 import AdamATan2 from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata from utils.functions import load_model_class, get_model_source_path from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed class LossConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str class ArchConfig(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra='allow') name: str loss: LossConfig class PretrainConfig(pydantic.BaseModel): # Config arch: ArchConfig # Data data_path: str # Hyperparams global_batch_size: int epochs: int lr: float lr_min_ratio: float lr_warmup_steps: int weight_decay: float beta1: float beta2: float # Puzzle embedding puzzle_emb_lr: float puzzle_emb_weight_decay: float # Names project_name: Optional[str] = None run_name: Optional[str] = None checkpoint_path: Optional[str] = None # Extras seed: int = 0 checkpoint_every_eval: bool = False eval_interval: Optional[int] = None eval_save_outputs: List[str] = [] @dataclass class TrainState: model: nn.Module optimizers: Sequence[torch.optim.Optimizer] optimizer_lrs: Sequence[float] carry: Any step: int total_steps: int def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): dataset = PuzzleDataset(PuzzleDatasetConfig( seed=config.seed, dataset_path=config.data_path, rank=rank, num_replicas=world_size, **kwargs ), split=split) dataloader = DataLoader( dataset, batch_size=None, num_workers=1, prefetch_factor=8, pin_memory=True, persistent_workers=True ) return dataloader, dataset.metadata def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): model_cfg = dict( **config.arch.__pydantic_extra__, # type: ignore batch_size=config.global_batch_size // world_size, vocab_size=train_metadata.vocab_size, seq_len=train_metadata.seq_len, num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, causal=False # Non-autoregressive ) # Instantiate model with loss head model_cls = load_model_class(config.arch.name) loss_head_cls = load_model_class(config.arch.loss.name) with torch.device("cuda"): model: nn.Module = model_cls(model_cfg) model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore if "DISABLE_COMPILE" not in os.environ: model = torch.compile(model, dynamic=False) # type: ignore # Broadcast parameters from rank 0 if world_size > 1: with torch.no_grad(): for param in list(model.parameters()) + list(model.buffers()): dist.broadcast(param, src=0) # Optimizers and lr optimizers = [ CastedSparseEmbeddingSignSGD_Distributed( model.model.puzzle_emb.buffers(), # type: ignore lr=0, # Needs to be set by scheduler weight_decay=config.puzzle_emb_weight_decay, world_size=world_size ), AdamATan2( model.parameters(), lr=0, # Needs to be set by scheduler weight_decay=config.weight_decay, betas=(config.beta1, config.beta2) ) ] optimizer_lrs = [ config.puzzle_emb_lr, config.lr ] return model, optimizers, optimizer_lrs def cosine_schedule_with_warmup_lr_lambda( current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 ): if current_step < num_warmup_steps: return base_lr * float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): # Estimated total training steps total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) # Model model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size) return TrainState( step=0, total_steps=total_steps, model=model, optimizers=optimizers, optimizer_lrs=optimizer_lrs, carry=None ) def save_train_state(config: PretrainConfig, train_state: TrainState): # FIXME: Only saved model. if config.checkpoint_path is None: return os.makedirs(config.checkpoint_path, exist_ok=True) torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): return cosine_schedule_with_warmup_lr_lambda( current_step=train_state.step, base_lr=base_lr, num_warmup_steps=round(config.lr_warmup_steps), num_training_steps=train_state.total_steps, min_ratio=config.lr_min_ratio ) def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): train_state.step += 1 if train_state.step > train_state.total_steps: # At most train_total_steps return # To device batch = {k: v.cuda() for k, v in batch.items()} # Init carry if it is None if train_state.carry is None: with torch.device("cuda"): train_state.carry = train_state.model.initial_carry(batch) # type: ignore # Forward train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) ((1 / global_batch_size) * loss).backward() # Allreduce if world_size > 1: for param in train_state.model.parameters(): if param.grad is not None: dist.all_reduce(param.grad) # Apply optimizer lr_this_step = None for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): lr_this_step = compute_lr(base_lr, config, train_state) for param_group in optim.param_groups: param_group['lr'] = lr_this_step optim.step() optim.zero_grad() # Reduce metrics if len(metrics): assert not any(v.requires_grad for v in metrics.values()) metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. # Reduce and reconstruct metric_values = torch.stack([metrics[k] for k in metric_keys]) if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: metric_values = metric_values.cpu().numpy() reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} # Postprocess count = max(reduced_metrics["count"], 1) # Avoid NaNs reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} reduced_metrics["train/lr"] = lr_this_step return reduced_metrics def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): with torch.inference_mode(): set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} all_preds = {} metric_keys = [] metric_values = None metric_global_batch_size = [0 for _ in range(len(set_ids))] carry = None for set_name, batch, global_batch_size in eval_loader: # To device batch = {k: v.cuda() for k, v in batch.items()} with torch.device("cuda"): carry = train_state.model.initial_carry(batch) # type: ignore # Forward while True: carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs) if all_finish: break for collection in (batch, preds): for k, v in collection.items(): if k in config.eval_save_outputs: all_preds.setdefault(k, []) all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory del carry, preds, batch, all_finish # Aggregate set_id = set_ids[set_name] if metric_values is None: metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda") metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) metric_global_batch_size[set_id] += global_batch_size if len(all_preds) and config.checkpoint_path is not None: all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()} os.makedirs(config.checkpoint_path, exist_ok=True) torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")) # Logging # Reduce to rank 0 if metric_values is not None: if world_size > 1: dist.reduce(metric_values, dst=0) if rank == 0: reduced_metrics = metric_values.cpu().numpy() reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)} for set_id, set_name in enumerate(set_ids)} # Postprocess for set_name, metrics in reduced_metrics.items(): count = metrics.pop("count") reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()} return reduced_metrics def save_code_and_config(config: PretrainConfig): if config.checkpoint_path is None or wandb.run is None: return os.makedirs(config.checkpoint_path, exist_ok=True) # Copy code code_list = [ get_model_source_path(config.arch.name), get_model_source_path(config.arch.loss.name) ] for code_file in code_list: if code_file is not None: code_name = os.path.basename(code_file) shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) # Dump config as yaml config_file = os.path.join(config.checkpoint_path, "all_config.yaml") with open(config_file, "wt") as f: yaml.dump(config.model_dump(), f) # Log code wandb.run.log_code(config.checkpoint_path) def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: objects = [None] if rank == 0: config = PretrainConfig(**hydra_config) # type: ignore # Naming if config.project_name is None: config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch" if config.run_name is None: config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" if config.checkpoint_path is None: config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) objects = [config] if world_size > 1: dist.broadcast_object_list(objects, src=0) return objects[0] # type: ignore @hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) def launch(hydra_config: DictConfig): RANK = 0 WORLD_SIZE = 1 # Initialize distributed training if in distributed environment (e.g. torchrun) if "LOCAL_RANK" in os.environ: # Initialize distributed, default device and dtype dist.init_process_group(backend="nccl") RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) # Load sync'ed config config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) # Seed RNGs to ensure consistency torch.random.manual_seed(config.seed + RANK) # Dataset train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs total_iters = config.epochs // train_epochs_per_iter assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) # Train state train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) # Progress bar and logger progress_bar = None if RANK == 0: progress_bar = tqdm.tqdm(total=train_state.total_steps) wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) save_code_and_config(config) # Training Loop for _iter_id in range(total_iters): print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") ############ Train Iter train_state.model.train() for set_name, batch, global_batch_size in train_loader: metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) progress_bar.update(train_state.step - progress_bar.n) # type: ignore ############ Evaluation train_state.model.eval() metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) ############ Checkpointing if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): save_train_state(config, train_state) # finalize if dist.is_initialized(): dist.destroy_process_group() wandb.finish() if __name__ == "__main__": launch() ================================================ FILE: puzzle_dataset.py ================================================ import os import json import numpy as np import pydantic import torch from torch.utils.data import IterableDataset, get_worker_info from models.losses import IGNORE_LABEL_ID from dataset.common import PuzzleDatasetMetadata def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int): # Pack examples into a full batch batch = [] batch_puzzle_indices = [] current_size = 0 while (start_index < group_order.size) and (current_size < global_batch_size): # Pick a group and a puzzle from that group group_id = group_order[start_index] puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1]) start_index += 1 # Get range of the puzzle puzzle_start = puzzle_indices[puzzle_id] puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start) append_size = min(puzzle_size, global_batch_size - current_size) # Put into batch batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32)) batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False)) current_size += append_size return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices) class PuzzleDatasetConfig(pydantic.BaseModel): seed: int dataset_path: str global_batch_size: int test_set_mode: bool epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead. rank: int num_replicas: int class PuzzleDataset(IterableDataset): def __init__(self, config: PuzzleDatasetConfig, split: str = "train"): super().__init__() self.config = config self.split = split self.metadata = self._load_metadata() # Checks assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}." self.local_batch_size = self.config.global_batch_size // self.config.num_replicas # State self._data = None self._iters = 0 def _load_metadata(self) -> PuzzleDatasetMetadata: with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f: return PuzzleDatasetMetadata(**json.load(f)) def _lazy_load_dataset(self): if self._data is not None: return field_mmap_modes = { "inputs": "r", "labels": "r", # Keep indices in memory "puzzle_identifiers": None, "puzzle_indices": None, "group_indices": None } # Load data self._data = {} for set_name in self.metadata.sets: # Load subset self._data[set_name] = { field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode) for field_name, mmap_mode in field_mmap_modes.items() } def _collate_batch(self, batch): # Convert dtype batch = {k: v.astype(np.int32) for k, v in batch.items()} # Convert ignore label IDs if self.metadata.ignore_label_id is not None: batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID # Pad if batch["puzzle_identifiers"].size < self.local_batch_size: pad_size = self.local_batch_size - batch["puzzle_identifiers"].size pad_values = { "inputs": self.metadata.pad_id, "labels": IGNORE_LABEL_ID, "puzzle_identifiers": self.metadata.blank_identifier_id } batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()} # To tensor return {k: torch.from_numpy(v) for k, v in batch.items()} def _iter_test(self): for set_name, dataset in self._data.items(): # type: ignore total_examples = len(dataset["inputs"]) # Load examples one by one start_index = 0 while start_index < total_examples: # Compute indices end_index = min(total_examples, start_index + self.config.global_batch_size) local_start = start_index + self.config.rank * self.local_batch_size local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index) # Get batch of examples, and also puzzle IDs puzzle_indices = [] puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1 for i in range(local_start, local_end): while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]: puzzle_index += 1 puzzle_indices.append(puzzle_index) batch = self._collate_batch({ "inputs": dataset["inputs"][local_start: local_end], "labels": dataset["labels"][local_start: local_end], "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices] }) yield set_name, batch, end_index - start_index # Advance to next batch start_index += self.config.global_batch_size def _iter_train(self): for set_name, dataset in self._data.items(): # type: ignore # Increase epoch count self._iters += 1 # Randomly shuffle groups rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters)) group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)]) start_index = 0 while start_index < group_order.size: start_index, batch_indices, batch_puzzle_indices = _sample_batch( rng, group_order=group_order, puzzle_indices=dataset["puzzle_indices"], group_indices=dataset["group_indices"], start_index=start_index, global_batch_size=self.config.global_batch_size, ) # Select current rank and collate global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads # Drop last batch if global_effective_batch_size < self.config.global_batch_size: break batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] batch = self._collate_batch({ "inputs": dataset["inputs"][batch_indices], "labels": dataset["labels"][batch_indices], "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices] }) yield set_name, batch, global_effective_batch_size def __iter__(self): worker_info = get_worker_info() assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported." self._lazy_load_dataset() # Iterate using specified mode if self.config.test_set_mode: yield from self._iter_test() else: yield from self._iter_train() ================================================ FILE: puzzle_visualizer.html ================================================ ARC‐Converted Dataset Visualizer (Upload Local Folder)

ARC‐Converted Dataset Visualizer (Local Directory)



================================================ FILE: requirements.txt ================================================ torch adam-atan2 einops tqdm coolname pydantic argdantic wandb omegaconf hydra-core huggingface_hub ================================================ FILE: utils/functions.py ================================================ import importlib import inspect def load_model_class(identifier: str, prefix: str = "models."): module_path, class_name = identifier.split('@') # Import the module module = importlib.import_module(prefix + module_path) cls = getattr(module, class_name) return cls def get_model_source_path(identifier: str, prefix: str = "models."): module_path, class_name = identifier.split('@') module = importlib.import_module(prefix + module_path) return inspect.getsourcefile(module)