Repository: sapientinc/HRM
Branch: main
Commit: 42410daaaf6a
Files: 25
Total size: 124.9 KB
Directory structure:
gitextract_7bpy0bpn/
├── .gitignore
├── .gitmodules
├── .vscode/
│ ├── launch.json
│ └── settings.json
├── LICENSE
├── README.md
├── arc_eval.ipynb
├── assets/
│ └── npyjs.js
├── config/
│ ├── arch/
│ │ └── hrm_v1.yaml
│ └── cfg_pretrain.yaml
├── dataset/
│ ├── build_arc_dataset.py
│ ├── build_maze_dataset.py
│ ├── build_sudoku_dataset.py
│ └── common.py
├── evaluate.py
├── models/
│ ├── common.py
│ ├── hrm/
│ │ └── hrm_act_v1.py
│ ├── layers.py
│ ├── losses.py
│ └── sparse_embedding.py
├── pretrain.py
├── puzzle_dataset.py
├── puzzle_visualizer.html
├── requirements.txt
└── utils/
└── functions.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# WandB
/wandb/
# checkpoints
/checkpoints/
# cache
/cache/
# data
/data/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
================================================
FILE: .gitmodules
================================================
[submodule "dataset/raw-data/ConceptARC"]
path = dataset/raw-data/ConceptARC
url = git@github.com:victorvikram/ConceptARC.git
[submodule "dataset/raw-data/ARC-AGI"]
path = dataset/raw-data/ARC-AGI
url = git@github.com:fchollet/ARC-AGI.git
[submodule "dataset/raw-data/ARC-AGI-2"]
path = dataset/raw-data/ARC-AGI-2
url = git@github.com:arcprize/ARC-AGI-2.git
================================================
FILE: .vscode/launch.json
================================================
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
},
{
"name": "Debug: Single GPU",
"type": "debugpy",
"request": "launch",
"program": "pretrain.py",
"args": [],
"env": {
"OMP_NUM_THREADS": "1",
"DISABLE_COMPILE": "true"
}
}
]
}
================================================
FILE: .vscode/settings.json
================================================
{
"python.analysis.typeCheckingMode": "standard"
}
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Hierarchical Reasoning Model

Reasoning, the process of devising and executing complex goal-oriented action sequences, remains a critical challenge in AI.
Current large language models (LLMs) primarily employ Chain-of-Thought (CoT) techniques, which suffer from brittle task decomposition, extensive data requirements, and high latency. Inspired by the hierarchical and multi-timescale processing in the human brain, we propose the Hierarchical Reasoning Model (HRM), a novel recurrent architecture that attains significant computational depth while maintaining both training stability and efficiency.
HRM executes sequential reasoning tasks in a single forward pass without explicit supervision of the intermediate process, through two interdependent recurrent modules: a high-level module responsible for slow, abstract planning, and a low-level module handling rapid, detailed computations. With only 27 million parameters, HRM achieves exceptional performance on complex reasoning tasks using only 1000 training samples. The model operates without pre-training or CoT data, yet achieves nearly perfect performance on challenging tasks including complex Sudoku puzzles and optimal path finding in large mazes.
Furthermore, HRM outperforms much larger models with significantly longer context windows on the Abstraction and Reasoning Corpus (ARC), a key benchmark for measuring artificial general intelligence capabilities.
These results underscore HRM’s potential as a transformative advancement toward universal computation and general-purpose reasoning systems.
**Join our Discord Community: [https://discord.gg/sapient](https://discord.gg/sapient)**
## Quick Start Guide 🚀
### Prerequisites ⚙️
Ensure PyTorch and CUDA are installed. The repo needs CUDA extensions to be built. If not present, run the following commands:
```bash
# Install CUDA 12.6
CUDA_URL=https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run
wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer.run $CUDA_URL
sudo sh cuda_installer.run --silent --toolkit --override
export CUDA_HOME=/usr/local/cuda-12.6
# Install PyTorch with CUDA 12.6
PYTORCH_INDEX_URL=https://download.pytorch.org/whl/cu126
pip3 install torch torchvision torchaudio --index-url $PYTORCH_INDEX_URL
# Additional packages for building extensions
pip3 install packaging ninja wheel setuptools setuptools-scm
```
Then install FlashAttention. For Hopper GPUs, install FlashAttention 3
```bash
git clone git@github.com:Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
For Ampere or earlier GPUs, install FlashAttention 2
```bash
pip3 install flash-attn
```
## Install Python Dependencies 🐍
```bash
pip install -r requirements.txt
```
## W&B Integration 📈
This project uses [Weights & Biases](https://wandb.ai/) for experiment tracking and metric visualization. Ensure you're logged in:
```bash
wandb login
```
## Run Experiments
### Quick Demo: Sudoku Solver 💻🗲
Train a master-level Sudoku AI capable of solving extremely difficult puzzles on a modern laptop GPU. 🧩
```bash
# Download and build Sudoku dataset
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
# Start training (single GPU, smaller batch size)
OMP_NUM_THREADS=8 python pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 global_batch_size=384 lr=7e-5 puzzle_emb_lr=7e-5 weight_decay=1.0 puzzle_emb_weight_decay=1.0
```
Runtime: ~10 hours on a RTX 4070 laptop GPU
## Trained Checkpoints 🚧
- [ARC-AGI-2](https://huggingface.co/sapientinc/HRM-checkpoint-ARC-2)
- [Sudoku 9x9 Extreme (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-sudoku-extreme)
- [Maze 30x30 Hard (1000 examples)](https://huggingface.co/sapientinc/HRM-checkpoint-maze-30x30-hard)
To use the checkpoints, see Evaluation section below.
## Full-scale Experiments 🔵
Experiments below assume an 8-GPU setup.
### Dataset Preparation
```bash
# Initialize submodules
git submodule update --init --recursive
# ARC-1
python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples
# ARC-2
python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples
# Sudoku-Extreme
python dataset/build_sudoku_dataset.py # Full version
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples
# Maze
python dataset/build_maze_dataset.py # 1000 examples
```
### Dataset Visualization
Explore the puzzles visually:
* Open `puzzle_visualizer.html` in your browser.
* Upload the generated dataset folder located in `data/...`.
## Launch experiments
### Small-sample (1K)
ARC-1:
```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
```
*Runtime:* ~24 hours
ARC-2:
```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
```
*Runtime:* ~24 hours (checkpoint after 8 hours is often sufficient)
Sudoku Extreme (1k):
```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-extreme-1k-aug-1000 epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
```
*Runtime:* ~10 minutes
Maze 30x30 Hard (1k):
```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/maze-30x30-hard-1k epochs=20000 eval_interval=2000 lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0
```
*Runtime:* ~1 hour
### Full Sudoku-Hard
```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/sudoku-hard-full epochs=100 eval_interval=10 lr_min_ratio=0.1 global_batch_size=2304 lr=3e-4 puzzle_emb_lr=3e-4 weight_decay=0.1 puzzle_emb_weight_decay=0.1 arch.loss.loss_type=softmax_cross_entropy arch.L_cycles=8 arch.halt_max_steps=8 arch.pos_encodings=learned
```
*Runtime:* ~2 hours
## Evaluation
Evaluate your trained models:
* Check `eval/exact_accuracy` in W&B.
* For ARC-AGI, follow these additional steps:
```bash
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 evaluate.py checkpoint=<CHECKPOINT_PATH>
```
* Then use the provided `arc_eval.ipynb` notebook to finalize and inspect your results.
## Notes
- Small-sample learning typically exhibits accuracy variance of around ±2 points.
- For Sudoku-Extreme (1,000-example dataset), late-stage overfitting may cause numerical instability during training and Q-learning. It is advisable to use early stopping once the training accuracy approaches 100%.
## Citation 📜
```bibtex
@misc{wang2025hierarchicalreasoningmodel,
title={Hierarchical Reasoning Model},
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
year={2025},
eprint={2506.21734},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2506.21734},
}
```
================================================
FILE: arc_eval.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from glob import glob\n",
"import hashlib\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.colors as mcolors\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"from numba import njit\n",
"\n",
"from dataset.common import inverse_dihedral_transform\n",
"\n",
"\n",
"DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n",
"# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n",
"\n",
"CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n",
"\n",
"\n",
"PAD_PUZZLE_IDENTIFIER = 0\n",
"\n",
"# Visualization\n",
"ARC_COLOR_MAP = mcolors.ListedColormap([\n",
" \"#000000\", # symbol_0: black\n",
" \"#0074D9\", # symbol_1: blue\n",
" \"#FF4136\", # symbol_2: red\n",
" \"#2ECC40\", # symbol_3: green\n",
" \"#FFDC00\", # symbol_4: yellow\n",
" \"#AAAAAA\", # symbol_5: grey\n",
" \"#F012BE\", # symbol_6: fuschia\n",
" \"#FF851B\", # symbol_7: orange\n",
" \"#7FDBFF\", # symbol_8: teal\n",
" \"#870C25\" # symbol_9: brown\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n",
" # Load puzzle identifiers\n",
" with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n",
" identifier_map = json.load(f)\n",
" \n",
" # Load preds\n",
" all_preds = {}\n",
" for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n",
" preds = torch.load(filename)\n",
" for k, v in preds.items():\n",
" all_preds.setdefault(k, [])\n",
" all_preds[k].append(v)\n",
" \n",
" del preds\n",
"\n",
" all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n",
" \n",
" # Remove paddings\n",
" mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n",
" all_preds = {k: v[mask] for k, v in all_preds.items()}\n",
"\n",
" return identifier_map, all_preds\n",
"\n",
"\n",
"def inverse_aug(name: str, grid: np.ndarray):\n",
" if \"_\" not in name:\n",
" return grid\n",
"\n",
" trans_id, perm = name.split(\"_\")[-2:]\n",
" trans_id = int(trans_id[1:]) # Remove \"t\" letter\n",
" inv_perm = np.argsort(list(perm))\n",
" \n",
" return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n",
"\n",
"\n",
"def grid_hash(grid: np.ndarray):\n",
" return hash((grid.tobytes(), grid.shape))\n",
"\n",
"\n",
"@njit\n",
"def crop(grid: np.ndarray):\n",
" # Find maximum-sized rectangle without any EOS token inside.\n",
" grid = grid.reshape(30, 30)\n",
"\n",
" max_area = 0\n",
" max_size = (0, 0)\n",
" nr, nc = grid.shape\n",
" \n",
" num_c = nc\n",
" for num_r in range(1, nr + 1):\n",
" # Scan for maximum c\n",
" for c in range(1, num_c + 1):\n",
" x = grid[num_r - 1, c - 1]\n",
" if (x < 2) | (x > 11):\n",
" num_c = c - 1\n",
" break\n",
" \n",
" area = num_r * num_c\n",
" if area > max_area:\n",
" max_area = area\n",
" max_size = (num_r, num_c)\n",
"\n",
" return grid[:max_size[0], :max_size[1]] - 2\n",
"\n",
"\n",
"def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n",
" identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n",
" \n",
" global_hmap = {}\n",
" \n",
" # Get puzzles and corresponding answers\n",
" puzzle_labels = {}\n",
" for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n",
" name = identifier_map[identifier]\n",
" if \"_\" not in name: # Not-augmented\n",
" puzzle_labels.setdefault(name, {})\n",
" \n",
" input = crop(input.numpy())\n",
" label = crop(label.numpy())\n",
"\n",
" input_hash = grid_hash(input)\n",
" label_hash = grid_hash(label)\n",
"\n",
" global_hmap[input_hash] = input\n",
" global_hmap[label_hash] = label\n",
"\n",
" assert input_hash not in puzzle_labels[name]\n",
" puzzle_labels[name][input_hash] = label_hash\n",
" \n",
" print (\"Number of puzzles\", len(puzzle_labels))\n",
" \n",
" # Argmax prediction\n",
" preds = all_preds[\"logits\"].argmax(-1)\n",
"\n",
" # Collate\n",
" pred_answers = {}\n",
" for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n",
" name = identifier_map[identifier]\n",
" orig_name = name.split(\"_\")[0]\n",
" \n",
" input = input.numpy()\n",
" input_hash = grid_hash(inverse_aug(name, crop(input)))\n",
" assert input_hash in puzzle_labels[orig_name]\n",
" \n",
" pred = inverse_aug(name, crop(pred.numpy()))\n",
" pred_hash = grid_hash(pred)\n",
" global_hmap[pred_hash] = pred\n",
" \n",
" pred_answers.setdefault(orig_name, {})\n",
" pred_answers[orig_name].setdefault(input_hash, [])\n",
" pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n",
"\n",
" # test-1\n",
" if visualize:\n",
" num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n",
" fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n",
" \n",
" fig_id = 0\n",
" \n",
" correct = [0 for _ in range(len(Ks))]\n",
" for name, tests in puzzle_labels.items():\n",
" num_test_correct = [0 for _ in range(len(Ks))]\n",
" for input_hash, label_hash in tests.items():\n",
" p = pred_answers[name][input_hash]\n",
" p_map = {}\n",
" \n",
" for h, q in p:\n",
" p_map.setdefault(h, [0, 0])\n",
" p_map[h][0] += 1\n",
" p_map[h][1] += q\n",
" \n",
" for h, stats in p_map.items():\n",
" stats[1] /= stats[0]\n",
" \n",
" p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n",
"\n",
" # 2-vote\n",
" for i, k in enumerate(Ks):\n",
" ok = False\n",
" for h, stats in p_map[:k]:\n",
" ok |= h == label_hash\n",
" \n",
" num_test_correct[i] += ok\n",
"\n",
" if visualize:\n",
" # Show input and ground truth\n",
" axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n",
" axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n",
" axes[fig_id, 0].axis('off')\n",
" \n",
" axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n",
" axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n",
" axes[fig_id, 1].axis('off')\n",
" \n",
" trial_id = 2\n",
" for h, stats in p_map[:2]:\n",
" ans = global_hmap[h]\n",
" \n",
" axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n",
" axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n",
" axes[fig_id, trial_id].axis('off')\n",
" \n",
" trial_id += 1\n",
" \n",
" fig_id += 1\n",
" \n",
" # Total correctness\n",
" for i in range(len(Ks)):\n",
" correct[i] += num_test_correct[i] == len(tests)\n",
"\n",
" for i, k in enumerate(Ks):\n",
" print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n",
"\n",
"\n",
"test(visualize=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: assets/npyjs.js
================================================
class npyjs {
constructor(opts) {
if (opts && !('convertFloat16' in opts)) {
console.warn([
"npyjs constructor now accepts {convertFloat16?: boolean}.",
"For usage, go to https://github.com/jhuapl-boss/npyjs."
].join(" "));
}
this.convertFloat16 = opts?.convertFloat16 ?? true;
this.dtypes = {
"<u1": {
name: "uint8",
size: 8,
arrayConstructor: Uint8Array,
},
"|u1": {
name: "uint8",
size: 8,
arrayConstructor: Uint8Array,
},
"<u2": {
name: "uint16",
size: 16,
arrayConstructor: Uint16Array,
},
"|i1": {
name: "int8",
size: 8,
arrayConstructor: Int8Array,
},
"<i2": {
name: "int16",
size: 16,
arrayConstructor: Int16Array,
},
"<u4": {
name: "uint32",
size: 32,
arrayConstructor: Uint32Array,
},
"<i4": {
name: "int32",
size: 32,
arrayConstructor: Int32Array,
},
"<u8": {
name: "uint64",
size: 64,
arrayConstructor: BigUint64Array,
},
"<i8": {
name: "int64",
size: 64,
arrayConstructor: BigInt64Array,
},
"<f4": {
name: "float32",
size: 32,
arrayConstructor: Float32Array
},
"<f8": {
name: "float64",
size: 64,
arrayConstructor: Float64Array
},
"<f2": {
name: "float16",
size: 16,
arrayConstructor: Uint16Array,
converter: this.convertFloat16 ? this.float16ToFloat32Array : undefined
},
};
}
float16ToFloat32Array(float16Array) {
const length = float16Array.length;
const float32Array = new Float32Array(length);
for (let i = 0; i < length; i++) {
float32Array[i] = npyjs.float16ToFloat32(float16Array[i]);
}
return float32Array;
}
static float16ToFloat32(float16) {
// Extract the parts of the float16
const sign = (float16 >> 15) & 0x1;
const exponent = (float16 >> 10) & 0x1f;
const fraction = float16 & 0x3ff;
// Handle special cases
if (exponent === 0) {
if (fraction === 0) {
// Zero
return sign ? -0 : 0;
}
// Denormalized number
return (sign ? -1 : 1) * Math.pow(2, -14) * (fraction / 0x400);
} else if (exponent === 0x1f) {
if (fraction === 0) {
// Infinity
return sign ? -Infinity : Infinity;
}
// NaN
return NaN;
}
// Normalized number
return (sign ? -1 : 1) * Math.pow(2, exponent - 15) * (1 + fraction / 0x400);
}
parse(arrayBufferContents) {
// const version = arrayBufferContents.slice(6, 8); // Uint8-encoded
const headerLength = new DataView(arrayBufferContents.slice(8, 10)).getUint8(0);
const offsetBytes = 10 + headerLength;
const hcontents = new TextDecoder("utf-8").decode(
new Uint8Array(arrayBufferContents.slice(10, 10 + headerLength))
);
const header = JSON.parse(
hcontents
.toLowerCase() // True -> true
.replace(/'/g, '"')
.replace("(", "[")
.replace(/,*\),*/g, "]")
);
const shape = header.shape;
const dtype = this.dtypes[header.descr];
if (!dtype) {
console.error(`Unsupported dtype: ${header.descr}`);
return null;
}
const nums = new dtype.arrayConstructor(
arrayBufferContents,
offsetBytes
);
// Convert float16 to float32 if converter exists
const data = dtype.converter ? dtype.converter.call(this, nums) : nums;
return {
dtype: dtype.name,
data: data,
shape,
fortranOrder: header.fortran_order
};
}
async load(filename, callback, fetchArgs) {
/*
Loads an array from a stream of bytes.
*/
fetchArgs = fetchArgs || {};
let arrayBuf;
// If filename is ArrayBuffer
if (filename instanceof ArrayBuffer) {
arrayBuf = filename;
}
// If filename is a file path
else {
const resp = await fetch(filename, { ...fetchArgs });
arrayBuf = await resp.arrayBuffer();
}
const result = this.parse(arrayBuf);
if (callback) {
return callback(result);
}
return result;
}
}
================================================
FILE: config/arch/hrm_v1.yaml
================================================
name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
loss:
name: losses@ACTLossHead
loss_type: stablemax_cross_entropy
halt_exploration_prob: 0.1
halt_max_steps: 16
H_cycles: 2
L_cycles: 2
H_layers: 4
L_layers: 4
hidden_size: 512
num_heads: 8 # min(2, hidden_size // 64)
expansion: 4
puzzle_emb_ndim: ${.hidden_size}
pos_encodings: rope
================================================
FILE: config/cfg_pretrain.yaml
================================================
# ARC training config
defaults:
- arch: hrm_v1
- _self_
hydra:
output_subdir: null
# Data path
data_path: data/arc-aug-1000
# Hyperparams - Training
global_batch_size: 768
epochs: 100000
eval_interval: 10000
checkpoint_every_eval: True
lr: 1e-4
lr_min_ratio: 1.0
lr_warmup_steps: 2000
# Standard hyperparameter settings for LM, as used in Llama
beta1: 0.9
beta2: 0.95
weight_decay: 0.1
puzzle_emb_weight_decay: 0.1
# Hyperparams - Puzzle embeddings training
puzzle_emb_lr: 1e-2
================================================
FILE: dataset/build_arc_dataset.py
================================================
from typing import List, Optional, Tuple, Dict
from dataclasses import dataclass
from pathlib import Path
import os
import json
import hashlib
import numpy as np
from glob import glob
from argdantic import ArgParser
from pydantic import BaseModel
from common import PuzzleDatasetMetadata, dihedral_transform
cli = ArgParser()
class DataProcessConfig(BaseModel):
# ARC-1
dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"]
output_dir: str = "data/arc-aug-1000"
# ARC-2
# dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"]
# output_dir: str = "data/arc-2-aug-1000"
seed: int = 42
num_aug: int = 1000
ARCMaxGridSize = 30
ARCAugmentRetriesFactor = 5
@dataclass
class ARCPuzzle:
id: str
examples: List[Tuple[np.ndarray, np.ndarray]]
def arc_grid_to_np(grid: List[List[int]]):
arr = np.array(grid)
# Shape check
assert arr.ndim == 2
assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize
# Element check
assert np.all((arr >= 0) & (arr <= 9))
return arr.astype(np.uint8)
def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):
# PAD: 0, <eos>: 1, digits: 2 ... 11
# Compute random top-left pad
if do_translation:
pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1)
pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1)
else:
pad_r = pad_c = 0
# Pad grid
result = []
for grid in [inp, out]:
nrow, ncol = grid.shape
grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0)
# Add <eos>
eos_row, eos_col = pad_r + nrow, pad_c + ncol
if eos_row < ARCMaxGridSize:
grid[eos_row, pad_c:eos_col] = 1
if eos_col < ARCMaxGridSize:
grid[pad_r:eos_row, eos_col] = 1
result.append(grid.flatten())
return result
def puzzle_hash(puzzle: dict):
# Hash the puzzle for checking equivalence
def _grid_hash(grid: np.ndarray):
buffer = [x.to_bytes(1) for x in grid.shape]
buffer.append(grid.tobytes())
return hashlib.sha256(b"".join(buffer)).hexdigest()
hashes = []
for example_type, example in puzzle.items():
for input, label in example.examples:
hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}")
hashes.sort()
return hashlib.sha256("|".join(hashes).encode()).hexdigest()
def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):
# Remove "name"
name = puzzle.pop("name", default_name)
# Convert
dests = set(dest_mapping.values())
converted = {dest: ARCPuzzle(name, []) for dest in dests}
for example_type, examples in puzzle.items():
dest = dest_mapping[example_type]
converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples])
group = [converted]
# Augment
if aug_count > 0:
hashes = {puzzle_hash(converted)}
for _trial in range(ARCAugmentRetriesFactor * aug_count):
# Augment plan
trans_id = np.random.randint(0, 8)
mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black)
aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}"
def _map_grid(grid: np.ndarray):
return dihedral_transform(mapping[grid], trans_id)
# Check duplicate
augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()}
h = puzzle_hash(augmented)
if h not in hashes:
hashes.add(h)
group.append(augmented)
if len(group) >= aug_count + 1:
break
if len(group) < aug_count + 1:
print (f"[Puzzle {name}] augmentation not full, only {len(group)}")
# Append
for dest in dests:
# Convert the examples
dest_split, dest_set = dest
results.setdefault(dest_split, {})
results[dest_split].setdefault(dest_set, [])
results[dest_split][dest_set].append([converted[dest] for converted in group])
def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):
train_examples_dest = ("train", "all")
test_examples_map = {
"evaluation": [(1.0, ("test", "all"))],
"_default": [(1.0, ("train", "all"))]
}
total_puzzles = 0
for subdir in os.scandir(dataset_path):
if subdir.is_dir():
# Load all puzzles in this directory
puzzles = []
for filename in glob(os.path.join(subdir.path, "*.json")):
with open(filename, "r") as f:
puzzles.append((Path(filename).stem, json.load(f)))
# Shuffle puzzles
np.random.shuffle(puzzles)
# Assign by fraction
for idx, (default_name, puzzle) in enumerate(puzzles):
fraction = idx / len(puzzles)
test_examples_dest = None
for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]):
if fraction < f:
test_examples_dest = dest
break
assert test_examples_dest is not None
convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest})
total_puzzles += 1
print (f"[{dataset_path}] total puzzles: {total_puzzles}")
def convert_dataset(config: DataProcessConfig):
np.random.seed(config.seed)
# Read dataset
data = {}
for dataset_dir in config.dataset_dirs:
load_puzzles_arcagi(data, dataset_dir, config)
# Map global puzzle identifiers
num_identifiers = 1 # 0 is blank
identifier_map = {}
for split_name, split in data.items():
for subset_name, subset in split.items():
for group in subset:
for puzzle in group:
if puzzle.id not in identifier_map:
identifier_map[puzzle.id] = num_identifiers
num_identifiers += 1
print (f"Total puzzle IDs (including <blank>): {num_identifiers}")
# Save
for split_name, split in data.items():
os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True)
# Translational augmentations
enable_translational_augment = split_name == "train"
# Statistics
total_examples = 0
total_puzzles = 0
total_groups = 0
for subset_name, subset in split.items():
# Construct subset
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
results["puzzle_indices"].append(0)
results["group_indices"].append(0)
example_id = 0
puzzle_id = 0
for group in subset:
for puzzle in group:
# Push puzzle
no_aug_id = np.random.randint(0, len(puzzle.examples))
for _idx_ex, (inp, out) in enumerate(puzzle.examples):
inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id)
results["inputs"].append(inp)
results["labels"].append(out)
example_id += 1
total_examples += 1
results["puzzle_indices"].append(example_id)
results["puzzle_identifiers"].append(identifier_map[puzzle.id])
puzzle_id += 1
total_puzzles += 1
# Push group
results["group_indices"].append(puzzle_id)
total_groups += 1
for k, v in results.items():
if k in {"inputs", "labels"}:
v = np.stack(v, 0)
else:
v = np.array(v, dtype=np.int32)
np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v)
# Metadata
metadata = PuzzleDatasetMetadata(
seq_len=ARCMaxGridSize * ARCMaxGridSize,
vocab_size=10 + 2, # PAD + EOS + "0" ... "9"
pad_id=0,
ignore_label_id=0,
blank_identifier_id=0,
num_puzzle_identifiers=num_identifiers,
total_groups=total_groups,
mean_puzzle_examples=total_examples / total_puzzles,
sets=list(split.keys())
)
# Save metadata as JSON.
with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f:
json.dump(metadata.model_dump(), f)
# Save IDs mapping
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
ids_mapping = {v: k for k, v in identifier_map.items()}
json.dump([ids_mapping.get(i, "<blank>") for i in range(num_identifiers)], f)
@cli.command(singleton=True)
def main(config: DataProcessConfig):
convert_dataset(config)
if __name__ == "__main__":
cli()
================================================
FILE: dataset/build_maze_dataset.py
================================================
from typing import Optional
import math
import os
import csv
import json
import numpy as np
from argdantic import ArgParser
from pydantic import BaseModel
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from common import PuzzleDatasetMetadata, dihedral_transform
CHARSET = "# SGo"
cli = ArgParser()
class DataProcessConfig(BaseModel):
source_repo: str = "sapientinc/maze-30x30-hard-1k"
output_dir: str = "data/maze-30x30-hard-1k"
subsample_size: Optional[int] = None
aug: bool = False
def convert_subset(set_name: str, config: DataProcessConfig):
# Read CSV
all_chars = set()
grid_size = None
inputs = []
labels = []
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore
reader = csv.reader(csvfile)
next(reader) # Skip header
for source, q, a, rating in reader:
all_chars.update(q)
all_chars.update(a)
if grid_size is None:
n = int(len(q) ** 0.5)
grid_size = (n, n)
inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size))
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size))
# If subsample_size is specified for the training set,
# randomly sample the desired number of examples.
if set_name == "train" and config.subsample_size is not None:
total_samples = len(inputs)
if config.subsample_size < total_samples:
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
inputs = [inputs[i] for i in indices]
labels = [labels[i] for i in indices]
# Generate dataset
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
puzzle_id = 0
example_id = 0
results["puzzle_indices"].append(0)
results["group_indices"].append(0)
for inp, out in zip(tqdm(inputs), labels):
# Dihedral transformations for augmentation
for aug_idx in range(8 if (set_name == "train" and config.aug) else 1):
results["inputs"].append(dihedral_transform(inp, aug_idx))
results["labels"].append(dihedral_transform(out, aug_idx))
example_id += 1
puzzle_id += 1
results["puzzle_indices"].append(example_id)
results["puzzle_identifiers"].append(0)
# Push group
results["group_indices"].append(puzzle_id)
# Char mappings
assert len(all_chars - set(CHARSET)) == 0
char2id = np.zeros(256, np.uint8)
char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1
# To Numpy
def _seq_to_numpy(seq):
arr = np.vstack([char2id[s.reshape(-1)] for s in seq])
return arr
results = {
"inputs": _seq_to_numpy(results["inputs"]),
"labels": _seq_to_numpy(results["labels"]),
"group_indices": np.array(results["group_indices"], dtype=np.int32),
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
}
# Metadata
metadata = PuzzleDatasetMetadata(
seq_len=int(math.prod(grid_size)), # type: ignore
vocab_size=len(CHARSET) + 1, # PAD + Charset
pad_id=0,
ignore_label_id=0,
blank_identifier_id=0,
num_puzzle_identifiers=1,
total_groups=len(results["group_indices"]) - 1,
mean_puzzle_examples=1,
sets=["all"]
)
# Save metadata as JSON.
save_dir = os.path.join(config.output_dir, set_name)
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
json.dump(metadata.model_dump(), f)
# Save data
for k, v in results.items():
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
# Save IDs mapping (for visualization only)
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
json.dump(["<blank>"], f)
@cli.command(singleton=True)
def preprocess_data(config: DataProcessConfig):
convert_subset("train", config)
convert_subset("test", config)
if __name__ == "__main__":
cli()
================================================
FILE: dataset/build_sudoku_dataset.py
================================================
from typing import Optional
import os
import csv
import json
import numpy as np
from argdantic import ArgParser
from pydantic import BaseModel
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from common import PuzzleDatasetMetadata
cli = ArgParser()
class DataProcessConfig(BaseModel):
source_repo: str = "sapientinc/sudoku-extreme"
output_dir: str = "data/sudoku-extreme-full"
subsample_size: Optional[int] = None
min_difficulty: Optional[int] = None
num_aug: int = 0
def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
# Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
# Randomly decide whether to transpose.
transpose_flag = np.random.rand() < 0.5
# Generate a valid row permutation:
# - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
bands = np.random.permutation(3)
row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
# Similarly for columns (stacks).
stacks = np.random.permutation(3)
col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
# Build an 81->81 mapping. For each new cell at (i, j)
# (row index = i // 9, col index = i % 9),
# its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])
def apply_transformation(x: np.ndarray) -> np.ndarray:
# Apply transpose flag
if transpose_flag:
x = x.T
# Apply the position mapping.
new_board = x.flatten()[mapping].reshape(9, 9).copy()
# Apply digit mapping
return digit_map[new_board]
return apply_transformation(board), apply_transformation(solution)
def convert_subset(set_name: str, config: DataProcessConfig):
# Read CSV
inputs = []
labels = []
with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
reader = csv.reader(csvfile)
next(reader) # Skip header
for source, q, a, rating in reader:
if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
assert len(q) == 81 and len(a) == 81
inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
# If subsample_size is specified for the training set,
# randomly sample the desired number of examples.
if set_name == "train" and config.subsample_size is not None:
total_samples = len(inputs)
if config.subsample_size < total_samples:
indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
inputs = [inputs[i] for i in indices]
labels = [labels[i] for i in indices]
# Generate dataset
num_augments = config.num_aug if set_name == "train" else 0
results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
puzzle_id = 0
example_id = 0
results["puzzle_indices"].append(0)
results["group_indices"].append(0)
for orig_inp, orig_out in zip(tqdm(inputs), labels):
for aug_idx in range(1 + num_augments):
# First index is not augmented
if aug_idx == 0:
inp, out = orig_inp, orig_out
else:
inp, out = shuffle_sudoku(orig_inp, orig_out)
# Push puzzle (only single example)
results["inputs"].append(inp)
results["labels"].append(out)
example_id += 1
puzzle_id += 1
results["puzzle_indices"].append(example_id)
results["puzzle_identifiers"].append(0)
# Push group
results["group_indices"].append(puzzle_id)
# To Numpy
def _seq_to_numpy(seq):
arr = np.concatenate(seq).reshape(len(seq), -1)
assert np.all((arr >= 0) & (arr <= 9))
return arr + 1
results = {
"inputs": _seq_to_numpy(results["inputs"]),
"labels": _seq_to_numpy(results["labels"]),
"group_indices": np.array(results["group_indices"], dtype=np.int32),
"puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
}
# Metadata
metadata = PuzzleDatasetMetadata(
seq_len=81,
vocab_size=10 + 1, # PAD + "0" ... "9"
pad_id=0,
ignore_label_id=0,
blank_identifier_id=0,
num_puzzle_identifiers=1,
total_groups=len(results["group_indices"]) - 1,
mean_puzzle_examples=1,
sets=["all"]
)
# Save metadata as JSON.
save_dir = os.path.join(config.output_dir, set_name)
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "dataset.json"), "w") as f:
json.dump(metadata.model_dump(), f)
# Save data
for k, v in results.items():
np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
# Save IDs mapping (for visualization only)
with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
json.dump(["<blank>"], f)
@cli.command(singleton=True)
def preprocess_data(config: DataProcessConfig):
convert_subset("train", config)
convert_subset("test", config)
if __name__ == "__main__":
cli()
================================================
FILE: dataset/common.py
================================================
from typing import List, Optional
import pydantic
import numpy as np
# Global list mapping each dihedral transform id to its inverse.
# Index corresponds to the original tid, and the value is its inverse.
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]
class PuzzleDatasetMetadata(pydantic.BaseModel):
pad_id: int
ignore_label_id: Optional[int]
blank_identifier_id: int
vocab_size: int
seq_len: int
num_puzzle_identifiers: int
total_groups: int
mean_puzzle_examples: float
sets: List[str]
def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
"""8 dihedral symmetries by rotate, flip and mirror"""
if tid == 0:
return arr # identity
elif tid == 1:
return np.rot90(arr, k=1)
elif tid == 2:
return np.rot90(arr, k=2)
elif tid == 3:
return np.rot90(arr, k=3)
elif tid == 4:
return np.fliplr(arr) # horizontal flip
elif tid == 5:
return np.flipud(arr) # vertical flip
elif tid == 6:
return arr.T # transpose (reflection along main diagonal)
elif tid == 7:
return np.fliplr(np.rot90(arr, k=1)) # anti-diagonal reflection
else:
return arr
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])
================================================
FILE: evaluate.py
================================================
from typing import List
import yaml
import os
import torch
import torch.distributed as dist
import pydantic
from omegaconf import OmegaConf
from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
class EvalConfig(pydantic.BaseModel):
checkpoint: str
save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]
def launch():
eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore
RANK = 0
WORLD_SIZE = 1
# Initialize distributed training if in distributed environment (e.g. torchrun)
if "LOCAL_RANK" in os.environ:
# Initialize distributed, default device and dtype
dist.init_process_group(backend="nccl")
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
config = PretrainConfig(**yaml.safe_load(f))
config.eval_save_outputs = eval_cfg.save_outputs
config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)
# Dataloader
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
# Models
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
# Try unwrap torch.compile
try:
train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
except:
train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
train_state.step = 0
ckpt_filename = os.path.basename(eval_cfg.checkpoint)
if ckpt_filename.startswith("step_"):
train_state.step = int(ckpt_filename.removeprefix("step_"))
# Evaluate
print ("Starting evaluation")
train_state.model.eval()
metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
if metrics is not None:
print (metrics)
if __name__ == "__main__":
launch()
================================================
FILE: models/common.py
================================================
import math
import torch
from torch import nn
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
with torch.no_grad():
if std == 0:
tensor.zero_()
else:
sqrt2 = math.sqrt(2)
a = math.erf(lower / sqrt2)
b = math.erf(upper / sqrt2)
z = (b - a) / 2
c = (2 * math.pi) ** -0.5
pdf_u = c * math.exp(-0.5 * lower ** 2)
pdf_l = c * math.exp(-0.5 * upper ** 2)
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
tensor.uniform_(a, b)
tensor.erfinv_()
tensor.mul_(sqrt2 * comp_std)
tensor.clip_(lower * comp_std, upper * comp_std)
return tensor
================================================
FILE: models/hrm/hrm_act_v1.py
================================================
from typing import Tuple, List, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
@dataclass
class HierarchicalReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L: torch.Tensor
@dataclass
class HierarchicalReasoningModel_ACTV1Carry:
inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class HierarchicalReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
class HierarchicalReasoningModel_ACTV1Block(nn.Module):
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
super().__init__()
self.self_attn = Attention(
hidden_size=config.hidden_size,
head_dim=config.hidden_size // config.num_heads,
num_heads=config.num_heads,
num_key_value_heads=config.num_heads,
causal=False
)
self.mlp = SwiGLU(
hidden_size=config.hidden_size,
expansion=config.expansion,
)
self.norm_eps = config.rms_norm_eps
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
# Post Norm
# Self Attention
hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps)
# Fully Connected
hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps)
return hidden_states
class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:
# Input injection (add)
hidden_states = hidden_states + input_injection
# Layers
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states, **kwargs)
return hidden_states
class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
if self.config.puzzle_emb_ndim > 0:
# Zero init puzzle embeddings
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
# LM Blocks
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
else:
raise NotImplementedError()
# Reasoning Layers
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5) # type: ignore
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
# Token embedding
embedding = self.embed_tokens(input.to(torch.int32))
# Puzzle embeddings
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
# Position embeddings
if self.config.pos_encodings == "learned":
# scale by 1/sqrt(2) to maintain forward variance
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
# Scale
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
return HierarchicalReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):
return HierarchicalReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
seq_info = dict(
cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None,
)
# Input encoding
input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
# Forward iterations
with torch.no_grad():
z_H, z_L = carry.z_H, carry.z_L
for _H_step in range(self.config.H_cycles):
for _L_step in range(self.config.L_cycles):
if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)):
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
if not (_H_step == self.config.H_cycles - 1):
z_H = self.H_level(z_H, z_L, **seq_info)
assert not z_H.requires_grad and not z_L.requires_grad
# 1-step grad
z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info)
z_H = self.H_level(z_H, z_L, **seq_info)
# LM Outputs
new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
# Q head
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class HierarchicalReasoningModel_ACTV1(nn.Module):
"""ACT wrapper."""
def __init__(self, config_dict: dict):
super().__init__()
self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict)
self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
batch_size = batch["inputs"].shape[0]
return HierarchicalReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted.
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
# Update data, carry (removing halted sequences)
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
# Step
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
# if training, and ACT is enabled
if self.training and (self.config.halt_max_steps > 1):
# Halt signal
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
halted = halted | (q_halt_logits > q_continue_logits)
# Exploration
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
# Compute target Q
# NOTE: No replay buffer and target networks for computing target Q-value.
# As batch_size is large, there're many parallel envs.
# Similar concept as PQN https://arxiv.org/abs/2407.04811
next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1]
outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits)))
return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs
================================================
FILE: models/layers.py
================================================
from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
try:
from flash_attn_interface import flash_attn_func # type: ignore[import]
except ImportError:
# Fallback to FlashAttention 2
from flash_attn import flash_attn_func # type: ignore[import]
from models.common import trunc_normal_init_
CosSin = Tuple[torch.Tensor, torch.Tensor]
def _find_multiple(a, b):
return (-(a // -b)) * b
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q, k: [bs, seq_len, num_heads, head_dim]
# cos, sin: [seq_len, head_dim]
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class CastedLinear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool):
super().__init__()
# Truncated LeCun normal init
self.weight = nn.Parameter(
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
)
self.bias = None
if bias:
# Zero init bias
self.bias = nn.Parameter(torch.zeros((out_features, )))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
class CastedEmbedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
init_std: float,
cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Truncated LeCun normal init
self.embedding_weight = nn.Parameter(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.embedding(input, self.embedding_weight.to(self.cast_to))
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings, base, device=None):
super().__init__()
# RoPE
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
def forward(self):
return self.cos_cached, self.sin_cached
class Attention(nn.Module):
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
super().__init__()
self.hidden_size = hidden_size
self.head_dim = head_dim
self.output_size = head_dim * num_heads
self.num_heads = num_heads
self.num_key_value_heads = num_key_value_heads
self.causal = causal
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
# hidden_states: [bs, seq_len, num_heads, head_dim]
qkv = self.qkv_proj(hidden_states)
# Split head
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query = qkv[:, :, :self.num_heads]
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
# RoPE
if cos_sin is not None:
cos, sin = cos_sin
query, key = apply_rotary_pos_emb(query, key, cos, sin)
# flash attn
attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
attn_output = attn_output[0]
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
return self.o_proj(attn_output)
class SwiGLU(nn.Module):
def __init__(self, hidden_size: int, expansion: float):
super().__init__()
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(F.silu(gate) * up)
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return hidden_states.to(input_dtype)
================================================
FILE: models/losses.py
================================================
from typing import Any, Tuple, Dict, Sequence, Optional
import torch
import torch.nn.functional as F
from torch import nn
IGNORE_LABEL_ID = -100
def s(x, epsilon=1e-30):
return torch.where(
x<0,
1/(1-x+ epsilon),
x + 1
)
def log_stablemax(x, dim=-1):
s_x = s(x)
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
valid_mask = labels != ignore_index
transformed_labels = torch.where(valid_mask, labels, 0)
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
return -torch.where(valid_mask, prediction_logprobs, 0)
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
# Cast logits to f32
# Flatten logits
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
class ACTLossHead(nn.Module):
def __init__(self, model: nn.Module, loss_type: str):
super().__init__()
self.model = model
self.loss_fn = globals()[loss_type]
def initial_carry(self, *args, **kwargs):
return self.model.initial_carry(*args, **kwargs) # type: ignore
def forward(
self,
return_keys: Sequence[str],
# Model args
**model_kwargs,
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
# Model logits
# B x SeqLen x D
new_carry, outputs = self.model(**model_kwargs)
labels = new_carry.current_data["labels"]
# Correctness
with torch.no_grad():
mask = labels != IGNORE_LABEL_ID
loss_counts = mask.sum(-1)
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
seq_is_correct = is_correct.sum(-1) == loss_counts
# Metrics (halted)
valid_metrics = new_carry.halted & (loss_counts > 0)
metrics = {
"count": valid_metrics.sum(),
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
}
# Losses
# FIXME: Assuming the batch is always full
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
metrics.update({
"lm_loss": lm_loss.detach(),
"q_halt_loss": q_halt_loss.detach(),
})
# Q continue (bootstrapping target loss)
q_continue_loss = 0
if "target_q_continue" in outputs:
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
metrics["q_continue_loss"] = q_continue_loss.detach()
# Filter outputs for return
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
================================================
FILE: models/sparse_embedding.py
================================================
from typing import Union
import torch
from torch import nn
import torch.distributed as dist
from torch.optim.optimizer import Optimizer, ParamsT
from models.common import trunc_normal_init_
class CastedSparseEmbedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
super().__init__()
self.cast_to = cast_to
# Real Weights
# Truncated LeCun normal init
self.weights = nn.Buffer(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
)
# Local weights and IDs
# Local embeddings, with gradient, not persistent
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
# Local embedding IDs, not persistent
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not self.training:
# Test mode, no gradient
return self.weights[inputs].to(self.cast_to)
# Training mode, fill puzzle embedding from weights
with torch.no_grad():
self.local_weights.copy_(self.weights[inputs])
self.local_ids.copy_(inputs)
return self.local_weights.to(self.cast_to)
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
def __init__(
self,
params: ParamsT,
world_size: int,
lr: Union[float, torch.Tensor] = 1e-3,
weight_decay: float = 1e-2,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
weight_decay=weight_decay,
world_size=world_size
)
super().__init__(params, defaults)
@torch.no_grad
def step(self, closure=None): # type: ignore
for group in self.param_groups:
# Find the sparse embedding weights
local_weights_grad = None
local_ids = None
weights = None
assert len(group["params"]) == 3
for p in group["params"]:
if p.requires_grad:
local_weights_grad = p.grad
elif p.ndim == 1:
local_ids = p
elif p.ndim == 2:
weights = p
else:
assert False
assert local_weights_grad is not None
assert local_ids is not None
assert weights is not None
# Apply SignSGD
# Adam ≈ SignSGD if gradient is very sparse
_sparse_emb_signsgd_dist(
local_weights_grad,
local_ids,
weights,
lr=group["lr"],
weight_decay=group["weight_decay"],
world_size=group["world_size"]
)
def _sparse_emb_signsgd_dist(
local_weights_grad: torch.Tensor,
local_ids: torch.Tensor,
weights: torch.Tensor,
lr: float,
weight_decay: float,
world_size: int
) -> None:
N, D = local_weights_grad.shape
# All-gather
all_weights_grad = local_weights_grad
all_ids = local_ids
if world_size > 1:
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
dist.all_gather_into_tensor(all_ids, local_ids)
# Unique
grad_ids, inv = all_ids.unique(return_inverse=True)
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
# SignSGD with decoupled weight decay
p = weights[grad_ids]
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
# Write updated slices back
weights[grad_ids] = p
================================================
FILE: pretrain.py
================================================
from typing import Optional, Any, Sequence, List
from dataclasses import dataclass
import os
import math
import yaml
import shutil
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader
import tqdm
import wandb
import coolname
import hydra
import pydantic
from omegaconf import DictConfig
from adam_atan2 import AdamATan2
from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
from utils.functions import load_model_class, get_model_source_path
from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed
class LossConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
class ArchConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
loss: LossConfig
class PretrainConfig(pydantic.BaseModel):
# Config
arch: ArchConfig
# Data
data_path: str
# Hyperparams
global_batch_size: int
epochs: int
lr: float
lr_min_ratio: float
lr_warmup_steps: int
weight_decay: float
beta1: float
beta2: float
# Puzzle embedding
puzzle_emb_lr: float
puzzle_emb_weight_decay: float
# Names
project_name: Optional[str] = None
run_name: Optional[str] = None
checkpoint_path: Optional[str] = None
# Extras
seed: int = 0
checkpoint_every_eval: bool = False
eval_interval: Optional[int] = None
eval_save_outputs: List[str] = []
@dataclass
class TrainState:
model: nn.Module
optimizers: Sequence[torch.optim.Optimizer]
optimizer_lrs: Sequence[float]
carry: Any
step: int
total_steps: int
def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):
dataset = PuzzleDataset(PuzzleDatasetConfig(
seed=config.seed,
dataset_path=config.data_path,
rank=rank,
num_replicas=world_size,
**kwargs
), split=split)
dataloader = DataLoader(
dataset,
batch_size=None,
num_workers=1,
prefetch_factor=8,
pin_memory=True,
persistent_workers=True
)
return dataloader, dataset.metadata
def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
model_cfg = dict(
**config.arch.__pydantic_extra__, # type: ignore
batch_size=config.global_batch_size // world_size,
vocab_size=train_metadata.vocab_size,
seq_len=train_metadata.seq_len,
num_puzzle_identifiers=train_metadata.num_puzzle_identifiers,
causal=False # Non-autoregressive
)
# Instantiate model with loss head
model_cls = load_model_class(config.arch.name)
loss_head_cls = load_model_class(config.arch.loss.name)
with torch.device("cuda"):
model: nn.Module = model_cls(model_cfg)
model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore
if "DISABLE_COMPILE" not in os.environ:
model = torch.compile(model, dynamic=False) # type: ignore
# Broadcast parameters from rank 0
if world_size > 1:
with torch.no_grad():
for param in list(model.parameters()) + list(model.buffers()):
dist.broadcast(param, src=0)
# Optimizers and lr
optimizers = [
CastedSparseEmbeddingSignSGD_Distributed(
model.model.puzzle_emb.buffers(), # type: ignore
lr=0, # Needs to be set by scheduler
weight_decay=config.puzzle_emb_weight_decay,
world_size=world_size
),
AdamATan2(
model.parameters(),
lr=0, # Needs to be set by scheduler
weight_decay=config.weight_decay,
betas=(config.beta1, config.beta2)
)
]
optimizer_lrs = [
config.puzzle_emb_lr,
config.lr
]
return model, optimizers, optimizer_lrs
def cosine_schedule_with_warmup_lr_lambda(
current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
):
if current_step < num_warmup_steps:
return base_lr * float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))))
def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):
# Estimated total training steps
total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size)
# Model
model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size)
return TrainState(
step=0,
total_steps=total_steps,
model=model,
optimizers=optimizers,
optimizer_lrs=optimizer_lrs,
carry=None
)
def save_train_state(config: PretrainConfig, train_state: TrainState):
# FIXME: Only saved model.
if config.checkpoint_path is None:
return
os.makedirs(config.checkpoint_path, exist_ok=True)
torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}"))
def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):
return cosine_schedule_with_warmup_lr_lambda(
current_step=train_state.step,
base_lr=base_lr,
num_warmup_steps=round(config.lr_warmup_steps),
num_training_steps=train_state.total_steps,
min_ratio=config.lr_min_ratio
)
def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):
train_state.step += 1
if train_state.step > train_state.total_steps: # At most train_total_steps
return
# To device
batch = {k: v.cuda() for k, v in batch.items()}
# Init carry if it is None
if train_state.carry is None:
with torch.device("cuda"):
train_state.carry = train_state.model.initial_carry(batch) # type: ignore
# Forward
train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[])
((1 / global_batch_size) * loss).backward()
# Allreduce
if world_size > 1:
for param in train_state.model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad)
# Apply optimizer
lr_this_step = None
for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs):
lr_this_step = compute_lr(base_lr, config, train_state)
for param_group in optim.param_groups:
param_group['lr'] = lr_this_step
optim.step()
optim.zero_grad()
# Reduce metrics
if len(metrics):
assert not any(v.requires_grad for v in metrics.values())
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
# Reduce and reconstruct
metric_values = torch.stack([metrics[k] for k in metric_keys])
if world_size > 1:
dist.reduce(metric_values, dst=0)
if rank == 0:
metric_values = metric_values.cpu().numpy()
reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)}
# Postprocess
count = max(reduced_metrics["count"], 1) # Avoid NaNs
reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()}
reduced_metrics["train/lr"] = lr_this_step
return reduced_metrics
def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):
with torch.inference_mode():
set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)}
all_preds = {}
metric_keys = []
metric_values = None
metric_global_batch_size = [0 for _ in range(len(set_ids))]
carry = None
for set_name, batch, global_batch_size in eval_loader:
# To device
batch = {k: v.cuda() for k, v in batch.items()}
with torch.device("cuda"):
carry = train_state.model.initial_carry(batch) # type: ignore
# Forward
while True:
carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs)
if all_finish:
break
for collection in (batch, preds):
for k, v in collection.items():
if k in config.eval_save_outputs:
all_preds.setdefault(k, [])
all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory
del carry, preds, batch, all_finish
# Aggregate
set_id = set_ids[set_name]
if metric_values is None:
metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order.
metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda")
metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys])
metric_global_batch_size[set_id] += global_batch_size
if len(all_preds) and config.checkpoint_path is not None:
all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}
os.makedirs(config.checkpoint_path, exist_ok=True)
torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}"))
# Logging
# Reduce to rank 0
if metric_values is not None:
if world_size > 1:
dist.reduce(metric_values, dst=0)
if rank == 0:
reduced_metrics = metric_values.cpu().numpy()
reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)}
for set_id, set_name in enumerate(set_ids)}
# Postprocess
for set_name, metrics in reduced_metrics.items():
count = metrics.pop("count")
reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()}
return reduced_metrics
def save_code_and_config(config: PretrainConfig):
if config.checkpoint_path is None or wandb.run is None:
return
os.makedirs(config.checkpoint_path, exist_ok=True)
# Copy code
code_list = [
get_model_source_path(config.arch.name),
get_model_source_path(config.arch.loss.name)
]
for code_file in code_list:
if code_file is not None:
code_name = os.path.basename(code_file)
shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name))
# Dump config as yaml
config_file = os.path.join(config.checkpoint_path, "all_config.yaml")
with open(config_file, "wt") as f:
yaml.dump(config.model_dump(), f)
# Log code
wandb.run.log_code(config.checkpoint_path)
def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:
objects = [None]
if rank == 0:
config = PretrainConfig(**hydra_config) # type: ignore
# Naming
if config.project_name is None:
config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch"
if config.run_name is None:
config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}"
if config.checkpoint_path is None:
config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name)
objects = [config]
if world_size > 1:
dist.broadcast_object_list(objects, src=0)
return objects[0] # type: ignore
@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None)
def launch(hydra_config: DictConfig):
RANK = 0
WORLD_SIZE = 1
# Initialize distributed training if in distributed environment (e.g. torchrun)
if "LOCAL_RANK" in os.environ:
# Initialize distributed, default device and dtype
dist.init_process_group(backend="nccl")
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
# Load sync'ed config
config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE)
# Seed RNGs to ensure consistency
torch.random.manual_seed(config.seed + RANK)
# Dataset
train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs
total_iters = config.epochs // train_epochs_per_iter
assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs."
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
# Train state
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
# Progress bar and logger
progress_bar = None
if RANK == 0:
progress_bar = tqdm.tqdm(total=train_state.total_steps)
wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore
wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0)
save_code_and_config(config)
# Training Loop
for _iter_id in range(total_iters):
print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}")
############ Train Iter
train_state.model.train()
for set_name, batch, global_batch_size in train_loader:
metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE)
if RANK == 0 and metrics is not None:
wandb.log(metrics, step=train_state.step)
progress_bar.update(train_state.step - progress_bar.n) # type: ignore
############ Evaluation
train_state.model.eval()
metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
if RANK == 0 and metrics is not None:
wandb.log(metrics, step=train_state.step)
############ Checkpointing
if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)):
save_train_state(config, train_state)
# finalize
if dist.is_initialized():
dist.destroy_process_group()
wandb.finish()
if __name__ == "__main__":
launch()
================================================
FILE: puzzle_dataset.py
================================================
import os
import json
import numpy as np
import pydantic
import torch
from torch.utils.data import IterableDataset, get_worker_info
from models.losses import IGNORE_LABEL_ID
from dataset.common import PuzzleDatasetMetadata
def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
# Pack examples into a full batch
batch = []
batch_puzzle_indices = []
current_size = 0
while (start_index < group_order.size) and (current_size < global_batch_size):
# Pick a group and a puzzle from that group
group_id = group_order[start_index]
puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
start_index += 1
# Get range of the puzzle
puzzle_start = puzzle_indices[puzzle_id]
puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)
append_size = min(puzzle_size, global_batch_size - current_size)
# Put into batch
batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))
current_size += append_size
return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)
class PuzzleDatasetConfig(pydantic.BaseModel):
seed: int
dataset_path: str
global_batch_size: int
test_set_mode: bool
epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead.
rank: int
num_replicas: int
class PuzzleDataset(IterableDataset):
def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
super().__init__()
self.config = config
self.split = split
self.metadata = self._load_metadata()
# Checks
assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
self.local_batch_size = self.config.global_batch_size // self.config.num_replicas
# State
self._data = None
self._iters = 0
def _load_metadata(self) -> PuzzleDatasetMetadata:
with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f:
return PuzzleDatasetMetadata(**json.load(f))
def _lazy_load_dataset(self):
if self._data is not None:
return
field_mmap_modes = {
"inputs": "r",
"labels": "r",
# Keep indices in memory
"puzzle_identifiers": None,
"puzzle_indices": None,
"group_indices": None
}
# Load data
self._data = {}
for set_name in self.metadata.sets:
# Load subset
self._data[set_name] = {
field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode)
for field_name, mmap_mode in field_mmap_modes.items()
}
def _collate_batch(self, batch):
# Convert dtype
batch = {k: v.astype(np.int32) for k, v in batch.items()}
# Convert ignore label IDs
if self.metadata.ignore_label_id is not None:
batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID
# Pad
if batch["puzzle_identifiers"].size < self.local_batch_size:
pad_size = self.local_batch_size - batch["puzzle_identifiers"].size
pad_values = {
"inputs": self.metadata.pad_id,
"labels": IGNORE_LABEL_ID,
"puzzle_identifiers": self.metadata.blank_identifier_id
}
batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()}
# To tensor
return {k: torch.from_numpy(v) for k, v in batch.items()}
def _iter_test(self):
for set_name, dataset in self._data.items(): # type: ignore
total_examples = len(dataset["inputs"])
# Load examples one by one
start_index = 0
while start_index < total_examples:
# Compute indices
end_index = min(total_examples, start_index + self.config.global_batch_size)
local_start = start_index + self.config.rank * self.local_batch_size
local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
# Get batch of examples, and also puzzle IDs
puzzle_indices = []
puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
for i in range(local_start, local_end):
while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
puzzle_index += 1
puzzle_indices.append(puzzle_index)
batch = self._collate_batch({
"inputs": dataset["inputs"][local_start: local_end],
"labels": dataset["labels"][local_start: local_end],
"puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices]
})
yield set_name, batch, end_index - start_index
# Advance to next batch
start_index += self.config.global_batch_size
def _iter_train(self):
for set_name, dataset in self._data.items(): # type: ignore
# Increase epoch count
self._iters += 1
# Randomly shuffle groups
rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))
group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
start_index = 0
while start_index < group_order.size:
start_index, batch_indices, batch_puzzle_indices = _sample_batch(
rng,
group_order=group_order,
puzzle_indices=dataset["puzzle_indices"],
group_indices=dataset["group_indices"],
start_index=start_index,
global_batch_size=self.config.global_batch_size,
)
# Select current rank and collate
global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads
# Drop last batch
if global_effective_batch_size < self.config.global_batch_size:
break
batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
batch = self._collate_batch({
"inputs": dataset["inputs"][batch_indices],
"labels": dataset["labels"][batch_indices],
"puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
})
yield set_name, batch, global_effective_batch_size
def __iter__(self):
worker_info = get_worker_info()
assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
self._lazy_load_dataset()
# Iterate using specified mode
if self.config.test_set_mode:
yield from self._iter_test()
else:
yield from self._iter_train()
================================================
FILE: puzzle_visualizer.html
================================================
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<title>ARC‐Converted Dataset Visualizer (Upload Local Folder)</title>
<style>
body {
font-family: sans-serif;
margin: 16px;
}
.selector-area {
margin-bottom: 1rem;
}
.grid-canvas {
margin: 4px;
border: 1px solid #ccc;
}
.example-container {
display: inline-block;
margin: 0 16px 16px 0;
vertical-align: top;
}
.puzzle-display {
margin-top: 1rem;
}
.puzzle-id {
font-weight: bold;
margin-bottom: 0.5rem;
}
#groupList, #puzzleList {
margin: 1rem 0;
}
.group-item, .puzzle-item {
cursor: pointer;
margin: 4px 8px 4px 0;
padding: 2px 6px;
border: 1px solid #aaa;
display: inline-block;
}
.group-item:hover, .puzzle-item:hover {
background: #eef;
}
</style>
</head>
<body>
<h1>ARC‐Converted Dataset Visualizer (Local Directory)</h1>
<div class="selector-area">
<!-- 1) Directory input with webkitdirectory, mozdirectory -->
<label>Upload ARC Folder:</label>
<input type="file" id="folderInput"
webkitdirectory mozdirectory multiple
onchange="onFolderSelected(event)" />
<br><br>
<!-- 2) We'll enable set/subset selection after user chooses a folder and data is validated -->
<label>Set:</label>
<select id="setSelect" disabled>
<option value="train">train</option>
<option value="test">test</option>
</select>
<label> Subset:</label>
<select id="subsetSelect" disabled>
<option value="all">all</option>
</select>
<button id="loadBtn" disabled>Load</button>
</div>
<div>
<div id="groupList"></div>
<div id="puzzleList"></div>
<div class="puzzle-display" id="puzzleView"></div>
</div>
<!--
3) Use local 'assets/npyjs.js' from your project folder instead of a CDN.
Make sure 'assets/npyjs.js' is the unbundled or UMD version that doesn't
contain "import" statements.
-->
<script src="assets/npyjs.js"></script>
<script>
/***************************************************************************
* Global Maps & Variables
***************************************************************************/
// Map from "train/all__inputs.npy" => File, etc.
let filesByPath = {};
// Once loaded, we store typed arrays for the chosen set/subset
let inputsArr, labelsArr;
let puzzleIndicesArr, groupIndicesArr, puzzleIdentifiersArr;
let identifiersJson;
// The shape of inputs is [N_examples, seqLen], so we discover seqLen & gridSize
let seqLen = 0;
let gridSize = 0;
/***************************************************************************
* 1) Handle folder selection: read all files, find identifiers.json,
* remove topmost folder from each file path, validate.
***************************************************************************/
function onFolderSelected(event) {
filesByPath = {};
const fileList = event.target.files;
if (!fileList || fileList.length === 0) {
alert("No files selected!");
return;
}
// We'll gather all webkitRelativePaths
const paths = [];
for (let i = 0; i < fileList.length; i++) {
// Typically "arc-aug-10/train/all__inputs.npy", etc.
const file = fileList[i];
const relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
paths.push(relPath);
}
// 1. Check if we have "identifiers.json" somewhere.
const idPath = paths.find(p => p.endsWith("identifiers.json"));
if (!idPath) {
alert("Error: No 'identifiers.json' found in the uploaded folder.");
return;
}
// 2. Derive the top-level directory from that file's path
// e.g. if idPath = "arc-aug-10/identifiers.json", topDir = "arc-aug-10"
// If there's no slash, topDir = "" => do nothing
let topDir = "";
const lastSlash = idPath.lastIndexOf("/");
if (lastSlash >= 0) {
topDir = idPath.substring(0, lastSlash);
}
// 3. Rebuild filesByPath with the top folder removed.
// For example, if topDir = "arc-aug-10", then "arc-aug-10/train/all__inputs.npy"
// becomes "train/all__inputs.npy"
for (let i = 0; i < fileList.length; i++) {
const file = fileList[i];
let relPath = file.webkitRelativePath || file.mozRelativePath || file.name;
// If relPath starts with "arc-aug-10/", remove that prefix
if (topDir && relPath.startsWith(topDir + "/")) {
relPath = relPath.substring(topDir.length + 1);
}
filesByPath[relPath] = file;
}
// Enable set/subset selection and "Load"
document.getElementById("setSelect").disabled = false;
document.getElementById("subsetSelect").disabled = false;
document.getElementById("loadBtn").disabled = false;
}
// When user clicks "Load," parse the .npy for the chosen set/subset
document.getElementById("loadBtn").addEventListener("click", async () => {
document.getElementById("groupList").innerHTML = "";
document.getElementById("puzzleList").innerHTML = "";
document.getElementById("puzzleView").innerHTML = "";
const setName = document.getElementById("setSelect").value; // e.g. "train"
const subsetName = document.getElementById("subsetSelect").value; // e.g. "all"
try {
await loadDataset(setName, subsetName);
buildGroupList(); // show groups
} catch (err) {
console.error(err);
alert("Error while loading dataset: " + err);
}
});
/***************************************************************************
* 2) Load .npy from local files using Npyjs + FileReader (ArrayBuffer)
***************************************************************************/
async function loadDataset(setName, subsetName) {
const prefix = `${setName}/${subsetName}__`;
// e.g. "train/all__inputs.npy"
const inputsPath = prefix + "inputs.npy";
const labelsPath = prefix + "labels.npy";
const pIdxPath = prefix + "puzzle_indices.npy";
const gIdxPath = prefix + "group_indices.npy";
const pIdsPath = prefix + "puzzle_identifiers.npy";
const identifiersPath = "identifiers.json";
// Check existence
const needed = [inputsPath, labelsPath, pIdxPath, gIdxPath, pIdsPath, identifiersPath];
for (const f of needed) {
if (!filesByPath[f]) {
throw new Error(`Missing file: ${f}`);
}
}
// parseNpy => read from File -> ArrayBuffer -> Npyjs => typed array
const inputsNpy = await parseNpy(filesByPath[inputsPath]);
const labelsNpy = await parseNpy(filesByPath[labelsPath]);
const puzzleIndicesNpy= await parseNpy(filesByPath[pIdxPath]);
const groupIndicesNpy = await parseNpy(filesByPath[gIdxPath]);
const puzzleIdsNpy = await parseNpy(filesByPath[pIdsPath]);
inputsArr = inputsNpy.data;
labelsArr = labelsNpy.data;
puzzleIndicesArr = puzzleIndicesNpy.data;
groupIndicesArr = groupIndicesNpy.data;
puzzleIdentifiersArr = puzzleIdsNpy.data;
// shape e.g. [N_examples, seqLen]
seqLen = inputsNpy.shape[1];
gridSize = Math.sqrt(seqLen);
// read JSON
identifiersJson = await readJsonFile(filesByPath[identifiersPath]);
}
/***************************************************************************
* parseNpy => read a File as ArrayBuffer, parse with npyjs
***************************************************************************/
function parseNpy(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = async () => {
try {
const arrayBuffer = reader.result;
const npy = new npyjs();
resolve(await npy.parse(arrayBuffer));
} catch (err) {
reject(err);
}
};
reader.onerror = err => reject(err);
reader.readAsArrayBuffer(file);
});
}
/***************************************************************************
* readJsonFile => read a local JSON file into object
***************************************************************************/
function readJsonFile(file) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = () => {
try {
const obj = JSON.parse(reader.result);
resolve(obj);
} catch (err) {
reject(err);
}
};
reader.onerror = (err) => reject(err);
reader.readAsText(file);
});
}
/***************************************************************************
* 3) Build group list in UI
***************************************************************************/
function buildGroupList() {
document.getElementById("groupList").innerHTML = "<h3>Groups</h3>";
const groupListDiv = document.getElementById("groupList");
const nGroups = groupIndicesArr.length - 1;
for (let g = 0; g < nGroups; g++) {
const div = document.createElement("span");
div.className = "group-item";
div.textContent = `Group ${g}`;
div.onclick = () => onSelectGroup(g);
groupListDiv.appendChild(div);
}
}
/***************************************************************************
* onSelectGroup => show puzzles in that group
***************************************************************************/
function onSelectGroup(groupIndex) {
document.getElementById("puzzleList").innerHTML = "";
document.getElementById("puzzleView").innerHTML = "";
const puzzleListDiv = document.getElementById("puzzleList");
puzzleListDiv.innerHTML = `<h4>Puzzles in Group ${groupIndex}</h4>`;
const firstPuzzle = groupIndicesArr[groupIndex];
const lastPuzzle = groupIndicesArr[groupIndex + 1];
for (let p = firstPuzzle; p < lastPuzzle; p++) {
const puzzleIntId = puzzleIdentifiersArr[p];
const puzzleStrId = (puzzleIntId < identifiersJson.length)
? identifiersJson[puzzleIntId]
: "<unknown>";
const div = document.createElement("span");
div.className = "puzzle-item";
div.textContent = `Puzzle #${p} [ID=${puzzleIntId}: ${puzzleStrId}]`;
div.onclick = () => onSelectPuzzle(p);
puzzleListDiv.appendChild(div);
}
}
/***************************************************************************
* onSelectPuzzle => show each example
***************************************************************************/
function onSelectPuzzle(puzzleIndex) {
const puzzleView = document.getElementById("puzzleView");
puzzleView.innerHTML = "";
// puzzle ID
const puzzleIntId = puzzleIdentifiersArr[puzzleIndex];
const puzzleStrId = (puzzleIntId < identifiersJson.length)
? identifiersJson[puzzleIntId]
: "<unknown>";
const titleDiv = document.createElement("div");
titleDiv.className = "puzzle-id";
titleDiv.textContent = `Puzzle #${puzzleIndex} — ID: ${puzzleStrId}`;
puzzleView.appendChild(titleDiv);
// Examples are [puzzleIndicesArr[p], puzzleIndicesArr[p+1])
const firstExample = puzzleIndicesArr[puzzleIndex];
const lastExample = puzzleIndicesArr[puzzleIndex + 1];
for (let e = firstExample; e < lastExample; e++) {
const inputSeq = slice1D(inputsArr, e*seqLen, (e+1)*seqLen);
const outputSeq = slice1D(labelsArr, e*seqLen, (e+1)*seqLen);
const inputGrid = decodeGrid(inputSeq);
const outputGrid = decodeGrid(outputSeq);
const exDiv = document.createElement("div");
exDiv.className = "example-container";
exDiv.appendChild(document.createTextNode(`Example ${e}`));
exDiv.appendChild(document.createElement("br"));
exDiv.appendChild(renderGrid(inputGrid));
exDiv.appendChild(renderGrid(outputGrid));
puzzleView.appendChild(exDiv);
}
}
/***************************************************************************
* slice1D => typed array slicing
***************************************************************************/
function slice1D(arr, start, end) {
const result = new Uint32Array(end - start);
for (let i = start; i < end; i++) {
result[i - start] = Number(arr[i]);
}
return result;
}
/***************************************************************************
* decodeGrid => turn the flattened seq of length=gridSize^2 into 2D
***************************************************************************/
function decodeGrid(seq) {
const grid = [];
let idx = 0;
for (let r = 0; r < gridSize; r++) {
const row = [];
for (let c = 0; c < gridSize; c++) {
row.push(seq[idx]);
idx++;
}
grid.push(row);
}
return grid;
}
/***************************************************************************
* renderGrid => draws a 2D grid to <canvas>
***************************************************************************/
function renderGrid(grid2d) {
const rows = grid2d.length;
const cols = grid2d[0].length;
const scale = 10;
const canvas = document.createElement("canvas");
canvas.width = cols * scale;
canvas.height = rows * scale;
canvas.className = "grid-canvas";
const ctx = canvas.getContext("2d");
for (let r = 0; r < rows; r++) {
for (let c = 0; c < cols; c++) {
const val = grid2d[r][c];
ctx.fillStyle = indexToColor(val);
ctx.fillRect(c * scale, r * scale, scale, scale);
}
}
return canvas;
}
/***************************************************************************
* indexToColor => color palette:
* 0 => pad => white
* 1 => eos => light gray
* 2..11 => original color(0..9)
***************************************************************************/
function indexToColor(value) {
if (value === 0) return "#FFFFFF"; // pad => white
if (value === 1) return "#DDDDDD"; // eos => light gray
// shift by 2 => original color in [0..9]
const colorIdx = value - 2;
const palette = [
"#000000", // color0 => black
"#FF0000", // color1 => red
"#00FF00", // color2 => green
"#0000FF", // color3 => blue
"#FFFF00", // color4 => yellow
"#FFA500", // color5 => orange
"#800080", // color6 => purple
"#00FFFF", // color7 => cyan
"#FFC0CB", // color8 => pink
"#808080" // color9 => gray
];
if (colorIdx >= 0 && colorIdx < palette.length) {
return palette[colorIdx];
}
return "#FFFFFF"; // fallback
}
</script>
</body>
</html>
================================================
FILE: requirements.txt
================================================
torch
adam-atan2
einops
tqdm
coolname
pydantic
argdantic
wandb
omegaconf
hydra-core
huggingface_hub
================================================
FILE: utils/functions.py
================================================
import importlib
import inspect
def load_model_class(identifier: str, prefix: str = "models."):
module_path, class_name = identifier.split('@')
# Import the module
module = importlib.import_module(prefix + module_path)
cls = getattr(module, class_name)
return cls
def get_model_source_path(identifier: str, prefix: str = "models."):
module_path, class_name = identifier.split('@')
module = importlib.import_module(prefix + module_path)
return inspect.getsourcefile(module)
gitextract_7bpy0bpn/
├── .gitignore
├── .gitmodules
├── .vscode/
│ ├── launch.json
│ └── settings.json
├── LICENSE
├── README.md
├── arc_eval.ipynb
├── assets/
│ └── npyjs.js
├── config/
│ ├── arch/
│ │ └── hrm_v1.yaml
│ └── cfg_pretrain.yaml
├── dataset/
│ ├── build_arc_dataset.py
│ ├── build_maze_dataset.py
│ ├── build_sudoku_dataset.py
│ └── common.py
├── evaluate.py
├── models/
│ ├── common.py
│ ├── hrm/
│ │ └── hrm_act_v1.py
│ ├── layers.py
│ ├── losses.py
│ └── sparse_embedding.py
├── pretrain.py
├── puzzle_dataset.py
├── puzzle_visualizer.html
├── requirements.txt
└── utils/
└── functions.py
SYMBOL INDEX (109 symbols across 14 files)
FILE: assets/npyjs.js
class npyjs (line 1) | class npyjs {
method constructor (line 3) | constructor(opts) {
method float16ToFloat32Array (line 78) | float16ToFloat32Array(float16Array) {
method float16ToFloat32 (line 89) | static float16ToFloat32(float16) {
method parse (line 116) | parse(arrayBufferContents) {
method load (line 155) | async load(filename, callback, fetchArgs) {
FILE: dataset/build_arc_dataset.py
class DataProcessConfig (line 19) | class DataProcessConfig(BaseModel):
class ARCPuzzle (line 37) | class ARCPuzzle:
function arc_grid_to_np (line 43) | def arc_grid_to_np(grid: List[List[int]]):
function np_grid_to_seq_translational_augment (line 54) | def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarra...
function puzzle_hash (line 81) | def puzzle_hash(puzzle: dict):
function convert_single_arc_puzzle (line 98) | def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: ...
function load_puzzles_arcagi (line 148) | def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataPr...
function convert_dataset (line 184) | def convert_dataset(config: DataProcessConfig):
function main (line 286) | def main(config: DataProcessConfig):
FILE: dataset/build_maze_dataset.py
class DataProcessConfig (line 22) | class DataProcessConfig(BaseModel):
function convert_subset (line 30) | def convert_subset(set_name: str, config: DataProcessConfig):
function preprocess_data (line 136) | def preprocess_data(config: DataProcessConfig):
FILE: dataset/build_sudoku_dataset.py
class DataProcessConfig (line 18) | class DataProcessConfig(BaseModel):
function shuffle_sudoku (line 27) | def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
function convert_subset (line 60) | def convert_subset(set_name: str, config: DataProcessConfig):
function preprocess_data (line 163) | def preprocess_data(config: DataProcessConfig):
FILE: dataset/common.py
class PuzzleDatasetMetadata (line 12) | class PuzzleDatasetMetadata(pydantic.BaseModel):
function dihedral_transform (line 27) | def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
function inverse_dihedral_transform (line 50) | def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
FILE: evaluate.py
class EvalConfig (line 13) | class EvalConfig(pydantic.BaseModel):
function launch (line 19) | def launch():
FILE: models/common.py
function trunc_normal_init_ (line 7) | def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: fl...
FILE: models/hrm/hrm_act_v1.py
class HierarchicalReasoningModel_ACTV1InnerCarry (line 16) | class HierarchicalReasoningModel_ACTV1InnerCarry:
class HierarchicalReasoningModel_ACTV1Carry (line 22) | class HierarchicalReasoningModel_ACTV1Carry:
class HierarchicalReasoningModel_ACTV1Config (line 31) | class HierarchicalReasoningModel_ACTV1Config(BaseModel):
class HierarchicalReasoningModel_ACTV1Block (line 60) | class HierarchicalReasoningModel_ACTV1Block(nn.Module):
method __init__ (line 61) | def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> ...
method forward (line 77) | def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> tor...
class HierarchicalReasoningModel_ACTV1ReasoningModule (line 86) | class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):
method __init__ (line 87) | def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):
method forward (line 92) | def forward(self, hidden_states: torch.Tensor, input_injection: torch....
class HierarchicalReasoningModel_ACTV1_Inner (line 102) | class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):
method __init__ (line 103) | def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> ...
method _input_embeddings (line 146) | def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: t...
method empty_carry (line 168) | def empty_carry(self, batch_size: int):
method reset_carry (line 174) | def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalRea...
method forward (line 180) | def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, b...
class HierarchicalReasoningModel_ACTV1 (line 216) | class HierarchicalReasoningModel_ACTV1(nn.Module):
method __init__ (line 219) | def __init__(self, config_dict: dict):
method puzzle_emb (line 225) | def puzzle_emb(self):
method initial_carry (line 228) | def initial_carry(self, batch: Dict[str, torch.Tensor]):
method forward (line 240) | def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch:...
FILE: models/layers.py
function _find_multiple (line 19) | def _find_multiple(a, b):
function rotate_half (line 23) | def rotate_half(x: torch.Tensor):
function apply_rotary_pos_emb (line 30) | def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Te...
class CastedLinear (line 43) | class CastedLinear(nn.Module):
method __init__ (line 44) | def __init__(self,
method forward (line 58) | def forward(self, input: torch.Tensor) -> torch.Tensor:
class CastedEmbedding (line 62) | class CastedEmbedding(nn.Module):
method __init__ (line 63) | def __init__(self,
method forward (line 76) | def forward(self, input: torch.Tensor) -> torch.Tensor:
class RotaryEmbedding (line 80) | class RotaryEmbedding(nn.Module):
method __init__ (line 81) | def __init__(self, dim, max_position_embeddings, base, device=None):
method forward (line 94) | def forward(self):
class Attention (line 98) | class Attention(nn.Module):
method __init__ (line 99) | def __init__(self, hidden_size, head_dim, num_heads, num_key_value_hea...
method forward (line 112) | def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> tor...
class SwiGLU (line 138) | class SwiGLU(nn.Module):
method __init__ (line 139) | def __init__(self, hidden_size: int, expansion: float):
method forward (line 146) | def forward(self, x):
function rms_norm (line 151) | def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> to...
FILE: models/losses.py
function s (line 11) | def s(x, epsilon=1e-30):
function log_stablemax (line 19) | def log_stablemax(x, dim=-1):
function stablemax_cross_entropy (line 24) | def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
function softmax_cross_entropy (line 34) | def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
class ACTLossHead (line 40) | class ACTLossHead(nn.Module):
method __init__ (line 41) | def __init__(self, model: nn.Module, loss_type: str):
method initial_carry (line 46) | def initial_carry(self, *args, **kwargs):
method forward (line 49) | def forward(
FILE: models/sparse_embedding.py
class CastedSparseEmbedding (line 11) | class CastedSparseEmbedding(nn.Module):
method __init__ (line 12) | def __init__(self, num_embeddings: int, embedding_dim: int, batch_size...
method forward (line 28) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
class CastedSparseEmbeddingSignSGD_Distributed (line 41) | class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
method __init__ (line 42) | def __init__(
method step (line 63) | def step(self, closure=None): # type: ignore
function _sparse_emb_signsgd_dist (line 98) | def _sparse_emb_signsgd_dist(
FILE: pretrain.py
class LossConfig (line 26) | class LossConfig(pydantic.BaseModel):
class ArchConfig (line 32) | class ArchConfig(pydantic.BaseModel):
class PretrainConfig (line 39) | class PretrainConfig(pydantic.BaseModel):
class TrainState (line 74) | class TrainState:
function create_dataloader (line 84) | def create_dataloader(config: PretrainConfig, split: str, rank: int, wor...
function create_model (line 108) | def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMe...
function cosine_schedule_with_warmup_lr_lambda (line 162) | def cosine_schedule_with_warmup_lr_lambda(
function init_train_state (line 172) | def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatas...
function save_train_state (line 190) | def save_train_state(config: PretrainConfig, train_state: TrainState):
function compute_lr (line 199) | def compute_lr(base_lr: float, config: PretrainConfig, train_state: Trai...
function train_batch (line 209) | def train_batch(config: PretrainConfig, train_state: TrainState, batch: ...
function evaluate (line 266) | def evaluate(config: PretrainConfig, train_state: TrainState, eval_loade...
function save_code_and_config (line 333) | def save_code_and_config(config: PretrainConfig):
function load_synced_config (line 359) | def load_synced_config(hydra_config: DictConfig, rank: int, world_size: ...
function launch (line 381) | def launch(hydra_config: DictConfig):
FILE: puzzle_dataset.py
function _sample_batch (line 14) | def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puz...
class PuzzleDatasetConfig (line 41) | class PuzzleDatasetConfig(pydantic.BaseModel):
class PuzzleDataset (line 53) | class PuzzleDataset(IterableDataset):
method __init__ (line 54) | def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
method _load_metadata (line 68) | def _load_metadata(self) -> PuzzleDatasetMetadata:
method _lazy_load_dataset (line 72) | def _lazy_load_dataset(self):
method _collate_batch (line 95) | def _collate_batch(self, batch):
method _iter_test (line 118) | def _iter_test(self):
method _iter_train (line 151) | def _iter_train(self):
method __iter__ (line 189) | def __iter__(self):
FILE: utils/functions.py
function load_model_class (line 5) | def load_model_class(identifier: str, prefix: str = "models."):
function get_model_source_path (line 15) | def get_model_source_path(identifier: str, prefix: str = "models."):
Condensed preview — 25 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (134K chars).
[
{
"path": ".gitignore",
"chars": 3152,
"preview": "# WandB\n/wandb/\n# checkpoints\n/checkpoints/\n# cache\n/cache/\n# data\n/data/\n\n# Byte-compiled / optimized / DLL files\n__pyc"
},
{
"path": ".gitmodules",
"chars": 364,
"preview": "[submodule \"dataset/raw-data/ConceptARC\"]\n\tpath = dataset/raw-data/ConceptARC\n\turl = git@github.com:victorvikram/Concept"
},
{
"path": ".vscode/launch.json",
"chars": 778,
"preview": "{\n // Use IntelliSense to learn about possible attributes.\n // Hover to view descriptions of existing attributes.\n"
},
{
"path": ".vscode/settings.json",
"chars": 54,
"preview": "{\n \"python.analysis.typeCheckingMode\": \"standard\"\n}"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": "\n Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 7130,
"preview": "# Hierarchical Reasoning Model\n\n\n\nReasoning, the process of devising and executing complex goal-ori"
},
{
"path": "arc_eval.ipynb",
"chars": 9317,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": null,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": "
},
{
"path": "assets/npyjs.js",
"chars": 5216,
"preview": "class npyjs {\n\n constructor(opts) {\n if (opts && !('convertFloat16' in opts)) {\n console.warn([\n "
},
{
"path": "config/arch/hrm_v1.yaml",
"chars": 349,
"preview": "name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1\nloss:\n name: losses@ACTLossHead\n loss_type: stablemax_cross_entr"
},
{
"path": "config/cfg_pretrain.yaml",
"chars": 492,
"preview": "# ARC training config\n\ndefaults:\n - arch: hrm_v1\n - _self_\n\nhydra:\n output_subdir: null\n\n# Data path\ndata_path: data/"
},
{
"path": "dataset/build_arc_dataset.py",
"chars": 10084,
"preview": "from typing import List, Optional, Tuple, Dict\nfrom dataclasses import dataclass\nfrom pathlib import Path\nimport os\nimpo"
},
{
"path": "dataset/build_maze_dataset.py",
"chars": 4461,
"preview": "from typing import Optional\nimport math\nimport os\nimport csv\nimport json\nimport numpy as np\n\nfrom argdantic import ArgPa"
},
{
"path": "dataset/build_sudoku_dataset.py",
"chars": 5753,
"preview": "from typing import Optional\nimport os\nimport csv\nimport json\nimport numpy as np\n\nfrom argdantic import ArgParser\nfrom py"
},
{
"path": "dataset/common.py",
"chars": 1381,
"preview": "from typing import List, Optional\n\nimport pydantic\nimport numpy as np\n\n\n# Global list mapping each dihedral transform id"
},
{
"path": "evaluate.py",
"chars": 2490,
"preview": "from typing import List\nimport yaml\nimport os\n\nimport torch\nimport torch.distributed as dist\n\nimport pydantic\nfrom omega"
},
{
"path": "models/common.py",
"chars": 1216,
"preview": "import math\n\nimport torch\nfrom torch import nn\n\n\ndef trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: f"
},
{
"path": "models/hrm/hrm_act_v1.py",
"chars": 12161,
"preview": "from typing import Tuple, List, Dict, Optional\nfrom dataclasses import dataclass\nimport math\n\nimport torch\nimport torch."
},
{
"path": "models/layers.py",
"chars": 5654,
"preview": "from typing import Tuple\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\ntry:\n from flash_attn_in"
},
{
"path": "models/losses.py",
"chars": 3804,
"preview": "from typing import Any, Tuple, Dict, Sequence, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import "
},
{
"path": "models/sparse_embedding.py",
"chars": 4364,
"preview": "from typing import Union\n\nimport torch\nfrom torch import nn\nimport torch.distributed as dist\nfrom torch.optim.optimizer "
},
{
"path": "pretrain.py",
"chars": 15607,
"preview": "from typing import Optional, Any, Sequence, List\nfrom dataclasses import dataclass\nimport os\nimport math\nimport yaml\nimp"
},
{
"path": "puzzle_dataset.py",
"chars": 7980,
"preview": "import os\nimport json\n\nimport numpy as np\nimport pydantic\n\nimport torch\nfrom torch.utils.data import IterableDataset, ge"
},
{
"path": "puzzle_visualizer.html",
"chars": 14113,
"preview": "<!DOCTYPE html>\n<html>\n<head>\n <meta charset=\"UTF-8\" />\n <title>ARC‐Converted Dataset Visualizer (Upload Local Folder)"
},
{
"path": "requirements.txt",
"chars": 100,
"preview": "torch\nadam-atan2\neinops\ntqdm\ncoolname\npydantic\nargdantic\nwandb\nomegaconf\nhydra-core\nhuggingface_hub\n"
},
{
"path": "utils/functions.py",
"chars": 516,
"preview": "import importlib\nimport inspect\n\n\ndef load_model_class(identifier: str, prefix: str = \"models.\"):\n module_path, class"
}
]
About this extraction
This page contains the full source code of the sapientinc/HRM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 25 files (124.9 KB), approximately 31.4k tokens, and a symbol index with 109 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.