Full Code of sapientinc/HRM for AI

main 42410daaaf6a cached
25 files
124.9 KB
31.4k tokens
109 symbols
1 requests
Download .txt
Repository: sapientinc/HRM
Branch: main
Commit: 42410daaaf6a
Files: 25
Total size: 124.9 KB

Directory structure:
gitextract_7bpy0bpn/

├── .gitignore
├── .gitmodules
├── .vscode/
│   ├── launch.json
│   └── settings.json
├── LICENSE
├── README.md
├── arc_eval.ipynb
├── assets/
│   └── npyjs.js
├── config/
│   ├── arch/
│   │   └── hrm_v1.yaml
│   └── cfg_pretrain.yaml
├── dataset/
│   ├── build_arc_dataset.py
│   ├── build_maze_dataset.py
│   ├── build_sudoku_dataset.py
│   └── common.py
├── evaluate.py
├── models/
│   ├── common.py
│   ├── hrm/
│   │   └── hrm_act_v1.py
│   ├── layers.py
│   ├── losses.py
│   └── sparse_embedding.py
├── pretrain.py
├── puzzle_dataset.py
├── puzzle_visualizer.html
├── requirements.txt
└── utils/
    └── functions.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# WandB
/wandb/
# checkpoints
/checkpoints/
# cache
/cache/
# data
/data/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

================================================
FILE: .gitmodules
================================================
[submodule "dataset/raw-data/ConceptARC"]
	path = dataset/raw-data/ConceptARC
	url = git@github.com:victorvikram/ConceptARC.git
[submodule "dataset/raw-data/ARC-AGI"]
	path = dataset/raw-data/ARC-AGI
	url = git@github.com:fchollet/ARC-AGI.git
[submodule "dataset/raw-data/ARC-AGI-2"]
	path = dataset/raw-data/ARC-AGI-2
	url = git@github.com:arcprize/ARC-AGI-2.git


================================================
FILE: .vscode/launch.json
================================================
{
    // Use IntelliSense to learn about possible attributes.
    // Hover to view descriptions of existing attributes.
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python Debugger: Current File",
            "type": "debugpy",
            "request": "launch",
            "program": "${file}",
            "console": "integratedTerminal"
        },
        {
            "name": "Debug: Single GPU",
            "type": "debugpy",
            "request": "launch",
            "program": "pretrain.py",
            "args": [],
            "env": {
                "OMP_NUM_THREADS": "1",
                "DISABLE_COMPILE": "true"
            }
        }
    ]
}

================================================
FILE: .vscode/settings.json
================================================
{
    "python.analysis.typeCheckingMode": "standard"
}

================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
# Hierarchical Reasoning Model

![](./assets/hrm.png)

Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.
Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.
HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.
Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.
These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.

**Join our Discord Community: [https://discord.gg/sapient](https://discord.gg/sapient)**


## Quick Start Guide 🚀

### Prerequisites ⚙️

Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:

```bash
# Install CUDA 12.6
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run

wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
sudo sh cuda_installer.run --silent --toolkit --override

export CUDA_HOME=/usr/local/cuda-12.6

# Install PyTorch with CUDA 12.6
PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126

pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL

# Additional packages for building extensions
pip3 install packaging ninja wheel setuptools setuptools-scm
```

Then install FlashAttention. For Hopper GPUs, install FlashAttention 3

```bash
git clone git@github.com:Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```

For Ampere or earlier GPUs, install FlashAttention 2

```bash
pip3 install flash-attn
```

## Install Python Dependencies 🐍

```bash
pip install -r requirements.txt
```

## W&B Integration 📈

This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:

```bash
wandb login
```

## Run Experiments

### Quick Demo: Sudoku Solver 💻🗲

Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩

```bash
# Download and build Sudoku dataset
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000  --subsample-size 1000 --num-aug 1000

# Start training (single GPU, smaller batch size)
OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
```

Runtime: ~10 hours on a RTX 4070 laptop GPU

## Trained Checkpoints 🚧

 - [ARC-AGI-2](https://huggingface.co/sapientinc/HRM-checkpoint-ARC-2)
 - [Sudoku 9x9 Extreme (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-sudoku-extreme)
 - [Maze 30x30 Hard (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-maze-30x30-hard)

To use the checkpoints, see Evaluation section below.

## Full-scale Experiments 🔵

Experiments below assume an 8-GPU setup.

### Dataset Preparation

```bash
# Initialize submodules
git submodule update --init --recursive

# ARC-1
python dataset/build_arc_dataset.py  # ARC offical + ConceptARC, 960 examples
# ARC-2
python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000  # ARC-2 official, 1120 examples

# Sudoku-Extreme
python dataset/build_sudoku_dataset.py  # Full version
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000  --subsample-size 1000 --num-aug 1000  # 1000 examples

# Maze
python dataset/build_maze_dataset.py  # 1000 examples
```

### Dataset Visualization

Explore the puzzles visually:

* Open `puzzle_visualizer.html` in your browser.
* Upload the generated dataset folder located in `data/...`.

## Launch experiments

### Small-sample (1K)

ARC-1:

```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py 
```

*Runtime:* ~24 hours

ARC-2:

```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
```

*Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)

Sudoku Extreme (1k):

```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
```

*Runtime:* ~10 minutes

Maze 30x30 Hard (1k):

```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
```

*Runtime:* ~1 hour

### Full Sudoku-Hard

```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned
```

*Runtime:* ~2 hours

## Evaluation

Evaluate your trained models:

* Check `eval/exact_accuracy` in W&B.
* For ARC-AGI, follow these additional steps:

```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>
```

* Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.

## Notes

 - Small-sample learning typically exhibits accuracy variance of around ±2 points.
 - For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.

## Citation 📜

```bibtex
@misc{wang2025hierarchicalreasoningmodel,
      title={Hierarchical Reasoning Model}, 
      author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
      year={2025},
      eprint={2506.21734},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2506.21734}, 
}
```


================================================
FILE: arc_eval.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from glob import glob\n",
    "import hashlib\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from numba import njit\n",
    "\n",
    "from dataset.common import inverse_dihedral_transform\n",
    "\n",
    "\n",
    "DATASET_PATH = \"data/arc-aug-1000\"  # ARC-1\n",
    "# DATASET_PATH = \"data/arc-2-aug-1000\"  # ARC-2\n",
    "\n",
    "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
    "\n",
    "\n",
    "PAD_PUZZLE_IDENTIFIER = 0\n",
    "\n",
    "# Visualization\n",
    "ARC_COLOR_MAP = mcolors.ListedColormap([\n",
    "    \"#000000\",  # symbol_0: black\n",
    "    \"#0074D9\",  # symbol_1: blue\n",
    "    \"#FF4136\",  # symbol_2: red\n",
    "    \"#2ECC40\",  # symbol_3: green\n",
    "    \"#FFDC00\",  # symbol_4: yellow\n",
    "    \"#AAAAAA\",  # symbol_5: grey\n",
    "    \"#F012BE\",  # symbol_6: fuschia\n",
    "    \"#FF851B\",  # symbol_7: orange\n",
    "    \"#7FDBFF\",  # symbol_8: teal\n",
    "    \"#870C25\"   # symbol_9: brown\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
    "    # Load puzzle identifiers\n",
    "    with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
    "        identifier_map = json.load(f)\n",
    "        \n",
    "    # Load preds\n",
    "    all_preds = {}\n",
    "    for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
    "        preds = torch.load(filename)\n",
    "        for k, v in preds.items():\n",
    "            all_preds.setdefault(k, [])\n",
    "            all_preds[k].append(v)\n",
    "            \n",
    "        del preds\n",
    "\n",
    "    all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
    "    \n",
    "    # Remove paddings\n",
    "    mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
    "    all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
    "\n",
    "    return identifier_map, all_preds\n",
    "\n",
    "\n",
    "def inverse_aug(name: str, grid: np.ndarray):\n",
    "    if \"_\" not in name:\n",
    "        return grid\n",
    "\n",
    "    trans_id, perm = name.split(\"_\")[-2:]\n",
    "    trans_id = int(trans_id[1:])  # Remove \"t\" letter\n",
    "    inv_perm = np.argsort(list(perm))\n",
    "    \n",
    "    return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
    "\n",
    "\n",
    "def grid_hash(grid: np.ndarray):\n",
    "    return hash((grid.tobytes(), grid.shape))\n",
    "\n",
    "\n",
    "@njit\n",
    "def crop(grid: np.ndarray):\n",
    "    # Find maximum-sized rectangle without any EOS token inside.\n",
    "    grid = grid.reshape(30, 30)\n",
    "\n",
    "    max_area = 0\n",
    "    max_size = (0, 0)\n",
    "    nr, nc = grid.shape\n",
    "    \n",
    "    num_c = nc\n",
    "    for num_r in range(1, nr + 1):\n",
    "        # Scan for maximum c\n",
    "        for c in range(1, num_c + 1):\n",
    "            x = grid[num_r - 1, c - 1]\n",
    "            if (x < 2) | (x > 11):\n",
    "                num_c = c - 1\n",
    "                break\n",
    "        \n",
    "        area = num_r * num_c\n",
    "        if area > max_area:\n",
    "            max_area = area\n",
    "            max_size = (num_r, num_c)\n",
    "\n",
    "    return grid[:max_size[0], :max_size[1]] - 2\n",
    "\n",
    "\n",
    "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
    "    identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
    "    \n",
    "    global_hmap = {}\n",
    "    \n",
    "    # Get puzzles and corresponding answers\n",
    "    puzzle_labels = {}\n",
    "    for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
    "        name = identifier_map[identifier]\n",
    "        if \"_\" not in name:   # Not-augmented\n",
    "            puzzle_labels.setdefault(name, {})\n",
    "            \n",
    "            input = crop(input.numpy())\n",
    "            label = crop(label.numpy())\n",
    "\n",
    "            input_hash = grid_hash(input)\n",
    "            label_hash = grid_hash(label)\n",
    "\n",
    "            global_hmap[input_hash] = input\n",
    "            global_hmap[label_hash] = label\n",
    "\n",
    "            assert input_hash not in puzzle_labels[name]\n",
    "            puzzle_labels[name][input_hash] = label_hash\n",
    "            \n",
    "    print (\"Number of puzzles\", len(puzzle_labels))\n",
    "    \n",
    "    # Argmax prediction\n",
    "    preds = all_preds[\"logits\"].argmax(-1)\n",
    "\n",
    "    # Collate\n",
    "    pred_answers = {}\n",
    "    for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
    "        name = identifier_map[identifier]\n",
    "        orig_name = name.split(\"_\")[0]\n",
    "        \n",
    "        input = input.numpy()\n",
    "        input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
    "        assert input_hash in puzzle_labels[orig_name]\n",
    "        \n",
    "        pred = inverse_aug(name, crop(pred.numpy()))\n",
    "        pred_hash = grid_hash(pred)\n",
    "        global_hmap[pred_hash] = pred\n",
    "        \n",
    "        pred_answers.setdefault(orig_name, {})\n",
    "        pred_answers[orig_name].setdefault(input_hash, [])\n",
    "        pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
    "\n",
    "    # test-1\n",
    "    if visualize:\n",
    "        num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
    "        fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
    "        \n",
    "        fig_id = 0\n",
    "    \n",
    "    correct = [0 for _ in range(len(Ks))]\n",
    "    for name, tests in puzzle_labels.items():\n",
    "        num_test_correct = [0 for _ in range(len(Ks))]\n",
    "        for input_hash, label_hash in tests.items():\n",
    "            p = pred_answers[name][input_hash]\n",
    "            p_map = {}\n",
    "            \n",
    "            for h, q in p:\n",
    "                p_map.setdefault(h, [0, 0])\n",
    "                p_map[h][0] += 1\n",
    "                p_map[h][1] += q\n",
    "                \n",
    "            for h, stats in p_map.items():\n",
    "                stats[1] /= stats[0]\n",
    "                \n",
    "            p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
    "\n",
    "            # 2-vote\n",
    "            for i, k in enumerate(Ks):\n",
    "                ok = False\n",
    "                for h, stats in p_map[:k]:\n",
    "                    ok |= h == label_hash\n",
    "                    \n",
    "                num_test_correct[i] += ok\n",
    "\n",
    "            if visualize:\n",
    "                # Show input and ground truth\n",
    "                axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
    "                axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
    "                axes[fig_id, 0].axis('off')\n",
    "                \n",
    "                axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
    "                axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
    "                axes[fig_id, 1].axis('off')\n",
    "                \n",
    "                trial_id = 2\n",
    "                for h, stats in p_map[:2]:\n",
    "                    ans = global_hmap[h]\n",
    "                    \n",
    "                    axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
    "                    axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
    "                    axes[fig_id, trial_id].axis('off')\n",
    "                    \n",
    "                    trial_id += 1\n",
    "                \n",
    "                fig_id += 1\n",
    "            \n",
    "        # Total correctness\n",
    "        for i in range(len(Ks)):\n",
    "            correct[i] += num_test_correct[i] == len(tests)\n",
    "\n",
    "    for i, k in enumerate(Ks):\n",
    "        print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
    "\n",
    "\n",
    "test(visualize=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: assets/npyjs.js
================================================
class npyjs {

    constructor(opts) {
        if (opts && !('convertFloat16' in opts)) {
            console.warn([
                "npyjs constructor now accepts {convertFloat16?: boolean}.",
                "For usage, go to https://github.com/jhuapl-boss/npyjs."
            ].join(" "));
        }

        this.convertFloat16 = opts?.convertFloat16 ?? true;

        this.dtypes = {
            "<u1": {
                name: "uint8",
                size: 8,
                arrayConstructor: Uint8Array,
            },
            "|u1": {
                name: "uint8",
                size: 8,
                arrayConstructor: Uint8Array,
            },
            "<u2": {
                name: "uint16",
                size: 16,
                arrayConstructor: Uint16Array,
            },
            "|i1": {
                name: "int8",
                size: 8,
                arrayConstructor: Int8Array,
            },
            "<i2": {
                name: "int16",
                size: 16,
                arrayConstructor: Int16Array,
            },
            "<u4": {
                name: "uint32",
                size: 32,
                arrayConstructor: Uint32Array,
            },
            "<i4": {
                name: "int32",
                size: 32,
                arrayConstructor: Int32Array,
            },
            "<u8": {
                name: "uint64",
                size: 64,
                arrayConstructor: BigUint64Array,
            },
            "<i8": {
                name: "int64",
                size: 64,
                arrayConstructor: BigInt64Array,
            },
            "<f4": {
                name: "float32",
                size: 32,
                arrayConstructor: Float32Array
            },
            "<f8": {
                name: "float64",
                size: 64,
                arrayConstructor: Float64Array
            },
            "<f2": {
                name: "float16",
                size: 16,
                arrayConstructor: Uint16Array,
                converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined
            },
        };
    }

    float16ToFloat32Array(float16Array) {
        const length = float16Array.length;
        const float32Array = new Float32Array(length);
        
        for (let i = 0; i < length; i++) {
            float32Array[i] = npyjs.float16ToFloat32(float16Array[i]);
        }
        
        return float32Array;
    }

    static float16ToFloat32(float16) {
        // Extract the parts of the float16
        const sign = (float16 >> 15) & 0x1;
        const exponent = (float16 >> 10) & 0x1f;
        const fraction = float16 & 0x3ff;

        // Handle special cases
        if (exponent === 0) {
            if (fraction === 0) {
                // Zero
                return sign ? -0 : 0;
            }
            // Denormalized number
            return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400);
        } else if (exponent === 0x1f) {
            if (fraction === 0) {
                // Infinity
                return sign ? -Infinity : Infinity;
            }
            // NaN
            return NaN;
        }

        // Normalized number
        return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400);
    }

    parse(arrayBufferContents) {
        // const version = arrayBufferContents.slice(6, 8); // Uint8-encoded
        const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);
        const offsetBytes = 10 + headerLength;

        const hcontents = new TextDecoder("utf-8").decode(
            new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))
        );
        const header = JSON.parse(
            hcontents
                .toLowerCase() // True -> true
                .replace(/'/g, '"')
                .replace("(", "[")
                .replace(/,*\),*/g, "]")
        );
        const shape = header.shape;
        const dtype = this.dtypes[header.descr];

        if (!dtype) {
            console.error(`Unsupported dtype: ${header.descr}`);
            return null;
        }

        const nums = new dtype.arrayConstructor(
            arrayBufferContents,
            offsetBytes
        );

        // Convert float16 to float32 if converter exists
        const data = dtype.converter ? dtype.converter.call(this, nums) : nums;

        return {
            dtype: dtype.name,
            data: data,
            shape,
            fortranOrder: header.fortran_order
        };
    }

    async load(filename, callback, fetchArgs) {
        /*
        Loads an array from a stream of bytes.
        */
        fetchArgs = fetchArgs || {};
        let arrayBuf;
        // If filename is ArrayBuffer
        if (filename instanceof ArrayBuffer) {
            arrayBuf = filename;
        }
        // If filename is a file path
        else {
            const resp = await fetch(filename, { ...fetchArgs });
            arrayBuf = await resp.arrayBuffer();
        }
        const result = this.parse(arrayBuf);
        if (callback) {
            return callback(result);
        }
        return result;
    }
}


================================================
FILE: config/arch/hrm_v1.yaml
================================================
name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
loss:
  name: losses@ACTLossHead
  loss_type: stablemax_cross_entropy

halt_exploration_prob: 0.1
halt_max_steps: 16

H_cycles: 2
L_cycles: 2

H_layers: 4
L_layers: 4

hidden_size: 512
num_heads: 8  # min(2, hidden_size // 64)
expansion: 4

puzzle_emb_ndim: ${.hidden_size}

pos_encodings: rope


================================================
FILE: config/cfg_pretrain.yaml
================================================
# ARC training config

defaults:
  - arch: hrm_v1
  - _self_

hydra:
  output_subdir: null

# Data path
data_path: data/arc-aug-1000

# Hyperparams - Training
global_batch_size: 768

epochs: 100000
eval_interval: 10000
checkpoint_every_eval: True

lr: 1e-4
lr_min_ratio: 1.0
lr_warmup_steps: 2000

# Standard hyperparameter settings for LM, as used in Llama
beta1: 0.9
beta2: 0.95
weight_decay: 0.1
puzzle_emb_weight_decay: 0.1

# Hyperparams - Puzzle embeddings training
puzzle_emb_lr: 1e-2


================================================
FILE: dataset/build_arc_dataset.py
================================================
from typing import List, Optional, Tuple, Dict
from dataclasses import dataclass
from pathlib import Path
import os
import json
import hashlib
import numpy as np
from glob import glob

from argdantic import ArgParser
from pydantic import BaseModel

from common import PuzzleDatasetMetadata, dihedral_transform


cli = ArgParser()


class DataProcessConfig(BaseModel):
    # ARC-1
    dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"]
    output_dir: str = "data/arc-aug-1000"
    
    # ARC-2
    # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"]
    # output_dir: str = "data/arc-2-aug-1000"

    seed: int = 42
    num_aug: int = 1000
    
    
ARCMaxGridSize = 30
ARCAugmentRetriesFactor = 5
    

@dataclass
class ARCPuzzle:
    id: str

    examples: List[Tuple[np.ndarray, np.ndarray]]

    
def arc_grid_to_np(grid: List[List[int]]):
    arr = np.array(grid)

    # Shape check
    assert arr.ndim == 2
    assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
    # Element check
    assert np.all((arr >= 0) & (arr <= 9))
    return arr.astype(np.uint8)


def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
    # PAD: 0, <eos>: 1, digits: 2 ... 11
    # Compute random top-left pad
    if do_translation:
        pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
        pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
    else:
        pad_r = pad_c = 0

    # Pad grid
    result = []
    for grid in [inp, out]:
        nrow, ncol = grid.shape
        grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)

        # Add <eos>
        eos_row, eos_col = pad_r + nrow, pad_c + ncol
        if eos_row < ARCMaxGridSize:
            grid[eos_row, pad_c:eos_col] = 1
        if eos_col < ARCMaxGridSize:
            grid[pad_r:eos_row, eos_col] = 1

        result.append(grid.flatten())

    return result


def puzzle_hash(puzzle: dict):
    # Hash the puzzle for checking equivalence
    def _grid_hash(grid: np.ndarray):
        buffer = [x.to_bytes(1) for x in grid.shape]
        buffer.append(grid.tobytes())
        
        return hashlib.sha256(b"".join(buffer)).hexdigest()
    
    hashes = []
    for example_type, example in puzzle.items():
        for input, label in example.examples:
            hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}")
            
    hashes.sort()
    return hashlib.sha256("|".join(hashes).encode()).hexdigest()


def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
    # Remove "name"
    name = puzzle.pop("name", default_name)
    
    # Convert
    dests = set(dest_mapping.values())
    converted = {dest: ARCPuzzle(name, []) for dest in dests}
    for example_type, examples in puzzle.items():
        dest = dest_mapping[example_type]
        converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])

    group = [converted]
    
    # Augment
    if aug_count > 0:
        hashes = {puzzle_hash(converted)}

        for _trial in range(ARCAugmentRetriesFactor * aug_count):
            # Augment plan
            trans_id = np.random.randint(0, 8)
            mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))])  # Permute colors, Excluding "0" (black)
            
            aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}"

            def _map_grid(grid: np.ndarray):
                return dihedral_transform(mapping[grid], trans_id)
            
            # Check duplicate
            augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
            h = puzzle_hash(augmented)
            if h not in hashes:
                hashes.add(h)
                group.append(augmented)
                
            if len(group) >= aug_count + 1:
                break
            
        if len(group) < aug_count + 1:
            print (f"[Puzzle {name}] augmentation not full, only {len(group)}")

    # Append
    for dest in dests:
        # Convert the examples
        dest_split, dest_set = dest

        results.setdefault(dest_split, {})
        results[dest_split].setdefault(dest_set, [])
        results[dest_split][dest_set].append([converted[dest] for converted in group])


def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):
    train_examples_dest = ("train", "all")
    test_examples_map = {
        "evaluation": [(1.0, ("test", "all"))],
        "_default": [(1.0, ("train", "all"))]
    }
    
    total_puzzles = 0
    for subdir in os.scandir(dataset_path):
        if subdir.is_dir():
            # Load all puzzles in this directory
            puzzles = []
            for filename in glob(os.path.join(subdir.path, "*.json")):
                with open(filename, "r") as f:
                    puzzles.append((Path(filename).stem, json.load(f)))
                    
            # Shuffle puzzles
            np.random.shuffle(puzzles)
            
            # Assign by fraction
            for idx, (default_name, puzzle) in enumerate(puzzles):
                fraction = idx / len(puzzles)
                test_examples_dest = None
                for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]):
                    if fraction < f:
                        test_examples_dest = dest
                        break
                        
                assert test_examples_dest is not None
                
                convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
                total_puzzles += 1

    print (f"[{dataset_path}] total puzzles: {total_puzzles}")


def convert_dataset(config: DataProcessConfig):
    np.random.seed(config.seed)
    
    # Read dataset
    data = {}
    for dataset_dir in config.dataset_dirs:
        load_puzzles_arcagi(data, dataset_dir, config)
    
    # Map global puzzle identifiers
    num_identifiers = 1  # 0 is blank
    identifier_map = {}
    for split_name, split in data.items():
        for subset_name, subset in split.items():
            for group in subset:
                for puzzle in group:
                    if puzzle.id not in identifier_map:
                        identifier_map[puzzle.id] = num_identifiers
                        num_identifiers += 1

    print (f"Total puzzle IDs (including <blank>): {num_identifiers}")

    # Save
    for split_name, split in data.items():
        os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
        
        # Translational augmentations
        enable_translational_augment = split_name == "train"

        # Statistics
        total_examples = 0
        total_puzzles = 0
        total_groups = 0
        
        for subset_name, subset in split.items():
            # Construct subset
            results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
            results["puzzle_indices"].append(0)
            results["group_indices"].append(0)
            
            example_id = 0
            puzzle_id = 0
            
            for group in subset:
                for puzzle in group:
                    # Push puzzle
                    no_aug_id = np.random.randint(0, len(puzzle.examples))
                    for _idx_ex, (inp, out) in enumerate(puzzle.examples):
                        inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
                            
                        results["inputs"].append(inp)
                        results["labels"].append(out)
                        example_id += 1
                        
                        total_examples += 1

                    results["puzzle_indices"].append(example_id)
                    results["puzzle_identifiers"].append(identifier_map[puzzle.id])
                    
                    puzzle_id += 1
                    
                    total_puzzles += 1
                    
                # Push group
                results["group_indices"].append(puzzle_id)
                total_groups += 1
            
            for k, v in results.items():
                if k in {"inputs", "labels"}:
                    v = np.stack(v, 0)
                else:
                    v = np.array(v, dtype=np.int32)
                
                np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
        
        # Metadata
        metadata = PuzzleDatasetMetadata(
            seq_len=ARCMaxGridSize * ARCMaxGridSize,
            vocab_size=10 + 2,  # PAD + EOS + "0" ... "9"
            
            pad_id=0,
            ignore_label_id=0,
            
            blank_identifier_id=0,
            num_puzzle_identifiers=num_identifiers,
            
            total_groups=total_groups,
            mean_puzzle_examples=total_examples / total_puzzles,
            sets=list(split.keys())
        )

        # Save metadata as JSON.
        with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
            json.dump(metadata.model_dump(), f)
            
    # Save IDs mapping
    with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
        ids_mapping = {v: k for k, v in identifier_map.items()}
        
        json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)


@cli.command(singleton=True)
def main(config: DataProcessConfig):
    convert_dataset(config)


if __name__ == "__main__":
    cli()


================================================
FILE: dataset/build_maze_dataset.py
================================================
from typing import Optional
import math
import os
import csv
import json
import numpy as np

from argdantic import ArgParser
from pydantic import BaseModel
from tqdm import tqdm
from huggingface_hub import hf_hub_download

from common import PuzzleDatasetMetadata, dihedral_transform


CHARSET = "# SGo"


cli = ArgParser()


class DataProcessConfig(BaseModel):
    source_repo: str = "sapientinc/maze-30x30-hard-1k"
    output_dir: str = "data/maze-30x30-hard-1k"

    subsample_size: Optional[int] = None
    aug: bool = False


def convert_subset(set_name: str, config: DataProcessConfig):
    # Read CSV
    all_chars = set()
    grid_size = None
    inputs = []
    labels = []
    
    with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:  # type: ignore
        reader = csv.reader(csvfile)
        next(reader)  # Skip header
        for source, q, a, rating in reader:
            all_chars.update(q)
            all_chars.update(a)

            if grid_size is None:
                n = int(len(q) ** 0.5)
                grid_size = (n, n)
                
            inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
            labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))

    # If subsample_size is specified for the training set,
    # randomly sample the desired number of examples.
    if set_name == "train" and config.subsample_size is not None:
        total_samples = len(inputs)
        if config.subsample_size < total_samples:
            indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
            inputs = [inputs[i] for i in indices]
            labels = [labels[i] for i in indices]

    # Generate dataset
    results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
    puzzle_id = 0
    example_id = 0
    
    results["puzzle_indices"].append(0)
    results["group_indices"].append(0)
    
    for inp, out in zip(tqdm(inputs), labels):
        # Dihedral transformations for augmentation
        for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
            results["inputs"].append(dihedral_transform(inp, aug_idx))
            results["labels"].append(dihedral_transform(out, aug_idx))
            example_id += 1
            puzzle_id += 1
            
            results["puzzle_indices"].append(example_id)
            results["puzzle_identifiers"].append(0)
            
        # Push group
        results["group_indices"].append(puzzle_id)
            
    # Char mappings
    assert len(all_chars - set(CHARSET)) == 0
    
    char2id = np.zeros(256, np.uint8)
    char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1

    # To Numpy
    def _seq_to_numpy(seq):
        arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
        
        return arr
    
    results = {
        "inputs": _seq_to_numpy(results["inputs"]),
        "labels": _seq_to_numpy(results["labels"]),
        
        "group_indices": np.array(results["group_indices"], dtype=np.int32),
        "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
        "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
    }

    # Metadata
    metadata = PuzzleDatasetMetadata(
        seq_len=int(math.prod(grid_size)),  # type: ignore
        vocab_size=len(CHARSET) + 1,  # PAD + Charset
        
        pad_id=0,
        ignore_label_id=0,
        
        blank_identifier_id=0,
        num_puzzle_identifiers=1,
        
        total_groups=len(results["group_indices"]) - 1,
        mean_puzzle_examples=1,
        sets=["all"]
    )

    # Save metadata as JSON.
    save_dir = os.path.join(config.output_dir, set_name)
    os.makedirs(save_dir, exist_ok=True)
    
    with open(os.path.join(save_dir, "dataset.json"), "w") as f:
        json.dump(metadata.model_dump(), f)
        
    # Save data
    for k, v in results.items():
        np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
        
    # Save IDs mapping (for visualization only)
    with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
        json.dump(["<blank>"], f)


@cli.command(singleton=True)
def preprocess_data(config: DataProcessConfig):
    convert_subset("train", config)
    convert_subset("test", config)


if __name__ == "__main__":
    cli()


================================================
FILE: dataset/build_sudoku_dataset.py
================================================
from typing import Optional
import os
import csv
import json
import numpy as np

from argdantic import ArgParser
from pydantic import BaseModel
from tqdm import tqdm
from huggingface_hub import hf_hub_download

from common import PuzzleDatasetMetadata


cli = ArgParser()


class DataProcessConfig(BaseModel):
    source_repo: str = "sapientinc/sudoku-extreme"
    output_dir: str = "data/sudoku-extreme-full"

    subsample_size: Optional[int] = None
    min_difficulty: Optional[int] = None
    num_aug: int = 0


def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
    # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
    digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
    
    # Randomly decide whether to transpose.
    transpose_flag = np.random.rand() < 0.5

    # Generate a valid row permutation:
    # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
    bands = np.random.permutation(3)
    row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])

    # Similarly for columns (stacks).
    stacks = np.random.permutation(3)
    col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])

    # Build an 81->81 mapping. For each new cell at (i, j)
    # (row index = i // 9, col index = i % 9),
    # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
    mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])

    def apply_transformation(x: np.ndarray) -> np.ndarray:
        # Apply transpose flag
        if transpose_flag:
            x = x.T
        # Apply the position mapping.
        new_board = x.flatten()[mapping].reshape(9, 9).copy()
        # Apply digit mapping
        return digit_map[new_board]

    return apply_transformation(board), apply_transformation(solution)


def convert_subset(set_name: str, config: DataProcessConfig):
    # Read CSV
    inputs = []
    labels = []
    
    with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
        reader = csv.reader(csvfile)
        next(reader)  # Skip header
        for source, q, a, rating in reader:
            if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
                assert len(q) == 81 and len(a) == 81
                
                inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
                labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))

    # If subsample_size is specified for the training set,
    # randomly sample the desired number of examples.
    if set_name == "train" and config.subsample_size is not None:
        total_samples = len(inputs)
        if config.subsample_size < total_samples:
            indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
            inputs = [inputs[i] for i in indices]
            labels = [labels[i] for i in indices]

    # Generate dataset
    num_augments = config.num_aug if set_name == "train" else 0

    results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
    puzzle_id = 0
    example_id = 0
    
    results["puzzle_indices"].append(0)
    results["group_indices"].append(0)
    
    for orig_inp, orig_out in zip(tqdm(inputs), labels):
        for aug_idx in range(1 + num_augments):
            # First index is not augmented
            if aug_idx == 0:
                inp, out = orig_inp, orig_out
            else:
                inp, out = shuffle_sudoku(orig_inp, orig_out)

            # Push puzzle (only single example)
            results["inputs"].append(inp)
            results["labels"].append(out)
            example_id += 1
            puzzle_id += 1
            
            results["puzzle_indices"].append(example_id)
            results["puzzle_identifiers"].append(0)
            
        # Push group
        results["group_indices"].append(puzzle_id)
        
    # To Numpy
    def _seq_to_numpy(seq):
        arr = np.concatenate(seq).reshape(len(seq), -1)
        
        assert np.all((arr >= 0) & (arr <= 9))
        return arr + 1
    
    results = {
        "inputs": _seq_to_numpy(results["inputs"]),
        "labels": _seq_to_numpy(results["labels"]),
        
        "group_indices": np.array(results["group_indices"], dtype=np.int32),
        "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
        "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
    }

    # Metadata
    metadata = PuzzleDatasetMetadata(
        seq_len=81,
        vocab_size=10 + 1,  # PAD + "0" ... "9"
        
        pad_id=0,
        ignore_label_id=0,
        
        blank_identifier_id=0,
        num_puzzle_identifiers=1,
        
        total_groups=len(results["group_indices"]) - 1,
        mean_puzzle_examples=1,
        sets=["all"]
    )

    # Save metadata as JSON.
    save_dir = os.path.join(config.output_dir, set_name)
    os.makedirs(save_dir, exist_ok=True)
    
    with open(os.path.join(save_dir, "dataset.json"), "w") as f:
        json.dump(metadata.model_dump(), f)
        
    # Save data
    for k, v in results.items():
        np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
        
    # Save IDs mapping (for visualization only)
    with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
        json.dump(["<blank>"], f)


@cli.command(singleton=True)
def preprocess_data(config: DataProcessConfig):
    convert_subset("train", config)
    convert_subset("test", config)


if __name__ == "__main__":
    cli()


================================================
FILE: dataset/common.py
================================================
from typing import List, Optional

import pydantic
import numpy as np


# Global list mapping each dihedral transform id to its inverse.
# Index corresponds to the original tid, and the value is its inverse.
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]


class PuzzleDatasetMetadata(pydantic.BaseModel):
    pad_id: int
    ignore_label_id: Optional[int]
    blank_identifier_id: int
    
    vocab_size: int
    seq_len: int
    num_puzzle_identifiers: int
    
    total_groups: int
    mean_puzzle_examples: float

    sets: List[str]


def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
    """8 dihedral symmetries by rotate, flip and mirror"""
    
    if tid == 0:
        return arr  # identity
    elif tid == 1:
        return np.rot90(arr, k=1)
    elif tid == 2:
        return np.rot90(arr, k=2)
    elif tid == 3:
        return np.rot90(arr, k=3)
    elif tid == 4:
        return np.fliplr(arr)       # horizontal flip
    elif tid == 5:
        return np.flipud(arr)       # vertical flip
    elif tid == 6:
        return arr.T                # transpose (reflection along main diagonal)
    elif tid == 7:
        return np.fliplr(np.rot90(arr, k=1))  # anti-diagonal reflection
    else:
        return arr
    
    
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
    return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])


================================================
FILE: evaluate.py
================================================
from typing import List
import yaml
import os

import torch
import torch.distributed as dist

import pydantic
from omegaconf import OmegaConf
from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader


class EvalConfig(pydantic.BaseModel):
    checkpoint: str
    
    save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]


def launch():
    eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli()))  # type: ignore
    
    RANK = 0
    WORLD_SIZE = 1
    # Initialize distributed training if in distributed environment (e.g. torchrun)
    if "LOCAL_RANK" in os.environ:
        # Initialize distributed, default device and dtype
        dist.init_process_group(backend="nccl")

        RANK = dist.get_rank()
        WORLD_SIZE = dist.get_world_size()

        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

    with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
        config = PretrainConfig(**yaml.safe_load(f))

        config.eval_save_outputs = eval_cfg.save_outputs
        config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)

    # Dataloader
    train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
    eval_loader,  eval_metadata  = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)

    # Models
    train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
    # Try unwrap torch.compile
    try:
        train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
    except:
        train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
    
    train_state.step = 0
    ckpt_filename = os.path.basename(eval_cfg.checkpoint)
    if ckpt_filename.startswith("step_"):
        train_state.step = int(ckpt_filename.removeprefix("step_"))

    # Evaluate
    print ("Starting evaluation")
    
    train_state.model.eval()
    metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)

    if metrics is not None:
        print (metrics)


if __name__ == "__main__":
    launch()


================================================
FILE: models/common.py
================================================
import math

import torch
from torch import nn


def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
    # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
    # This function is a PyTorch version of jax truncated normal init (default init method in flax)
    # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
    # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199

    with torch.no_grad():
        if std == 0:
            tensor.zero_()
        else:
            sqrt2 = math.sqrt(2)
            a = math.erf(lower / sqrt2)
            b = math.erf(upper / sqrt2)
            z = (b - a) / 2

            c = (2 * math.pi) ** -0.5
            pdf_u = c * math.exp(-0.5 * lower ** 2)
            pdf_l = c * math.exp(-0.5 * upper ** 2)
            comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)

            tensor.uniform_(a, b)
            tensor.erfinv_()
            tensor.mul_(sqrt2 * comp_std)
            tensor.clip_(lower * comp_std, upper * comp_std)

    return tensor


================================================
FILE: models/hrm/hrm_act_v1.py
================================================
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math

import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel

from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding


@dataclass
class HierarchicalReasoningModel_ACTV1InnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor


@dataclass
class HierarchicalReasoningModel_ACTV1Carry:
    inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
    
    steps: torch.Tensor
    halted: torch.Tensor
    
    current_data: Dict[str, torch.Tensor]


class HierarchicalReasoningModel_ACTV1Config(BaseModel):
    batch_size: int
    seq_len: int
    puzzle_emb_ndim: int = 0
    num_puzzle_identifiers: int
    vocab_size: int

    H_cycles: int
    L_cycles: int

    H_layers: int
    L_layers: int

    # Transformer config
    hidden_size: int
    expansion: float
    num_heads: int
    pos_encodings: str

    rms_norm_eps: float = 1e-5
    rope_theta: float = 10000.0
    
    # Halting Q-learning config
    halt_max_steps: int
    halt_exploration_prob: float

    forward_dtype: str = "bfloat16"


class HierarchicalReasoningModel_ACTV1Block(nn.Module):
    def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
        super().__init__()

        self.self_attn = Attention(
            hidden_size=config.hidden_size,
            head_dim=config.hidden_size // config.num_heads,
            num_heads=config.num_heads,
            num_key_value_heads=config.num_heads,
            causal=False
        )
        self.mlp = SwiGLU(
            hidden_size=config.hidden_size,
            expansion=config.expansion,
        )
        self.norm_eps = config.rms_norm_eps

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
        # Post Norm
        # Self Attention
        hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
        # Fully Connected
        hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
        return hidden_states


class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
    def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
        super().__init__()

        self.layers = torch.nn.ModuleList(layers)

    def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
        # Input injection (add)
        hidden_states = hidden_states + input_injection
        # Layers
        for layer in self.layers:
            hidden_states = layer(hidden_states=hidden_states, **kwargs)

        return hidden_states


class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
    def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
        super().__init__()
        self.config = config
        self.forward_dtype = getattr(torch, self.config.forward_dtype)

        # I/O
        self.embed_scale  = math.sqrt(self.config.hidden_size)
        embed_init_std = 1.0 / self.embed_scale

        self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        self.lm_head      = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.q_head       = CastedLinear(self.config.hidden_size, 2, bias=True)

        self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size)  # ceil div
        if self.config.puzzle_emb_ndim > 0:
            # Zero init puzzle embeddings
            self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
                                                    batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)

        # LM Blocks
        if self.config.pos_encodings == "rope":
            self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
                                              max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
                                              base=self.config.rope_theta)
        elif self.config.pos_encodings == "learned":
            self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
        else:
            raise NotImplementedError()

        # Reasoning Layers
        self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
        self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
        
        # Initial states
        self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
        self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)

        # Q head special init
        # Init Q to (almost) zero for faster learning during bootstrapping
        with torch.no_grad():
            self.q_head.weight.zero_()
            self.q_head.bias.fill_(-5)  # type: ignore

    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
        # Token embedding
        embedding = self.embed_tokens(input.to(torch.int32))

        # Puzzle embeddings
        if self.config.puzzle_emb_ndim > 0:
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
            
            pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))

            embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)

        # Position embeddings
        if self.config.pos_encodings == "learned":
            # scale by 1/sqrt(2) to maintain forward variance
            embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))

        # Scale
        return self.embed_scale * embedding

    def empty_carry(self, batch_size: int):
        return HierarchicalReasoningModel_ACTV1InnerCarry(
            z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
            z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
        )
        
    def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
        return HierarchicalReasoningModel_ACTV1InnerCarry(
            z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
            z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
        )

    def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        seq_info = dict(
            cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
        )

        # Input encoding
        input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])

        # Forward iterations
        with torch.no_grad():
            z_H, z_L = carry.z_H, carry.z_L

            for _H_step in range(self.config.H_cycles):
                for _L_step in range(self.config.L_cycles):
                    if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
                        z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)

                if not (_H_step == self.config.H_cycles - 1):
                    z_H = self.H_level(z_H, z_L, **seq_info)

        assert not z_H.requires_grad and not z_L.requires_grad

        # 1-step grad
        z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
        z_H = self.H_level(z_H, z_L, **seq_info)

        # LM Outputs
        new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach())  # New carry no grad
        output = self.lm_head(z_H)[:, self.puzzle_emb_len:]

        # Q head
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
        
        return new_carry, output, (q_logits[..., 0], q_logits[..., 1])


class HierarchicalReasoningModel_ACTV1(nn.Module):
    """ACT wrapper."""

    def __init__(self, config_dict: dict):
        super().__init__()
        self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
        self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)

    @property
    def puzzle_emb(self):
        return self.inner.puzzle_emb

    def initial_carry(self, batch: Dict[str, torch.Tensor]):
        batch_size = batch["inputs"].shape[0]

        return HierarchicalReasoningModel_ACTV1Carry(
            inner_carry=self.inner.empty_carry(batch_size),  # Empty is expected, it will be reseted in first pass as all sequences are halted.
            
            steps=torch.zeros((batch_size, ), dtype=torch.int32),
            halted=torch.ones((batch_size, ), dtype=torch.bool),  # Default to halted
            
            current_data={k: torch.empty_like(v) for k, v in batch.items()}
        )
        
    def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
        # Update data, carry (removing halted sequences)
        new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
        
        new_steps = torch.where(carry.halted, 0, carry.steps)

        new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}

        # Forward inner model
        new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)

        outputs = {
            "logits": logits,
            "q_halt_logits": q_halt_logits,
            "q_continue_logits": q_continue_logits
        }
        
        with torch.no_grad():
            # Step
            new_steps = new_steps + 1
            is_last_step = new_steps >= self.config.halt_max_steps
            
            halted = is_last_step

            # if training, and ACT is enabled
            if self.training and (self.config.halt_max_steps > 1):
                # Halt signal
                # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
                halted = halted | (q_halt_logits > q_continue_logits)

                # Exploration
                min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)

                halted = halted & (new_steps >= min_halt_steps)

                # Compute target Q
                # NOTE: No replay buffer and target networks for computing target Q-value.
                # As batch_size is large, there're many parallel envs.
                # Similar concept as PQN https://arxiv.org/abs/2407.04811
                next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
                
                outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))

        return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs


================================================
FILE: models/layers.py
================================================
from typing import Tuple

import torch
from torch import nn
import torch.nn.functional as F

try:
    from flash_attn_interface import flash_attn_func  # type: ignore[import]
except ImportError:
    # Fallback to FlashAttention 2
    from flash_attn import flash_attn_func  # type: ignore[import]

from models.common import trunc_normal_init_


CosSin = Tuple[torch.Tensor, torch.Tensor]


def _find_multiple(a, b):
    return (-(a // -b)) * b


def rotate_half(x: torch.Tensor):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    # q, k: [bs, seq_len, num_heads, head_dim]
    # cos, sin: [seq_len, head_dim]
    orig_dtype = q.dtype
    q = q.to(cos.dtype)
    k = k.to(cos.dtype)

    q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
    k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))

    return q_embed.to(orig_dtype), k_embed.to(orig_dtype)


class CastedLinear(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool):
        super().__init__()
        # Truncated LeCun normal init
        self.weight = nn.Parameter(
            trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
        )
        self.bias = None
        if bias:
            # Zero init bias
            self.bias = nn.Parameter(torch.zeros((out_features, )))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)


class CastedEmbedding(nn.Module):
    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 init_std: float,
                 cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to

        # Truncated LeCun normal init
        self.embedding_weight = nn.Parameter(
            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
        )
        
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.embedding(input, self.embedding_weight.to(self.cast_to))


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings, base, device=None):
        super().__init__()

        # RoPE
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
        t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
        freqs = torch.outer(t, inv_freq)

        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
        self.sin_cached = nn.Buffer(emb.sin(), persistent=False)

    def forward(self):
        return self.cos_cached, self.sin_cached


class Attention(nn.Module):
    def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
        super().__init__()

        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.output_size = head_dim * num_heads
        self.num_heads = num_heads
        self.num_key_value_heads = num_key_value_heads
        self.causal = causal

        self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
        self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)

    def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, _ = hidden_states.shape

        # hidden_states: [bs, seq_len, num_heads, head_dim]
        qkv = self.qkv_proj(hidden_states)

        # Split head
        qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
        query = qkv[:, :, :self.num_heads]
        key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
        value = qkv[:, :, self.num_heads + self.num_key_value_heads:]

        # RoPE
        if cos_sin is not None:
            cos, sin = cos_sin
            query, key = apply_rotary_pos_emb(query, key, cos, sin)

        # flash attn
        attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
        if isinstance(attn_output, tuple):  # fa2 and fa3 compatibility
            attn_output = attn_output[0]

        attn_output = attn_output.view(batch_size, seq_len, self.output_size)  # type: ignore
        return self.o_proj(attn_output)


class SwiGLU(nn.Module):
    def __init__(self, hidden_size: int, expansion: float):
        super().__init__()
        inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)

        self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
        self.down_proj    = CastedLinear(inter, hidden_size, bias=False)

    def forward(self, x):
        gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
        return self.down_proj(F.silu(gate) * up)


def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)

    variance = hidden_states.square().mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return hidden_states.to(input_dtype)


================================================
FILE: models/losses.py
================================================
from typing import Any, Tuple, Dict, Sequence, Optional

import torch
import torch.nn.functional as F
from torch import nn


IGNORE_LABEL_ID = -100


def s(x, epsilon=1e-30):
    return torch.where(
        x<0,
        1/(1-x+ epsilon),
        x + 1
    )


def log_stablemax(x, dim=-1):
    s_x = s(x)
    return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))


def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
    logprobs = log_stablemax(logits.to(torch.float64), dim=-1)

    valid_mask = labels != ignore_index
    transformed_labels = torch.where(valid_mask, labels, 0)
    prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)

    return -torch.where(valid_mask, prediction_logprobs, 0)


def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
    # Cast logits to f32
    # Flatten logits
    return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)


class ACTLossHead(nn.Module):
    def __init__(self, model: nn.Module, loss_type: str):
        super().__init__()
        self.model = model
        self.loss_fn = globals()[loss_type]
        
    def initial_carry(self, *args, **kwargs):
        return self.model.initial_carry(*args, **kwargs)  # type: ignore

    def forward(
        self,
        return_keys: Sequence[str],
        # Model args
        **model_kwargs,
    ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
        # Model logits
        # B x SeqLen x D
        new_carry, outputs = self.model(**model_kwargs)
        labels = new_carry.current_data["labels"]

        # Correctness
        with torch.no_grad():
            mask = labels != IGNORE_LABEL_ID
            loss_counts = mask.sum(-1)
            loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1)  # Avoid NaNs in division

            is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
            seq_is_correct = is_correct.sum(-1) == loss_counts
            
            # Metrics (halted)
            valid_metrics = new_carry.halted & (loss_counts > 0)
            metrics = {
                "count": valid_metrics.sum(),
                
                "accuracy":       torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
                "exact_accuracy": (valid_metrics & seq_is_correct).sum(),

                "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
                "steps":          torch.where(valid_metrics, new_carry.steps, 0).sum(),
            }

        # Losses
        # FIXME: Assuming the batch is always full
        lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
        q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")

        metrics.update({
            "lm_loss": lm_loss.detach(),
            "q_halt_loss": q_halt_loss.detach(),
        })

        # Q continue (bootstrapping target loss)
        q_continue_loss = 0
        if "target_q_continue" in outputs:
            q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")

            metrics["q_continue_loss"] = q_continue_loss.detach()

        # Filter outputs for return
        detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}

        return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()


================================================
FILE: models/sparse_embedding.py
================================================
from typing import Union

import torch
from torch import nn
import torch.distributed as dist
from torch.optim.optimizer import Optimizer, ParamsT

from models.common import trunc_normal_init_


class CastedSparseEmbedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
        super().__init__()
        self.cast_to = cast_to

        # Real Weights
        # Truncated LeCun normal init
        self.weights = nn.Buffer(
            trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
        )

        # Local weights and IDs
        # Local embeddings, with gradient, not persistent
        self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
        # Local embedding IDs, not persistent
        self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        if not self.training:
            # Test mode, no gradient
            return self.weights[inputs].to(self.cast_to)
            
        # Training mode, fill puzzle embedding from weights
        with torch.no_grad():
            self.local_weights.copy_(self.weights[inputs])
            self.local_ids.copy_(inputs)

        return self.local_weights.to(self.cast_to)


class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
    def __init__(
        self,
        params: ParamsT,

        world_size: int,
        lr: Union[float, torch.Tensor] = 1e-3,
        weight_decay: float = 1e-2,
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            world_size=world_size
        )
        super().__init__(params, defaults)

    @torch.no_grad
    def step(self, closure=None):  # type: ignore
        for group in self.param_groups:
            # Find the sparse embedding weights
            local_weights_grad = None
            local_ids = None
            weights = None
            
            assert len(group["params"]) == 3
            for p in group["params"]:
                if p.requires_grad:
                    local_weights_grad = p.grad
                elif p.ndim == 1:
                    local_ids = p
                elif p.ndim == 2:
                    weights = p
                else:
                    assert False
                
            assert local_weights_grad is not None
            assert local_ids is not None
            assert weights is not None
        
            # Apply SignSGD
            # Adam ≈ SignSGD if gradient is very sparse
            _sparse_emb_signsgd_dist(
                local_weights_grad,
                local_ids,
                weights,
                
                lr=group["lr"],
                weight_decay=group["weight_decay"],
                world_size=group["world_size"]
            )


def _sparse_emb_signsgd_dist(
    local_weights_grad: torch.Tensor,
    local_ids: torch.Tensor,
    weights: torch.Tensor,
    
    lr: float,
    weight_decay: float,
    world_size: int
) -> None:
    N, D = local_weights_grad.shape
    
    # All-gather
    all_weights_grad = local_weights_grad
    all_ids = local_ids

    if world_size > 1:
        all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
        all_ids = torch.empty(world_size * N,               dtype=local_ids.dtype,          device=local_ids.device)
    
        dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
        dist.all_gather_into_tensor(all_ids,          local_ids)

    # Unique
    grad_ids, inv = all_ids.unique(return_inverse=True)

    grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
    grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)

    # SignSGD with decoupled weight decay
    p = weights[grad_ids]

    p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)

    # Write updated slices back
    weights[grad_ids] = p


================================================
FILE: pretrain.py
================================================
from typing import Optional, Any, Sequence, List
from dataclasses import dataclass
import os
import math
import yaml
import shutil

import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader

import tqdm
import wandb
import coolname
import hydra
import pydantic
from omegaconf import DictConfig
from adam_atan2 import AdamATan2

from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
from utils.functions import load_model_class, get_model_source_path
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed


class LossConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')
    
    name: str


class ArchConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')

    name: str
    loss: LossConfig


class PretrainConfig(pydantic.BaseModel):
    # Config
    arch: ArchConfig
    # Data
    data_path: str

    # Hyperparams
    global_batch_size: int
    epochs: int

    lr: float
    lr_min_ratio: float
    lr_warmup_steps: int

    weight_decay: float
    beta1: float
    beta2: float

    # Puzzle embedding
    puzzle_emb_lr: float
    puzzle_emb_weight_decay: float

    # Names
    project_name: Optional[str] = None
    run_name: Optional[str] = None
    checkpoint_path: Optional[str] = None

    # Extras
    seed: int = 0
    checkpoint_every_eval: bool = False
    eval_interval: Optional[int] = None
    eval_save_outputs: List[str] = []


@dataclass
class TrainState:
    model: nn.Module
    optimizers: Sequence[torch.optim.Optimizer]
    optimizer_lrs: Sequence[float]
    carry: Any

    step: int
    total_steps: int


def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
    dataset = PuzzleDataset(PuzzleDatasetConfig(
        seed=config.seed,

        dataset_path=config.data_path,

        rank=rank,
        num_replicas=world_size,
        
        **kwargs
    ), split=split)
    dataloader = DataLoader(
        dataset,
        batch_size=None,

        num_workers=1,
        prefetch_factor=8,

        pin_memory=True,
        persistent_workers=True
    )
    return dataloader, dataset.metadata


def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
    model_cfg = dict(
        **config.arch.__pydantic_extra__,  # type: ignore

        batch_size=config.global_batch_size // world_size,

        vocab_size=train_metadata.vocab_size,
        seq_len=train_metadata.seq_len,
        num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
        causal=False  # Non-autoregressive
    )

    # Instantiate model with loss head
    model_cls = load_model_class(config.arch.name)
    loss_head_cls = load_model_class(config.arch.loss.name)

    with torch.device("cuda"):
        model: nn.Module = model_cls(model_cfg)
        model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__)  # type: ignore
        if "DISABLE_COMPILE" not in os.environ:
            model = torch.compile(model, dynamic=False)  # type: ignore

        # Broadcast parameters from rank 0
        if world_size > 1:
            with torch.no_grad():
                for param in list(model.parameters()) + list(model.buffers()):
                    dist.broadcast(param, src=0)

    # Optimizers and lr
    optimizers = [
        CastedSparseEmbeddingSignSGD_Distributed(
            model.model.puzzle_emb.buffers(),  # type: ignore
            
            lr=0,  # Needs to be set by scheduler
            weight_decay=config.puzzle_emb_weight_decay,

            world_size=world_size
        ),
        AdamATan2(
            model.parameters(),

            lr=0,  # Needs to be set by scheduler
            weight_decay=config.weight_decay,
            betas=(config.beta1, config.beta2)
        )
    ]
    optimizer_lrs = [
        config.puzzle_emb_lr,
        config.lr
    ]

    return model, optimizers, optimizer_lrs


def cosine_schedule_with_warmup_lr_lambda(
    current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
):
    if current_step < num_warmup_steps:
        return base_lr * float(current_step) / float(max(1, num_warmup_steps))

    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))


def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
    # Estimated total training steps
    total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)

    # Model
    model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)

    return TrainState(
        step=0,
        total_steps=total_steps,

        model=model,
        optimizers=optimizers,
        optimizer_lrs=optimizer_lrs,
        carry=None
    )


def save_train_state(config: PretrainConfig, train_state: TrainState):
    # FIXME: Only saved model.
    if config.checkpoint_path is None:
        return

    os.makedirs(config.checkpoint_path, exist_ok=True)
    torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))


def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
    return cosine_schedule_with_warmup_lr_lambda(
        current_step=train_state.step,
        base_lr=base_lr,
        num_warmup_steps=round(config.lr_warmup_steps),
        num_training_steps=train_state.total_steps,
        min_ratio=config.lr_min_ratio
    )


def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
    train_state.step += 1
    if train_state.step > train_state.total_steps:  # At most train_total_steps
        return

    # To device
    batch = {k: v.cuda() for k, v in batch.items()}

    # Init carry if it is None
    if train_state.carry is None:
        with torch.device("cuda"):
            train_state.carry = train_state.model.initial_carry(batch)  # type: ignore

    # Forward
    train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])

    ((1 / global_batch_size) * loss).backward()

    # Allreduce
    if world_size > 1:
        for param in train_state.model.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad)
            
    # Apply optimizer
    lr_this_step = None    
    for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
        lr_this_step = compute_lr(base_lr, config, train_state)

        for param_group in optim.param_groups:
            param_group['lr'] = lr_this_step
            
        optim.step()
        optim.zero_grad()

    # Reduce metrics
    if len(metrics):
        assert not any(v.requires_grad for v in metrics.values())

        metric_keys = list(sorted(metrics.keys()))  # Sort keys to guarantee all processes use the same order.
        # Reduce and reconstruct
        metric_values = torch.stack([metrics[k] for k in metric_keys])
        if world_size > 1:
            dist.reduce(metric_values, dst=0)

        if rank == 0:
            metric_values = metric_values.cpu().numpy()
            reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
            
            # Postprocess
            count = max(reduced_metrics["count"], 1)  # Avoid NaNs
            reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}

            reduced_metrics["train/lr"] = lr_this_step
            return reduced_metrics


def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
    with torch.inference_mode():
        set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
        
        all_preds = {}

        metric_keys = []
        metric_values = None
        metric_global_batch_size = [0 for _ in range(len(set_ids))]
        
        carry = None
        for set_name, batch, global_batch_size in eval_loader:
            # To device
            batch = {k: v.cuda() for k, v in batch.items()}
            with torch.device("cuda"):
                carry = train_state.model.initial_carry(batch)  # type: ignore

            # Forward
            while True:
                carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)
                
                if all_finish:
                    break

            for collection in (batch, preds):
                for k, v in collection.items():
                    if k in config.eval_save_outputs:
                        all_preds.setdefault(k, [])
                        all_preds[k].append(v.cpu())  # Move to CPU for saving GPU memory
                        
            del carry, preds, batch, all_finish

            # Aggregate
            set_id = set_ids[set_name]
            
            if metric_values is None:
                metric_keys = list(sorted(metrics.keys()))  # Sort keys to guarantee all processes use the same order.
                metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
                
            metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
            metric_global_batch_size[set_id] += global_batch_size

        if len(all_preds) and config.checkpoint_path is not None:
            all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}

            os.makedirs(config.checkpoint_path, exist_ok=True)
            torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))

        # Logging
        # Reduce to rank 0
        if metric_values is not None:
            if world_size > 1:
                dist.reduce(metric_values, dst=0)
            
            if rank == 0:
                reduced_metrics = metric_values.cpu().numpy()
                reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
                                   for set_id, set_name in enumerate(set_ids)}
                
                # Postprocess
                for set_name, metrics in reduced_metrics.items():
                    count = metrics.pop("count")
                    reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}

                return reduced_metrics


def save_code_and_config(config: PretrainConfig):
    if config.checkpoint_path is None or wandb.run is None:
        return

    os.makedirs(config.checkpoint_path, exist_ok=True)

    # Copy code
    code_list = [
        get_model_source_path(config.arch.name),
        get_model_source_path(config.arch.loss.name)
    ]
    for code_file in code_list:
        if code_file is not None:
            code_name = os.path.basename(code_file)

            shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))

    # Dump config as yaml
    config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
    with open(config_file, "wt") as f:
        yaml.dump(config.model_dump(), f)

    # Log code
    wandb.run.log_code(config.checkpoint_path)


def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
    objects = [None]
    if rank == 0:
        config = PretrainConfig(**hydra_config)  # type: ignore

        # Naming
        if config.project_name is None:
            config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
        if config.run_name is None:
            config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
        if config.checkpoint_path is None:
            config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)

        objects = [config]

    if world_size > 1:
        dist.broadcast_object_list(objects, src=0)

    return objects[0]  # type: ignore


@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
def launch(hydra_config: DictConfig):
    RANK = 0
    WORLD_SIZE = 1

    # Initialize distributed training if in distributed environment (e.g. torchrun)
    if "LOCAL_RANK" in os.environ:
        # Initialize distributed, default device and dtype
        dist.init_process_group(backend="nccl")

        RANK = dist.get_rank()
        WORLD_SIZE = dist.get_world_size()

        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
        
    # Load sync'ed config
    config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)

    # Seed RNGs to ensure consistency
    torch.random.manual_seed(config.seed + RANK)

    # Dataset
    train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
    total_iters = config.epochs // train_epochs_per_iter

    assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."

    train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
    eval_loader,  eval_metadata  = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)

    # Train state
    train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)

    # Progress bar and logger
    progress_bar = None
    if RANK == 0:
        progress_bar = tqdm.tqdm(total=train_state.total_steps)

        wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True))  # type: ignore
        wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
        save_code_and_config(config)

    # Training Loop
    for _iter_id in range(total_iters):
        print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")

        ############ Train Iter
        train_state.model.train()
        for set_name, batch, global_batch_size in train_loader:
            metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)

            if RANK == 0 and metrics is not None:
                wandb.log(metrics, step=train_state.step)
                progress_bar.update(train_state.step - progress_bar.n)  # type: ignore

        ############ Evaluation
        train_state.model.eval()
        metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)

        if RANK == 0 and metrics is not None:
            wandb.log(metrics, step=train_state.step)
            
        ############ Checkpointing
        if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
            save_train_state(config, train_state)

    # finalize
    if dist.is_initialized():
        dist.destroy_process_group()
    wandb.finish()


if __name__ == "__main__":
    launch()


================================================
FILE: puzzle_dataset.py
================================================
import os
import json

import numpy as np
import pydantic

import torch
from torch.utils.data import IterableDataset, get_worker_info

from models.losses import IGNORE_LABEL_ID
from dataset.common import PuzzleDatasetMetadata


def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
    # Pack examples into a full batch
    batch = []
    batch_puzzle_indices = []
    current_size = 0

    while (start_index < group_order.size) and (current_size < global_batch_size):
        # Pick a group and a puzzle from that group
        group_id = group_order[start_index]
        puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
        start_index += 1

        # Get range of the puzzle
        puzzle_start = puzzle_indices[puzzle_id]
        puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)

        append_size = min(puzzle_size, global_batch_size - current_size)

        # Put into batch
        batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
        batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))

        current_size += append_size

    return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)


class PuzzleDatasetConfig(pydantic.BaseModel):
    seed: int
    dataset_path: str
    global_batch_size: int
    test_set_mode: bool

    epochs_per_iter: int  # Batch X epochs in an iteration to reduce overhead.

    rank: int
    num_replicas: int


class PuzzleDataset(IterableDataset):
    def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
        super().__init__()
        self.config = config
        self.split = split
        self.metadata = self._load_metadata()
        
        # Checks
        assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
        self.local_batch_size = self.config.global_batch_size // self.config.num_replicas

        # State
        self._data = None
        self._iters = 0

    def _load_metadata(self) -> PuzzleDatasetMetadata:
        with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
            return PuzzleDatasetMetadata(**json.load(f))

    def _lazy_load_dataset(self):
        if self._data is not None:
            return

        field_mmap_modes = {
            "inputs": "r",
            "labels": "r",

            # Keep indices in memory
            "puzzle_identifiers": None,
            "puzzle_indices": None,
            "group_indices": None
        }

        # Load data
        self._data = {}
        for set_name in self.metadata.sets:
            # Load subset
            self._data[set_name] = {
                field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
                for field_name, mmap_mode in field_mmap_modes.items()
            }

    def _collate_batch(self, batch):
        # Convert dtype
        batch = {k: v.astype(np.int32) for k, v in batch.items()}

        # Convert ignore label IDs
        if self.metadata.ignore_label_id is not None:
            batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID

        # Pad
        if batch["puzzle_identifiers"].size < self.local_batch_size:
            pad_size = self.local_batch_size - batch["puzzle_identifiers"].size

            pad_values = {
                "inputs": self.metadata.pad_id,
                "labels": IGNORE_LABEL_ID,

                "puzzle_identifiers": self.metadata.blank_identifier_id
            }
            batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}

        # To tensor
        return {k: torch.from_numpy(v) for k, v in batch.items()}
    
    def _iter_test(self):
        for set_name, dataset in self._data.items():  # type: ignore
            total_examples = len(dataset["inputs"])

            # Load examples one by one
            start_index = 0
            while start_index < total_examples:
                # Compute indices
                end_index = min(total_examples, start_index + self.config.global_batch_size)
                
                local_start = start_index + self.config.rank * self.local_batch_size
                local_end   = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
                
                # Get batch of examples, and also puzzle IDs
                puzzle_indices = []
                puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
                for i in range(local_start, local_end):
                    while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
                        puzzle_index += 1

                    puzzle_indices.append(puzzle_index)
                
                batch = self._collate_batch({
                    "inputs": dataset["inputs"][local_start: local_end],
                    "labels": dataset["labels"][local_start: local_end],
                    "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
                })

                yield set_name, batch, end_index - start_index
                
                # Advance to next batch
                start_index += self.config.global_batch_size

    def _iter_train(self):
        for set_name, dataset in self._data.items():  # type: ignore
            # Increase epoch count
            self._iters += 1

            # Randomly shuffle groups
            rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))

            group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
            start_index = 0
            
            while start_index < group_order.size:
                start_index, batch_indices, batch_puzzle_indices = _sample_batch(
                    rng,
                    group_order=group_order,
                    puzzle_indices=dataset["puzzle_indices"],
                    group_indices=dataset["group_indices"],
                    start_index=start_index,
                    global_batch_size=self.config.global_batch_size,
                )

                # Select current rank and collate
                global_effective_batch_size = batch_puzzle_indices.size  # Global effective batch size, excluding pads

                # Drop last batch
                if global_effective_batch_size < self.config.global_batch_size:
                    break

                batch_indices        = batch_indices       [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
                batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
                batch = self._collate_batch({
                    "inputs": dataset["inputs"][batch_indices],
                    "labels": dataset["labels"][batch_indices],
                    "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
                })

                yield set_name, batch, global_effective_batch_size
                
    def __iter__(self):
        worker_info = get_worker_info()
        assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
        
        self._lazy_load_dataset()
        
        # Iterate using specified mode
        if self.config.test_set_mode:
            yield from self._iter_test()
        else:
            yield from self._iter_train()


================================================
FILE: puzzle_visualizer.html
================================================
<!DOCTYPE html>
<html>
<head>
  <meta charset="UTF-8" />
  <title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title>
  <style>
    body {
      font-family: sans-serif;
      margin: 16px;
    }
    .selector-area {
      margin-bottom: 1rem;
    }
    .grid-canvas {
      margin: 4px;
      border: 1px solid #ccc;
    }
    .example-container {
      display: inline-block;
      margin: 0 16px 16px 0;
      vertical-align: top;
    }
    .puzzle-display {
      margin-top: 1rem;
    }
    .puzzle-id {
      font-weight: bold;
      margin-bottom: 0.5rem;
    }
    #groupList, #puzzleList {
      margin: 1rem 0;
    }
    .group-item, .puzzle-item {
      cursor: pointer;
      margin: 4px 8px 4px 0;
      padding: 2px 6px;
      border: 1px solid #aaa;
      display: inline-block;
    }
    .group-item:hover, .puzzle-item:hover {
      background: #eef;
    }
  </style>
</head>
<body>
<h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1>

<div class="selector-area">
  <!-- 1) Directory input with webkitdirectory, mozdirectory -->
  <label>Upload ARC Folder:</label>
  <input type="file" id="folderInput"
         webkitdirectory mozdirectory multiple
         onchange="onFolderSelected(event)" />
  <br><br>

  <!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated -->
  <label>Set:</label>
  <select id="setSelect" disabled>
    <option value="train">train</option>
    <option value="test">test</option>
  </select>

  <label> Subset:</label>
  <select id="subsetSelect" disabled>
    <option value="all">all</option>
  </select>

  <button id="loadBtn" disabled>Load</button>
</div>

<div>
  <div id="groupList"></div>
  <div id="puzzleList"></div>
  <div class="puzzle-display" id="puzzleView"></div>
</div>

<!-- 
   3) Use local 'assets/npyjs.js' from your project folder instead of a CDN.
   Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't
   contain "import" statements. 
-->
<script src="assets/npyjs.js"></script>

<script>
/***************************************************************************
 * Global Maps & Variables
 ***************************************************************************/

// Map from "train/all__inputs.npy" => File, etc.
let filesByPath = {};

// Once loaded, we store typed arrays for the chosen set/subset
let inputsArr, labelsArr;
let puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr;
let identifiersJson;

// The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize
let seqLen = 0;
let gridSize = 0;


/***************************************************************************
 * 1) Handle folder selection: read all files, find identifiers.json,
 *    remove topmost folder from each file path, validate.
 ***************************************************************************/
function onFolderSelected(event) {
  filesByPath = {};
  const fileList = event.target.files;
  if (!fileList || fileList.length === 0) {
    alert("No files selected!");
    return;
  }

  // We'll gather all webkitRelativePaths
  const paths = [];
  for (let i = 0; i < fileList.length; i++) {
    // Typically "arc-aug-10/train/all__inputs.npy", etc.
    const file = fileList[i];
    const relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
    paths.push(relPath);
  }

  // 1. Check if we have "identifiers.json" somewhere.
  const idPath = paths.find(p => p.endsWith("identifiers.json"));
  if (!idPath) {
    alert("Error: No 'identifiers.json' found in the uploaded folder.");
    return;
  }

  // 2. Derive the top-level directory from that file's path
  //    e.g. if idPath = "arc-aug-10/identifiers.json", topDir = "arc-aug-10"
  //    If there's no slash, topDir = "" => do nothing
  let topDir = "";
  const lastSlash = idPath.lastIndexOf("/");
  if (lastSlash >= 0) {
    topDir = idPath.substring(0, lastSlash);
  }

  // 3. Rebuild filesByPath with the top folder removed.
  //    For example, if topDir = "arc-aug-10", then "arc-aug-10/train/all__inputs.npy"
  //    becomes "train/all__inputs.npy"
  for (let i = 0; i < fileList.length; i++) {
    const file = fileList[i];
    let relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
    // If relPath starts with "arc-aug-10/", remove that prefix
    if (topDir && relPath.startsWith(topDir + "/")) {
      relPath = relPath.substring(topDir.length + 1);
    }
    filesByPath[relPath] = file;
  }

  // Enable set/subset selection and "Load"
  document.getElementById("setSelect").disabled = false;
  document.getElementById("subsetSelect").disabled = false;
  document.getElementById("loadBtn").disabled = false;
}

// When user clicks "Load," parse the .npy for the chosen set/subset
document.getElementById("loadBtn").addEventListener("click", async () => {
  document.getElementById("groupList").innerHTML = "";
  document.getElementById("puzzleList").innerHTML = "";
  document.getElementById("puzzleView").innerHTML = "";

  const setName = document.getElementById("setSelect").value;        // e.g. "train"
  const subsetName = document.getElementById("subsetSelect").value;  // e.g. "all"

  try {
    await loadDataset(setName, subsetName);
    buildGroupList(); // show groups
  } catch (err) {
    console.error(err);
    alert("Error while loading dataset: " + err);
  }
});


/***************************************************************************
 * 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer)
 ***************************************************************************/
async function loadDataset(setName, subsetName) {
  const prefix = `${setName}/${subsetName}__`;
  // e.g. "train/all__inputs.npy"
  const inputsPath  = prefix + "inputs.npy";
  const labelsPath  = prefix + "labels.npy";
  const pIdxPath    = prefix + "puzzle_indices.npy";
  const gIdxPath    = prefix + "group_indices.npy";
  const pIdsPath    = prefix + "puzzle_identifiers.npy";
  const identifiersPath = "identifiers.json";

  // Check existence
  const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath];
  for (const f of needed) {
    if (!filesByPath[f]) {
      throw new Error(`Missing file: ${f}`);
    }
  }

  // parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array
  const inputsNpy       = await parseNpy(filesByPath[inputsPath]);
  const labelsNpy       = await parseNpy(filesByPath[labelsPath]);
  const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]);
  const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]);
  const puzzleIdsNpy    = await parseNpy(filesByPath[pIdsPath]);

  inputsArr            = inputsNpy.data;
  labelsArr            = labelsNpy.data;
  puzzleIndicesArr     = puzzleIndicesNpy.data;
  groupIndicesArr      = groupIndicesNpy.data;
  puzzleIdentifiersArr = puzzleIdsNpy.data;

  // shape e.g. [N_examples, seqLen]
  seqLen   = inputsNpy.shape[1];
  gridSize = Math.sqrt(seqLen);

  // read JSON
  identifiersJson = await readJsonFile(filesByPath[identifiersPath]);
}

/***************************************************************************
 * parseNpy => read a File as ArrayBuffer, parse with npyjs
 ***************************************************************************/
function parseNpy(file) {
  return new Promise((resolve, reject) => {
    const reader = new FileReader();
    reader.onload = async () => {
      try {
        const arrayBuffer = reader.result;
        const npy = new npyjs();
        resolve(await npy.parse(arrayBuffer));
      } catch (err) {
        reject(err);
      }
    };
    reader.onerror = err => reject(err);
    reader.readAsArrayBuffer(file);
  });
}

/***************************************************************************
 * readJsonFile => read a local JSON file into object
 ***************************************************************************/
function readJsonFile(file) {
  return new Promise((resolve, reject) => {
    const reader = new FileReader();
    reader.onload = () => {
      try {
        const obj = JSON.parse(reader.result);
        resolve(obj);
      } catch (err) {
        reject(err);
      }
    };
    reader.onerror = (err) => reject(err);
    reader.readAsText(file);
  });
}

/***************************************************************************
 * 3) Build group list in UI
 ***************************************************************************/
function buildGroupList() {
  document.getElementById("groupList").innerHTML = "<h3>Groups</h3>";
  const groupListDiv = document.getElementById("groupList");

  const nGroups = groupIndicesArr.length - 1;
  for (let g = 0; g < nGroups; g++) {
    const div = document.createElement("span");
    div.className = "group-item";
    div.textContent = `Group ${g}`;
    div.onclick = () => onSelectGroup(g);
    groupListDiv.appendChild(div);
  }
}

/***************************************************************************
 * onSelectGroup => show puzzles in that group
 ***************************************************************************/
function onSelectGroup(groupIndex) {
  document.getElementById("puzzleList").innerHTML = "";
  document.getElementById("puzzleView").innerHTML = "";

  const puzzleListDiv = document.getElementById("puzzleList");
  puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`;

  const firstPuzzle = groupIndicesArr[groupIndex];
  const lastPuzzle  = groupIndicesArr[groupIndex + 1];

  for (let p = firstPuzzle; p < lastPuzzle; p++) {
    const puzzleIntId = puzzleIdentifiersArr[p];
    const puzzleStrId = (puzzleIntId < identifiersJson.length)
                        ? identifiersJson[puzzleIntId]
                        : "<unknown>";

    const div = document.createElement("span");
    div.className = "puzzle-item";
    div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`;
    div.onclick = () => onSelectPuzzle(p);
    puzzleListDiv.appendChild(div);
  }
}

/***************************************************************************
 * onSelectPuzzle => show each example
 ***************************************************************************/
function onSelectPuzzle(puzzleIndex) {
  const puzzleView = document.getElementById("puzzleView");
  puzzleView.innerHTML = "";

  // puzzle ID
  const puzzleIntId = puzzleIdentifiersArr[puzzleIndex];
  const puzzleStrId = (puzzleIntId < identifiersJson.length)
                      ? identifiersJson[puzzleIntId]
                      : "<unknown>";

  const titleDiv = document.createElement("div");
  titleDiv.className = "puzzle-id";
  titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`;
  puzzleView.appendChild(titleDiv);

  // Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1])
  const firstExample = puzzleIndicesArr[puzzleIndex];
  const lastExample  = puzzleIndicesArr[puzzleIndex + 1];

  for (let e = firstExample; e < lastExample; e++) {
    const inputSeq  = slice1D(inputsArr,  e*seqLen, (e+1)*seqLen);
    const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen);

    const inputGrid  = decodeGrid(inputSeq);
    const outputGrid = decodeGrid(outputSeq);

    const exDiv = document.createElement("div");
    exDiv.className = "example-container";
    exDiv.appendChild(document.createTextNode(`Example ${e}`));
    exDiv.appendChild(document.createElement("br"));

    exDiv.appendChild(renderGrid(inputGrid));
    exDiv.appendChild(renderGrid(outputGrid));

    puzzleView.appendChild(exDiv);
  }
}

/***************************************************************************
 * slice1D => typed array slicing
 ***************************************************************************/
function slice1D(arr, start, end) {
  const result = new Uint32Array(end - start);
  for (let i = start; i < end; i++) {
    result[i - start] = Number(arr[i]);
  }
  return result;
}

/***************************************************************************
 * decodeGrid => turn the flattened seq of length=gridSize^2 into 2D
 ***************************************************************************/
function decodeGrid(seq) {
  const grid = [];
  let idx = 0;
  for (let r = 0; r < gridSize; r++) {
    const row = [];
    for (let c = 0; c < gridSize; c++) {
      row.push(seq[idx]);
      idx++;
    }
    grid.push(row);
  }
  return grid;
}

/***************************************************************************
 * renderGrid => draws a 2D grid to <canvas>
 ***************************************************************************/
function renderGrid(grid2d) {
  const rows = grid2d.length;
  const cols = grid2d[0].length;
  const scale = 10;

  const canvas = document.createElement("canvas");
  canvas.width  = cols * scale;
  canvas.height = rows * scale;
  canvas.className = "grid-canvas";
  const ctx = canvas.getContext("2d");

  for (let r = 0; r < rows; r++) {
    for (let c = 0; c < cols; c++) {
      const val = grid2d[r][c];
      ctx.fillStyle = indexToColor(val);
      ctx.fillRect(c * scale, r * scale, scale, scale);
    }
  }
  return canvas;
}

/***************************************************************************
 * indexToColor => color palette: 
 *   0 => pad => white
 *   1 => eos => light gray
 *   2..11 => original color(0..9)
 ***************************************************************************/
function indexToColor(value) {
  if (value === 0) return "#FFFFFF"; // pad => white
  if (value === 1) return "#DDDDDD"; // eos => light gray

  // shift by 2 => original color in [0..9]
  const colorIdx = value - 2;
  const palette = [
    "#000000", // color0 => black
    "#FF0000", // color1 => red
    "#00FF00", // color2 => green
    "#0000FF", // color3 => blue
    "#FFFF00", // color4 => yellow
    "#FFA500", // color5 => orange
    "#800080", // color6 => purple
    "#00FFFF", // color7 => cyan
    "#FFC0CB", // color8 => pink
    "#808080"  // color9 => gray
  ];
  if (colorIdx >= 0 && colorIdx < palette.length) {
    return palette[colorIdx];
  }
  return "#FFFFFF"; // fallback
}
</script>
</body>
</html>


================================================
FILE: requirements.txt
================================================
torch
adam-atan2
einops
tqdm
coolname
pydantic
argdantic
wandb
omegaconf
hydra-core
huggingface_hub


================================================
FILE: utils/functions.py
================================================
import importlib
import inspect


def load_model_class(identifier: str, prefix: str = "models."):
    module_path, class_name = identifier.split('@')

    # Import the module
    module = importlib.import_module(prefix + module_path)
    cls = getattr(module, class_name)
    
    return cls


def get_model_source_path(identifier: str, prefix: str = "models."):
    module_path, class_name = identifier.split('@')

    module = importlib.import_module(prefix + module_path)
    return inspect.getsourcefile(module)
Download .txt
gitextract_7bpy0bpn/

├── .gitignore
├── .gitmodules
├── .vscode/
│   ├── launch.json
│   └── settings.json
├── LICENSE
├── README.md
├── arc_eval.ipynb
├── assets/
│   └── npyjs.js
├── config/
│   ├── arch/
│   │   └── hrm_v1.yaml
│   └── cfg_pretrain.yaml
├── dataset/
│   ├── build_arc_dataset.py
│   ├── build_maze_dataset.py
│   ├── build_sudoku_dataset.py
│   └── common.py
├── evaluate.py
├── models/
│   ├── common.py
│   ├── hrm/
│   │   └── hrm_act_v1.py
│   ├── layers.py
│   ├── losses.py
│   └── sparse_embedding.py
├── pretrain.py
├── puzzle_dataset.py
├── puzzle_visualizer.html
├── requirements.txt
└── utils/
    └── functions.py
Download .txt
SYMBOL INDEX (109 symbols across 14 files)

FILE: assets/npyjs.js
  class npyjs (line 1) | class npyjs {
    method constructor (line 3) | constructor(opts) {
    method float16ToFloat32Array (line 78) | float16ToFloat32Array(float16Array) {
    method float16ToFloat32 (line 89) | static float16ToFloat32(float16) {
    method parse (line 116) | parse(arrayBufferContents) {
    method load (line 155) | async load(filename, callback, fetchArgs) {

FILE: dataset/build_arc_dataset.py
  class DataProcessConfig (line 19) | class DataProcessConfig(BaseModel):
  class ARCPuzzle (line 37) | class ARCPuzzle:
  function arc_grid_to_np (line 43) | def arc_grid_to_np(grid: List[List[int]]):
  function np_grid_to_seq_translational_augment (line 54) | def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarra...
  function puzzle_hash (line 81) | def puzzle_hash(puzzle: dict):
  function convert_single_arc_puzzle (line 98) | def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: ...
  function load_puzzles_arcagi (line 148) | def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataPr...
  function convert_dataset (line 184) | def convert_dataset(config: DataProcessConfig):
  function main (line 286) | def main(config: DataProcessConfig):

FILE: dataset/build_maze_dataset.py
  class DataProcessConfig (line 22) | class DataProcessConfig(BaseModel):
  function convert_subset (line 30) | def convert_subset(set_name: str, config: DataProcessConfig):
  function preprocess_data (line 136) | def preprocess_data(config: DataProcessConfig):

FILE: dataset/build_sudoku_dataset.py
  class DataProcessConfig (line 18) | class DataProcessConfig(BaseModel):
  function shuffle_sudoku (line 27) | def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
  function convert_subset (line 60) | def convert_subset(set_name: str, config: DataProcessConfig):
  function preprocess_data (line 163) | def preprocess_data(config: DataProcessConfig):

FILE: dataset/common.py
  class PuzzleDatasetMetadata (line 12) | class PuzzleDatasetMetadata(pydantic.BaseModel):
  function dihedral_transform (line 27) | def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
  function inverse_dihedral_transform (line 50) | def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:

FILE: evaluate.py
  class EvalConfig (line 13) | class EvalConfig(pydantic.BaseModel):
  function launch (line 19) | def launch():

FILE: models/common.py
  function trunc_normal_init_ (line 7) | def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: fl...

FILE: models/hrm/hrm_act_v1.py
  class HierarchicalReasoningModel_ACTV1InnerCarry (line 16) | class HierarchicalReasoningModel_ACTV1InnerCarry:
  class HierarchicalReasoningModel_ACTV1Carry (line 22) | class HierarchicalReasoningModel_ACTV1Carry:
  class HierarchicalReasoningModel_ACTV1Config (line 31) | class HierarchicalReasoningModel_ACTV1Config(BaseModel):
  class HierarchicalReasoningModel_ACTV1Block (line 60) | class HierarchicalReasoningModel_ACTV1Block(nn.Module):
    method __init__ (line 61) | def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> ...
    method forward (line 77) | def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> tor...
  class HierarchicalReasoningModel_ACTV1ReasoningModule (line 86) | class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
    method __init__ (line 87) | def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
    method forward (line 92) | def forward(self, hidden_states: torch.Tensor, input_injection: torch....
  class HierarchicalReasoningModel_ACTV1_Inner (line 102) | class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
    method __init__ (line 103) | def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> ...
    method _input_embeddings (line 146) | def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: t...
    method empty_carry (line 168) | def empty_carry(self, batch_size: int):
    method reset_carry (line 174) | def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalRea...
    method forward (line 180) | def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, b...
  class HierarchicalReasoningModel_ACTV1 (line 216) | class HierarchicalReasoningModel_ACTV1(nn.Module):
    method __init__ (line 219) | def __init__(self, config_dict: dict):
    method puzzle_emb (line 225) | def puzzle_emb(self):
    method initial_carry (line 228) | def initial_carry(self, batch: Dict[str, torch.Tensor]):
    method forward (line 240) | def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch:...

FILE: models/layers.py
  function _find_multiple (line 19) | def _find_multiple(a, b):
  function rotate_half (line 23) | def rotate_half(x: torch.Tensor):
  function apply_rotary_pos_emb (line 30) | def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Te...
  class CastedLinear (line 43) | class CastedLinear(nn.Module):
    method __init__ (line 44) | def __init__(self,
    method forward (line 58) | def forward(self, input: torch.Tensor) -> torch.Tensor:
  class CastedEmbedding (line 62) | class CastedEmbedding(nn.Module):
    method __init__ (line 63) | def __init__(self,
    method forward (line 76) | def forward(self, input: torch.Tensor) -> torch.Tensor:
  class RotaryEmbedding (line 80) | class RotaryEmbedding(nn.Module):
    method __init__ (line 81) | def __init__(self, dim, max_position_embeddings, base, device=None):
    method forward (line 94) | def forward(self):
  class Attention (line 98) | class Attention(nn.Module):
    method __init__ (line 99) | def __init__(self, hidden_size, head_dim, num_heads, num_key_value_hea...
    method forward (line 112) | def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> tor...
  class SwiGLU (line 138) | class SwiGLU(nn.Module):
    method __init__ (line 139) | def __init__(self, hidden_size: int, expansion: float):
    method forward (line 146) | def forward(self, x):
  function rms_norm (line 151) | def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> to...

FILE: models/losses.py
  function s (line 11) | def s(x, epsilon=1e-30):
  function log_stablemax (line 19) | def log_stablemax(x, dim=-1):
  function stablemax_cross_entropy (line 24) | def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
  function softmax_cross_entropy (line 34) | def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
  class ACTLossHead (line 40) | class ACTLossHead(nn.Module):
    method __init__ (line 41) | def __init__(self, model: nn.Module, loss_type: str):
    method initial_carry (line 46) | def initial_carry(self, *args, **kwargs):
    method forward (line 49) | def forward(

FILE: models/sparse_embedding.py
  class CastedSparseEmbedding (line 11) | class CastedSparseEmbedding(nn.Module):
    method __init__ (line 12) | def __init__(self, num_embeddings: int, embedding_dim: int, batch_size...
    method forward (line 28) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  class CastedSparseEmbeddingSignSGD_Distributed (line 41) | class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
    method __init__ (line 42) | def __init__(
    method step (line 63) | def step(self, closure=None):  # type: ignore
  function _sparse_emb_signsgd_dist (line 98) | def _sparse_emb_signsgd_dist(

FILE: pretrain.py
  class LossConfig (line 26) | class LossConfig(pydantic.BaseModel):
  class ArchConfig (line 32) | class ArchConfig(pydantic.BaseModel):
  class PretrainConfig (line 39) | class PretrainConfig(pydantic.BaseModel):
  class TrainState (line 74) | class TrainState:
  function create_dataloader (line 84) | def create_dataloader(config: PretrainConfig, split: str, rank: int, wor...
  function create_model (line 108) | def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMe...
  function cosine_schedule_with_warmup_lr_lambda (line 162) | def cosine_schedule_with_warmup_lr_lambda(
  function init_train_state (line 172) | def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatas...
  function save_train_state (line 190) | def save_train_state(config: PretrainConfig, train_state: TrainState):
  function compute_lr (line 199) | def compute_lr(base_lr: float, config: PretrainConfig, train_state: Trai...
  function train_batch (line 209) | def train_batch(config: PretrainConfig, train_state: TrainState, batch: ...
  function evaluate (line 266) | def evaluate(config: PretrainConfig, train_state: TrainState, eval_loade...
  function save_code_and_config (line 333) | def save_code_and_config(config: PretrainConfig):
  function load_synced_config (line 359) | def load_synced_config(hydra_config: DictConfig, rank: int, world_size: ...
  function launch (line 381) | def launch(hydra_config: DictConfig):

FILE: puzzle_dataset.py
  function _sample_batch (line 14) | def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puz...
  class PuzzleDatasetConfig (line 41) | class PuzzleDatasetConfig(pydantic.BaseModel):
  class PuzzleDataset (line 53) | class PuzzleDataset(IterableDataset):
    method __init__ (line 54) | def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
    method _load_metadata (line 68) | def _load_metadata(self) -> PuzzleDatasetMetadata:
    method _lazy_load_dataset (line 72) | def _lazy_load_dataset(self):
    method _collate_batch (line 95) | def _collate_batch(self, batch):
    method _iter_test (line 118) | def _iter_test(self):
    method _iter_train (line 151) | def _iter_train(self):
    method __iter__ (line 189) | def __iter__(self):

FILE: utils/functions.py
  function load_model_class (line 5) | def load_model_class(identifier: str, prefix: str = "models."):
  function get_model_source_path (line 15) | def get_model_source_path(identifier: str, prefix: str = "models."):
Condensed preview — 25 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (134K chars).
[
  {
    "path": ".gitignore",
    "chars": 3152,
    "preview": "# WandB\n/wandb/\n# checkpoints\n/checkpoints/\n# cache\n/cache/\n# data\n/data/\n\n# Byte-compiled / optimized / DLL files\n__pyc"
  },
  {
    "path": ".gitmodules",
    "chars": 364,
    "preview": "[submodule \"dataset/raw-data/ConceptARC\"]\n\tpath = dataset/raw-data/ConceptARC\n\turl = git@github.com:victorvikram/Concept"
  },
  {
    "path": ".vscode/launch.json",
    "chars": 778,
    "preview": "{\n    // Use IntelliSense to learn about possible attributes.\n    // Hover to view descriptions of existing attributes.\n"
  },
  {
    "path": ".vscode/settings.json",
    "chars": 54,
    "preview": "{\n    \"python.analysis.typeCheckingMode\": \"standard\"\n}"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "README.md",
    "chars": 7130,
    "preview": "# Hierarchical Reasoning Model\n\n![](./assets/hrm.png)\n\nReasoning, the process of devising and executing complex goal-ori"
  },
  {
    "path": "arc_eval.ipynb",
    "chars": 9317,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": "
  },
  {
    "path": "assets/npyjs.js",
    "chars": 5216,
    "preview": "class npyjs {\n\n    constructor(opts) {\n        if (opts && !('convertFloat16' in opts)) {\n            console.warn([\n   "
  },
  {
    "path": "config/arch/hrm_v1.yaml",
    "chars": 349,
    "preview": "name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1\nloss:\n  name: losses@ACTLossHead\n  loss_type: stablemax_cross_entr"
  },
  {
    "path": "config/cfg_pretrain.yaml",
    "chars": 492,
    "preview": "# ARC training config\n\ndefaults:\n  - arch: hrm_v1\n  - _self_\n\nhydra:\n  output_subdir: null\n\n# Data path\ndata_path: data/"
  },
  {
    "path": "dataset/build_arc_dataset.py",
    "chars": 10084,
    "preview": "from typing import List, Optional, Tuple, Dict\nfrom dataclasses import dataclass\nfrom pathlib import Path\nimport os\nimpo"
  },
  {
    "path": "dataset/build_maze_dataset.py",
    "chars": 4461,
    "preview": "from typing import Optional\nimport math\nimport os\nimport csv\nimport json\nimport numpy as np\n\nfrom argdantic import ArgPa"
  },
  {
    "path": "dataset/build_sudoku_dataset.py",
    "chars": 5753,
    "preview": "from typing import Optional\nimport os\nimport csv\nimport json\nimport numpy as np\n\nfrom argdantic import ArgParser\nfrom py"
  },
  {
    "path": "dataset/common.py",
    "chars": 1381,
    "preview": "from typing import List, Optional\n\nimport pydantic\nimport numpy as np\n\n\n# Global list mapping each dihedral transform id"
  },
  {
    "path": "evaluate.py",
    "chars": 2490,
    "preview": "from typing import List\nimport yaml\nimport os\n\nimport torch\nimport torch.distributed as dist\n\nimport pydantic\nfrom omega"
  },
  {
    "path": "models/common.py",
    "chars": 1216,
    "preview": "import math\n\nimport torch\nfrom torch import nn\n\n\ndef trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: f"
  },
  {
    "path": "models/hrm/hrm_act_v1.py",
    "chars": 12161,
    "preview": "from typing import Tuple, List, Dict, Optional\nfrom dataclasses import dataclass\nimport math\n\nimport torch\nimport torch."
  },
  {
    "path": "models/layers.py",
    "chars": 5654,
    "preview": "from typing import Tuple\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\ntry:\n    from flash_attn_in"
  },
  {
    "path": "models/losses.py",
    "chars": 3804,
    "preview": "from typing import Any, Tuple, Dict, Sequence, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import "
  },
  {
    "path": "models/sparse_embedding.py",
    "chars": 4364,
    "preview": "from typing import Union\n\nimport torch\nfrom torch import nn\nimport torch.distributed as dist\nfrom torch.optim.optimizer "
  },
  {
    "path": "pretrain.py",
    "chars": 15607,
    "preview": "from typing import Optional, Any, Sequence, List\nfrom dataclasses import dataclass\nimport os\nimport math\nimport yaml\nimp"
  },
  {
    "path": "puzzle_dataset.py",
    "chars": 7980,
    "preview": "import os\nimport json\n\nimport numpy as np\nimport pydantic\n\nimport torch\nfrom torch.utils.data import IterableDataset, ge"
  },
  {
    "path": "puzzle_visualizer.html",
    "chars": 14113,
    "preview": "<!DOCTYPE html>\n<html>\n<head>\n  <meta charset=\"UTF-8\" />\n  <title>ARC‐Converted Dataset Visualizer (Upload Local Folder)"
  },
  {
    "path": "requirements.txt",
    "chars": 100,
    "preview": "torch\nadam-atan2\neinops\ntqdm\ncoolname\npydantic\nargdantic\nwandb\nomegaconf\nhydra-core\nhuggingface_hub\n"
  },
  {
    "path": "utils/functions.py",
    "chars": 516,
    "preview": "import importlib\nimport inspect\n\n\ndef load_model_class(identifier: str, prefix: str = \"models.\"):\n    module_path, class"
  }
]

About this extraction

This page contains the full source code of the sapientinc/HRM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 25 files (124.9 KB), approximately 31.4k tokens, and a symbol index with 109 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!