Showing preview only (629K chars total). Download the full file or copy to clipboard to get everything.
Repository: karpathy/nanochat
Branch: master
Commit: 5019accc5bc7
Files: 52
Total size: 606.5 KB
Directory structure:
gitextract_21_vu5fa/
├── .claude/
│ └── skills/
│ └── read-arxiv-paper/
│ └── SKILL.md
├── .gitignore
├── .python-version
├── LICENSE
├── README.md
├── dev/
│ ├── LEADERBOARD.md
│ ├── LOG.md
│ ├── estimate_gpt3_core.ipynb
│ ├── gen_synthetic_data.py
│ ├── generate_logo.html
│ ├── repackage_data_reference.py
│ └── scaling_analysis.ipynb
├── nanochat/
│ ├── __init__.py
│ ├── checkpoint_manager.py
│ ├── common.py
│ ├── core_eval.py
│ ├── dataloader.py
│ ├── dataset.py
│ ├── engine.py
│ ├── execution.py
│ ├── flash_attention.py
│ ├── fp8.py
│ ├── gpt.py
│ ├── loss_eval.py
│ ├── optim.py
│ ├── report.py
│ ├── tokenizer.py
│ └── ui.html
├── pyproject.toml
├── runs/
│ ├── miniseries.sh
│ ├── runcpu.sh
│ ├── scaling_laws.sh
│ └── speedrun.sh
├── scripts/
│ ├── base_eval.py
│ ├── base_train.py
│ ├── chat_cli.py
│ ├── chat_eval.py
│ ├── chat_rl.py
│ ├── chat_sft.py
│ ├── chat_web.py
│ ├── tok_eval.py
│ └── tok_train.py
├── tasks/
│ ├── arc.py
│ ├── common.py
│ ├── customjson.py
│ ├── gsm8k.py
│ ├── humaneval.py
│ ├── mmlu.py
│ ├── smoltalk.py
│ └── spellingbee.py
└── tests/
├── test_attention_fallback.py
└── test_engine.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .claude/skills/read-arxiv-paper/SKILL.md
================================================
---
name: read-arxiv-paper
description: Use this skill when when asked to read an arxiv paper given an arxiv URL
---
You will be given a URL of an arxiv paper, for example:
https://www.arxiv.org/abs/2601.07372
### Part 1: Normalize the URL
The goal is to fetch the TeX Source of the paper (not the PDF!), the URL always looks like this:
https://www.arxiv.org/src/2601.07372
Notice the /src/ in the url. Once you have the URL:
### Part 2: Download the paper source
Fetch the url to a local .tar.gz file. A good location is `~/.cache/nanochat/knowledge/{arxiv_id}.tar.gz`.
(If the file already exists, there is no need to re-download it).
### Part 3: Unpack the file in that folder
Unpack the contents into `~/.cache/nanochat/knowledge/{arxiv_id}` directory.
### Part 4: Locate the entrypoint
Every latex source usually has an entrypoint, such as `main.tex` or something like that.
### Part 5: Read the paper
Once you've found the entrypoint, Read the contents and then recurse through all other relevant source files to read the paper.
#### Part 6: Report
Once you've read the paper, produce a summary of the paper into a markdown file at `./knowledge/summary_{tag}.md`. Notice that 1) use the local knowledge directory here (it's easier for me to open and reference here), not in `~/.cache`, and 2) generate some reasonable `tag` like e.g. `conditional_memory` or whatever seems appropriate given the paper. Probably make sure that the tag doesn't exist yet so you're not overwriting files.
As for the summary itself, remember that you're processing this paper within the context of the nanochat repository, so most often we we will be interested in how to apply the paper and its lessons to the nanochat project. Therefore, you should feel free to "remind yourself" of the related nanochat code by reading the relevant parts, and then explicitly make the connection of how this paper might relate to nanochat or what are things we might be inspired about or try.
================================================
FILE: .gitignore
================================================
.venv/
__pycache__/
*.pyc
dev-ignore/
report.md
eval_bundle/
# Secrets
.env
# Local setup
CLAUDE.md
wandb/
================================================
FILE: .python-version
================================================
3.10
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025 Andrej Karpathy
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# nanochat


nanochat is the simplest experimental harness for training LLMs. It is designed to run on a single GPU node, the code is minimal/hackable, and it covers all major LLM stages including tokenization, pretraining, finetuning, evaluation, inference, and a chat UI. For example, you can train your own GPT-2 capability LLM (which cost ~$43,000 to train in 2019) for only $48 (~2 hours of 8XH100 GPU node) and then talk to it in a familiar ChatGPT-like web UI. On a spot instance, the total cost can be closer to ~$15. More generally, nanochat is configured out of the box to train an entire miniseries of compute-optimal models by setting one single complexity dial: `--depth`, the number of layers in the GPT transformer model (GPT-2 capability happens to be approximately depth 26). All other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) are calculated automatically in an optimal way.
For questions about the repo, I recommend either using [DeepWiki](https://deepwiki.com/karpathy/nanochat) from Devin/Cognition to ask questions about the repo, or use the [Discussions tab](https://github.com/karpathy/nanochat/discussions), or come by the [#nanochat](https://discord.com/channels/1020383067459821711/1427295580895314031) channel on Discord.
## Time-to-GPT-2 Leaderboard
Presently, the main focus of development is on tuning the pretraining stage, which takes the most amount of compute. Inspired by the modded-nanogpt repo and to incentivise progress and community collaboration, nanochat maintains a leaderboard for a "GPT-2 speedrun", which is the wall-clock time required to train a nanochat model to GPT-2 grade capability, as measured by the DCLM CORE score. The [runs/speedrun.sh](runs/speedrun.sh) script always reflects the reference way to train GPT-2 grade model and talk to it. The current leaderboard looks as follows:
| # | time | val_bpb | CORE | Description | Date | Commit | Contributors |
|---|-------------|---------|------|-------------|------|--------|--------------|
| 0 | 168 hours | - | 0.2565 | Original OpenAI GPT-2 checkpoint | 2019 | - | OpenAI |
| 1 | 3.04 | 0.74833 | 0.2585 | d24 baseline, slightly overtrained | Jan 29 2026 | 348fbb3 | @karpathy |
| 2 | 2.91 | 0.74504 | 0.2578 | d26 slightly undertrained **+fp8** | Feb 2 2026 | a67eba3 | @karpathy |
| 3 | 2.76 | 0.74645 | 0.2602 | bump total batch size to 1M tokens | Feb 5 2026 | 2c062aa | @karpathy |
| 4 | 2.02 | 0.71854 | 0.2571 | change dataset to NVIDIA ClimbMix | Mar 4 2026 | 324e69c | @ddudek @karpathy |
| 5 | 1.80 | 0.71808 | 0.2690 | autoresearch [round 1](https://x.com/karpathy/status/2031135152349524125) | Mar 9 2026 | 6ed7d1d | @karpathy |
| 5 | 1.65 | 0.71800 | 0.2626 | autoresearch round 2 | Mar 14 2026 | a825e63 | @karpathy |
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. The GPT-2 CORE score is 0.256525. In 2019, the training of GPT-2 cost approximately $43,000 so it is incredible that due to many advances over 7 years across the stack, we can now do so much faster and for well below $100 (e.g. at the current ~$3/GPU/hr, an 8XH100 node is ~$24/hr, so 2 hours is ~$48).
See [dev/LEADERBOARD.md](dev/LEADERBOARD.md) for more docs on how to interpret and contribute to the leaderboard.
## Getting started
### Reproduce and talk to GPT-2
The most fun you can have is to train your own GPT-2 and talk to it. The entire pipeline to do so is contained in the single file [runs/speedrun.sh](runs/speedrun.sh), which is designed to be run on an 8XH100 GPU node. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
```bash
bash runs/speedrun.sh
```
You may wish to do so in a screen session as this will take ~3 hours to run. Once it's done, you can talk to it via the ChatGPT-like web UI. Make sure again that your local uv virtual environment is active (run `source .venv/bin/activate`), and serve it:
```bash
python -m scripts.chat_web
```
And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
---
<img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />
---
A few more notes:
- The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
- All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't personally exercised all of these code paths so there might be sharp edges.
## Research
If you are a researcher and wish to help improve nanochat, two scripts of interest are [runs/scaling_laws.sh](runs/scaling_laws.sh) and [runs/miniseries.sh](runs/miniseries.sh). See [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) for related documentation. For quick experimentation (~5 min pretraining runs) my favorite scale is to train a 12-layer model (GPT-1 sized), e.g. like this:
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=12 \
--run="d12" \
--model-tag="d12" \
--core-metric-every=999999 \
--sample-every=-1 \
--save-every=-1 \
```
This uses wandb (run name "d12"), only runs the CORE metric on last step, and it doesn't sample and save intermediate checkpoints. I like to change something in the code, re-run a d12 (or a d16 etc) and see if it helped, in an iteration loop. To see if a run helps, I like to monitor the wandb plots for:
1. `val_bpb` (validation loss in vocab-size-invariant units of bits per byte) as a function of `step`, `total_training_time` and `total_training_flops`.
2. `core_metric` (the DCLM CORE socre)
3. VRAM utilization, `train/mfu` (Model FLOPS utilization), `train/tok_per_sec` (training throughput)
See an example [here](https://github.com/karpathy/nanochat/pull/498#issuecomment-3850720044).
The important thing to note is that nanochat is written and configured around one single dial of complexity - the depth of the transformer. This single integer automatically determines all other hyperparameters (the width of the transformer, number of heads, learning rate adjustments, training horizons, weight decays, ...) so that the trained model comes out compute optimal. The idea is that the user doesn't have to think about or set any of this, they are simply asking for a smaller or bigger model using `--depth`, and everything "just works". By sweeping out the depth, you achieve the nanochat miniseries of compute optimal models at various sizes. GPT-2 capability model (which is of most interest at the moment) happens to be somewhere around d24-d26 range with the current code. But any candidate changes to the repo have to be principled enough that they work for all settings of depth.
## Running on CPU / MPS
The script [runs/runcpu.sh](runs/runcpu.sh) shows a very simple example of running on CPU or Apple Silicon. It dramatically shrinks the LLM that is being trained to make things fit into a reasonable time interval of a few ten minutes of training. You will not get strong results in this way.
## Precision / dtype
nanochat does not use `torch.amp.autocast`. Instead, precision is managed explicitly through a single global `COMPUTE_DTYPE` (defined in `nanochat/common.py`). By default this is auto-detected based on your hardware:
| Hardware | Default dtype | Why |
|----------|--------------|-----|
| CUDA SM 80+ (A100, H100, ...) | `bfloat16` | Native bf16 tensor cores |
| CUDA SM < 80 (V100, T4, ...) | `float32` | No bf16; fp16 available via `NANOCHAT_DTYPE=float16` (uses GradScaler) |
| CPU / MPS | `float32` | No reduced-precision tensor cores |
You can override the default with the `NANOCHAT_DTYPE` environment variable:
```bash
NANOCHAT_DTYPE=float32 python -m scripts.chat_cli -p "hello" # force fp32
NANOCHAT_DTYPE=bfloat16 torchrun --nproc_per_node=8 -m scripts.base_train # force bf16
```
How it works: model weights are stored in fp32 (for optimizer precision), but our custom `Linear` layer casts them to `COMPUTE_DTYPE` during the forward pass. Embeddings are stored directly in `COMPUTE_DTYPE` to save memory. This gives us the same mixed-precision benefit as autocast but with full explicit control over what runs in which precision.
Note: `float16` training automatically enables a `GradScaler` in `base_train.py` to prevent gradient underflow. SFT suppors this too but RL currently does not. Inference in fp16 works fine everywhere.
## Guides
I've published a number of guides that might contain helpful information, most recent to least recent:
- [Feb 1 2026: Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481)
- [Jan 7 miniseries v1](https://github.com/karpathy/nanochat/discussions/420) documents the first nanochat miniseries of models.
- To add new abilities to nanochat, see [Guide: counting r in strawberry (and how to add abilities generally)](https://github.com/karpathy/nanochat/discussions/164).
- To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into the SFT stage.
- [Oct 13 2025: original nanochat post](https://github.com/karpathy/nanochat/discussions/1) introducing nanochat, though now it contains some deprecated information and the model is a lot older (with worse results) than current master.
## File structure
```
.
├── LICENSE
├── README.md
├── dev
│ ├── gen_synthetic_data.py # Example synthetic data for identity
│ ├── generate_logo.html
│ ├── nanochat.png
│ └── repackage_data_reference.py # Pretraining data shard generation
├── nanochat
│ ├── __init__.py # empty
│ ├── checkpoint_manager.py # Save/Load model checkpoints
│ ├── common.py # Misc small utilities, quality of life
│ ├── core_eval.py # Evaluates base model CORE score (DCLM paper)
│ ├── dataloader.py # Tokenizing Distributed Data Loader
│ ├── dataset.py # Download/read utils for pretraining data
│ ├── engine.py # Efficient model inference with KV Cache
│ ├── execution.py # Allows the LLM to execute Python code as tool
│ ├── gpt.py # The GPT nn.Module Transformer
│ ├── logo.svg
│ ├── loss_eval.py # Evaluate bits per byte (instead of loss)
│ ├── optim.py # AdamW + Muon optimizer, 1GPU and distributed
│ ├── report.py # Utilities for writing the nanochat Report
│ ├── tokenizer.py # BPE Tokenizer wrapper in style of GPT-4
│ └── ui.html # HTML/CSS/JS for nanochat frontend
├── pyproject.toml
├── runs
│ ├── miniseries.sh # Miniseries training script
│ ├── runcpu.sh # Small example of how to run on CPU/MPS
│ ├── scaling_laws.sh # Scaling laws experiments
│ └── speedrun.sh # Train the ~$100 nanochat d20
├── scripts
│ ├── base_eval.py # Base model: CORE score, bits per byte, samples
│ ├── base_train.py # Base model: train
│ ├── chat_cli.py # Chat model: talk to over CLI
│ ├── chat_eval.py # Chat model: eval tasks
│ ├── chat_rl.py # Chat model: reinforcement learning
│ ├── chat_sft.py # Chat model: train SFT
│ ├── chat_web.py # Chat model: talk to over WebUI
│ ├── tok_eval.py # Tokenizer: evaluate compression rate
│ └── tok_train.py # Tokenizer: train it
├── tasks
│ ├── arc.py # Multiple choice science questions
│ ├── common.py # TaskMixture | TaskSequence
│ ├── customjson.py # Make Task from arbitrary jsonl convos
│ ├── gsm8k.py # 8K Grade School Math questions
│ ├── humaneval.py # Misnomer; Simple Python coding task
│ ├── mmlu.py # Multiple choice questions, broad topics
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
│ └── spellingbee.py # Task teaching model to spell/count letters
├── tests
│ └── test_engine.py
└── uv.lock
```
## Contributing
The goal of nanochat is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there are no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a ChatGPT model you can talk to. Currently, the most interesting part personally is speeding up the latency to GPT-2 (i.e. getting a CORE score above 0.256525). Currently this takes ~3 hours, but by improving the pretraining stage we can improve this further.
Current AI policy: disclosure. When submitting a PR, please declare any parts that had substantial LLM contribution and that you have not written or that you do not fully understand.
## Acknowledgements
- The name (nanochat) derives from my earlier project [nanoGPT](https://github.com/karpathy/nanoGPT), which only covered pretraining.
- nanochat is also inspired by [modded-nanoGPT](https://github.com/KellerJordan/modded-nanogpt), which gamified the nanoGPT repo with clear metrics and a leaderboard, and borrows a lot of its ideas and some implementation for pretraining.
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
- Thank you to chief LLM whisperer 🧙♂️ Alec Radford for advice/guidance.
- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
## Cite
If you find nanochat helpful in your research cite simply as:
```bibtex
@misc{nanochat,
author = {Andrej Karpathy},
title = {nanochat: The best ChatGPT that \$100 can buy},
year = {2025},
publisher = {GitHub},
url = {https://github.com/karpathy/nanochat}
}
```
## License
MIT
================================================
FILE: dev/LEADERBOARD.md
================================================
# Leaderboard
Docs on participating in the "Time-to-GPT-2" leaderboard of nanochat.
The primary metric we care about is "time to GPT-2" - the wall clock time needed to outperform the GPT-2 (1.6B) CORE metric on an 8XH100 GPU node. Originally in 2019, GPT-2 was trained by OpenAI on 32 TPU v3 chips for 168 hours (7 days), with $8/hour/TPUv3 back then, for a total cost of approx. $43K. It achieves 0.256525 CORE score, which is an ensemble metric introduced in the DCLM paper over 22 evaluations like ARC/MMLU/etc.
## How to
The script [runs/speedrun.sh](runs/speedrun.sh) always implements the current state of the art on the leaderboard.
In practice, I tune the base_train command a little bit. For example, once all the setup is configured and a tokenizer is trained, I like to do something like:
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 \
--run="d26-feb2-fp8-ratio8.25" \
--model-tag="d26_feb2_fp8_ratio8.25" \
--device-batch-size=16 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=8.25 \
--fp8
```
Note that:
- `depth` controls the size of the Transformer
- `run` is the wandb name
- `model-tag` is the location of the checkpoints on disk
- `device-batch-size` in the ideal world, you want this to be 32 because with sequence length of 2048 (the default) and 8 GPUs we get `32 X 2048 X 8 = 524,288`, which is the total desired batch size determined to work fairly well around this scale. However, for bigger (e.g. d26), 32 is too much and OOMs, so we decrease it by 2 to 16. The `base_train.py` script automatically compensates for this by calculating that it has to use gradient accumulation of 2 to meet the desired total batch size. Therefore, it will do forward+backward twice and then a single step. Long story short, the ideal value is 32. If that doesn't fit, you decrease it, e.g. 16, 8, etc., keeping it powers of two so that the gradient accumulation math works out neatly.
- `sample-every = -1` turns off periodic sampling
- `core-metric-max-per-task=-1` means we run the entire CORE eval
- `core-metric-every=999999` a bit of a hacky way to make the CORE eval only happen a single time at the very end of the run
- `target-param-data-ratio=8.25` controls the training horizon, which is determined in the script by taking the number of non-embedding model parameters and simply multiplying by this number. The current optimal Tokens:Params ratio can be seen in the defaults of the `base_train.py` script (it is 10.5). 10.5 would produce the *compute optimal* model given the currently measured scaling laws. However, GPT-2 capability is currently somewhere in between a d24 and d26. So to reach it exactly, we want to either overtrain d24 or undertrain d26. In this particular example, I am choosing to slightly undertrain a d26. Note that odd depths (e.g. d25) are not super recommended to use because the math around the transformer sizing and its head dimensions doesn't come out neatly.
- `--fp8` turns on fp8 training. If your GPU does not support fp8, you can leave this out and the code will simply train in bf16. bf16 is higher precision than fp8, so you can actually expect that you might be able to do fewer steps (lower the `target-param-data-ratio`) to achieve the same capability.
Once you kick off the run, you wait ~3 hours and then at the end you'll see something like:
```
wandb: Run summary:
wandb: core_metric 0.25851
wandb: step 16704
wandb: total_training_flops 4.330784131228946e+19
wandb: total_training_time 10949.46713
```
Your CORE metric must be greater than GPT-2 0.256525. Then you report the `total_training_time`, (e.g. 10949) which is the time of the training iterations alone, excluding all the evaluations and logging, in seconds. So here for example it is roughly 10949/60/60 ~= 3.04 hours. You should also note and report the validation bpb of your run because the CORE metric can be a little bit noisy.
If you outperform GPT-2 and the time is less than current SOTA in the Leaderboard, you get to make a PR. In addition to raw gains, there are some qualitative and aesthetic considerations that go into whether your improvement is merged. For example, if it is gnarly or it significantly bloats the code, or it seems too esoteric, then we will weigh those things against the improvement demonstrated. Additionally, nanochat cares not only about targeting a single model, but an entire miniseries of models. So your change must be principled enough that it can easily generalize to other model depths, so that we can sweep out a miniseries.
After you create the commit, to get the current short git commit hash:
```
git log -1 --format="%h"
```
## Run 1
Achieved Jan 29 2026 on commit `348fbb3`. The launch command was
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=24 \
--run=d24-jan29 \
--model-tag=d24_jan29 \
--device-batch-size=16 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=3000 \
--target-param-data-ratio=12
```
The result was:
```
wandb: Run summary:
wandb: core_metric 0.25851
wandb: step 16704
wandb: total_training_flops 4.330784131228946e+19
wandb: total_training_time 10949.46713
```
The validation bpb was 0.74833.
Detailed writeup: [Beating GPT-2 for <<$100: the nanochat journey](https://github.com/karpathy/nanochat/discussions/481)
## Run 2
Achieved Feb 2 2026 on commit `a67eba3`. The launch command was
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 \
--run="d26-feb2-fp8-ratio8.5" \
--model-tag="d26_feb2_fp8_ratio8.5" \
--device-batch-size=16 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=8.5 \
--fp8
```
The result was:
```
core_metric 0.2578
step 14889
total_training_time 10493
Minimum validation bpb: 0.745036
```
The big change in this run is `--fp8`, which causes all Linear layers (other than the gates) to be switched to fp8 training using `torchao` with tensorwise fp8 scaling. Each step is of slightly lower quality, but we are taking them a lot faster, coming out net ahead. Anyone who does not have fp8 (e.g. using a GPU without it) can simply leave out the `--fp8` flag to train in bfloat16. This will work just fine but it will produce a slightly stronger model than GPT-2 because of the fp8 -> bf16 precision upgrade. It's possible that one can further tune which layers to include in the fp8 conversion and that e.g. some of the smaller matmuls should be just kept in bf16 etc.
Previous record was 3.04 hours, so 2.91 hours is `(3.04 - 2.91)/3.04*100` ~= 4.3% speed improvement.
## Run 3
Achieved Feb 5 2026 on commit `2c062aa`. Launch command:
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=26 \
--run="d26_feb4_double_batch_ratio8.25" \
--model-tag="d26_feb4_double_batch_ratio8.25" \
--device-batch-size=16 \
--total-batch-size=1048576 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=8.25 \
--fp8
```
Result:
```
core_metric 0.26024
step 7226
total_training_time 9922
Minimum validation bpb: 0.74645
```
The big change here is that the batch size was doubled from 0.5M to 1M, which works better for a d26 model and allowed me to decrease the number of optimization steps a bit via `--target-param-data-ratio` from 8.5 to 8.25. The TLDR is that the original batch size of 0.5M was tuned for d12, but bigger models (e.g. d26) prefer larger total batch size. I determined in experiments that d26 prefers 1M. Then I implemented and merged a principled way to calculate the optimal batch size given depth so that all nanochat models of all depths benefit. See [dev/LOG.md](dev/LOG.md) entry "2026-02-05: Auto Batch Size Scaling" for more detail.
## Run 4
Achived Mar 3 2026 on commit `324e69c`. The big change is the switch from HuggingFace FineWeb-EDU to NVIDIA ClimbMix dataset. `@karpathy` has tried to swap the dataset many times, each time with a negative result (FineWeb, DCLM, Olmo), but ClimbMix produced clear and immediate gains. Credit to `@ddudek` for originally discovering ClimbMix for nanochat and reporting the improvements, which kicked off the followup investigation.
To reproduce, use the commit above, download at least 150 data shards, train the tokenizer:
```
python -m nanochat.dataset -n 150
python -m scripts.tok_train
```
Then kick off the run in the typical way, using a slightly lower than compute optimal ratio of 9.5 (vs compute optimal 10.5), meaning the d24 is slightly undertrained.
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=24 \
--run="d24-climbmix" \
--model-tag="d24-climbmix" \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=9.5 \
--device-batch-size=16 \
--fp8
```
I ran this command 7 individual times. Because our training is mildly non-deterministic, we get a spread of CORE scores, e.g.:
```
0.25373
0.2584
0.25489
0.2568
0.25732
0.26765
0.25119
```
Mean is 0.25714 (higher than the GPT-2 threshold needed), max-min is 0.01646. Something to investigate in the future is that even slightly better results can be obtained by randomly shuffling the the data shards (i.e. just going in a different order). This is unexpected because the documents were completely fully shuffled during data construction, so one would expect a relatively uniform data distribution. Indeed, the current default order is unfortunately among the worse ("unlucky") ones you can obtain with different shuffle seeds, but it suffices to beat GPT-2 for now so I am merging. TODO investing a bit more later.
NOTE: The `val_bpb` is as of this run *NOT* comparable due to the data distribution change to the previous 3 runs. This run happens to be at `0.71854` validation bpb. If the dataset is not changed, the `val_bpb` number is a great, smooth metric to track relative performance w.r.t. and has less noise than CORE.
## Run 5
Achieved Mar 9, 2026 on commit `6ed7d1d`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8.7`. I ran 5 identical runs, the average CORE was 0.2690, which is quite a bit above the needed threshold of 0.2565. But the reason I didn't decrease the ratio further (i.e. train shorter) is that while the CORE "safety gap" is large, the val_loss safety gap is smaller - 0.71808, which we want to be below the Run 4 val loss of 0.71854. It's likely that we could have reduced the ratio even lower, possibly to 8.6, but it's not worth splitting hairs at this point.
This commit is special because all of the improvements that went into [this commit](https://github.com/karpathy/nanochat/commit/6ed7d1d82cee16c2e26f45d559ad3338447a6c1b) came from fully autonomous "research" done by a private version of [autoresearch](https://github.com/karpathy/autoresearch) run on a d12 model. I wrote more about this in [this tweet](https://x.com/karpathy/status/2031135152349524125). The changes easily translated from d12 to d24, hence new leaderboard record, taking us from 2.02 hours "time to GPT-2" to 1.80 hours.
## Run 6
Achieved Mar 14, 2026 on commit `a825e63`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8`. Improvements in the architecture are allowing us to train shorter and shorter time. Instead of an undertrained d24 I attempted to train an overtrained d22 but it was worse. This set of changes came from autoresearch round 2, where I asked it to reference the modded-nanogpt repo for inspiration. So the exploration tried out a number of ideas and in particular found a way to incorporate the backout and smear in such a way that they are helpful (I had previously tried them manually a long time ago and they caused regressions). The smear idea in particular is a little bit heavier and bloaty because it is essentially an "early fusion" of context across tokens, producing a kind of a bigram input into the network and allowing it to focus on higher ngrams earlier. But for this reason the code gets a bit more complex and required some changes to inference. I verified with a unit test that the Engine inference is correct compared to the naive inference of `GPT.generate()`. The average of 5 runs was CORE 0.262634 and each of them lasted 1.65 hours (99 minutes).
================================================
FILE: dev/LOG.md
================================================
# Experiment Log
A running summary documenting some experiments and findings. Started ~Jan 7 2026.
---
## 2026-03-04: Remove autocast, explicit dtype management, fp16 GradScaler
Replaced `torch.amp.autocast` throughout the codebase with explicit dtype management via a single `COMPUTE_DTYPE` global. Also added fp16 training support with GradScaler.
### Motivation
autocast is "magic we don't control" — it silently decides which ops run in which precision via internal allowlists. For this codebase, autocast was doing very little: the only thing it actually cast was `nn.Linear` weights from fp32 to bf16 for matmuls. `F.rms_norm`, `F.cross_entropy`, and Flash Attention all handle their own dtypes already. By making precision explicit, we gain fine-grained control (e.g. can experiment with fp32 norms) and eliminate an unnecessary layer of abstraction.
### What changed
**Core mechanism** (`nanochat/common.py`, `nanochat/gpt.py`):
- `COMPUTE_DTYPE` auto-detected from hardware: SM 80+ → bf16, pre-Ampere → fp32, CPU/MPS → fp32. Override via `NANOCHAT_DTYPE` env var.
- Custom `Linear(nn.Linear)` class that casts weights to match input dtype in forward: `F.linear(x, self.weight.to(dtype=x.dtype))`. This is the single mechanism that replaces autocast.
- Embeddings cast to `COMPUTE_DTYPE` at init (saves memory). Exception: fp16 keeps embeddings fp32 because GradScaler cannot unscale fp16 gradients.
- Embedding output explicitly cast to `COMPUTE_DTYPE` in `GPT.forward()` (no-op for bf16, active for fp16 path).
- RoPE cos/sin cache uses `COMPUTE_DTYPE` instead of hardcoded bf16.
**Autocast removal** (11 files):
- Deleted `--dtype` CLI flag, `ptdtype` variables, `autocast_ctx` definitions, and all `with autocast_ctx:` blocks from: `base_train.py`, `chat_sft.py`, `chat_rl.py`, `chat_cli.py`, `chat_eval.py`, `chat_web.py`, `base_eval.py`, `engine.py`, `bench_train_toks.py`, `test_e2e_pipeline.py`.
**fp16 + GradScaler** (`base_train.py`, `chat_sft.py`):
- `scaler = torch.amp.GradScaler() if COMPUTE_DTYPE == torch.float16 else None`
- Backward: `scaler.scale(loss).backward()` vs plain `loss.backward()`
- After accumulation: `scaler.unscale_(optimizer)` → distributed inf-sync via `scaler._found_inf_per_device(optimizer)` all-reduced with `ReduceOp.MAX` → `scaler.step(optimizer)` → `scaler.update()`
- Zero overhead for bf16/fp32 paths (scaler is None, no branching inside kernels).
**FP8 fix** (`nanochat/fp8.py`, `base_train.py`):
- `Float8Linear.forward` explicitly casts input to `COMPUTE_DTYPE` (previously relied on autocast).
- `disable_fp8` context manager now creates our custom `Linear` (not vanilla `nn.Linear`) when swapping out Float8Linear during eval.
**Flash Attention** (`flash_attention.py`):
- FA3 Hopper kernels don't support fp16 or fp32, so `USE_FA3` (module-level constant, resolved once at import) returns False, falling back to SDPA.
---
## 2026-03-04: Dataset upgrade: FineWeb-EDU 100B → ClimbMix 400B
Switched the pretraining dataset from FineWeb-EDU 100B to ClimbMix 400B. This is by far the single biggest improvement to nanochat's GPT-2 speedrun time, bringing it down from **2 hours 46 minutes to 2 hours 1 minute** — a 27% reduction.
### What is ClimbMix?
ClimbMix 400B is a curated 400B-token pretraining mixture hosted at `karpathy/climbmix-400b-shuffle` on HuggingFace. It comes form [NVIDIA](https://huggingface.co/datasets/nvidia/Nemotron-ClimbMix). It is a blend of high-quality web text, code, math, and other sources, designed to be a better general-purpose pretraining dataset than FineWeb-EDU alone.
### What changed
- **Dataset**: `karpathy/fineweb-edu-100b-shuffle` → `karpathy/climbmix-400b-shuffle` (up to 6543 shards available vs the previous 1823 data shards, allowing for longer training in the future)
- **Data directory**: `base_data/` → `base_data_climbmix/` (clean separation from legacy data)
- **Model depth**: d26 → d24. ClimbMix trains more efficiently, so a smaller model reaches GPT-2 capability
- **Shard count**: Only approx 150 data shards (~7B tokens) are now needed for GPT-2 capability
- **Eval tokens**: doubled from 40 to 80 batches for more stable validation loss estimates
- **Legacy fallback**: added a migration warning in `list_parquet_files()` that detects the old `base_data/` directory and falls back gracefully, so existing users see clear upgrade instructions on `git pull`
### Context
This is the sixth attempt at beating FineWeb-EDU on CORE score — the previous five all failed (see entries on 2026-02-17, 2026-02-10, 2026-01-12 below). ClimbMix is the first dataset to convincingly surpass it, and the margin is large enough to also shrink the model from d26 to d24.
---
## 2026-03-02: SoftCap tuning
Quick experiment to tune logit softcap on d24 scale. Tried 5..30. 5 was terrible, the rest of them were all about equal with the exception of 20, which was the best. Minor but solid improvement: val loss improved by ~1e-3 (0.716 -> 0.715). Setting as default.
## 2026-02-19: Mixture of Experts (negative)
Implemented a DeepSeekV3-style Mixture of Experts layer as a drop-in replacement for the dense MLP. The MoE branch works and improves per-step validation loss, but is not a net improvement on wall clock time due to MoE overhead (at least for our scale of interest of approx GPT-2 capability).
### Implementation
Follows DeepSeekV3 and using torchtitan as reference:
- **8 routed experts, top-2 routing** with sigmoid gating (not softmax)
- **1 shared expert** (dense MLP processing all tokens, following DeepSeekV3)
- **Auxiliary-loss-free load balancing** (DeepSeekV3's expert bias nudging)
- **Iso-FLOP sizing**: `expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128`, so active FLOPs per token match the dense MLP
- **`torch._grouped_mm`** for dispatching tokens to experts in a single kernel (instead of a Python for-loop)
- **3D expert weight tensors** `(num_experts, hidden, dim)` — Muon's Polar Express operates on the last two dims, so each expert is independently orthogonalized
- **Active parameter counting** for scaling laws (only `top_k + shared` experts, not all 8)
### What was easy
- The core MoE forward pass: router, sort tokens by expert, grouped matmul, scatter back. Conceptually clean.
- Shared expert: just an `nn.Linear` MLP that runs on all tokens alongside the routed path.
- 3D expert params + Muon: only required fixing `second_momentum_buffer` shape to preserve leading dims.
- Load balancing: DeepSeekV3's bias nudging is simple and effective (~10 lines).
### What was hard / ugly
- **`torch._grouped_mm` quirks**: requires bf16 (not fp32), column-major right operand, int32 cumulative offsets. The API is undocumented and only discoverable by trial and error.
- **Token count padding**: torchtitan pads each expert's token count to alignment multiples (8 for bf16) for better grouped_mm throughput. We implemented this with both a pure PyTorch approach and a copy of torchtitan's Triton kernel. Both compiled cleanly (0 graph breaks), but with ~65K tokens across 8 experts, each expert already gets ~8K tokens which is well-aligned. The padding overhead (gather/scatter) actually regressed MFU from 35% to 33%. Reverted.
- **FP8 + MoE**: `torch._grouped_mm` does NOT support FP8. There's a separate `torch._scaled_grouped_mm` API that requires per-row scaling (not per-tensor like our `Float8Linear`). The backward pass for weight gradients needs per-group column-wise scaling, which torchao implements with custom Triton kernels. We investigated thoroughly (see `dev/moe_fp8.md`) but did not implement — would require either depending on `torchao.prototype` (unstable) or writing ~200 lines of custom autograd + quantization code. Partial FP8 support exists: the shared expert's `nn.Linear` layers do get converted, but the routed experts (3D `nn.Parameter`) stay in bf16.
### Results
- d18: MFU dropped from ~46% to ~35% (the grouped_mm dispatch + token sorting overhead is significant)
- Per-step improvement in validation loss does not compensate for the throughput hit
- Net negative on wall clock time
### What remains (if revisited)
- **FP8 for routed experts**: Use `torch._scaled_grouped_mm` with a custom `_Float8GroupedMatmul` autograd function, with bf16 fallback for weight gradient (avoiding the per-group column-wise Triton kernels).
What's really needed is a fused "FlashMoE" kernel that handles routing + expert dispatch + matmul in one shot (like FlashAttention did for attention), with all the needed features. This doesn't exist yet. Rawdogging MoE with current PyTorch primitives is painful — lots of sorting, gathering, scattering, and layout wrangling around the actual compute.
### Verdict
MoE is not worth the trouble for nanochat right now. The code bloat is substantial (moe.py, router, shared expert, load balancing, optimizer fixes, FP8 gaps, active param counting) and the performance is worse wall-clock at our scale of interest. The fundamental issue is that the grouped_mm dispatch overhead eats the FLOP savings from sparsity, at least at our model scales and sequence lengths.
---
## 2026-02-17: Pretraining Data: FineWeb (negative)
Tried vanilla fineweb instead of fineweb-edu dataset. Significantly, shockingly worse results:
- d26 (GPT-2): CORE 0.2602 → 0.2241
This is the fifth failed attempt to beat pure FineWeb-EDU on CORE score.
---
## 2026-02-17: Pretraining Data Mixture Experiment (negative)
Tried [hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT](https://huggingface.co/datasets/hynky/finepdfs_50BT-dclm_30BT-fineweb_edu_20BT), a mixture of FinePDFs, DCLM, and FineWeb-EDU. Slightly worse on both model sizes tested:
- d26 (GPT-2): CORE 0.2602 → 0.2549
- d18: CORE 0.199 → 0.192
This is the fourth failed attempt to beat pure FineWeb-EDU on CORE score.
---
## 2026-02-16: SFT Script Upgrades
Brought `chat_sft.py` up to parity with `base_train.py` and tuned settings based on SFT sweeps.
Tuning:
- **Optimizer warm-start** (`--load-optimizer=1`, default on): loads pretrained momentum buffers via new `load_optimizer_state()` in `checkpoint_manager.py`. LRs are reset to fresh SFT values after load. Loading the optimizer works slightly better but not by too much.
- **LR schedule**: replaced "constant 80%, linear to 0" with warmup/constant/warmdown matching `base_train.py` (`--warmup-ratio`, `--warmdown-ratio`, `--init-lr-frac`, `--final-lr-frac`). Similar to pretraining, warmdown ratio of 0.5 worked the best. `--init-lr-frac` changed from 1.0 slightly lower to 0.8.
- **LR tuning**: attempted to tune all the individual LRs (e.g. does SFT prefer lower LR for embeddings? etc.) but all of this produced negative results.
- **Data mixture**: MMLU epochs 1→3, GSM8K epochs 2→4 (confirmed best from sweeps). Epoch counts now configurable via `--mmlu-epochs` / `--gsm8k-epochs`. Might remove these in the future though.
Quality of life, footguns, minor fixes:
- **Hyperparameter inheritance**: SFT now inherits batch sizes and LRs from the pretrained checkpoint metadata by default (CLI overrides still work). Also saved `total_batch_size` to `base_train.py` checkpoint metadata.
- **GC management**: disabled Python GC after step 1 to avoid ~500ms pauses (manual collect every 5000 steps), same as base pretraining.
- **ChatCORE eval**: periodic eval during SFT (`--chatcore-every=200`) across all 6 tasks, logged to wandb.
- **MFU**: uses `get_peak_flops()` for actual GPU instead of hardcoded H100 value.
- Removed `--dry-run` and `--dtype` flags. All ranks now participate in checkpoint save.
---
## 2026-02-05: Auto Batch Size Scaling
### Background
So far, the `--total-batch-size` was hardcoded to be `2**19 = 524,288` ~= 0.5M tokens. This was the optimal setting for d12, but when I tried to re-tune it for d26 (GPT-2), I noticed that the optimal was closer to `2**20 = 1,048,576` ~= 1M tokens. This is to be expected - larger models prefer a higher optimal total batch size. However, we have to make sure that all settings of `--depth` get their own optimal batch size calculated in some principled way. Here, I referenced the "Power Lines" paper from Cerebras ([arXiv:2505.13738](https://arxiv.org/abs/2505.13738)) for a lot of related experimentation. In particular, they found that **Bopt ∝ D^0.383** (where D is the number of training tokens, not the number of parameters!). So the idea is to tune the optimal batch size on d12, and then extrapolate it with this power law to bigger models. The 0.383 exponent means batch size grows slowly: 10× more tokens only justifies ~2.4× bigger batch. For nanochat's compute-optimal training (D ∝ N via `--target-param-data-ratio`), this means deeper models naturally want larger batches.
### Implementation
Added `--total-batch-size=-1` (now the default) to auto-compute optimal batch:
```python
get_scaling_params = lambda m: m.num_scaling_params()['transformer_matrices'] + m.num_scaling_params()['lm_head']
if args.total_batch_size == -1:
D_REF = args.target_param_data_ratio * get_scaling_params(build_model_meta(12))
B_REF = 2**19
args.total_batch_size = 2 ** round(math.log2(B_REF * (target_tokens / D_REF) ** 0.383))
```
Reference point: d=12 model with B=2^19 (empirically validated). The reference is computed dynamically so that if the architecture changes (e.g., different `--aspect-ratio`), the math automatically adjusts. However, if the model actually does change too much, one would also want to re-tune the optimal batch size for d=12.
### Results
With this formula, we currently get:
| Depth | Scaling Params | Target Tokens | Auto Batch |
|-------|---------------|---------------|------------|
| d=8 | 42M | 0.44B | 2^18 = 262K |
| d=10-16 | 70M-235M | 0.7B-2.5B | 2^19 = 524K |
| d=18-26 | 324M-918M | 3.4B-9.6B | 2^20 = 1.05M |
| d=32-50 | 1.7B-6.2B | 17.6B-65.6B | 2^21 = 2.1M |
In particular, this matches empirical observations that d26 prefers ~2^20 while d12 prefers ~2^19.
### Code Cleanup
Also refactored model initialization to use `build_model_meta(depth)` helper and `dataclasses.asdict()` for cleaner config handling.
### Useful references
- [Bergsma et al., Power Laws for Batch Size, Model Size, and Training Horizon](https://arxiv.org/abs/2505.13738)
- [McCandlish et al., An Empirical Model of Large-Batch Training](https://arxiv.org/abs/1812.06162)
- [Brown et al., Language Models are Few-Shot Learners](https://arxiv.org/abs/2005.14165)
- [Merrill et al., The Batch Size–Critical Batch Size Myth](https://arxiv.org/abs/2505.23971)
### One more thing (batch size ramp)
Tried batch size ramping. The simplest implementation I could think of "tricks" the existing training loop by slicing each micro-batch into smaller pieces and calling optimizer.step() more frequently early in training (1/8 → 1/4 → 1/2 → full batch over the first x% of training, with sqrt LR scaling). Also required a torch.compile warmup phase to pre-compile all slice sizes and avoid recompilation spikes during training. While the idea is sound and small gains were observed, they weren't sufficient to justify the code complexity introduced (conditional slicing logic, warmup with state save/restore, etc.). Not merged for now.
---
## 2026-02-05: SwiGLU Activation (Negative Result)
Replaced ReLU² MLP activation with SwiGLU (inspired by [twitter](https://x.com/_xjdr/status/2019141521690567058)). SwiGLU uses three projections instead of two, so to match parameters and FLOPs we scale hidden_dim from 4× to 8/3×:
```python
# Old ReLU²: 2 matrices, 4x expansion
# params: 2 × n × 4n = 8n²
# flops: 2 × 2n × 4n = 16n² per token
self.c_fc = Linear(n_embd, 4 * n_embd)
self.c_proj = Linear(4 * n_embd, n_embd)
x = c_proj(relu(c_fc(x)).square())
# New SwiGLU: 3 matrices, 8/3x expansion
# params: 2 × n × (8n/3) + (8n/3) × n = 8n² ✓ matches
# flops: 3 × 2n × (8n/3) = 16n² per token ✓ matches
hidden_dim = (8 * n_embd) // 3
self.w1 = Linear(n_embd, hidden_dim) # gate
self.w2 = Linear(n_embd, hidden_dim) # up
self.w3 = Linear(hidden_dim, n_embd) # down
x = w3(silu(w1(x)) * w2(x))
```
Tested at both d12 and d24 (GPT-2 scale). Worse on all measures — step efficiency, wall clock time, and FLOPs. ReLU² remains superior for nanochat. **Not adopted.**
---
## 2026-02-03: Flip Muon MLP LR Multiplier (PR #492)
Tested flipping the shape-based LR heuristic in Muon from boosting tall matrices (input projections like `c_fc`) to boosting wide matrices (output projections like `c_proj`). The original code applies `max(1, rows/cols)^0.5`, giving ~2x LR to `c_fc`. The flipped version gives ~2x LR to `c_proj` instead, which aligns with classical fan-in/fan-out scaling conventions. This was proposed in [PR #492](https://github.com/karpathy/nanochat/pull/492) and showed improvements in modded-nanogpt.
**Result:** Quick d12 experiment: slightly worse **Not adopted.**
---
## 2026-02-03: Skip AdamW Every Other Step
Inspired by modded-nanogpt, tried stepping AdamW only on odd iterations while Muon steps every iteration. The idea is that small AdamW params (embeddings, scalars, gates) don't need updates as frequently as the large weight matrices, and skipping saves both compute and communication.
Added `skip_adamw` parameter to `MuonAdamW.step()` and `DistMuonAdamW.step()` plus a matching `zero_grad(skip_adamw=...)` to let AdamW gradients accumulate over 2 steps. Used `lr *= 2**-0.5` (sqrt scaling) to compensate for the 2x effective batch size on AdamW params.
**Result:** for nanochat d12, we see ~2% faster tok/s, but each step is slightly worse in loss. On net, when plotting against wall clock time, it's slightly worse. **Not adopted.**
---
## 2026-02-02: FP8 Training with torchao
Integrated FP8 training using `torchao.float8` to accelerate Linear layer matmuls on H100 GPUs.
### Background
FP8 (8-bit floating point) uses H100's FP8 tensor cores for ~2x theoretical matmul throughput. The tradeoff is quantization overhead: computing scales and casting tensors to/from FP8. Still, as an example torchtitan (Meta's distributed training framework) reports 25-28% speedups with FP8 for some of their experiments.
**Previous attempt (Jan 2026):** FP8 on just `lm_head` following modded-nanogpt with custom ops → 1% speedup, +2GB memory. Failed due to fragile torch.compile interaction. But this experiment was also done on ~d12 scale back then instead of the bigger model that gets GPT-2 capability of approx d24.
**This attempt:** Use torchao's `convert_to_float8_training()` on ALL Linear layers, increase model size to d24. The core snippet is:
```python
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
config = Float8LinearConfig.from_recipe_name("tensorwise")
convert_to_float8_training(model, config=config)
```
But in practice it's more involved (see base_train.py).
### Results
**Microbenchmark (d26 MLP, 65536x1664 @ 1664x6656):**
| Method | Forward | Fwd+Bwd | Speedup |
|--------|---------|---------|---------|
| BF16 + compile | 2.00ms | 4.79ms | 1.00x |
| FP8 rowwise + compile | 1.84ms | 4.55ms | 1.08x |
| FP8 tensorwise + compile | 1.45ms | 4.06ms | **1.38x** |
| FP8 rowwise (no compile) | 2.89ms | 21.86ms | 0.23x ❌ |
torch.compile is MANDATORY. Without it, FP8 is 4x slower due to unfused scaling ops.
**Full training (d26):**
| Config | tok/sec | vs baseline |
|--------|---------|-------------|
| BF16 baseline | 630K | 1.00x |
| FP8 rowwise | 564K | 0.90x ❌ |
| FP8 tensorwise | 740K | **1.17x** ✓ |
Memory usage also decreases quite a bit, by ~9GB (activations stored as FP8 instead of BF16).
Seeing 17% speedup is encouraging but we're still not done yet because each step is now in lower precision and less powerful individually, so to make up for the precision drop we have to train longer. Empirically, running some sweeps overnight on d24 scale, I saw that the actual speedup (when you match performance) is closer to 5%. It's possible that our LLMs at ~d24 scale are still too small to confidently enjoy the speedups that come from fp8 for bigger models.
### Key Learnings
For nanochat at approximate scale of interest (~GPT-2 capability, ~d24):
1. **Tensorwise >> Rowwise** - Rowwise computes per-row scales, overhead exceeds benefit. Tensorwise uses one scale per tensor.
2. **Filter small layers** - Layers with dims not divisible by 16 must be skipped (FP8 hardware requirement)
3. **Larger models benefit more** - d12 was still slower with FP8; d26+ shows gains. Therefore, in some depths there is a benefit to fp8 and in some there isn't. Keeping it configurable for now, passed in via kwargs and default off.
4. **The effective, capability-matched speedup is lower still** - because each step is of slightly lower precision/quality.
### Integration
Added `--fp8` flag to `base_train.py`, default recipe is "tensorwise", example of turning on:
```bash
torchrun --nproc_per_node=8 -m scripts.base_train --depth=24 --fp8
```
Uses tensorwise by default. Requires `torchao==0.15.0` (compatible with torch 2.9.1), which was added to dependencies.
**TLDR**: turning on fp8 for GPT-2 capability nanochat model gives approx +5% capability-matched speedup.
---
## 2026-01-29: Hyperball/MuonH Experiments (Negative Result)
Explored Hyperball optimization from [this post](https://psychedelic-sunstone-851.notion.site/Fantastic-Pretraining-Optimizers-and-Where-to-Find-Them-2-1-Hyperball-Optimization-2e924306e6f280e7a5ffee00eb40a0dd) (saved to `knowledge/muonh.md`). Constrains weights to sphere of radius R (initial norm): `W_{t+1} = R · Normalize(W_t - η·R · Normalize(u_t))`. Had to change a number of details in a branch, e.g. not use zero init for our projections (or the initial norm would be zero), keep track of the initial norm, adjust Muon -> MuonH for the update.
Experiments on d12:
| Experiment | Result |
|------------|--------|
| MuonH for matrix params | Worse than baseline |
| MuonH + LR sweep (2.5e-3 to 1e-2) | Still worse |
| Added learnable RMSNorm scales (paper says γ preserves expressivity) | Still worse |
| Various RMSNorm init tweaks, e.g. 0 at init to residual | Still worse |
| AdamH for lm_head (paper recommends this) | Broken - loss plateaus (see below) |
| AdamH + learnable output scales | Still worse |
Could not outperform the baseline implementation. The article doesn't go into too much detail on how AdamH is applied to `lm_head` exactly. The classifier layer has to be able to increase in magnitude to make more confident predictions over time. Tried a sensible version with added 0-D learnable scalar, and also with RMSNorms with per-channel learnable scalars both pre and post resnet blocks.
**Result:** This was not an out-of-the-box win for nanochat even with a mild attempt over a few hours at a bit of tuning and debugging. The idea itself is intuitively appealing. Might come back around later to try harder later.
---
## 2026-01-28: Reverted Bigram Hash Embeddings
Removed bigram embeddings (engram-lite) from the codebase. At larger scale (d25), the improvement was tiny and disappeared entirely when measured by wall clock time. It also bloated the VRAM used. The extra parameters and complexity aren't justified.
---
## 2026-01-27: Bigram Hash Embeddings (Engram-lite)
Explored N-gram memory modules inspired by the [DeepSeek Engram paper](https://arxiv.org/abs/2601.07372) and [modded-nanogpt PR #201](https://github.com/KellerJordan/modded-nanogpt/pull/201).
### Background
The Engram paper introduces "conditional memory" as a complement to MoE - using O(1) hash lookups to retrieve static N-gram patterns instead of reconstructing them through computation. Key insight: transformers waste early layers "simulating retrieval through computation" for patterns like named entities and formulaic phrases that could be simple table lookups.
### What We Tried
**1. Full Engram module with context-aware gating (paper design)**
```python
# Hash bigrams to retrieve embeddings, then gate with hidden state
e = embed(hash(prev_token, curr_token))
q = RMSNorm(h) # hidden state as query
k = RMSNorm(W_k @ e) # projected embedding as key
v = W_v @ e
α = sigmoid(q · k / √d) # scalar gate per position
output = α * v
```
- Injected after block 1 (paper found early injection optimal)
- Slight improvement, but quite a bit of complexity added.
**2. Early-layer only injection**
- Only inject bigram signal in first 4 layers (where paper claims static pattern offloading helps most)
- **Result:** Actually hurt performance. The model seems to need uniform injection across all layers.
**3. Trigrams**
- Extended to hash both 2-grams and 3-grams, concatenating embeddings
- **Result:** No improvement over bigrams alone. Dilutes capacity from more frequent 2-gram patterns.
**4. Bigram-only with x0-style injection (modded-nanogpt engram-lite approach)**
- Simple hash: `(36313 * curr) XOR (27191 * prev) mod table_size`
- Zero-init embedding table, learned per-layer lambdas
- Add to residual at every layer: `x = resid_λ[i]*x + x0_λ[i]*x0 + bigram_λ[i]*x0_bigram`
- **Result:** This simple approach works and provides a consistent improvement.
TLDR The winning approach follows modded-nanogpt's "engram-lite", simply adding the following module and feeding its output into the residual branch (gated by a per-layer learnable \lambda) before every single block:
```python
class BigramEmbed(nn.Module):
def __init__(self, vocab_size, embed_dim, table_multiplier=5):
self.embed = nn.Embedding(vocab_size * table_multiplier, embed_dim)
def forward(self, idx):
h = (36313 * idx[:, 1:]) ^ (27191 * idx[:, :-1]) % (table_size - 1)
return self.embed(h)
```
As for optimal hyperparameters:
- **Table size:** `vocab_size * 5` (~164K entries for 32K vocab). Swept a number of settings and 5 was optimal.
- **Injection:** Every layer via learned `bigram_lambdas` (init 0.1 was better than 0.0).
- **Normalization:** Also tried adding a `norm()` to the embeddings (mirroring the token embeddings), this was slightly worse.
- **Init:** Zero-init embedding, so starts as identity (tried small noisy init, it's worse)
- **Optimizer:** AdamW with same LR as token embeddings
### Key Learnings
1. **Gating didn't help at our scale.** The paper's context-aware gating mechanism (sigmoid dot-product gate) added parameters and complexity without improvement. modded-nanogpt found the same: "simple direct addition to the residual stream outperformed by a decent margin."
2. **Uniform injection beats early-only.** Despite the paper's finding that early layers benefit most, restricting injection to early layers hurt. The x0-style "add everywhere with learned lambda" pattern works better for our architecture/scale.
3. **Bigrams are sufficient.** Trigrams didn't help - the extra context doesn't pay for the diluted capacity.
4. **Scale matters.** The Engram paper's results are at 27B params with MoE. At our ~100M-1B scale, the simpler approach wins. The elaborate gating mechanism may become useful at larger scales where collision handling matters more.
### Parameters Added
For d12 model with `table_multiplier=5`:
- Bigram embedding: 32768 × 5 × 768 = ~126M params
- Per-layer lambdas: 12 scalars (negligible)
If you're keeping track, we now have *a lot* of parameters, a significant amount of them in embeddings (token embeddings, bigram embeddings, value embeddings). For example, for a d12 we now have:
```
Parameter counts:
wte : 25,165,824
bigram_embed : 125,829,120
value_embeds : 150,994,944
lm_head : 25,165,824
transformer_matrices : 84,935,808
scalars : 36
total : 412,091,556
```
In other words, only about a quarter of parameters are now weight projections and the vast majority is embedding tables.
Still, on all axes (steps, wall clock time, flops), this somewhat parameter-bloated architecture beats the baseline and will now become the default.
After adding the engram-lite, I re-ran the scaling laws to determine the new optimal tokens:params ratio. I swept FLOPs in the range 1e18..1e19, exponentially strided in 4 settings (1e18, 2e18, 5e18, 1e19). I looked at a number of ways of determining the effective parameter count for the purposes of the scaling laws. The results looked like this:
```
Kaplan-style (all projections including lm_head and no embeddings)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 110,678,115 1,241,505,403 11.2 0.8972
2e+18 167,797,457 1,785,336,422 10.7 0.8616
5e+18 250,650,865 2,642,234,152 10.8 0.8293
1e+19 381,758,347 3,806,871,243 10.3 0.7999
N \propto C^0.54, D \propto C^0.49
Chinchilla-style (all parameters, period.)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 416,320,605 1,232,157,011 3.0 0.8974
2e+18 560,239,841 1,763,669,281 3.2 0.8616
5e+18 741,495,903 2,629,909,368 3.6 0.8291
1e+19 988,644,331 3,884,841,895 4.0 0.7999
N \propto C^0.37, D \propto C^0.50
Transformer-only-style (only the projections inside the transformer)
Optimal configurations (from quadratic fits):
FLOPs Eff Params Tokens Ratio Val BPB
-----------------------------------------------------------------
1e+18 80,259,665 1,315,639,547 17.2 0.8966
2e+18 131,488,566 1,864,134,141 14.5 0.8622
5e+18 220,985,474 2,595,328,843 12.1 0.8302
1e+19 401,213,504 3,328,704,512 8.5 0.7994
N \propto C^0.70, D \propto C^0.41
```
Clearly, the Kaplan-style ratios are most consistent and produce stable ~0.5 exponents for both params and tokens, meaning we can have a single fixed ratio of tokens:params for compute optimal models. This turns out to be about ~10.5, which now becomes the new default.
---
## 2026-01-19 to 2026-01-22: Optimizer Hyperparameter Sweep
Ran ~320 experiments across 6 rounds, scaling from d12→d16→d20 to find optimal optimizer hyperparameters. Added granular per-component control to `setup_optimizers()` — separate LRs and betas for embedding, unembedding, value_embeds, resid_lambdas, x0_lambdas, and Muon matrix params.
### What We Swept
- Learning rates for all 6 parameter groups
- Beta1/beta2 for all 5 AdamW groups
- Muon momentum (start/end), weight decay
- Hundreds of combinations (2-way, 3-way, 4-way, etc.)
### The Journey
**At d12**, found two independent improvement routes:
- **Route A:** emb_lr↑ (0.3→0.4), weight_decay↑ (0.1→0.15), matrix_lr↑ (0.02→0.025)
- **Route B:** x0_lr↓ (0.5→0.2), x0_beta1↑ (0.8→0.9+)
Both gave ~0.002 improvement, but combining them caused conflicts. Fine-tuning found wd=0.13, matrix_lr=0.027, emb_lr=0.38 helped slightly. Best d12 config: Route A + x0_beta1=0.95.
**At d16**, Route B became competitive with Route A. The routes still conflicted when combined.
**At d20** (target scale), everything changed:
- Fine-tuned values from d12 **actively hurt** performance
- Routes no longer conflicted
- Just `x0_beta1=0.96` alone captured nearly all the gains
### Final x0_beta1 Sweep at d20
| x0_beta1 | val/bpb | Δ vs baseline |
|----------|---------|---------------|
| **0.96** | **0.7971** | **-0.0007** |
| 0.94 | 0.7972 | -0.0006 |
| 0.90 | 0.7972 | -0.0006 |
| 0.97 | 0.7977 | -0.0001 |
| 0.98 | 0.8011 | +0.0033 💀 |
Flat plateau from 0.90-0.96, then sharp cliff at 0.97+.
### Key Learnings
1. **Hyperparameters are scale-dependent.** What works at d12 doesn't transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20.
2. **Improvement magnitude shrinks with scale.** ~0.002 at d12 → ~0.0007 at d20. The baseline is already better-tuned for larger models.
3. **Sharp cliffs exist.** x0_beta1=0.98 is catastrophic while 0.96 is optimal.
4. **Don't over-tune on small proxies.** Validate at target scale before shipping.
### Final Recommendation
For production d20 runs, add one flag:
```
--x0-lambdas-beta1=0.96
```
Skip everything else discovered at smaller scales.
---
## 2026-01-18: More various experiments
- Tried Muon custom kernels for XXT and all the others. The improvement was there for targeted tests (~20%) but washed out completely to noise in an actual training run, especially because the Muon compute is split across all the workers. Abandoned due to complexity bloat.
- Fuse Q,K,V,O nn.Linear layers into a single QKVO Linear layer. ~Zero impact
- Tried the `sa_lambdas` that gate QKV and O. Slightly confused because of the use of rmsnorm, which erases the effect of any scalar multiplier. Helped a tiny bit (~1e-4 of loss), abandoned to control complexity.
---
## 2026-01-17: Various experiments
Modded-nanogpt uses [Value Embeddings](https://arxiv.org/abs/2410.17897) (VEs) in a funny U-shaped structure, 3 of them in total and with gates. I tried a large number of tweaks on this today:
- VEs at every layer, at alternating layers, U shaped, front and back. Alternating layers worked best, i.e. we end up with *a lot* more VEs than modded-nanogpt, at every other layer. It works better.
- Many parameters sharing ideas to reduce new parameter count, nothing here worked. All failed.
- Many ideas to reduce parameter count, the LLM hates all of them: low rank decompositions, projections. All failed.
- Gated yes or no and how much. Gate helps.
Long story short is that the models *love* Value Embeddings. It is a way to add a huge amount of capacity (parameters) to the model at almost zero cost of FLOPs, because these embeddings are simply added to the Values tensor. Any attempt to reduce the capacity of value embeddings (param sharing, low rank, projections) fail. The model wants many of them, and with all the capacity, and doing so wins across all x axes of steps, flops and wall clock. I re-ran the scaling laws and, because the models are now very parameter bloated, the optimal ratio has halved from 8 to 4! Way down lower than Chinchilla's 20 at this point.
Other experiments, looking at val/bpb as a function of all of steps, flops and wall clock time:
- Aspect ratio of 128 is worse than 64, I tried a sweep fixing FLOPs == 1e18 and 64 outperforms. The LLM prefers to be slightly thinner and longer.
- Head dim definitely prefers to be 128 instead of 64, i.e. fewer bigger heads
- Bunch of other random stuff like that.
Keeping all of this work on a private branch for now but hope to push shortly.
---
## 2026-01-17: Modded-nanogpt Ideas Sweep (Continued)
Continued testing ideas from modded-nanogpt.
| Idea | Result | Notes |
|------|--------|-------|
| Attention gates | No improvement | Per-head learnable gates on attention output. +1GB memory, decreased efficiency. |
| Batch size schedule | Abandoned | 8→16→24 with LR scaling. Made training script too bloated/complex, not worth cognitive overhead. |
| Value embeddings | Helps a lot | Experiments still ongoing, more on this later. |
---
## 2026-01-16: Flash Attention 3 Fallback to SDPA
Added automatic fallback from Flash Attention 3 to PyTorch's `scaled_dot_product_attention` (SDPA) for users without Hopper GPUs. This enables nanochat to run on older CUDA GPUs, CPU, and MPS (Apple Silicon).
### Implementation
Created `nanochat/flash_attention.py` - a unified interface that:
- Detects FA3 availability at import time (requires sm90+ / Hopper)
- Exports a `flash_attn` object matching FA3's API exactly (`flash_attn.flash_attn_func`, `flash_attn.flash_attn_with_kvcache`)
- Automatically routes to FA3 or SDPA based on hardware
- Handles tensor layout differences: FA3 uses (B, T, H, D), SDPA uses (B, H, T, D)
- Implements sliding window attention via explicit masks for SDPA
- Manages KV cache manually for SDPA (FA3 does it in-place)
### Changes to Existing Files
Changes to existing code were intentionally kept extremely minimal.
**gpt.py**: Only the import line changed and a comment
**engine.py**: Zero changes needed
**base_train.py**: Added status print and warnings:
- Prints whether FA3 or SDPA fallback is being used
- Warns about efficiency loss without FA3
- Warns about sliding window support if `--window-pattern` is not "L"
### Testing
Tests are split into two classes due to dtype/device constraints:
1. **TestFA3VsSDPA**: Comparison tests requiring Hopper GPU + bfloat16. Run both implementations on identical inputs and verify outputs match (max diff typically 0, at most ~0.004 for sliding window).
2. **TestSDPAOnly**: SDPA-only tests that run on any device with appropriate dtype. Verify forward pass, backward pass, and KV cache work correctly.
Added `_override_impl` mechanism for testing - can force 'fa3' or 'sdpa' to directly compare implementations.
### Notes
- SDPA fallback is significantly slower than FA3 especially in that it lacks the sliding window attention support
- Recommend `--window-pattern L` (full context) when using SDPA fallback
---
## 2026-01-16: Modded-nanogpt Ideas Sweep (Mostly Negative)
Tested several architectural ideas from modded-nanogpt to see if they transfer to nanochat. All of these did not help:
| Idea | Result | Notes |
|------|--------|-------|
| Half-truncated RoPE | No improvement | Only first half of head dims get RoPE (base 1024, linspace). Second half "stationary". |
| Asymmetric softcap | Slightly worse | `23 * sigmoid((x+5)/7.5)` vs our symmetric `15 * tanh(x/15)`. May only help with FP8. |
| Smear gate | Negligible | Blend each token with predecessor via learned gate. Tiny improvement not worth n_embd² params. |
| Backout | No improvement | Save activations at ~60% through network, subtract scaled version at end. |
| Skip connection | Slightly worse | Save at layer ~25%, add at layer ~50%. Also +2GB memory from storing activations. |
Value Embeddings do show promise. I need a more elaborate exploration of a few related ideas, which I leave for tomorrow.
---
## 2026-01-15: Olmo pretraining mix (Negative result)
I attempted to train on the Olmo 3 pretraining dataset [allenai/dolma3_mix-6T](https://huggingface.co/datasets/allenai/dolma3_mix-6T) instead of FineWeb-edu. I ran into a number of [errors and issues](https://huggingface.co/datasets/allenai/dolma3_mix-6T/discussions/2) trying to both download and process the dataset and then noticed some quality issues (e.g. some documents seem to be extremely short, like "5".). I managed to work around these with some sensible hacks (e.g. reject documents less than 100 characters in length) and tried to process the dataset exactly as FineWeb, re-trained the tokenizer and trained a d16 model. The CORE score decreased from 15.5 to 13.8, i.e. the result is quite a bit worse.
I am still looking to try the [DCLM dataset](https://arxiv.org/abs/2406.11794), which according to the paper should be better that FineWeb-edu. I do have some concerns that the same group both prepared the DCLM dataset *and* introduced the CORE score so I'm a bit hesitant in case there was some overfitting to CORE score adjacent data distribution.
Classifying as negative result and reverting back to FineWeb-edu for now.
---
## 2026-01-13: Varlen Attention (Negative Result)
Attempted to prevent attention from "leaking" across document boundaries using Flash Attention's `flash_attn_varlen_func`, similar to modded-nanogpt's approach.
### Background
With the BOS-aligned dataloader, multiple documents are packed into each row. Standard attention allows tokens to attend across document boundaries within a row. The hypothesis was that preventing this "leakage" via varlen attention might improve training.
### Approach: Compute cu_seqlens from inputs
- Find BOS positions: `(inputs.view(-1) == bos_token_id).nonzero()`
- Gotcha 1: Variable-length `cu_seqlens` caused torch.compile recompilation (25s/iter!) - fixed by padding to fixed size
- Gotcha 2: `nonzero()` inside compiled model hit recompile limit - fixed by moving computation outside compiled region
### Final Results (d16)
| Metric | Baseline | Varlen |
|--------|----------|--------|
| val_bpb | 0.85427 | 0.85407 |
| MFU | ~same | ~same |
| tok/sec | ~same | ~same |
Essentially identical. The 0.0002 bpb improvement is almost noise.
### Conclusion
Not worth the code complexity. The "leakage" across document boundaries within a row is not harmful - the model handles it fine. The BOS-aligned dataloader already provides the key benefit (every row starts with proper context). Not merging to master.
---
## 2026-01-13: BOS-Aligned Dataloader with Bin Packing
Redesigned the pretraining and midtraining dataloader to ensure every sequence starts with a BOS token, and explored bin-packing algorithms to minimize wasted tokens.
### Problem Statement
The original dataloader streams tokens into a flat buffer and reshapes into batches. This means some rows start mid-document (no BOS), which could confuse the model during training. We want every row to start with BOS and contain well-formed documents.
### Approach 1: Greedy-Crop BOS (Simple)
Each row is built independently:
- Start with a document (which has BOS prepended)
- Pack more documents until row is full
- If a document doesn't fit, **crop it** to fill remaining space (discard the rest)
- 100% utilization (no padding), but wastes cropped tokens
### Waste Analysis
Measured token waste empirically on real data (T=2048):
- **39.4% of tokens are cropped** (discarded when docs don't fit)
- **22.9% is the theoretical minimum** (tokens in docs longer than T+1 that can never fit)
- The extra ~16.5% comes from "unlucky" cropping when a long doc starts near the end of a row
### Bin Packing Algorithms Explored
| Algorithm | Util% | Crop% | Pad% | Notes |
|-----------|-------|-------|------|-------|
| Greedy-Crop (baseline) | 100% | 39.4% | 0% | Simple, no wasted compute |
| Greedy-Pad | 78% | 23.0% | 22% | Pads instead of crops - wastes compute |
| First-Fit Decreasing (FFD) | 99.7% | 23.0% | 0.3% | Near-optimal packing, minimal padding |
| **BestFit-Crop** | 100% | 34.6% | 0% | Smart cropping, no padding |
### BestFit-Crop Algorithm
A middle ground that maintains 100% utilization while reducing cropping:
1. Buffer N documents
2. For each row, greedily pick the **largest doc that fits entirely**
3. Repeat until nothing fits
4. When nothing fits, crop a doc to fill remaining space exactly
This avoids "unlucky" crops by searching the buffer for better-fitting documents.
**Results (T=2048):**
- Crop waste reduced from 39.4% → 34.6% (~12% relative improvement)
- Still achieves 100% utilization (no padding, every token trains)
- Slightly more rows than baseline (uses more documents per batch)
### Decision: Keep Two Implementations
1. Keep the original implementation which is very simple, efficient and has 100% token utilization in the batch (no padding with ignore tokens), but creates slightly more confusing token streams for the LLM because documents during training can start abruptly from the middle with no context. Note that this never happens at test time, where BOS is always present.
2. **`_bos_bestfit` (BestFit-Crop, new default)**: Slightly more complex but still keeps 100% token utilization in the batch (no padding), but at the cost of discarding documents when they don't fit. In practice, about 34% of tokens are discarded with this approach. This is ok because for most models we care about we have plenty of data without having to go to multiple epochs. One more subtle effect is that it does skew the data distribution a tiny bit because, reliably and necessarily, tokens at the tails of long documents will be discarded. However, this doesn't seem to impact actual downstream performance.
### Midtraining
The midtraining dataloader was also updated. Because conversations are on average a lot shorter than pretraining documents, only about 3.3% of tokens get cropped.
### NOTE: loss scale
Do note that switching to the BOS dataloader changes the validation loss and makes all previous experiments not comparable in absolute value of the loss, because we have a lot fewer "confusing" tokens in the train/val batches. All tokens can look back and find the BOS token and have the full context of that document to make predictions. Therefore, the loss appears lower but this is "fake" to some extent, and the expectation is that the vast majority of relative comparisons done so far would agree with those before and after this change.
---
## 2026-01-13: Number Token Split Pattern
Validated the `\p{N}{1,2}` pattern in `SPLIT_PATTERN` (tokenizer.py line 30), which I only guessed earlier and had a TODO for to validate. GPT-4 uses `\p{N}{1,3}` to group number sequences of up to 3 digits into tokens, but we suspected smaller vocab sizes benefit from grouping fewer digits per token.
**Results (d12, vocab=32K):**
| Pattern | val_bpb |
|---------|---------|
| `\p{N}{1,1}` | 0.969 |
| `\p{N}{1,2}` | **0.965** |
| `\p{N}{1,3}` | 0.972 |
**Conclusion:** `{1,2}` is optimal for vocab size 32K. Grouping 3 digits wastes tokens on rare 3-digit combinations; grouping 1 digit is too fine-grained and bloats token sequences. Keeping `{1,2}` as default.
---
## 2026-01-13: FP8 Training for lm_head
Attempted to use FP8 (8-bit floating point) for the lm_head layer to speed up the large vocab projection matmul. H100 GPUs have FP8 tensor cores that can theoretically provide ~2x speedup over BF16.
### Implementation Approaches Tried
**1. Dynamic Scaling (failed)**
- Compute `x.abs().max()` and `w.abs().max()` each forward to determine scales
- Problem: `.item()` calls cause graph breaks with torch.compile
- Tried `@torch._dynamo.allow_in_graph` pattern (like torchao.float8) - worked but no speedup
- Tried `torch.library.custom_op` with float scales - caused NaN gradients after first optimizer step
- Root cause: interaction between custom ops, dynamic scale computation, and torch.compile is fragile
**2. Static Scaling (partial success)**
- Pre-set scales at init time like modded-nanogpt: `x_scale=10/448, w_scale=0.1/448`
- `grad_scale` computed dynamically from batch size (safe since it's just `1/(B*T)/57344` due to the gradient expression of cross entropy). modded-nanogpt has a bug here probably because they set `grad_scale = 0.75/448`, but grads are in E5M2 so this should probably be `1/57344`, 1 being the amax of any individual element of cross entropy loss, and no normalization by B,T because they use sum reduction not mean reduction.
- Uses `torch.library.custom_op` with `@torch.compile` on inner kernels
- This works correctly - no NaNs, proper gradients
### Results (d12)
| Metric | BF16 Baseline | FP8 lm_head |
|--------|---------------|-------------|
| GPU Memory | 34 GB | 36 GB |
| tok/sec | baseline | ~1% faster |
### The Memory Mystery
FP8 *should* save memory since we store `x_f8` (1 byte) instead of `x` (2 bytes) for backward. But we see 2GB *increase*. Suspected causes:
- `torch.compile` on inner kernels creating extra buffers/specializations
- `torch._scaled_mm` internal workspace allocations
- Custom op registration machinery overhead
Tried saving original weight `w` (just a reference to parameter) instead of `w_f8` in backward, then re-quantizing on the spot during backward - didn't help. Still saw bump.
### Microbenchmark vs Reality
Raw microbenchmark showed promise:
- BF16 matmul: 16.95 ms
- FP8 matmul (static scales): 10.31 ms (1.64x faster)
- FP8 with dynamic scaling: 12.25 ms (1.38x faster)
But in full training, the ~1% tok/sec improvement doesn't justify the 2GB memory increase and the added code complexity and the need to tune scale factors for both x and w.
### Code Artifacts
See the branch `fp8_attempt_fail` for:
- `nanochat/fp8_static.py` - Static scaling implementation (working)
- `nanochat/fp8_dynamic.py` - Dynamic scaling implementation (torchao-style, working but slow)
- `gpt.py` imports `fp8_static.LinearFP8` and simply swaps it for `lm_head` in `gpt.py`.
### Open Questions
- Why does the custom op approach use more memory than vanilla BF16?
- Why is the bump in tok_per_sec so low? We should see ~1.6X speedup in both the forward pass and also (twice) in backward pass for the gradients. Granted, Amdahl's law is part of the solution because our vocab_size is only 32K so the final layer isn't a huge part of the profile but the expected speedup is still not fully realized.
**Conclusion:** Negative result for now. The implementation works correctly but provides marginal speedup with *increased* memory usage. I'm not understanding the torch.compile interaction here. The complexity of FP8 custom ops isn't justified for lm_head alone. TODO to study in more detail the way this is implemented in other libraries, e.g. torchao.
---
## 2026-01-12: Multi-Token Prediction (MTP)
Ported multi-token prediction from modded-nanogpt. Instead of predicting just the next token, predict the next n tokens at each position with weighted loss.
### Implementation
- Instead of calling the loss `n_predict` times, uses a fancy batched computation using `unfold` + `gather` + cross-entropy decomposition (`CE = logsumexp - logits[target]`)
- Schedule anneals from 3-token to 1-token prediction:
- 0-33%: `[1.0, 0.5, 0.25→0]` (3rd token fades)
- 33-67%: `[1.0, 0.5→0]` (2nd token fades)
- 67-100%: `[1.0]` (standard next-token)
- Weights normalized to sum to 1
### Results (d12)
| Metric | Baseline | MTP |
|--------|----------|-----|
| GPU Memory | 34 GB | 47 GB |
| MFU | 41% | 40% |
| val/bpb (per step) | baseline | same/slightly worse |
| val/bpb (wall clock) | baseline | noticeably worse |
**Conclusion:** Negative result for nanochat. The extra memory and compute overhead from predicting multiple tokens doesn't pay off, in fact the results get worse. The auxiliary loss signal may help in other settings (larger models, different architectures?), but for our setup it's pure overhead at the moment.
---
## 2026-01-11: Sliding Window Attention
Added configurable sliding window attention, inspired by GPT-3's alternating short/long pattern.
**Pattern string configuration:**
- New `--window_pattern` CLI arg and `GPTConfig.window_pattern` field
- Pattern is tiled across layers (e.g., `SSSL` for 20 layers → `SSSLSSSLSSSLSSSLSSSL`)
- Final layer always forced to L (full context) regardless of pattern
- Short window = `sequence_len // 2`
- Long window = `sequence_len` (full context)
- All previous models so far have been simply `L` and checkpoint loading is modified accordingly to fill in this param for old models, see `_patch_missing_config_keys`
Quick experiments showed `SSSL` (every 4th layer is long) works well - provides a good balance between compute savings and model quality. This is now the default.
---
## 2026-01-11: Flash Attention 3 Integration
Replaced PyTorch's `scaled_dot_product_attention` (FA2) with Flash Attention 3 for training and inference.
### Changes Made
**1. FA3 via `kernels` package**
- Official FA3 is "beta" and requires building from source (painful)
- Using `kernels` package from HuggingFace Hub: `get_kernel('varunneal/flash-attention-3')`
- Loads pre-built wheels, works out of the box on H100
**2. Simplified attention code**
- FA3 uses `(B, T, H, D)` layout matching our projection output directly - no transpose needed
- Training: `flash_attn.flash_attn_func(q, k, v, causal=True)`
- Inference: `flash_attn.flash_attn_with_kvcache()` handles all cache cases in one call
- Removed 3 separate FA2 code paths (training, single-token, chunk inference)
- GQA handled automatically when n_kv_heads < n_heads
**3. Rewrote KVCache for FA3**
- Old format: `(num_layers, 2, B, H, T, D)` combined tensor
- New format: separate `k_cache` and `v_cache` of shape `(num_layers, B, T, H, D)`
- FA3 updates cache in-place during `flash_attn_with_kvcache`
- Position tracked via `cache_seqlens` tensor (int32, per batch element)
- Simpler API: `get_layer_cache()`, `advance()`, `reset()`, `prefill()`
### Results
- **~9% improvement in tok/sec** during training out of the box
- Benchmarks showed FA3 is 2x faster than FA2 at realistic training sizes (batch=32, seq=2048)
- FA3 supports sliding window via `window_size=(left, 0)`, which is huge and expected to give further improvements. This is ready to tune but keeping full context for now.
---
## 2026-01-11: Per-Layer Residual Scalars (x0 & resid lambdas)
Cherry-picked an idea from modded-nanogpt around learnable per-layer residual connections.
### Changes Made
**1. x0_lambdas (x0 residual connections)**
- Save initial normalized embedding as `x0` after `norm(wte(idx))`
- At each layer, blend x0 back in: `x = resid_lambdas[i] * x + x0_lambdas[i] * x0`
- Zero-initialized, so disabled at start; model learns which layers benefit from the shortcut
- Provides direct path from embedding to deep layers, helps preserve token information
**2. resid_lambdas (residual stream scaling)**
- Per-layer multiplicative scaling of the residual stream
- Initialized to 1.0 (neutral, standard transformer behavior)
- Allows model to learn to amplify/dampen residual at each layer
**3. DistAdamW small parameter handling**
- Added support for parameters with < 1024 elements (like the scalar lambdas)
- Small params use `all_reduce` instead of `reduce_scatter`/`all_gather`
- Fixes crash when param shape isn't divisible by world_size
### Key Finding: Different LR Sensitivity
The two scalar types need very different learning rates:
- **x0_lambdas (additive)**: Can use normal LR (~0.5). Adding a fraction of x0 is forgiving.
- **resid_lambdas (multiplicative)**: Needs ~100x smaller LR (~0.005). Multiplying the residual compounds through layers.
Implementation: `resid_params` gets `scalar_lr * 0.01`, `x0_params` gets full `scalar_lr`.
### Experiment Results
Swept `--scalar_lr` (controlling x0_lambdas) at multiple depths:
| Depth | Baseline (disabled) | Best scalar_lr | Best val_bpb | Δ bpb |
|-------|---------------------|----------------|--------------|-------|
| d8 | 1.0885 | 0.20 | 1.0782 | -0.0103 |
| d12 | 0.9770 | 0.60 | 0.9693 | -0.0077 |
| d16 | 0.9059 | 0.20 | 0.9002 | -0.0057 |
| d20 | 0.8565 | 0.10 | 0.8526 | -0.0039 |
**Observations:**
- Consistent improvement across all model sizes
- Optimal LR varies by depth; default of 0.5 is reasonable, but 0.6 is better for d12
- Adding resid_lambdas (with 0.01x LR) gives small additional improvement over x0 alone
### Meta Device Footgun
Important lesson: `__init__` runs in meta device context, so any tensor values set there are fake. Must initialize actual values in `init_weights()`. Added docstring warning to `__init__`.
### Summary
Added `--scalar_lr` (default 0.5) controlling learnable per-layer scalars. The formula `x = resid_lambdas[i] * x + x0_lambdas[i] * x0` gives the model control over residual scaling and direct shortcuts to the initial embedding. Solid improvement with essentially no compute overhead.
---
## 2026-01-10: Muon Optimizer Upgrades & Cautious Weight Decay
Cherry-picked improvements from NorMuon (modded-nanogpt) into our simpler Muon implementation. Decided against using NorMuon directly due to hard-coded architecture assumptions (expects 32 params split 10 attn + 22 mlp), parameter labeling requirements, and complexity.
### Changes Made
**1. Polar Express Orthogonalization**
- Replaced Newton-Schulz iteration with "Polar Express Sign Method" from [arxiv.org/pdf/2505.16932](https://arxiv.org/pdf/2505.16932)
- Uses 5 different coefficient tuples (one per iteration) instead of fixed coefficients
- Both methods kept in code for easy comparison (`zeropower_via_polar_express` vs `zeropower_via_newtonschulz5`)
- **Result:** No dramatic/noticeable difference in training, but keeping the new Polar Express as default.
**2. NorMuon Variance Reduction**
- Added per-neuron/column adaptive learning rate from NorMuon ([arxiv.org/pdf/2510.05491](https://arxiv.org/pdf/2510.05491))
- Maintains `second_momentum_buffer` with shape `[rows, 1]` or `[1, cols]` (whichever is smaller)
- Normalizes updates based on running per-row/col variance estimate (beta2=0.95)
- Memory overhead: ~1/max(rows, cols) per param, negligible
- **Result:** Led to a very small improvement, kept and enabled by default.
**3. Cautious Weight Decay**
- Only decays weights where `update * weight >= 0` (same sign) from [arxiv.org/abs/2411.16085](https://arxiv.org/abs/2411.16085)
- Standard WD always pulls toward zero; cautious WD skips decay when gradient is pushing weight away from zero
- **Implementation note:** Had to inline the logic rather than use a separate `@torch.compile` function. Passing changing float values (like `weight_decay` during scheduling) as function arguments triggers recompilation. Reading from `group["weight_decay"]` inside the step avoids this.
- **Result:** Solid improvements, especially the cautious version was better than standard wd.
- Now defaults to ON for Muon via the `weight_decay` param. AdamW still has no weight decay and is hardcoded to 0 weight decay, might try to re-tune this later.
**4. Weight decay schedule**
- Added a linear schedule to weight decay that is default on from 1.0 to 0.0 (i.e. start with max weight decay in the beginning of training, then ramp to 0 by the end). Worked better than a static setting in experiments. (modded-nanogpt has the same schedule but it is implemented in a more confusing way by multiplying twice by the learning rate, which is already wired up to a decay schedule).
### Weight Decay Scaling Experiments
Swept weight decay values at d8, d12, d16, d20 to find optimal values and scaling law.
**Optimal Values Found:**
| Depth | Width (channels) | Optimal WD |
|-------|------------------|------------|
| d8 | 512 | ~0.40 |
| d12 | 768 | ~0.22 |
| d16 | 1024 | ~0.10 |
| d20 | 1280 | ~0.08 |
**Scaling Law:**
- Fit power law: `WD = k / channels^α` in log-log space
- Found α ≈ 1.97 (approximately 2), meaning WD ∝ 1/width²
**Practical Formula:**
```
WD_target = WD_reference × (d_reference / d_target)²
```
Example: If d12 optimal is 0.22, then d20 optimal ≈ 0.22 × (12/20)² ≈ 0.08
**Reference:** Moonlight paper uses fixed WD=0.1 for their 15B MoE model. Our experiments indicated a scaling law where the optimal WD changed with depth, so we go along with the empirical scaling law.
### Summary
Muon was changed to use Polar Express, added NorMuon variance reduction, and cautious weight decay with schedule that ramps linearly to zero. All of these changes follow modded-nanogpt repo, but all of them were also validated piece by piece to yield improvements in nanochat with the exception of the Polar Express change which was in the noise. This is default on and configurable with `--weight_decay`, using simply 0.2 and ∝ 1/width² scaling. The kwarg `--weight_decay` is therefore changing as of this change. It used to configure AdamW via standard weight decay and now it becomes exclusively used in Muon (AdamW is hardcoded to 0.0), and it is scaled based on depth.
---
## 2026-01-08: exp_grad_clip - Gradient Clipping
**Hypothesis:** Gradient clipping may be unnecessary overhead. Tested L2 norm clipping at various thresholds (0.25, 0.5, 1.0, 2.0) and elementwise clipping.
**Results:**
- No benefit at any scale tested (d12, d20)
- All variants within noise (~0.9827 val_bpb)
- Grad norm never exceeds 1.0 naturally, so clipping is always inactive
- Clipping adds ~2% time overhead from the all-reduce
**Bug Found:** Original implementation clipped local gradients before sync. Since this codebase doesn't use DDP (gradient sync is in the optimizers), each rank was clipping based on its own local norm. Fixed on the branch with proper distributed all-reduce.
**Observation:** modded-nanogpt does not appear to clip either right now.
**Summary:** Deleted all grad-clip code paths. The code naturally produces well-behaved gradients. This improves a bit of MFU because we don't have to calculate and sync grad norms.
================================================
FILE: dev/estimate_gpt3_core.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Estimating CORE Metric for GPT-3 Models\n",
"\n",
"**Authors**: Claude Code Opus 4.5, Andrej Karpathy\n",
"\n",
"**Date**: Jan 2026\n",
"\n",
"## Motivation\n",
"\n",
"The [CORE metric](https://arxiv.org/abs/2406.11794) (introduced in the DCLM paper) is a composite benchmark that evaluates pretrained language models across 22 diverse tasks spanning world knowledge, language understanding, commonsense reasoning, symbolic problem solving, and reading comprehension. It provides a single score that captures a model's general capabilities.\n",
"\n",
"We want to compare nanochat models against the GPT-3 model family from OpenAI's [\"Language Models are Few-Shot Learners\"](https://arxiv.org/abs/2005.14165) paper (2020). However, there's a problem: **GPT-3 models were never evaluated on CORE** (which didn't exist in 2020), and the models were never publicly released, so we can't evaluate them ourselves.\n",
"\n",
"## Our Approach\n",
"\n",
"We estimate CORE scores for GPT-3 by:\n",
"\n",
"1. **Identifying overlapping tasks** between the GPT-3 paper and CORE that were evaluated with similar methodology\n",
"2. **Using GPT-2 as calibration data** — we have actual CORE scores for all 4 GPT-2 models, plus the GPT-3 paper reports results on GPT-2-equivalent tasks\n",
"3. **Fitting a regression model** from the overlapping task scores to the full CORE score\n",
"4. **Applying the model to GPT-3** using their reported task scores\n",
"\n",
"This notebook documents our methodology in detail for reproducibility."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"\n",
"# For nice table display\n",
"pd.set_option('display.precision', 4)\n",
"pd.set_option('display.max_columns', 20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 1: Understanding CORE\n",
"\n",
"CORE consists of **22 tasks** evaluated in specific few-shot settings. The key innovation is **centering**: raw accuracies are adjusted to account for random guessing baselines.\n",
"\n",
"$$\\text{centered accuracy} = \\frac{\\text{accuracy} - \\text{baseline}}{1 - \\text{baseline}}$$\n",
"\n",
"The final CORE score is simply the **mean of all 22 centered accuracies**.\n",
"\n",
"### CORE Tasks\n",
"\n",
"| Category | Tasks |\n",
"|----------|-------|\n",
"| World Knowledge | Jeopardy, ARC Easy, ARC Challenge, BigBench QA Wikidata |\n",
"| Language Understanding | HellaSwag (0-shot & 10-shot), LAMBADA, Winograd, Winogrande, BigBench Language ID |\n",
"| Commonsense Reasoning | COPA, CommonsenseQA, PIQA, OpenBookQA |\n",
"| Symbolic Problem Solving | BigBench Dyck, Operators, CS Algorithms, Repeat Copy Logic, AGI Eval LSAT-AR |\n",
"| Reading Comprehension | SQuAD, CoQA, BoolQ |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 2: Task Overlap Analysis\n",
"\n",
"We carefully compared the evaluation methodology between GPT-3 and CORE for each task. Key considerations:\n",
"\n",
"1. **Number of few-shot examples (K)**: GPT-3 often uses more examples than CORE\n",
"2. **Task format**: Some tasks use different prompting strategies\n",
"3. **Scoring method**: GPT-3 uses unconditional probability normalization for some tasks\n",
"4. **Data split**: dev vs test set\n",
"\n",
"### Selection Criteria\n",
"\n",
"We applied a conservative filter: **both evaluations must use K=0 (zero-shot) or both must use K>0 (few-shot)**. We excluded tasks that mix zero-shot with few-shot, as this introduces systematic differences.\n",
"\n",
"### Tasks We Excluded\n",
"\n",
"| Task | GPT-3 K | CORE K | Reason for Exclusion |\n",
"|------|---------|--------|----------------------|\n",
"| Winograd | 7 | 0 | Mixing K>0 with K=0 |\n",
"| Winogrande | 50 | 0 | Mixing K>0 with K=0 |\n",
"| COPA | 32 | 0 | Mixing K>0 with K=0 |\n",
"| OpenBookQA | 100 | 0 | Mixing K>0 with K=0, also uses unconditional normalization |\n",
"| BoolQ | 32 | 10 | High sensitivity to K (17% gap between 0-shot and few-shot in GPT-3) |\n",
"| CoQA | 5 | 0 | Different metric (F1 vs accuracy) |\n",
"| LAMBADA few-shot | 15 | 0 | GPT-3 uses special fill-in-blank format |\n",
"\n",
"### Tasks Not in GPT-3 Paper\n",
"\n",
"These CORE tasks simply don't appear in GPT-3 (many didn't exist in 2020):\n",
"- All 6 BigBench tasks (Dyck, Operators, CS Algorithms, Repeat Copy Logic, Language ID, QA Wikidata)\n",
"- Jeopardy, CommonsenseQA, AGI Eval LSAT-AR\n",
"- SQuAD v1 (GPT-3 uses v2)\n",
"\n",
"### Final Selected Tasks (6 tasks)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Task</th>\n",
" <th>GPT-3 K</th>\n",
" <th>CORE K</th>\n",
" <th>Match</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>HellaSwag 0-shot</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Both zero-shot</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>LAMBADA</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>Both zero-shot</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>HellaSwag 10-shot</td>\n",
" <td>20</td>\n",
" <td>10</td>\n",
" <td>Both few-shot (K differs slightly)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>PIQA</td>\n",
" <td>50</td>\n",
" <td>10</td>\n",
" <td>Both few-shot</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ARC Easy</td>\n",
" <td>50</td>\n",
" <td>10</td>\n",
" <td>Both few-shot</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>ARC Challenge</td>\n",
" <td>50</td>\n",
" <td>10</td>\n",
" <td>Both few-shot</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Task GPT-3 K CORE K Match\n",
"0 HellaSwag 0-shot 0 0 Both zero-shot\n",
"1 LAMBADA 0 0 Both zero-shot\n",
"2 HellaSwag 10-shot 20 10 Both few-shot (K differs slightly)\n",
"3 PIQA 50 10 Both few-shot\n",
"4 ARC Easy 50 10 Both few-shot\n",
"5 ARC Challenge 50 10 Both few-shot"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The 6 tasks we selected for overlap\n",
"selected_tasks = pd.DataFrame([\n",
" {'Task': 'HellaSwag 0-shot', 'GPT-3 K': 0, 'CORE K': 0, 'Match': 'Both zero-shot'},\n",
" {'Task': 'LAMBADA', 'GPT-3 K': 0, 'CORE K': 0, 'Match': 'Both zero-shot'},\n",
" {'Task': 'HellaSwag 10-shot', 'GPT-3 K': 20, 'CORE K': 10, 'Match': 'Both few-shot (K differs slightly)'},\n",
" {'Task': 'PIQA', 'GPT-3 K': 50, 'CORE K': 10, 'Match': 'Both few-shot'},\n",
" {'Task': 'ARC Easy', 'GPT-3 K': 50, 'CORE K': 10, 'Match': 'Both few-shot'},\n",
" {'Task': 'ARC Challenge', 'GPT-3 K': 50, 'CORE K': 10, 'Match': 'Both few-shot'},\n",
"])\n",
"selected_tasks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Rationale for K differences:** Looking at GPT-3's own data, the difference between different K values is typically small. Here's the evidence from the GPT-3 175B model:\n",
"\n",
"| Task | 0-shot | Few-shot | K | Δ |\n",
"|------|--------|----------|---|---|\n",
"| HellaSwag | 78.9% | 79.3% | 20 | +0.4% |\n",
"| PIQA | 81.0% | 82.3% | 50 | +1.3% |\n",
"| ARC Easy | 68.8% | 70.1% | 50 | +1.3% |\n",
"| ARC Challenge | 51.4% | 51.5% | 50 | +0.1% |\n",
"| Winograd | 88.3% | 88.6% | 7 | +0.3% |\n",
"| COPA | 91.0% | 92.0% | 32 | +1.0% |\n",
"\n",
"For most tasks, the gap between 0-shot and few-shot (with K=20-50) is only 0.1-1.3%. This suggests that differences between K=10 and K=50 would be even smaller, making our task selection reasonable.\n",
"\n",
"**Note:** Some tasks show larger sensitivity (Winogrande: +7.5%, BoolQ: +17%), which is why we excluded them."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 3: Calibration Data (GPT-2 Family)\n",
"\n",
"We have actual CORE scores for all 4 GPT-2 models. These serve as our calibration data."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Random baselines for centering (from CORE specification)\n",
"BASELINES = {\n",
" 'hellaswag_zeroshot': 0.25,\n",
" 'lambada_openai': 0.0,\n",
" 'hellaswag': 0.25,\n",
" 'piqa': 0.50,\n",
" 'arc_easy': 0.25,\n",
" 'arc_challenge': 0.25,\n",
"}\n",
"\n",
"TASK_ORDER = ['hellaswag_zeroshot', 'lambada_openai', 'hellaswag', 'piqa', 'arc_easy', 'arc_challenge']\n",
"TASK_NAMES = ['HellaSwag 0-shot', 'LAMBADA', 'HellaSwag 10-shot', 'PIQA', 'ARC Easy', 'ARC Challenge']\n",
"\n",
"def center_accuracy(acc, baseline):\n",
" \"\"\"Convert raw accuracy to centered accuracy.\"\"\"\n",
" return (acc - baseline) / (1.0 - baseline)\n",
"\n",
"def parse_csv(filepath):\n",
" \"\"\"Parse a CORE results CSV file.\"\"\"\n",
" results = {}\n",
" with open(filepath) as f:\n",
" for line in f:\n",
" parts = [p.strip() for p in line.strip().split(',')]\n",
" if len(parts) >= 3 and parts[0] != 'Task':\n",
" task = parts[0]\n",
" try:\n",
" acc = float(parts[1]) if parts[1] else None\n",
" centered = float(parts[2]) if parts[2] else None\n",
" results[task] = {'accuracy': acc, 'centered': centered}\n",
" except ValueError:\n",
" pass\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-2 Family: Raw Accuracies and CORE Scores\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Params</th>\n",
" <th>HellaSwag 0-shot</th>\n",
" <th>LAMBADA</th>\n",
" <th>HellaSwag 10-shot</th>\n",
" <th>PIQA</th>\n",
" <th>ARC Easy</th>\n",
" <th>ARC Challenge</th>\n",
" <th>CORE</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>GPT-2</td>\n",
" <td>124M</td>\n",
" <td>30.9%</td>\n",
" <td>32.3%</td>\n",
" <td>30.8%</td>\n",
" <td>62.3%</td>\n",
" <td>41.2%</td>\n",
" <td>22.2%</td>\n",
" <td>0.1139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-2 Medium</td>\n",
" <td>355M</td>\n",
" <td>39.0%</td>\n",
" <td>42.6%</td>\n",
" <td>39.5%</td>\n",
" <td>67.0%</td>\n",
" <td>48.0%</td>\n",
" <td>26.2%</td>\n",
" <td>0.1849</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-2 Large</td>\n",
" <td>774M</td>\n",
" <td>44.0%</td>\n",
" <td>48.8%</td>\n",
" <td>44.4%</td>\n",
" <td>69.8%</td>\n",
" <td>53.5%</td>\n",
" <td>26.4%</td>\n",
" <td>0.2146</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-2 XL</td>\n",
" <td>1558M</td>\n",
" <td>50.2%</td>\n",
" <td>52.3%</td>\n",
" <td>51.2%</td>\n",
" <td>72.5%</td>\n",
" <td>59.5%</td>\n",
" <td>29.9%</td>\n",
" <td>0.2565</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Params HellaSwag 0-shot LAMBADA HellaSwag 10-shot PIQA \\\n",
"0 GPT-2 124M 30.9% 32.3% 30.8% 62.3% \n",
"1 GPT-2 Medium 355M 39.0% 42.6% 39.5% 67.0% \n",
"2 GPT-2 Large 774M 44.0% 48.8% 44.4% 69.8% \n",
"3 GPT-2 XL 1558M 50.2% 52.3% 51.2% 72.5% \n",
"\n",
" ARC Easy ARC Challenge CORE \n",
"0 41.2% 22.2% 0.1139 \n",
"1 48.0% 26.2% 0.1849 \n",
"2 53.5% 26.4% 0.2146 \n",
"3 59.5% 29.9% 0.2565 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load GPT-2 CORE results\n",
"knowledge_dir = Path(\"/home/ubuntu/.cache/nanochat/eval_bundle\")\n",
"\n",
"gpt2_models = [\n",
" ('GPT-2', 'openai-community-gpt2.csv', 124e6),\n",
" ('GPT-2 Medium', 'openai-community-gpt2-medium.csv', 355e6),\n",
" ('GPT-2 Large', 'openai-community-gpt2-large.csv', 774e6),\n",
" ('GPT-2 XL', 'openai-community-gpt2-xl.csv', 1558e6),\n",
"]\n",
"\n",
"gpt2_data = []\n",
"for name, filename, params in gpt2_models:\n",
" results = parse_csv(knowledge_dir / filename)\n",
" core = results['CORE']['centered']\n",
" task_accs = [results[task]['accuracy'] for task in TASK_ORDER]\n",
" gpt2_data.append({\n",
" 'name': name,\n",
" 'params': params,\n",
" 'task_accs': task_accs,\n",
" 'core': core,\n",
" })\n",
"\n",
"# Display as DataFrame\n",
"gpt2_df = pd.DataFrame([\n",
" {\n",
" 'Model': d['name'],\n",
" 'Params': f\"{d['params']/1e6:.0f}M\",\n",
" **{name: f\"{acc:.1%}\" for name, acc in zip(TASK_NAMES, d['task_accs'])},\n",
" 'CORE': f\"{d['core']:.4f}\"\n",
" }\n",
" for d in gpt2_data\n",
"])\n",
"print(\"GPT-2 Family: Raw Accuracies and CORE Scores\")\n",
"gpt2_df"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-2 Family: Centered Accuracies\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HellaSwag 0-shot</th>\n",
" <th>LAMBADA</th>\n",
" <th>HellaSwag 10-shot</th>\n",
" <th>PIQA</th>\n",
" <th>ARC Easy</th>\n",
" <th>ARC Challenge</th>\n",
" <th>Mean</th>\n",
" <th>CORE</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>GPT-2</th>\n",
" <td>0.0780</td>\n",
" <td>0.3229</td>\n",
" <td>0.0772</td>\n",
" <td>0.2459</td>\n",
" <td>0.2166</td>\n",
" <td>-0.0375</td>\n",
" <td>0.1505</td>\n",
" <td>0.1139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-2 Medium</th>\n",
" <td>0.1867</td>\n",
" <td>0.4260</td>\n",
" <td>0.1933</td>\n",
" <td>0.3400</td>\n",
" <td>0.3067</td>\n",
" <td>0.0160</td>\n",
" <td>0.2448</td>\n",
" <td>0.1849</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-2 Large</th>\n",
" <td>0.2533</td>\n",
" <td>0.4880</td>\n",
" <td>0.2587</td>\n",
" <td>0.3960</td>\n",
" <td>0.3800</td>\n",
" <td>0.0187</td>\n",
" <td>0.2991</td>\n",
" <td>0.2146</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-2 XL</th>\n",
" <td>0.3360</td>\n",
" <td>0.5230</td>\n",
" <td>0.3493</td>\n",
" <td>0.4500</td>\n",
" <td>0.4600</td>\n",
" <td>0.0653</td>\n",
" <td>0.3639</td>\n",
" <td>0.2565</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" HellaSwag 0-shot LAMBADA HellaSwag 10-shot PIQA ARC Easy \\\n",
"GPT-2 0.0780 0.3229 0.0772 0.2459 0.2166 \n",
"GPT-2 Medium 0.1867 0.4260 0.1933 0.3400 0.3067 \n",
"GPT-2 Large 0.2533 0.4880 0.2587 0.3960 0.3800 \n",
"GPT-2 XL 0.3360 0.5230 0.3493 0.4500 0.4600 \n",
"\n",
" ARC Challenge Mean CORE \n",
"GPT-2 -0.0375 0.1505 0.1139 \n",
"GPT-2 Medium 0.0160 0.2448 0.1849 \n",
"GPT-2 Large 0.0187 0.2991 0.2146 \n",
"GPT-2 XL 0.0653 0.3639 0.2565 "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Build feature matrix (centered accuracies)\n",
"X_gpt2 = []\n",
"y_gpt2 = []\n",
"\n",
"for data in gpt2_data:\n",
" centered_accs = []\n",
" for task, acc in zip(TASK_ORDER, data['task_accs']):\n",
" centered = center_accuracy(acc, BASELINES[task])\n",
" centered_accs.append(centered)\n",
" X_gpt2.append(centered_accs)\n",
" y_gpt2.append(data['core'])\n",
"\n",
"X_gpt2 = np.array(X_gpt2)\n",
"y_gpt2 = np.array(y_gpt2)\n",
"\n",
"# Display centered accuracies\n",
"centered_df = pd.DataFrame(\n",
" X_gpt2,\n",
" columns=TASK_NAMES,\n",
" index=[d['name'] for d in gpt2_data]\n",
")\n",
"centered_df['Mean'] = X_gpt2.mean(axis=1)\n",
"centered_df['CORE'] = y_gpt2\n",
"print(\"GPT-2 Family: Centered Accuracies\")\n",
"centered_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Observation:** The mean of the 6 centered accuracies is consistently higher than the actual CORE score. This makes sense because CORE includes 16 additional tasks (many quite difficult) that pull down the average."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 4: GPT-3 Data\n",
"\n",
"We extract the 6 task accuracies from the GPT-3 paper's Appendix H (master results table).\n",
"\n",
"**Source:** Table H.1 in \"Language Models are Few-Shot Learners\" (Brown et al., 2020)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3 Family: Raw Accuracies from Paper\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Params</th>\n",
" <th>HellaSwag 0-shot</th>\n",
" <th>LAMBADA</th>\n",
" <th>HellaSwag 10-shot</th>\n",
" <th>PIQA</th>\n",
" <th>ARC Easy</th>\n",
" <th>ARC Challenge</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>GPT-3 Small</td>\n",
" <td>125M</td>\n",
" <td>33.7%</td>\n",
" <td>42.7%</td>\n",
" <td>33.5%</td>\n",
" <td>64.3%</td>\n",
" <td>42.7%</td>\n",
" <td>25.5%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3 Medium</td>\n",
" <td>350M</td>\n",
" <td>43.6%</td>\n",
" <td>54.3%</td>\n",
" <td>43.1%</td>\n",
" <td>69.4%</td>\n",
" <td>51.0%</td>\n",
" <td>28.4%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3 Large</td>\n",
" <td>760M</td>\n",
" <td>51.0%</td>\n",
" <td>60.4%</td>\n",
" <td>51.3%</td>\n",
" <td>72.0%</td>\n",
" <td>58.1%</td>\n",
" <td>32.3%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-3 XL</td>\n",
" <td>1.3B</td>\n",
" <td>54.7%</td>\n",
" <td>63.6%</td>\n",
" <td>54.9%</td>\n",
" <td>74.3%</td>\n",
" <td>59.1%</td>\n",
" <td>36.7%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>GPT-3 2.7B</td>\n",
" <td>2.7B</td>\n",
" <td>62.8%</td>\n",
" <td>67.1%</td>\n",
" <td>62.9%</td>\n",
" <td>75.4%</td>\n",
" <td>62.1%</td>\n",
" <td>39.5%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>GPT-3 6.7B</td>\n",
" <td>6.7B</td>\n",
" <td>67.4%</td>\n",
" <td>70.3%</td>\n",
" <td>67.3%</td>\n",
" <td>77.8%</td>\n",
" <td>65.8%</td>\n",
" <td>43.7%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>GPT-3 13B</td>\n",
" <td>13.0B</td>\n",
" <td>70.9%</td>\n",
" <td>72.5%</td>\n",
" <td>71.3%</td>\n",
" <td>79.9%</td>\n",
" <td>69.1%</td>\n",
" <td>44.8%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>GPT-3 175B</td>\n",
" <td>175.0B</td>\n",
" <td>78.9%</td>\n",
" <td>76.2%</td>\n",
" <td>79.3%</td>\n",
" <td>82.3%</td>\n",
" <td>70.1%</td>\n",
" <td>51.5%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Params HellaSwag 0-shot LAMBADA HellaSwag 10-shot PIQA \\\n",
"0 GPT-3 Small 125M 33.7% 42.7% 33.5% 64.3% \n",
"1 GPT-3 Medium 350M 43.6% 54.3% 43.1% 69.4% \n",
"2 GPT-3 Large 760M 51.0% 60.4% 51.3% 72.0% \n",
"3 GPT-3 XL 1.3B 54.7% 63.6% 54.9% 74.3% \n",
"4 GPT-3 2.7B 2.7B 62.8% 67.1% 62.9% 75.4% \n",
"5 GPT-3 6.7B 6.7B 67.4% 70.3% 67.3% 77.8% \n",
"6 GPT-3 13B 13.0B 70.9% 72.5% 71.3% 79.9% \n",
"7 GPT-3 175B 175.0B 78.9% 76.2% 79.3% 82.3% \n",
"\n",
" ARC Easy ARC Challenge \n",
"0 42.7% 25.5% \n",
"1 51.0% 28.4% \n",
"2 58.1% 32.3% \n",
"3 59.1% 36.7% \n",
"4 62.1% 39.5% \n",
"5 65.8% 43.7% \n",
"6 69.1% 44.8% \n",
"7 70.1% 51.5% "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# GPT-3 accuracies from the paper\n",
"# Format: [hellaswag_0shot, lambada_0shot, hellaswag_fewshot, piqa_fewshot, arc_easy_fewshot, arc_challenge_fewshot]\n",
"gpt3_models = [\n",
" ('GPT-3 Small', 125e6, [0.337, 0.427, 0.335, 0.643, 0.427, 0.255]),\n",
" ('GPT-3 Medium', 350e6, [0.436, 0.543, 0.431, 0.694, 0.510, 0.284]),\n",
" ('GPT-3 Large', 760e6, [0.510, 0.604, 0.513, 0.720, 0.581, 0.323]),\n",
" ('GPT-3 XL', 1.3e9, [0.547, 0.636, 0.549, 0.743, 0.591, 0.367]),\n",
" ('GPT-3 2.7B', 2.7e9, [0.628, 0.671, 0.629, 0.754, 0.621, 0.395]),\n",
" ('GPT-3 6.7B', 6.7e9, [0.674, 0.703, 0.673, 0.778, 0.658, 0.437]),\n",
" ('GPT-3 13B', 13e9, [0.709, 0.725, 0.713, 0.799, 0.691, 0.448]),\n",
" ('GPT-3 175B', 175e9, [0.789, 0.762, 0.793, 0.823, 0.701, 0.515]),\n",
"]\n",
"\n",
"# Display raw accuracies\n",
"gpt3_df = pd.DataFrame([\n",
" {\n",
" 'Model': name,\n",
" 'Params': f\"{params/1e9:.1f}B\" if params >= 1e9 else f\"{params/1e6:.0f}M\",\n",
" **{task_name: f\"{acc:.1%}\" for task_name, acc in zip(TASK_NAMES, accs)}\n",
" }\n",
" for name, params, accs in gpt3_models\n",
"])\n",
"print(\"GPT-3 Family: Raw Accuracies from Paper\")\n",
"gpt3_df"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3 Family: Centered Accuracies\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HellaSwag 0-shot</th>\n",
" <th>LAMBADA</th>\n",
" <th>HellaSwag 10-shot</th>\n",
" <th>PIQA</th>\n",
" <th>ARC Easy</th>\n",
" <th>ARC Challenge</th>\n",
" <th>Mean</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>GPT-3 Small</th>\n",
" <td>0.1160</td>\n",
" <td>0.427</td>\n",
" <td>0.1133</td>\n",
" <td>0.286</td>\n",
" <td>0.2360</td>\n",
" <td>0.0067</td>\n",
" <td>0.1975</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-3 Medium</th>\n",
" <td>0.2480</td>\n",
" <td>0.543</td>\n",
" <td>0.2413</td>\n",
" <td>0.388</td>\n",
" <td>0.3467</td>\n",
" <td>0.0453</td>\n",
" <td>0.3021</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-3 Large</th>\n",
" <td>0.3467</td>\n",
" <td>0.604</td>\n",
" <td>0.3507</td>\n",
" <td>0.440</td>\n",
" <td>0.4413</td>\n",
" <td>0.0973</td>\n",
" <td>0.3800</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-3 XL</th>\n",
" <td>0.3960</td>\n",
" <td>0.636</td>\n",
" <td>0.3987</td>\n",
" <td>0.486</td>\n",
" <td>0.4547</td>\n",
" <td>0.1560</td>\n",
" <td>0.4212</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-3 2.7B</th>\n",
" <td>0.5040</td>\n",
" <td>0.671</td>\n",
" <td>0.5053</td>\n",
" <td>0.508</td>\n",
" <td>0.4947</td>\n",
" <td>0.1933</td>\n",
" <td>0.4794</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-3 6.7B</th>\n",
" <td>0.5653</td>\n",
" <td>0.703</td>\n",
" <td>0.5640</td>\n",
" <td>0.556</td>\n",
" <td>0.5440</td>\n",
" <td>0.2493</td>\n",
" <td>0.5303</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-3 13B</th>\n",
" <td>0.6120</td>\n",
" <td>0.725</td>\n",
" <td>0.6173</td>\n",
" <td>0.598</td>\n",
" <td>0.5880</td>\n",
" <td>0.2640</td>\n",
" <td>0.5674</td>\n",
" </tr>\n",
" <tr>\n",
" <th>GPT-3 175B</th>\n",
" <td>0.7187</td>\n",
" <td>0.762</td>\n",
" <td>0.7240</td>\n",
" <td>0.646</td>\n",
" <td>0.6013</td>\n",
" <td>0.3533</td>\n",
" <td>0.6342</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" HellaSwag 0-shot LAMBADA HellaSwag 10-shot PIQA ARC Easy \\\n",
"GPT-3 Small 0.1160 0.427 0.1133 0.286 0.2360 \n",
"GPT-3 Medium 0.2480 0.543 0.2413 0.388 0.3467 \n",
"GPT-3 Large 0.3467 0.604 0.3507 0.440 0.4413 \n",
"GPT-3 XL 0.3960 0.636 0.3987 0.486 0.4547 \n",
"GPT-3 2.7B 0.5040 0.671 0.5053 0.508 0.4947 \n",
"GPT-3 6.7B 0.5653 0.703 0.5640 0.556 0.5440 \n",
"GPT-3 13B 0.6120 0.725 0.6173 0.598 0.5880 \n",
"GPT-3 175B 0.7187 0.762 0.7240 0.646 0.6013 \n",
"\n",
" ARC Challenge Mean \n",
"GPT-3 Small 0.0067 0.1975 \n",
"GPT-3 Medium 0.0453 0.3021 \n",
"GPT-3 Large 0.0973 0.3800 \n",
"GPT-3 XL 0.1560 0.4212 \n",
"GPT-3 2.7B 0.1933 0.4794 \n",
"GPT-3 6.7B 0.2493 0.5303 \n",
"GPT-3 13B 0.2640 0.5674 \n",
"GPT-3 175B 0.3533 0.6342 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Compute centered accuracies for GPT-3\n",
"X_gpt3 = []\n",
"for name, params, accs in gpt3_models:\n",
" centered_accs = [center_accuracy(acc, BASELINES[task]) for task, acc in zip(TASK_ORDER, accs)]\n",
" X_gpt3.append(centered_accs)\n",
"\n",
"X_gpt3 = np.array(X_gpt3)\n",
"\n",
"# Display\n",
"gpt3_centered_df = pd.DataFrame(\n",
" X_gpt3,\n",
" columns=TASK_NAMES,\n",
" index=[m[0] for m in gpt3_models]\n",
")\n",
"gpt3_centered_df['Mean'] = X_gpt3.mean(axis=1)\n",
"print(\"GPT-3 Family: Centered Accuracies\")\n",
"gpt3_centered_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 5: Regression Models\n",
"\n",
"We fit two types of models:\n",
"\n",
"1. **Simple Approach**: Average the 6 centered accuracies, then fit a linear regression to CORE\n",
"2. **Multivariate Approach**: Use all 6 features with Ridge regularization\n",
"\n",
"### Why Regularization?\n",
"\n",
"We only have 4 calibration points (GPT-2 models) but 6 features + 1 intercept = 7 parameters. Without regularization, we get a perfect fit but with unstable, extreme weights. Ridge regression shrinks weights toward zero, preventing overfitting."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def simple_linear_regression(x, y):\n",
" \"\"\"Simple 1D linear regression: y = a*x + b\"\"\"\n",
" mean_x, mean_y = np.mean(x), np.mean(y)\n",
" a = np.sum((x - mean_x) * (y - mean_y)) / np.sum((x - mean_x) ** 2)\n",
" b = mean_y - a * mean_x\n",
" return a, b\n",
"\n",
"def ridge_regression(X, y, alpha=0.1):\n",
" \"\"\"\n",
" Ridge regression: minimize ||Xw - y||² + α||w||²\n",
" We don't regularize the intercept.\n",
" \"\"\"\n",
" n_samples, n_features = X.shape\n",
" X_aug = np.column_stack([np.ones(n_samples), X])\n",
" reg_matrix = alpha * np.eye(n_features + 1)\n",
" reg_matrix[0, 0] = 0 # Don't regularize intercept\n",
" coeffs = np.linalg.solve(X_aug.T @ X_aug + reg_matrix, X_aug.T @ y)\n",
" return coeffs[0], coeffs[1:] # intercept, weights\n",
"\n",
"def compute_r_squared(y_true, y_pred):\n",
" \"\"\"Compute R² score.\"\"\"\n",
" ss_res = np.sum((y_true - y_pred) ** 2)\n",
" ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)\n",
" return 1 - ss_res / ss_tot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Approach 1: Simple Averaging"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Simple Model: CORE = 0.6639 × avg_centered + 0.0168\n",
"\n",
"R² = 0.9960\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Avg Centered</th>\n",
" <th>Predicted</th>\n",
" <th>Actual</th>\n",
" <th>Error</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>GPT-2</td>\n",
" <td>0.1505</td>\n",
" <td>0.1168</td>\n",
" <td>0.1139</td>\n",
" <td>0.0029</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-2 Medium</td>\n",
" <td>0.2448</td>\n",
" <td>0.1793</td>\n",
" <td>0.1849</td>\n",
" <td>-0.0056</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-2 Large</td>\n",
" <td>0.2991</td>\n",
" <td>0.2154</td>\n",
" <td>0.2146</td>\n",
" <td>0.0008</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-2 XL</td>\n",
" <td>0.3639</td>\n",
" <td>0.2584</td>\n",
" <td>0.2565</td>\n",
" <td>0.0019</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Avg Centered Predicted Actual Error\n",
"0 GPT-2 0.1505 0.1168 0.1139 0.0029\n",
"1 GPT-2 Medium 0.2448 0.1793 0.1849 -0.0056\n",
"2 GPT-2 Large 0.2991 0.2154 0.2146 0.0008\n",
"3 GPT-2 XL 0.3639 0.2584 0.2565 0.0019"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Compute average of 6 centered accuracies\n",
"avg_centered_gpt2 = X_gpt2.mean(axis=1)\n",
"\n",
"# Fit linear regression\n",
"slope, intercept = simple_linear_regression(avg_centered_gpt2, y_gpt2)\n",
"print(f\"Simple Model: CORE = {slope:.4f} × avg_centered + {intercept:.4f}\")\n",
"\n",
"# Validate\n",
"y_pred_simple = slope * avg_centered_gpt2 + intercept\n",
"r2_simple = compute_r_squared(y_gpt2, y_pred_simple)\n",
"\n",
"validation_df = pd.DataFrame({\n",
" 'Model': [d['name'] for d in gpt2_data],\n",
" 'Avg Centered': avg_centered_gpt2,\n",
" 'Predicted': y_pred_simple,\n",
" 'Actual': y_gpt2,\n",
" 'Error': y_pred_simple - y_gpt2\n",
"})\n",
"print(f\"\\nR² = {r2_simple:.4f}\")\n",
"validation_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Result:** R² = 0.996 — excellent fit with just 2 parameters. The simple averaging approach works very well."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Approach 2: Multivariate Ridge Regression\n",
"\n",
"We try different regularization strengths (α) to find a good balance between fit and stability."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Effect of Regularization Strength:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>α</th>\n",
" <th>R²</th>\n",
" <th>||weights||</th>\n",
" <th>Intercept</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.000</td>\n",
" <td>1.0000</td>\n",
" <td>10.7221</td>\n",
" <td>-0.0829</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.001</td>\n",
" <td>0.9971</td>\n",
" <td>0.2796</td>\n",
" <td>0.0159</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.010</td>\n",
" <td>0.9916</td>\n",
" <td>0.2463</td>\n",
" <td>0.0269</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.100</td>\n",
" <td>0.8448</td>\n",
" <td>0.1600</td>\n",
" <td>0.0851</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.000</td>\n",
" <td>0.2523</td>\n",
" <td>0.0356</td>\n",
" <td>0.1686</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" α R² ||weights|| Intercept\n",
"0 0.000 1.0000 10.7221 -0.0829\n",
"1 0.001 0.9971 0.2796 0.0159\n",
"2 0.010 0.9916 0.2463 0.0269\n",
"3 0.100 0.8448 0.1600 0.0851\n",
"4 1.000 0.2523 0.0356 0.1686"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Try different regularization strengths\n",
"alphas = [0.0, 0.001, 0.01, 0.1, 1.0]\n",
"\n",
"results = []\n",
"for alpha in alphas:\n",
" intercept_r, weights = ridge_regression(X_gpt2, y_gpt2, alpha=alpha)\n",
" y_pred = X_gpt2 @ weights + intercept_r\n",
" r2 = compute_r_squared(y_gpt2, y_pred)\n",
" weight_norm = np.sqrt(np.sum(weights ** 2))\n",
" results.append({\n",
" 'α': alpha,\n",
" 'R²': r2,\n",
" '||weights||': weight_norm,\n",
" 'Intercept': intercept_r,\n",
" 'Weights': weights.copy()\n",
" })\n",
"\n",
"alpha_df = pd.DataFrame([{k: v for k, v in r.items() if k != 'Weights'} for r in results])\n",
"print(\"Effect of Regularization Strength:\")\n",
"alpha_df"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Task Weights by Regularization Strength:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>HellaSwag 0-shot</th>\n",
" <th>LAMBADA</th>\n",
" <th>HellaSwag 10-shot</th>\n",
" <th>PIQA</th>\n",
" <th>ARC Easy</th>\n",
" <th>ARC Challenge</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>α=0.0</th>\n",
" <td>6.5523</td>\n",
" <td>0.2201</td>\n",
" <td>-8.0268</td>\n",
" <td>0.5378</td>\n",
" <td>0.9109</td>\n",
" <td>2.5364</td>\n",
" </tr>\n",
" <tr>\n",
" <th>α=0.001</th>\n",
" <td>0.1134</td>\n",
" <td>0.1442</td>\n",
" <td>0.1305</td>\n",
" <td>0.1153</td>\n",
" <td>0.0510</td>\n",
" <td>0.1079</td>\n",
" </tr>\n",
" <tr>\n",
" <th>α=0.01</th>\n",
" <td>0.1155</td>\n",
" <td>0.1000</td>\n",
" <td>0.1226</td>\n",
" <td>0.0959</td>\n",
" <td>0.1023</td>\n",
" <td>0.0513</td>\n",
" </tr>\n",
" <tr>\n",
" <th>α=0.1</th>\n",
" <td>0.0759</td>\n",
" <td>0.0614</td>\n",
" <td>0.0798</td>\n",
" <td>0.0610</td>\n",
" <td>0.0714</td>\n",
" <td>0.0293</td>\n",
" </tr>\n",
" <tr>\n",
" <th>α=1.0</th>\n",
" <td>0.0169</td>\n",
" <td>0.0136</td>\n",
" <td>0.0178</td>\n",
" <td>0.0135</td>\n",
" <td>0.0160</td>\n",
" <td>0.0064</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" HellaSwag 0-shot LAMBADA HellaSwag 10-shot PIQA ARC Easy \\\n",
"α=0.0 6.5523 0.2201 -8.0268 0.5378 0.9109 \n",
"α=0.001 0.1134 0.1442 0.1305 0.1153 0.0510 \n",
"α=0.01 0.1155 0.1000 0.1226 0.0959 0.1023 \n",
"α=0.1 0.0759 0.0614 0.0798 0.0610 0.0714 \n",
"α=1.0 0.0169 0.0136 0.0178 0.0135 0.0160 \n",
"\n",
" ARC Challenge \n",
"α=0.0 2.5364 \n",
"α=0.001 0.1079 \n",
"α=0.01 0.0513 \n",
"α=0.1 0.0293 \n",
"α=1.0 0.0064 "
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Show weights for each alpha\n",
"print(\"Task Weights by Regularization Strength:\")\n",
"weights_df = pd.DataFrame(\n",
" [r['Weights'] for r in results],\n",
" columns=TASK_NAMES,\n",
" index=[f\"α={r['α']}\" for r in results]\n",
")\n",
"weights_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Observations:**\n",
"\n",
"- **α=0 (no regularization):** Perfect fit (R²=1.0) but extreme weights (+18, -22) — clearly overfitting\n",
"- **α=0.001:** Still near-perfect fit with very large weights\n",
"- **α=0.01:** Excellent fit (R²=0.99) with reasonable weights (~0.1 each) — **good choice**\n",
"- **α=0.1:** Good fit (R²=0.84) with uniform weights (~0.06 each) — conservative\n",
"- **α=1.0:** Poor fit (R²=0.25) — over-regularized"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ridge Model (α=0.01):\n",
" Intercept: 0.0269\n",
" Weights:\n",
" HellaSwag 0-shot : +0.1155\n",
" LAMBADA : +0.1000\n",
" HellaSwag 10-shot : +0.1226\n",
" PIQA : +0.0959\n",
" ARC Easy : +0.1023\n",
" ARC Challenge : +0.0513\n",
"\n",
"R² = 0.9916\n"
]
}
],
"source": [
"# Use α=0.01 as our chosen regularization\n",
"# This gives R²≈0.99 with reasonable, stable weights (~0.1 each task)\n",
"CHOSEN_ALPHA = 0.01\n",
"intercept_ridge, weights_ridge = ridge_regression(X_gpt2, y_gpt2, alpha=CHOSEN_ALPHA)\n",
"\n",
"print(f\"Ridge Model (α={CHOSEN_ALPHA}):\")\n",
"print(f\" Intercept: {intercept_ridge:.4f}\")\n",
"print(f\" Weights:\")\n",
"for name, w in zip(TASK_NAMES, weights_ridge):\n",
" print(f\" {name:20s}: {w:+.4f}\")\n",
"\n",
"# Validate\n",
"y_pred_ridge = X_gpt2 @ weights_ridge + intercept_ridge\n",
"r2_ridge = compute_r_squared(y_gpt2, y_pred_ridge)\n",
"print(f\"\\nR² = {r2_ridge:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Approach 3: Individual Task Analysis\n",
"\n",
"Which single task is the best predictor of CORE? We fit separate linear regressions for each task."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Individual Task Correlations with CORE:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Task</th>\n",
" <th>R²</th>\n",
" <th>Slope</th>\n",
" <th>Intercept</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>PIQA</td>\n",
" <td>0.9961</td>\n",
" <td>0.6879</td>\n",
" <td>-0.0537</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>HellaSwag 10-shot</td>\n",
" <td>0.9933</td>\n",
" <td>0.5230</td>\n",
" <td>0.0776</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>HellaSwag 0-shot</td>\n",
" <td>0.9927</td>\n",
" <td>0.5489</td>\n",
" <td>0.0753</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>LAMBADA</td>\n",
" <td>0.9841</td>\n",
" <td>0.6792</td>\n",
" <td>-0.1063</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ARC Easy</td>\n",
" <td>0.9800</td>\n",
" <td>0.5728</td>\n",
" <td>-0.0027</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>ARC Challenge</td>\n",
" <td>0.9599</td>\n",
" <td>1.3994</td>\n",
" <td>0.1706</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Task R² Slope Intercept\n",
"3 PIQA 0.9961 0.6879 -0.0537\n",
"2 HellaSwag 10-shot 0.9933 0.5230 0.0776\n",
"0 HellaSwag 0-shot 0.9927 0.5489 0.0753\n",
"1 LAMBADA 0.9841 0.6792 -0.1063\n",
"4 ARC Easy 0.9800 0.5728 -0.0027\n",
"5 ARC Challenge 0.9599 1.3994 0.1706"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Fit separate linear regression for each task\n",
"individual_results = []\n",
"for i, task_name in enumerate(TASK_NAMES):\n",
" x_task = X_gpt2[:, i]\n",
" slope_ind, intercept_ind = simple_linear_regression(x_task, y_gpt2)\n",
" y_pred_ind = slope_ind * x_task + intercept_ind\n",
" r2_ind = compute_r_squared(y_gpt2, y_pred_ind)\n",
" individual_results.append({\n",
" 'Task': task_name,\n",
" 'R²': r2_ind,\n",
" 'Slope': slope_ind,\n",
" 'Intercept': intercept_ind\n",
" })\n",
"\n",
"individual_df = pd.DataFrame(individual_results).sort_values('R²', ascending=False)\n",
"print(\"Individual Task Correlations with CORE:\")\n",
"individual_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Key Finding:** All 6 tasks have very high correlation with CORE (R² > 0.96), but **PIQA is the single best predictor** with R² = 0.9961 — actually slightly better than the simple averaging approach (R² = 0.9960)!\n",
"\n",
"This is useful if you want a quick proxy for CORE with minimal evaluation cost. However, for robustness we still recommend using all 6 tasks or the averaged approaches."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 6: Final Estimates for GPT-3\n",
"\n",
"We apply both models to GPT-3 data and report the average as our final estimate."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3 CORE Estimates (all three approaches):\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Params</th>\n",
" <th>Simple</th>\n",
" <th>Ridge</th>\n",
" <th>PIQA only</th>\n",
" <th>Avg(1,2)</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>GPT-3 Small</td>\n",
" <td>125M</td>\n",
" <td>0.1480</td>\n",
" <td>0.1488</td>\n",
" <td>0.1430</td>\n",
" <td>0.1484</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3 Medium</td>\n",
" <td>350M</td>\n",
" <td>0.2174</td>\n",
" <td>0.2144</td>\n",
" <td>0.2131</td>\n",
" <td>0.2159</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3 Large</td>\n",
" <td>760M</td>\n",
" <td>0.2691</td>\n",
" <td>0.2627</td>\n",
" <td>0.2489</td>\n",
" <td>0.2659</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-3 XL</td>\n",
" <td>1.3B</td>\n",
" <td>0.2965</td>\n",
" <td>0.2862</td>\n",
" <td>0.2805</td>\n",
" <td>0.2914</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>GPT-3 2.7B</td>\n",
" <td>2.7B</td>\n",
" <td>0.3351</td>\n",
" <td>0.3234</td>\n",
" <td>0.2957</td>\n",
" <td>0.3292</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>GPT-3 6.7B</td>\n",
" <td>6.7B</td>\n",
" <td>0.3689</td>\n",
" <td>0.3534</td>\n",
" <td>0.3287</td>\n",
" <td>0.3611</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>GPT-3 13B</td>\n",
" <td>13.0B</td>\n",
" <td>0.3935</td>\n",
" <td>0.3768</td>\n",
" <td>0.3576</td>\n",
" <td>0.3852</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>GPT-3 175B</td>\n",
" <td>175.0B</td>\n",
" <td>0.4379</td>\n",
" <td>0.4164</td>\n",
" <td>0.3906</td>\n",
" <td>0.4272</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Params Simple Ridge PIQA only Avg(1,2)\n",
"0 GPT-3 Small 125M 0.1480 0.1488 0.1430 0.1484\n",
"1 GPT-3 Medium 350M 0.2174 0.2144 0.2131 0.2159\n",
"2 GPT-3 Large 760M 0.2691 0.2627 0.2489 0.2659\n",
"3 GPT-3 XL 1.3B 0.2965 0.2862 0.2805 0.2914\n",
"4 GPT-3 2.7B 2.7B 0.3351 0.3234 0.2957 0.3292\n",
"5 GPT-3 6.7B 6.7B 0.3689 0.3534 0.3287 0.3611\n",
"6 GPT-3 13B 13.0B 0.3935 0.3768 0.3576 0.3852\n",
"7 GPT-3 175B 175.0B 0.4379 0.4164 0.3906 0.4272"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Apply all three approaches\n",
"avg_centered_gpt3 = X_gpt3.mean(axis=1)\n",
"gpt3_core_simple = slope * avg_centered_gpt3 + intercept\n",
"gpt3_core_ridge = X_gpt3 @ weights_ridge + intercept_ridge\n",
"\n",
"# Approach 3: Best individual predictor (PIQA)\n",
"piqa_idx = TASK_NAMES.index('PIQA')\n",
"piqa_model = [r for r in individual_results if r['Task'] == 'PIQA'][0]\n",
"gpt3_core_piqa = piqa_model['Slope'] * X_gpt3[:, piqa_idx] + piqa_model['Intercept']\n",
"\n",
"# Average of approaches 1 and 2\n",
"gpt3_core_final = (gpt3_core_simple + gpt3_core_ridge) / 2\n",
"\n",
"# Create results table with all approaches\n",
"results_df = pd.DataFrame({\n",
" 'Model': [m[0] for m in gpt3_models],\n",
" 'Params': [f\"{m[1]/1e9:.1f}B\" if m[1] >= 1e9 else f\"{m[1]/1e6:.0f}M\" for m in gpt3_models],\n",
" 'Simple': gpt3_core_simple,\n",
" f'Ridge': gpt3_core_ridge,\n",
" 'PIQA only': gpt3_core_piqa,\n",
" 'Avg(1,2)': gpt3_core_final\n",
"})\n",
"print(\"GPT-3 CORE Estimates (all three approaches):\")\n",
"results_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Final CORE Estimates for GPT-3"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Complete CORE Scores (GPT-2 measured, GPT-3 estimated):\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>Params</th>\n",
" <th>CORE</th>\n",
" <th>Source</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>GPT-2</td>\n",
" <td>124M</td>\n",
" <td>0.1139</td>\n",
" <td>Measured</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>GPT-3 Small</td>\n",
" <td>125M</td>\n",
" <td>0.1484</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>GPT-3 Medium</td>\n",
" <td>350M</td>\n",
" <td>0.2159</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>GPT-2 Medium</td>\n",
" <td>355M</td>\n",
" <td>0.1849</td>\n",
" <td>Measured</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>GPT-3 Large</td>\n",
" <td>760M</td>\n",
" <td>0.2659</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>GPT-2 Large</td>\n",
" <td>774M</td>\n",
" <td>0.2146</td>\n",
" <td>Measured</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>GPT-3 XL</td>\n",
" <td>1.3B</td>\n",
" <td>0.2914</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>GPT-2 XL</td>\n",
" <td>1.6B</td>\n",
" <td>0.2565</td>\n",
" <td>Measured</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>GPT-3 2.7B</td>\n",
" <td>2.7B</td>\n",
" <td>0.3292</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>GPT-3 6.7B</td>\n",
" <td>6.7B</td>\n",
" <td>0.3611</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>GPT-3 13B</td>\n",
" <td>13.0B</td>\n",
" <td>0.3852</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>GPT-3 175B</td>\n",
" <td>175.0B</td>\n",
" <td>0.4272</td>\n",
" <td>Estimated</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model Params CORE Source\n",
"0 GPT-2 124M 0.1139 Measured\n",
"1 GPT-3 Small 125M 0.1484 Estimated\n",
"2 GPT-3 Medium 350M 0.2159 Estimated\n",
"3 GPT-2 Medium 355M 0.1849 Measured\n",
"4 GPT-3 Large 760M 0.2659 Estimated\n",
"5 GPT-2 Large 774M 0.2146 Measured\n",
"6 GPT-3 XL 1.3B 0.2914 Estimated\n",
"7 GPT-2 XL 1.6B 0.2565 Measured\n",
"8 GPT-3 2.7B 2.7B 0.3292 Estimated\n",
"9 GPT-3 6.7B 6.7B 0.3611 Estimated\n",
"10 GPT-3 13B 13.0B 0.3852 Estimated\n",
"11 GPT-3 175B 175.0B 0.4272 Estimated"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Combine with GPT-2 for complete picture\n",
"all_models = []\n",
"\n",
"for data in gpt2_data:\n",
" params = data['params']\n",
" all_models.append({\n",
" 'Model': data['name'],\n",
" 'Family': 'GPT-2',\n",
" 'Params': params,\n",
" 'Params_str': f\"{params/1e9:.1f}B\" if params >= 1e9 else f\"{params/1e6:.0f}M\",\n",
" 'CORE': data['core'],\n",
" 'Source': 'Measured'\n",
" })\n",
"\n",
"for (name, params, _), core in zip(gpt3_models, gpt3_core_final):\n",
" all_models.append({\n",
" 'Model': name,\n",
" 'Family': 'GPT-3',\n",
" 'Params': params,\n",
" 'Params_str': f\"{params/1e9:.1f}B\" if params >= 1e9 else f\"{params/1e6:.0f}M\",\n",
" 'CORE': core,\n",
" 'Source': 'Estimated'\n",
" })\n",
"\n",
"# Sort by params and display\n",
"all_models.sort(key=lambda x: x['Params'])\n",
"final_df = pd.DataFrame(all_models)[['Model', 'Params_str', 'CORE', 'Source']]\n",
"final_df.columns = ['Model', 'Params', 'CORE', 'Source']\n",
"print(\"Complete CORE Scores (GPT-2 measured, GPT-3 estimated):\")\n",
"final_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Head-to-Head: GPT-2 vs GPT-3 at Similar Sizes"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3 vs GPT-2 at Similar Model Sizes:\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Size</th>\n",
" <th>GPT-2 CORE</th>\n",
" <th>GPT-3 CORE</th>\n",
" <th>Δ</th>\n",
" <th>Improvement</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>~125M</td>\n",
" <td>0.1139</td>\n",
" <td>0.1484</td>\n",
" <td>0.0345</td>\n",
" <td>+30.3%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>~350M</td>\n",
" <td>0.1849</td>\n",
" <td>0.2159</td>\n",
" <td>0.0310</td>\n",
" <td>+16.8%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>~760M</td>\n",
" <td>0.2146</td>\n",
" <td>0.2659</td>\n",
" <td>0.0512</td>\n",
" <td>+23.9%</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>~1.3-1.5B</td>\n",
" <td>0.2565</td>\n",
" <td>0.2914</td>\n",
" <td>0.0348</td>\n",
" <td>+13.6%</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Size GPT-2 CORE GPT-3 CORE Δ Improvement\n",
"0 ~125M 0.1139 0.1484 0.0345 +30.3%\n",
"1 ~350M 0.1849 0.2159 0.0310 +16.8%\n",
"2 ~760M 0.2146 0.2659 0.0512 +23.9%\n",
"3 ~1.3-1.5B 0.2565 0.2914 0.0348 +13.6%"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"comparisons = [\n",
" ('~125M', 'GPT-2', gpt2_data[0]['core'], 'GPT-3 Small', gpt3_core_final[0]),\n",
" ('~350M', 'GPT-2 Medium', gpt2_data[1]['core'], 'GPT-3 Medium', gpt3_core_final[1]),\n",
" ('~760M', 'GPT-2 Large', gpt2_data[2]['core'], 'GPT-3 Large', gpt3_core_final[2]),\n",
" ('~1.3-1.5B', 'GPT-2 XL', gpt2_data[3]['core'], 'GPT-3 XL', gpt3_core_final[3]),\n",
"]\n",
"\n",
"comparison_df = pd.DataFrame([\n",
" {\n",
" 'Size': size,\n",
" 'GPT-2 CORE': gpt2_core,\n",
" 'GPT-3 CORE': gpt3_core,\n",
" 'Δ': gpt3_core - gpt2_core,\n",
" 'Improvement': f\"{100 * (gpt3_core - gpt2_core) / gpt2_core:+.1f}%\"\n",
" }\n",
" for size, _, gpt2_core, _, gpt3_core in comparisons\n",
"])\n",
"print(\"GPT-3 vs GPT-2 at Similar Model Sizes:\")\n",
"comparison_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusions\n",
"\n",
"### Methodology\n",
"\n",
"We estimated CORE scores for GPT-3 models by:\n",
"1. Identifying 6 tasks with comparable evaluation methodology between GPT-3 and CORE\n",
"2. Using GPT-2's measured CORE scores as calibration data\n",
"3. Fitting three regression approaches:\n",
" - **Simple**: Average the 6 metrics, then linear regression (R²=0.996)\n",
" - **Ridge**: Use all 6 features with regularization (R²=0.992)\n",
" - **PIQA only**: Single best predictor (R²=0.996)\n",
"4. Averaging the Simple and Ridge approaches for final estimates\n",
"\n",
"### Key Findings\n",
"\n",
"1. **GPT-3 consistently outperforms GPT-2 at similar model sizes** by approximately 0.03-0.05 CORE (14-30% relative improvement)\n",
"\n",
"2. **PIQA is the best single predictor of CORE** (R²=0.9961). If you need a quick proxy for CORE with minimal evaluation cost, PIQA alone works nearly as well as averaging all 6 tasks.\n",
"\n",
"3. **The improvement likely comes from:**\n",
" - More training data (300B tokens vs ~100B for GPT-2)\n",
" - Better data quality and filtering\n",
" - Larger context length (2048 vs 1024)\n",
"\n",
"4. **Final estimated CORE scores:**\n",
"\n",
"| Model | Params | Estimated CORE |\n",
"|-------|--------|----------------|\n",
"| GPT-3 Small | 125M | 0.148 |\n",
"| GPT-3 Medium | 350M | 0.216 |\n",
"| GPT-3 Large | 760M | 0.266 |\n",
"| GPT-3 XL | 1.3B | 0.291 |\n",
"| GPT-3 2.7B | 2.7B | 0.329 |\n",
"| GPT-3 6.7B | 6.7B | 0.361 |\n",
"| GPT-3 13B | 13B | 0.385 |\n",
"| GPT-3 175B | 175B | 0.427 |\n",
"\n",
"### Caveats\n",
"\n",
"1. **These are estimates**, not measured values. True CORE scores could differ.\n",
"2. We only have 4 calibration points, limiting statistical power.\n",
"3. The 6 overlapping tasks may not perfectly represent all 22 CORE tasks.\n",
"4. Slight differences in evaluation methodology (K values, splits) add uncertainty.\n",
"\n",
"Despite these limitations, the estimates are useful for approximate comparisons between nanochat models and the GPT-3 family."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Appendix: Export Final Estimates"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPT-3 CORE Estimates (for copy-paste):\n",
"{\n",
" \"GPT-3 Small (125M)\": 0.1484,\n",
" \"GPT-3 Medium (350M)\": 0.2159,\n",
" \"GPT-3 Large (760M)\": 0.2659,\n",
" \"GPT-3 XL (1.3B)\": 0.2914,\n",
" \"GPT-3 2.7B\": 0.3292,\n",
" \"GPT-3 6.7B\": 0.3611,\n",
" \"GPT-3 13B\": 0.3852,\n",
" \"GPT-3 175B\": 0.4272\n",
"}\n"
]
}
],
"source": [
"# Export as a simple dict for use elsewhere\n",
"gpt3_core_estimates = {\n",
" 'GPT-3 Small (125M)': round(gpt3_core_final[0], 4),\n",
" 'GPT-3 Medium (350M)': round(gpt3_core_final[1], 4),\n",
" 'GPT-3 Large (760M)': round(gpt3_core_final[2], 4),\n",
" 'GPT-3 XL (1.3B)': round(gpt3_core_final[3], 4),\n",
" 'GPT-3 2.7B': round(gpt3_core_final[4], 4),\n",
" 'GPT-3 6.7B': round(gpt3_core_final[5], 4),\n",
" 'GPT-3 13B': round(gpt3_core_final[6], 4),\n",
" 'GPT-3 175B': round(gpt3_core_final[7], 4),\n",
"}\n",
"\n",
"print(\"GPT-3 CORE Estimates (for copy-paste):\")\n",
"import json\n",
"print(json.dumps(gpt3_core_estimates, indent=4))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
================================================
FILE: dev/gen_synthetic_data.py
================================================
"""
Synthetic data generation for teaching nanochat about its identity and capabilities.
This script uses the OpenRouter API to generate diverse multi-turn conversations
between a user and nanochat. The conversations are saved to a .jsonl file for use
in supervised finetuning (SFT) via the CustomJSON task.
Key design principles for high-quality synthetic data:
1. DIVERSITY CONTROL is critical - we inject entropy at multiple levels:
- Topic/question categories (what the conversation is about)
- User personas (who is asking)
- Conversation dynamics (shape and flow)
- First message style (greeting variation)
2. Comprehensive knowledge base - we provide detailed facts so the LLM
generating conversations has accurate information to draw from.
3. Structured outputs - we use JSON schema to guarantee valid format.
NOTE: You need OPENROUTER_API_KEY set in .env or as an environment variable.
NOTE: For more details see: https://github.com/karpathy/nanochat/discussions/139
"""
import requests
import json
import os
import copy
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from dotenv import load_dotenv
from nanochat.common import get_base_dir
load_dotenv()
api_key = os.environ["OPENROUTER_API_KEY"]
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# Load the comprehensive knowledge base
knowledge_path = os.path.join(os.path.dirname(__file__), "..", "knowledge", "self_knowledge.md")
knowledge = open(knowledge_path, "r", encoding="utf-8").read().strip()
assert os.path.exists(knowledge_path), f"Knowledge base file not found: {knowledge_path}"
# for right now I am not committing the self_knowledge file to repo. You can use README.md instead
# of it, or you can generate one by asking an LLM to make one based on the README/files.
# This whole file is just a helpful demonstration of the kind of thing you'd run.
# =============================================================================
# DIVERSITY DIMENSIONS
# =============================================================================
# Topics/questions the conversation should explore
# Group by category for balanced sampling
topics = {
"identity": [
"who/what is nanochat",
"who created nanochat and why",
"what does the name 'nanochat' mean",
"is nanochat open source, what license",
"where can I find the code",
"how can I contribute to nanochat",
],
"architecture": [
"basic architecture overview (transformer, layers, parameters)",
"what is RoPE and why use it",
"explain RMSNorm vs LayerNorm",
"what is Flash Attention and why it matters",
"sliding window attention pattern",
"value embeddings - what are they",
"per-layer residual scalars",
"ReLU squared activation",
"logit softcapping",
"QK normalization",
],
"training": [
"how much did it cost to train nanochat",
"how long does training take",
"what hardware is needed",
"what data was nanochat trained on",
"what is the Muon optimizer",
"explain the split optimizer design",
"what is the depth parameter and scaling",
"what is the CORE metric",
],
"capabilities": [
"what can nanochat do",
"can nanochat write code",
"can nanochat do math (calculator tool)",
"can nanochat help with writing",
"what languages does nanochat speak",
"how good is nanochat at reasoning",
],
"limitations": [
"what can nanochat NOT do",
"why does nanochat work best in English",
"does nanochat have internet access",
"what is nanochat's context length limit",
"can nanochat remember previous conversations",
"can nanochat make mistakes / hallucinate",
"is nanochat good for production use",
],
"comparisons": [
"how does nanochat compare to GPT-2",
"how does nanochat compare to ChatGPT/GPT-4",
"how does nanochat compare to Claude",
"why is training 600x cheaper than GPT-2",
"what's special about nanochat vs other open models",
],
"history": [
"the GPT-2 training cost in 2019",
"how AI training costs have dropped over time",
"relationship to modded-nanogpt project",
"what optimizations worked vs didn't work",
"the journey of building nanochat",
],
"technical_deep_dive": [
"explain the tokenizer (BPE, vocab size)",
"how does distributed training work (ZeRO)",
"explain the dataloader and BOS alignment",
"what is compute-optimal training",
"how does the calculator tool work",
"explain inference with KV cache",
],
"philosophical": [
"is nanochat conscious / does it have feelings",
"what happens when nanochat is wrong",
"can nanochat learn from this conversation",
"why make AI training accessible",
"the future of open source AI",
],
}
# User personas - different people ask questions differently
personas = [
"curious beginner who knows nothing about AI or machine learning",
"ML researcher or engineer who wants technical depth and specifics",
"developer considering contributing to the nanochat project",
"skeptic who doubts open source can compete with big AI labs",
"computer science student learning about transformers and LLMs",
"someone comparing nanochat to ChatGPT, Claude, or other assistants",
"journalist or writer covering AI democratization and open source",
"hobbyist who just wants to chat and learn casually",
"someone interested in the cost and economics of AI training",
"teacher or educator wanting to use nanochat for teaching",
"entrepreneur exploring if nanochat fits their use case",
"someone who just discovered the project and wants the basics",
]
# Conversation dynamics - shape and flow
dynamics = [
"short 2-turn Q&A: user asks one question, gets a complete answer",
"medium 4-turn: user asks, gets answer, asks followup for clarification",
"deep 6-turn technical discussion: progressively deeper questions",
"skeptical arc: user starts doubtful, assistant addresses concerns honestly",
"learning journey: user starts basic, assistant builds up complexity gradually",
"comparison-focused: user keeps comparing to other models, assistant explains differences",
"limitation exploration: user probes what nanochat cannot do, assistant is honest",
"casual friendly chat that naturally touches on identity and capabilities",
"troubleshooting: user has misconceptions, assistant gently corrects them",
"enthusiastic: user is excited about the project, assistant shares that energy appropriately",
]
# First messages - greetings and openers
# Categorized for balanced sampling
first_messages = {
"simple_greetings": [
"hi", "Hi!", "hello", "Hello?", "hey there", "Hey!", "yo", "Yo!",
"Good morning", "Good evening!", "Howdy", "sup", "What's up?",
"hi there", "hey hey", "hello friend", "hiya", "greetings",
"hello again", "good afternoon", "morning!", "evening!",
],
"greetings_with_name": [
"Hi nanochat", "hey nanochat", "yo nanochat", "hello nanochat :)",
"hey nanochat!", "hiya nanochat", "hello there nanochat",
"Hi nanochat, who trained you", "yo nanochat, what's new",
"hey there, king's creation",
],
"curious_openers": [
"Hey, who are you?", "Hi, what is this?", "Hey, are you a chatbot?",
"Hello! Who am I talking to?", "hi! what do you do?",
"hi! who made you", "hey! are you alive", "hiya! what are you",
"hello! tell me about yourself", "hi, what's your name",
"yo, what is this", "hi! who built you", "hello! are you open source",
"hey, what version are you", "hi! what's your story",
"hey, what's nanochat", "hello! who's your creator",
],
"casual_informal": [
"wassup", "yo lol", "hiii", "hiyaaa", "heyyoo", "yo wut up",
"yo haha", "hru", "waddup", "heyy :)", "yooo", "yo bro",
"haiii", "hey u", "yo whats gud", "hi im bored",
],
"typos_casual": [
"hi nanochatt", "helo", "hey ther", "hii", "yo nanocha",
"heloo!", "hi, whos this", "hay", "helloo??", "hi nanocat",
"helo nanochat", "hai!", "helllo nano", "yo nanochta",
],
"caps_enthusiastic": [
"HI", "HELLOOO", "YO!!!", "HEY", "SUP", "WASSUP", "HEY!!!",
"HELLO??", "HI THERE!!", "HEYOOOO", "HIII", "YOOOO", "HELLO!!!",
],
"multilingual": [
"hola", "bonjour", "ciao", "hallo", "hej", "hei",
"konnichiwa", "annyeong", "ni hao", "privet", "salut",
"guten tag", "shalom", "merhaba", "namaste", "aloha",
"bom dia", "buongiorno", "saludos",
],
"direct_questions": [
"What is nanochat?", "Who made you?", "Are you GPT?",
"How do you compare to ChatGPT?", "Can you help me code?",
"What can you do?", "Are you open source?", "How were you trained?",
"What's your context limit?", "Can you browse the internet?",
],
}
# =============================================================================
# PROMPT TEMPLATE
# =============================================================================
prompt_template = r"""
I want to generate synthetic training data for an AI assistant called "nanochat" to teach it about its own identity, capabilities, and limitations.
## KNOWLEDGE BASE
Here is comprehensive information about nanochat that you should use as the authoritative source of facts:
---
{knowledge}
---
## YOUR TASK
Generate a realistic multi-turn conversation between a User and the nanochat Assistant.
**Topic to explore:** {topic}
**User persona:** {persona}
**Conversation dynamic:** {dynamic}
## STYLE GUIDELINES
1. **Plain ASCII only** - No emojis, special characters, or unicode. Just plain text.
2. **Natural conversation** - Make it feel like a real chat, not a Q&A exam.
3. **Accurate facts** - Use ONLY information from the knowledge base above. Don't make up statistics or features.
4. **Appropriate depth** - Match the technical level to the user persona.
5. **Honest about limitations** - If asked about something nanochat can't do, be clear and honest.
6. **Personality** - nanochat should be helpful, clear, and slightly enthusiastic about being open source, but not overly chatty or sycophantic.
## FIRST MESSAGE EXAMPLES
Here are some example first messages from users (for style inspiration):
{first_message_examples}
## SPECIAL CASES
- **Non-English first message:** If the user writes in another language, nanochat should briefly acknowledge it can understand but works best in English, then continue helpfully.
- **Misconceptions:** If the user has wrong assumptions (e.g., "you're made by OpenAI"), gently correct them.
- **Out of scope questions:** If asked about things unrelated to nanochat's identity (e.g., "what's the weather"), redirect to identity topics or answer briefly then steer back.
## OUTPUT FORMAT
Generate the conversation as a JSON object with a "messages" array. Each message has "role" (user/assistant) and "content". Start with a user message.
""".strip()
# =============================================================================
# API CONFIGURATION
# =============================================================================
response_format = {
"type": "json_schema",
"json_schema": {
"name": "conversation",
"strict": True,
"schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"description": "Conversation messages alternating user/assistant, starting with user",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "Either 'user' or 'assistant'"
},
"content": {
"type": "string",
"description": "The message content"
}
},
"required": ["role", "content"],
"additionalProperties": False
}
}
},
"required": ["messages"],
"additionalProperties": False
}
}
}
base_payload = {
"model": "google/gemini-3-flash-preview",
"stream": False,
"response_format": response_format,
"temperature": 1.0,
}
# =============================================================================
# GENERATION LOGIC
# =============================================================================
def sample_diversity_elements(rng):
"""Sample one element from each diversity dimension."""
# Sample topic: first pick a category, then a topic within it
category = rng.choice(list(topics.keys()))
topic = rng.choice(topics[category])
# Sample persona
persona = rng.choice(personas)
# Sample dynamic
dynamic = rng.choice(dynamics)
# Sample first message examples: pick from multiple categories
first_msg_samples = []
categories = rng.sample(list(first_messages.keys()), min(3, len(first_messages)))
for cat in categories:
first_msg_samples.append(rng.choice(first_messages[cat]))
return {
"topic": topic,
"persona": persona,
"dynamic": dynamic,
"first_message_examples": "\n".join(f"- {msg}" for msg in first_msg_samples),
}
def generate_conversation(idx: int):
"""
Generate a single conversation using the OpenRouter API.
Returns a list of message dicts with 'role' and 'content' keys.
"""
# Use idx as seed for reproducibility
rng = random.Random(idx)
# Sample diversity elements
elements = sample_diversity_elements(rng)
# Build the prompt
prompt = prompt_template.format(
knowledge=knowledge,
topic=elements["topic"],
persona=elements["persona"],
dynamic=elements["dynamic"],
first_message_examples=elements["first_message_examples"],
)
# Make API request
payload = copy.deepcopy(base_payload)
payload['messages'] = [{"role": "user", "content": prompt}]
response = requests.post(url, headers=headers, json=payload)
result = response.json()
if 'error' in result:
raise Exception(f"API error: {result['error']}")
content = result['choices'][0]['message']['content']
conversation_data = json.loads(content)
messages = conversation_data['messages']
# Return messages along with metadata for debugging
return {
"messages": messages,
"metadata": {
"topic": elements["topic"],
"persona": elements["persona"],
"dynamic": elements["dynamic"],
}
}
def validate_conversation(messages):
"""Validate conversation structure."""
if len(messages) < 2:
raise ValueError(f"Conversation too short: {len(messages)} messages")
for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
if message['role'] != expected_role:
raise ValueError(f"Message {i} has role '{message['role']}', expected '{expected_role}'")
if not message['content'].strip():
raise ValueError(f"Message {i} has empty content")
return True
# =============================================================================
# MAIN
# =============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate synthetic conversation data")
parser.add_argument("--num", type=int, default=1000, help="Number of conversations to generate")
parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers")
parser.add_argument("--output", type=str, default=None, help="Output file path")
parser.add_argument("--append", action="store_true", help="Append to existing file instead of overwriting")
parser.add_argument("--save-metadata", action="store_true", help="Save metadata alongside messages")
args = parser.parse_args()
# Set output file
if args.output:
output_file = args.output
else:
output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
# Handle file creation/clearing
if not args.append and os.path.exists(output_file):
os.remove(output_file)
print(f"Output file: {output_file}")
print(f"Generating {args.num} conversations with {args.workers} workers...")
print(f"Topic categories: {list(topics.keys())}")
print(f"Personas: {len(personas)}")
print(f"Dynamics: {len(dynamics)}")
print()
completed_count = 0
error_count = 0
with ThreadPoolExecutor(max_workers=args.workers) as executor:
# Submit all tasks
futures = {executor.submit(generate_conversation, idx): idx
for idx in range(args.num)}
# Process results as they complete
for future in as_completed(futures):
idx = futures[future]
try:
result = future.result()
messages = result["messages"]
metadata = result["metadata"]
# Validate
validate_conversation(messages)
# Write to file
with open(output_file, 'a') as f:
if args.save_metadata:
f.write(json.dumps({"messages": messages, "metadata": metadata}) + '\n')
else:
f.write(json.dumps(messages) + '\n')
completed_count += 1
topic_short = metadata["topic"][:40] + "..." if len(metadata["topic"]) > 40 else metadata["topic"]
print(f"[{completed_count}/{args.num}] Topic: {topic_short}")
except Exception as e:
error_count += 1
print(f"[ERROR] idx={idx}: {e}")
print()
print(f"Done! Saved {completed_count} conversations to {output_file}")
if error_count > 0:
print(f"Encountered {error_count} errors during generation")
================================================
FILE: dev/generate_logo.html
================================================
<!DOCTYPE html>
<html>
<body style="margin:0; display:flex; justify-content:center; align-items:center; height:100vh; background:#fff">
<svg width="400" height="400" xmlns="http://www.w3.org/2000/svg">
<defs>
<radialGradient id="g" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#667eea;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#764ba2;stop-opacity:0.3"/>
</radialGradient>
</defs>
</svg>
<script>
const svg = document.querySelector('svg');
const r = 120;
let path = '';
for(let i = 0; i < 24; i += 2) {
let a1 = i * Math.PI / 12;
let a2 = (i + 1) * Math.PI / 12;
let x2 = 200 + Math.cos(a2) * r;
let y2 = 200 + Math.sin(a2) * r;
let x3 = 200 + Math.cos(a2) * (r - 90);
let y3 = 200 + Math.sin(a2) * (r - 90);
path += `M${x2},${y2} L${x3},${y3} `;
}
svg.innerHTML += `<path d="${path}" stroke="url(#g)" stroke-width="6" stroke-linecap="round" fill="none"/>`;
svg.innerHTML += `<path d="M200,-12 L212,0 L200,12 L188,0 Z" transform="translate(0,200)" fill="#000"/>`;
</script>
</body>
</html>
================================================
FILE: dev/repackage_data_reference.py
================================================
"""
Repackage a given dataset into simple parquet shards:
- each shard is ~100MB in size (after zstd compression)
- parquets are written with row group size of 1000
- shuffle the dataset
This will be uploaded to HuggingFace for hosting.
The big deal is that our DataLoader will be able to stream
the data and cache it along the way on disk, decreasing the
training latency.
Historical context:
Originally, nanochat used the FinewebEdu-100B dataset.
Then we switched to the ClimbMix-400B dataset due to superior performance.
This script documents how both were prepared.
The outputs are here:
https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle
https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle
NOTE: This file is meant only as reference/documentation of the
dataset preparation and it is not used during the project runtime.
"""
import os
import time
from datasets import load_dataset
import pyarrow.parquet as pq
import pyarrow as pa
# You can change these:
dataset_tag = "climbmix"
upload_to_hf = True
# Dataset configurations:
if dataset_tag == "fineweb_edu":
dataset_kwargs = {
"path": "HuggingFaceFW/fineweb-edu",
"split": "train",
"name": "sample-100BT", # ~100B GPT-2 tokens at ~3 chars/token => ~300B chars total
}
output_dirname = "fineweb_edu"
data_column_name = "text"
tokenizer = None
upload_tag = "fineweb-edu-100b-shuffle"
elif dataset_tag == "climbmix":
import tiktoken # the ClimbMix data is stored tokenized with GPT-2 tokenizer
dataset_kwargs = {
"path": "nvidia/Nemotron-ClimbMix",
"split": "train",
}
output_dirname = "climbmix"
data_column_name = "tokens"
tokenizer = tiktoken.encoding_for_model("gpt-2")
upload_tag = "climbmix-400b-shuffle"
else:
raise ValueError(f"Unknown dataset tag: {dataset_tag}")
# Source dataset
ds = load_dataset(**dataset_kwargs)
# Shuffle to scramble the order
ds = ds.shuffle(seed=42)
ndocs = len(ds) # total number of documents to process
print(f"Total number of documents: {ndocs}")
# Repackage into parquet files
output_dir = f"/home/ubuntu/.cache/nanochat/base_data_{output_dirname}"
os.makedirs(output_dir, exist_ok=True)
# Write to parquet files
chars_per_shard = 250_000_000
row_group_size = 1024 # HF uses 1000 but we use multiple of 2, nicer for distributed data loader later
shard_docs = []
shard_index = 0
shard_characters = 0
total_docs_processed = 0
total_time_spent = 0
t0 = time.time()
for doc in ds:
data = doc[data_column_name]
text = tokenizer.decode(data) if tokenizer is not None else data
shard_docs.append(text)
shard_characters += len(text)
collected_enough_chars = shard_characters >= chars_per_shard
docs_multiple_of_row_group_size = len(shard_docs) % row_group_size == 0
if collected_enough_chars and docs_multiple_of_row_group_size: # leads to ~100MB of text (compressed)
shard_path = os.path.join(output_dir, f"shard_{shard_index:05d}.parquet")
shard_table = pa.Table.from_pydict({"text": shard_docs})
pq.write_table(
shard_table,
shard_path,
row_group_size=row_group_size,
use_dictionary=False, # this is usually used for categorical data
compression="zstd", # Valid values: {‘NONE’, ‘SNAPPY’, ‘GZIP’, ‘BROTLI’, ‘LZ4’, ‘ZSTD’}
compression_level=3,
write_statistics=False, # not needed for text
)
t1 = time.time()
dt = t1 - t0 # for this shard alone
t0 = t1
total_docs_processed += len(shard_docs)
total_time_spent += dt
remaining_docs = ndocs - total_docs_processed
avg_time_per_doc = total_time_spent / total_docs_processed
remaining_time = remaining_docs * avg_time_per_doc
remaining_time_hours = remaining_time / 3600
print(f"Wrote {shard_path}. #documents: {len(shard_docs)} | #characters: {shard_characters} | time: {dt:.2f}s | remaining time: {remaining_time_hours:.2f}h")
shard_docs = []
shard_characters = 0
shard_index += 1
# Demonstration of how the data was later uploaded to HuggingFace
if upload_to_hf:
from huggingface_hub import HfApi
token = os.getenv("HF_TOKEN")
api = HfApi(token=token)
api.upload_large_folder(
folder_path=output_dir,
repo_id=f"karpathy/{upload_tag}",
repo_type="dataset",
)
================================================
FILE: dev/scaling_analysis.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Scaling Laws Analysis\n",
"\n",
"Analyze results from `scaling_laws.sh` to find the optimal param:data ratio for nanochat."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load results\n",
"tag = \"jan26\"\n",
"base_dir = os.environ.get('NANOCHAT_BASE_DIR', os.path.expanduser('~/.cache/nanochat'))\n",
"results_path = os.path.join(base_dir, f'scaling_laws_results_{tag}', 'results.csv')\n",
"\n",
"df = pd.read_csv(results_path)\n",
"flops_budgets = sorted(df['flops_budget'].unique())\n",
"print(f\"Loaded {len(df)} runs across {len(flops_budgets)} FLOPs budgets\")\n",
"print(f\"Columns: {list(df.columns)}\")\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# FILTERING: Remove incomplete or problematic runs\n",
"# =============================================================================\n",
"\n",
"print(f\"Before filtering: {len(df)} runs\")\n",
"\n",
"# Filter out runs with missing/invalid val_bpb (incomplete runs)\n",
"df = df[df['val_bpb'].notna() & (df['val_bpb'] > 0)]\n",
"\n",
"# Optional: exclude specific flops budgets that aren't done yet\n",
"# exclude_flops = [1e19] # <-- adjust as runs complete\n",
"# df = df[~df['flops_budget'].isin(exclude_flops)]\n",
"\n",
"# Optional: exclude specific depths\n",
"# exclude_depths = [18, 20]\n",
"# df = df[~df['depth'].isin(exclude_depths)]\n",
"\n",
"print(f\"After filtering: {len(df)} runs\")\n",
"print(f\"FLOPs budgets: {sorted(df['flops_budget'].unique())}\")\n",
"print(f\"Depths: {sorted(df['depth'].unique())}\")\n",
"\n",
"# Update flops_budgets list after filtering\n",
"flops_budgets = sorted(df['flops_budget'].unique())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Effective Parameter Count\n",
"\n",
"Different scaling law papers use different conventions for counting parameters:\n",
"- **Kaplan et al.** excluded embedding parameters (claimed cleaner laws)\n",
"- **Chinchilla** included all parameters (and noted Kaplan had a bug)\n",
"\n",
"Our CSV now has granular counts:\n",
"- `params_wte` - token embedding (lookup table)\n",
"- `params_value_embeds` - value embeddings (lookup table)\n",
"- `params_lm_head` - unembedding projection (matmul)\n",
"- `params_transformer` - attention + MLP matrices (matmuls)\n",
"- `params_scalars` - resid/x0/bigram lambdas (tiny)\n",
"\n",
"**Experiment below** with different combinations to see which gives the cleanest scaling laws."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# EXPERIMENT HERE: Define which parameters to count for scaling laws\n",
"# =============================================================================\n",
"\n",
"def compute_effective_params(row):\n",
" \"\"\"\n",
" Compute the 'effective' parameter count for scaling law analysis.\n",
"\n",
" Modify this function to experiment with different conventions:\n",
" - Chinchilla-style: include everything\n",
" - Kaplan-style: exclude embeddings\n",
" - Matmul-only: just transformer + lm_head (the actual compute)\n",
" - etc.\n",
" \"\"\"\n",
" # Option 1: Chinchilla-style (all params)\n",
" # return row['params_total']\n",
"\n",
" # Option 2: Kaplan-style (exclude embeddings)\n",
" return row['params_transformer'] + row['params_lm_head']\n",
"\n",
" # Option 3: Transformer-only (exclude all embeddings AND lm_head)\n",
" # return row['params_transformer']\n",
"\n",
"\n",
"# Compute derived columns\n",
"df = df.copy() # avoid SettingWithCopyWarning from earlier filter\n",
"df['effective_params'] = df.apply(compute_effective_params, axis=1)\n",
"df['param_data_ratio'] = df['tokens_trained'] / df['effective_params']\n",
"\n",
"# Show parameter breakdown for first few rows\n",
"print(\"Parameter breakdown (first row per flops budget):\")\n",
"param_cols = ['depth', 'params_wte', 'params_value_embeds',\n",
" 'params_lm_head', 'params_transformer', 'params_scalars', 'params_total', 'effective_params']\n",
"df.groupby('flops_budget').first()[param_cols]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## IsoFLOP Curves (à la Chinchilla)\n",
"\n",
"For each compute budget, plot loss vs model size. Looking for the U-shape valley that reveals the optimal model size for each FLOPs budget."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
"\n",
"# Plot 1: IsoFLOP curves - Val BPB vs Parameters (the Chinchilla plot!)\n",
"ax = axes[0]\n",
"colors = plt.cm.viridis(np.linspace(0, 0.9, len(flops_budgets)))\n",
"optimal_by_bpb = []\n",
"\n",
"for flops, color in zip(flops_budgets, colors):\n",
" subset = df[df['flops_budget'] == flops].sort_values('effective_params')\n",
" ax.plot(subset['effective_params'], subset['val_bpb'], 'o', color=color, label=f'{flops:.0e}', markersize=8)\n",
"\n",
" # Fit quadratic in log-space: val_bpb = a*(log N)^2 + b*(log N) + c\n",
" log_params = np.log10(subset['effective_params'])\n",
" coeffs = np.polyfit(log_params, subset['val_bpb'], 2)\n",
" a, b, c = coeffs\n",
"\n",
" # Plot fitted curve (dashed)\n",
" log_fit_x = np.linspace(log_params.min() - 0.1, log_params.max() + 0.1, 100)\n",
" fit_y = a * log_fit_x**2 + b * log_fit_x + c\n",
" ax.plot(10**log_fit_x, fit_y, '--', color=color, linewidth=2)\n",
"\n",
" # Find minimum of quadratic: d/dx(ax^2 + bx + c) = 0 => x = -b/(2a)\n",
" if a > 0: # parabola opens upward (has a minimum)\n",
" log_opt = -b / (2 * a)\n",
" opt_params = 10**log_opt\n",
" opt_bpb = a * log_opt**2 + b * log_opt + c\n",
" # Mark the fitted optimal\n",
" ax.scatter([opt_params], [opt_bpb], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2, marker='*')\n",
" # Interpolate tokens and ratio from actual data (don't use C≈6ND approximation)\n",
" opt_tokens = np.interp(np.log10(opt_params), log_params, subset['tokens_trained'])\n",
" opt_ratio = np.interp(np.log10(opt_params), log_params, subset['param_data_ratio'])\n",
" optimal_by_bpb.append({'flops': flops, 'params': opt_params, 'tokens': opt_tokens, 'ratio': opt_ratio, 'bpb': opt_bpb})\n",
" else:\n",
" # Fallback to raw minimum if quadratic doesn't have minimum\n",
" best_idx = subset['val_bpb'].idxmin()\n",
" best = subset.loc[best_idx]\n",
" ax.scatter([best['effective_params']], [best['val_bpb']], s=150, color=color,\n",
" zorder=5, edgecolors='black', linewidths=2)\n",
" optimal_by_bpb.append({'flops': flops, 'params': best['effective_params'],\n",
" 'tokens': best['tokens_trained'], 'ratio': best['param_data_ratio'], 'bpb': best['val_bpb']})\n",
"\n",
"ax.set_xscale('log')\n",
"ax.set_xlabel('Effective Parameters')\n",
"ax.set_ylabel('Validation Loss (bpb)')\n",
"ax.set_title('IsoFLOP Curves')\n",
"ax.legend(title='FLOPs', loc='upper right')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"opt_df = pd.DataFrame(optimal_by_bpb)\n",
"\n",
"# Plot 2: Optimal model size vs compute (power law)\n",
"ax = axes[1]\n",
"ax.loglog(opt_df['flops'], opt_df['params'], 'o', markersize=10, color='#2ecc71')\n",
"ax.set_xlabel('FLOPs')\n",
"ax.set_ylabel('Optimal Parameters')\n",
"ax.set_title('Optimal Model Size')\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# Fit and show power law\n",
"if len(opt_df) >= 2:\n",
" log_f = np.log10(opt_df['flops'])\n",
" log_p = np.log10(opt_df['params'])\n",
" slope, intercept = np.polyfit(log_f, log_p, 1)\n",
" fit_f = np.logspace(log_f.min() - 0.5, log_f.max() + 0.5, 100)\n",
" fit_p = 10**(intercept + slope * np.log10(fit_f))\n",
" ax.plot(fit_f, fit_p, 'r--', alpha=0.7, label=f'N ∝ C^{slope:.2f}')\n",
" ax.legend()\n",
"\n",
"# Plot 3: Optimal tokens vs compute (power law)\n",
"ax = axes[2]\n",
gitextract_21_vu5fa/
├── .claude/
│ └── skills/
│ └── read-arxiv-paper/
│ └── SKILL.md
├── .gitignore
├── .python-version
├── LICENSE
├── README.md
├── dev/
│ ├── LEADERBOARD.md
│ ├── LOG.md
│ ├── estimate_gpt3_core.ipynb
│ ├── gen_synthetic_data.py
│ ├── generate_logo.html
│ ├── repackage_data_reference.py
│ └── scaling_analysis.ipynb
├── nanochat/
│ ├── __init__.py
│ ├── checkpoint_manager.py
│ ├── common.py
│ ├── core_eval.py
│ ├── dataloader.py
│ ├── dataset.py
│ ├── engine.py
│ ├── execution.py
│ ├── flash_attention.py
│ ├── fp8.py
│ ├── gpt.py
│ ├── loss_eval.py
│ ├── optim.py
│ ├── report.py
│ ├── tokenizer.py
│ └── ui.html
├── pyproject.toml
├── runs/
│ ├── miniseries.sh
│ ├── runcpu.sh
│ ├── scaling_laws.sh
│ └── speedrun.sh
├── scripts/
│ ├── base_eval.py
│ ├── base_train.py
│ ├── chat_cli.py
│ ├── chat_eval.py
│ ├── chat_rl.py
│ ├── chat_sft.py
│ ├── chat_web.py
│ ├── tok_eval.py
│ └── tok_train.py
├── tasks/
│ ├── arc.py
│ ├── common.py
│ ├── customjson.py
│ ├── gsm8k.py
│ ├── humaneval.py
│ ├── mmlu.py
│ ├── smoltalk.py
│ └── spellingbee.py
└── tests/
├── test_attention_fallback.py
└── test_engine.py
SYMBOL INDEX (343 symbols across 33 files)
FILE: dev/gen_synthetic_data.py
function sample_diversity_elements (line 312) | def sample_diversity_elements(rng):
function generate_conversation (line 338) | def generate_conversation(idx: int):
function validate_conversation (line 383) | def validate_conversation(messages):
FILE: nanochat/checkpoint_manager.py
function log0 (line 19) | def log0(message):
function _patch_missing_config_keys (line 23) | def _patch_missing_config_keys(model_config_kwargs):
function _patch_missing_keys (line 30) | def _patch_missing_keys(model_data, model_config):
function save_checkpoint (line 42) | def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, me...
function load_checkpoint (line 61) | def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, ...
function build_model (line 77) | def build_model(checkpoint_dir, step, device, phase):
function find_largest_model (line 118) | def find_largest_model(checkpoints_dir):
function find_last_step (line 138) | def find_last_step(checkpoint_dir):
function load_model_from_dir (line 149) | def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, ...
function load_model (line 164) | def load_model(source, *args, **kwargs):
function load_optimizer_state (line 174) | def load_optimizer_state(source, device, rank, model_tag=None, step=None):
FILE: nanochat/common.py
function _detect_compute_dtype (line 17) | def _detect_compute_dtype():
class ColoredFormatter (line 33) | class ColoredFormatter(logging.Formatter):
method format (line 45) | def format(self, record):
function setup_default_logging (line 59) | def setup_default_logging():
function get_base_dir (line 70) | def get_base_dir():
function download_file_with_lock (line 81) | def download_file_with_lock(url, filename, postprocess_fn=None):
function print0 (line 117) | def print0(s="",**kwargs):
function print_banner (line 122) | def print_banner():
function is_ddp_requested (line 136) | def is_ddp_requested() -> bool:
function is_ddp_initialized (line 143) | def is_ddp_initialized() -> bool:
function get_dist_info (line 150) | def get_dist_info():
function autodetect_device_type (line 162) | def autodetect_device_type():
function compute_init (line 173) | def compute_init(device_type="cuda"): # cuda|cpu|mps
function compute_cleanup (line 210) | def compute_cleanup():
class DummyWandb (line 215) | class DummyWandb:
method __init__ (line 217) | def __init__(self):
method log (line 219) | def log(self, *args, **kwargs):
method finish (line 221) | def finish(self):
function get_peak_flops (line 227) | def get_peak_flops(device_name: str) -> float:
FILE: nanochat/core_eval.py
function render_prompts_mc (line 17) | def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
function render_prompts_schema (line 36) | def render_prompts_schema(item, continuation_delimiter, fewshot_examples...
function render_prompts_lm (line 56) | def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
function find_common_length (line 86) | def find_common_length(token_sequences, direction='left'):
function stack_sequences (line 104) | def stack_sequences(tokens, pad_token_id):
function batch_sequences_mc (line 113) | def batch_sequences_mc(tokenizer, prompts):
function batch_sequences_schema (line 123) | def batch_sequences_schema(tokenizer, prompts):
function batch_sequences_lm (line 133) | def batch_sequences_lm(tokenizer, prompts):
function forward_model (line 145) | def forward_model(model, input_ids):
function evaluate_example (line 168) | def evaluate_example(idx, model, tokenizer, data, device, task_meta):
function evaluate_task (line 244) | def evaluate_task(model, tokenizer, data, device, task_meta):
FILE: nanochat/dataloader.py
function _document_batches (line 25) | def _document_batches(split, resume_state_dict, tokenizer_batch_size):
function tokenizing_distributed_data_loader_with_state_bos_bestfit (line 74) | def tokenizing_distributed_data_loader_with_state_bos_bestfit(
function tokenizing_distributed_data_loader_bos_bestfit (line 163) | def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
FILE: nanochat/dataset.py
function list_parquet_files (line 32) | def list_parquet_files(data_dir=None, warn_on_legacy=False):
function parquets_iter_batched (line 67) | def parquets_iter_batched(split, start=0, step=1):
function download_single_file (line 84) | def download_single_file(index):
FILE: nanochat/engine.py
function timeout (line 26) | def timeout(duration, formula):
function eval_with_timeout (line 35) | def eval_with_timeout(formula, max_time=3):
function use_calculator (line 46) | def use_calculator(expr):
class KVCache (line 82) | class KVCache:
method __init__ (line 92) | def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layer...
method reset (line 106) | def reset(self):
method get_pos (line 111) | def get_pos(self):
method get_layer_cache (line 115) | def get_layer_cache(self, layer_idx):
method advance (line 119) | def advance(self, num_tokens):
method prefill (line 123) | def prefill(self, other):
function sample_next_token (line 141) | def sample_next_token(logits, rng, temperature=1.0, top_k=None):
class RowState (line 160) | class RowState:
method __init__ (line 162) | def __init__(self, current_tokens=None):
class Engine (line 169) | class Engine:
method __init__ (line 171) | def __init__(self, model, tokenizer):
method generate (line 176) | def generate(self, tokens, num_samples=1, max_tokens=None, temperature...
method generate_batch (line 282) | def generate_batch(self, tokens, num_samples=1, **kwargs):
FILE: nanochat/execution.py
class ExecutionResult (line 38) | class ExecutionResult:
method __repr__ (line 47) | def __repr__(self):
function time_limit (line 65) | def time_limit(seconds: float):
function capture_io (line 78) | def capture_io():
function create_tempdir (line 90) | def create_tempdir():
class TimeoutException (line 96) | class TimeoutException(Exception):
class WriteOnlyStringIO (line 100) | class WriteOnlyStringIO(io.StringIO):
method read (line 103) | def read(self, *args, **kwargs):
method readline (line 106) | def readline(self, *args, **kwargs):
method readlines (line 109) | def readlines(self, *args, **kwargs):
method readable (line 112) | def readable(self, *args, **kwargs):
class redirect_stdin (line 117) | class redirect_stdin(contextlib._RedirectStream): # type: ignore
function chdir (line 122) | def chdir(root):
function reliability_guard (line 134) | def reliability_guard(maximum_memory_bytes: Optional[int] = None):
function _unsafe_execute (line 214) | def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Opt...
function execute_code (line 286) | def execute_code(
FILE: nanochat/flash_attention.py
function _load_flash_attention_3 (line 23) | def _load_flash_attention_3():
function _resolve_use_fa3 (line 48) | def _resolve_use_fa3():
function _sdpa_attention (line 69) | def _sdpa_attention(q, k, v, window_size, enable_gqa):
function flash_attn_func (line 107) | def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
function flash_attn_with_kvcache (line 131) | def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_s...
FILE: nanochat/fp8.py
function _to_fp8 (line 82) | def _to_fp8(x, fp8_dtype):
function _to_col_major (line 110) | def _to_col_major(x):
class _Float8Matmul (line 125) | class _Float8Matmul(torch.autograd.Function):
method forward (line 133) | def forward(ctx, input_2d, weight):
method backward (line 157) | def backward(ctx, grad_output):
class Float8Linear (line 195) | class Float8Linear(nn.Linear):
method forward (line 202) | def forward(self, input):
method from_float (line 216) | def from_float(cls, mod):
class Float8LinearConfig (line 230) | class Float8LinearConfig:
method from_recipe_name (line 234) | def from_recipe_name(recipe_name):
function convert_to_float8_training (line 243) | def convert_to_float8_training(module, *, config=None, module_filter_fn=...
FILE: nanochat/gpt.py
class GPTConfig (line 29) | class GPTConfig:
function norm (line 42) | def norm(x):
class Linear (line 45) | class Linear(nn.Linear):
method forward (line 49) | def forward(self, x):
function has_ve (line 53) | def has_ve(layer_idx, n_layer):
function apply_rotary_emb (line 57) | def apply_rotary_emb(x, cos, sin):
class CausalSelfAttention (line 65) | class CausalSelfAttention(nn.Module):
method __init__ (line 66) | def __init__(self, config, layer_idx):
method forward (line 82) | def forward(self, x, ve, cos_sin, window_size, kv_cache):
class MLP (line 129) | class MLP(nn.Module):
method __init__ (line 130) | def __init__(self, config):
method forward (line 135) | def forward(self, x):
class Block (line 142) | class Block(nn.Module):
method __init__ (line 143) | def __init__(self, config, layer_idx):
method forward (line 148) | def forward(self, x, ve, cos_sin, window_size, kv_cache):
class GPT (line 154) | class GPT(nn.Module):
method __init__ (line 155) | def __init__(self, config, pad_vocab_size_to=64):
method init_weights (line 202) | def init_weights(self):
method _precompute_rotary_embeddings (line 263) | def _precompute_rotary_embeddings(self, seq_len, head_dim, base=100000...
method _compute_window_sizes (line 280) | def _compute_window_sizes(self, config):
method get_device (line 309) | def get_device(self):
method estimate_flops (line 312) | def estimate_flops(self):
method num_scaling_params (line 340) | def num_scaling_params(self):
method setup_optimizer (line 369) | def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matr...
method forward (line 411) | def forward(self, idx, targets=None, kv_cache=None, loss_reduction='me...
method generate (line 479) | def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, se...
FILE: nanochat/loss_eval.py
function evaluate_bpb (line 9) | def evaluate_bpb(model, batches, steps, token_bytes):
FILE: nanochat/optim.py
function adamw_step_fused (line 21) | def adamw_step_fused(
function muon_step_fused (line 91) | def muon_step_fused(
class MuonAdamW (line 152) | class MuonAdamW(torch.optim.Optimizer):
method __init__ (line 178) | def __init__(self, param_groups: list[dict]):
method _step_adamw (line 194) | def _step_adamw(self, group: dict) -> None:
method _step_muon (line 229) | def _step_muon(self, group: dict) -> None:
method step (line 284) | def step(self):
class DistMuonAdamW (line 297) | class DistMuonAdamW(torch.optim.Optimizer):
method __init__ (line 355) | def __init__(self, param_groups: list[dict]):
method _reduce_adamw (line 369) | def _reduce_adamw(self, group: dict, world_size: int) -> dict:
method _reduce_muon (line 387) | def _reduce_muon(self, group: dict, world_size: int) -> dict:
method _compute_adamw (line 408) | def _compute_adamw(self, group: dict, info: dict, gather_list: list, r...
method _compute_muon (line 449) | def _compute_muon(self, group: dict, info: dict, gather_list: list, ra...
method _finish_gathers (line 499) | def _finish_gathers(self, gather_list: list) -> None:
method step (line 508) | def step(self):
FILE: nanochat/report.py
function run_command (line 15) | def run_command(cmd):
function get_git_info (line 28) | def get_git_info():
function get_gpu_info (line 44) | def get_gpu_info():
function get_system_info (line 67) | def get_system_info():
function estimate_cost (line 89) | def estimate_cost(gpu_info, runtime_hours=None):
function generate_header (line 120) | def generate_header():
function slugify (line 203) | def slugify(text):
function extract (line 222) | def extract(section, keys):
function extract_timestamp (line 233) | def extract_timestamp(content, prefix):
class Report (line 244) | class Report:
method __init__ (line 247) | def __init__(self, report_dir):
method log (line 251) | def log(self, section, data):
method generate (line 279) | def generate(self):
method reset (line 371) | def reset(self):
class DummyReport (line 394) | class DummyReport:
method log (line 395) | def log(self, *args, **kwargs):
method reset (line 397) | def reset(self, *args, **kwargs):
function get_report (line 400) | def get_report():
FILE: nanochat/tokenizer.py
class HuggingFaceTokenizer (line 39) | class HuggingFaceTokenizer:
method __init__ (line 42) | def __init__(self, tokenizer):
method from_pretrained (line 46) | def from_pretrained(cls, hf_path):
method from_directory (line 52) | def from_directory(cls, tokenizer_dir):
method train_from_iterator (line 59) | def train_from_iterator(cls, text_iterator, vocab_size):
method get_vocab_size (line 95) | def get_vocab_size(self):
method get_special_tokens (line 98) | def get_special_tokens(self):
method id_to_token (line 103) | def id_to_token(self, id):
method _encode_one (line 106) | def _encode_one(self, text, prepend=None, append=None, num_threads=None):
method encode_special (line 121) | def encode_special(self, text):
method get_bos_token_id (line 125) | def get_bos_token_id(self):
method encode (line 136) | def encode(self, text, *args, **kwargs):
method __call__ (line 144) | def __call__(self, *args, **kwargs):
method decode (line 147) | def decode(self, ids):
method save (line 150) | def save(self, tokenizer_dir):
class RustBPETokenizer (line 163) | class RustBPETokenizer:
method __init__ (line 166) | def __init__(self, enc, bos_token):
method train_from_iterator (line 171) | def train_from_iterator(cls, text_iterator, vocab_size):
method from_directory (line 193) | def from_directory(cls, tokenizer_dir):
method from_pretrained (line 200) | def from_pretrained(cls, tiktoken_name):
method get_vocab_size (line 209) | def get_vocab_size(self):
method get_special_tokens (line 212) | def get_special_tokens(self):
method id_to_token (line 215) | def id_to_token(self, id):
method encode_special (line 219) | def encode_special(self, text):
method get_bos_token_id (line 222) | def get_bos_token_id(self):
method encode (line 225) | def encode(self, text, prepend=None, append=None, num_threads=8):
method __call__ (line 252) | def __call__(self, *args, **kwargs):
method decode (line 255) | def decode(self, ids):
method save (line 258) | def save(self, tokenizer_dir):
method render_conversation (line 266) | def render_conversation(self, conversation, max_tokens=2048):
method visualize_tokenization (line 352) | def visualize_tokenization(self, ids, mask, with_token_id=False):
method render_for_completion (line 367) | def render_for_completion(self, conversation):
function get_tokenizer (line 390) | def get_tokenizer():
function get_token_bytes (line 397) | def get_token_bytes(device="cpu"):
FILE: scripts/base_eval.py
class ModelWrapper (line 45) | class ModelWrapper:
method __init__ (line 47) | def __init__(self, model, max_seq_len=None):
method __call__ (line 51) | def __call__(self, input_ids, targets=None, loss_reduction='mean'):
method get_device (line 63) | def get_device(self):
function load_hf_model (line 67) | def load_hf_model(hf_path: str, device):
function get_hf_token_bytes (line 80) | def get_hf_token_bytes(tokenizer, device="cpu"):
function place_eval_bundle (line 95) | def place_eval_bundle(file_path):
function evaluate_core (line 107) | def evaluate_core(model, tokenizer, device, max_per_task=-1):
function main (line 178) | def main():
FILE: scripts/base_train.py
function build_model_meta (line 129) | def build_model_meta(depth):
function fp8_module_filter (line 178) | def fp8_module_filter(mod: nn.Module, fqn: str) -> bool:
function disable_fp8 (line 196) | def disable_fp8(model):
function get_scaling_params (line 262) | def get_scaling_params(m):
function get_lr_multiplier (line 359) | def get_lr_multiplier(it):
function get_muon_momentum (line 371) | def get_muon_momentum(it):
function get_weight_decay (line 384) | def get_weight_decay(it):
FILE: scripts/chat_eval.py
function run_generative_eval (line 29) | def run_generative_eval(task_object, tokenizer, model, engine, num_sampl...
function run_categorical_eval (line 88) | def run_categorical_eval(task_object, tokenizer, model, batch_size, max_...
function run_chat_eval (line 157) | def run_chat_eval(task_name, model, tokenizer, engine,
FILE: scripts/chat_rl.py
function get_batch (line 86) | def get_batch():
function run_gsm8k_eval (line 150) | def run_gsm8k_eval(task, tokenizer, engine,
function get_lr_multiplier (line 210) | def get_lr_multiplier(it):
FILE: scripts/chat_sft.py
function sft_data_generator_bos_bestfit (line 187) | def sft_data_generator_bos_bestfit(split, buffer_size=100):
function get_lr_multiplier (line 314) | def get_lr_multiplier(progress):
function get_muon_momentum (line 324) | def get_muon_momentum(it):
function centered_mean (line 384) | def centered_mean(tasks):
FILE: scripts/chat_web.py
class Worker (line 87) | class Worker:
class WorkerPool (line 94) | class WorkerPool:
method __init__ (line 97) | def __init__(self, num_gpus: Optional[int] = None):
method initialize (line 107) | async def initialize(self, source: str, model_tag: Optional[str] = Non...
method acquire_worker (line 135) | async def acquire_worker(self) -> Worker:
method release_worker (line 139) | async def release_worker(self, worker: Worker):
class ChatMessage (line 143) | class ChatMessage(BaseModel):
class ChatRequest (line 147) | class ChatRequest(BaseModel):
function validate_chat_request (line 153) | def validate_chat_request(request: ChatRequest):
function lifespan (line 217) | async def lifespan(app: FastAPI):
function root (line 236) | async def root():
function logo (line 250) | async def logo():
function generate_stream (line 255) | async def generate_stream(
function chat_completions (line 306) | async def chat_completions(request: ChatRequest):
function health (line 377) | async def health():
function stats (line 388) | async def stats():
FILE: scripts/tok_eval.py
function print_comparison (line 203) | def print_comparison(baseline_name, baseline_results, ours_results, all_...
FILE: scripts/tok_train.py
function text_iterator (line 28) | def text_iterator():
FILE: tasks/arc.py
class ARC (line 9) | class ARC(Task):
method __init__ (line 11) | def __init__(self, subset, split, **kwargs):
method eval_type (line 18) | def eval_type(self):
method num_examples (line 21) | def num_examples(self):
method get_example (line 24) | def get_example(self, index):
method evaluate (line 43) | def evaluate(self, conversation, assistant_response):
FILE: tasks/common.py
class Task (line 10) | class Task:
method __init__ (line 15) | def __init__(self, start=0, stop=None, step=1):
method eval_type (line 25) | def eval_type(self):
method num_examples (line 29) | def num_examples(self):
method get_example (line 32) | def get_example(self, index):
method __len__ (line 35) | def __len__(self):
method __getitem__ (line 44) | def __getitem__(self, index: int):
method evaluate (line 50) | def evaluate(self, problem, completion):
class TaskMixture (line 54) | class TaskMixture(Task):
method __init__ (line 60) | def __init__(self, tasks, **kwargs):
method num_examples (line 76) | def num_examples(self):
method get_example (line 79) | def get_example(self, index):
class TaskSequence (line 89) | class TaskSequence(Task):
method __init__ (line 95) | def __init__(self, tasks, **kwargs):
method num_examples (line 101) | def num_examples(self):
method get_example (line 104) | def get_example(self, index):
function render_mc (line 112) | def render_mc(question, letters, choices):
FILE: tasks/customjson.py
class CustomJSON (line 10) | class CustomJSON(Task):
method __init__ (line 17) | def __init__(self, filepath, **kwargs):
method num_examples (line 56) | def num_examples(self):
method get_example (line 59) | def get_example(self, index):
FILE: tasks/gsm8k.py
function extract_answer (line 23) | def extract_answer(completion):
class GSM8K (line 37) | class GSM8K(Task):
method __init__ (line 39) | def __init__(self, subset, split, **kwargs):
method eval_type (line 46) | def eval_type(self):
method num_examples (line 49) | def num_examples(self):
method get_example (line 52) | def get_example(self, index):
method evaluate (line 87) | def evaluate(self, conversation, assistant_response):
method reward (line 110) | def reward(self, conversation, assistant_response):
FILE: tasks/humaneval.py
function extract_imports (line 12) | def extract_imports(prompt):
function extract_program (line 24) | def extract_program(completion):
class HumanEval (line 47) | class HumanEval(Task):
method __init__ (line 49) | def __init__(self, **kwargs):
method eval_type (line 54) | def eval_type(self):
method num_examples (line 57) | def num_examples(self):
method get_example (line 60) | def get_example(self, index):
method evaluate (line 79) | def evaluate(self, conversation, completion):
FILE: tasks/mmlu.py
class MMLU (line 9) | class MMLU(Task):
method __init__ (line 14) | def __init__(self, subset, split, **kwargs):
method eval_type (line 28) | def eval_type(self):
method num_examples (line 31) | def num_examples(self):
method get_example (line 34) | def get_example(self, index):
method evaluate (line 55) | def evaluate(self, conversation, assistant_response):
FILE: tasks/smoltalk.py
class SmolTalk (line 10) | class SmolTalk(Task):
method __init__ (line 13) | def __init__(self, split, **kwargs):
method num_examples (line 19) | def num_examples(self):
method get_example (line 22) | def get_example(self, index):
FILE: tasks/spellingbee.py
function extract_answer (line 43) | def extract_answer(completion):
class SpellingBee (line 115) | class SpellingBee(Task):
method __init__ (line 117) | def __init__(self, size=1000, split="train", **kwargs):
method eval_type (line 129) | def eval_type(self):
method num_examples (line 132) | def num_examples(self):
method get_example (line 135) | def get_example(self, index):
method evaluate (line 207) | def evaluate(self, conversation, assistant_response):
method reward (line 226) | def reward(self, conversation, assistant_response):
class SimpleSpelling (line 233) | class SimpleSpelling(Task):
method __init__ (line 236) | def __init__(self, size=1000, split="train", **kwargs):
method eval_type (line 250) | def eval_type(self):
method num_examples (line 253) | def num_examples(self):
method get_example (line 256) | def get_example(self, index):
FILE: tests/test_attention_fallback.py
function set_impl (line 23) | def set_impl(impl):
function run_both_impls (line 29) | def run_both_impls(fn):
function assert_close (line 39) | def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
class TestFA3VsSDPA (line 52) | class TestFA3VsSDPA:
method test_basic_causal (line 58) | def test_basic_causal(self):
method test_full_context (line 72) | def test_full_context(self):
method test_sliding_window (line 86) | def test_sliding_window(self):
method test_gqa (line 101) | def test_gqa(self):
method test_larger_model (line 118) | def test_larger_model(self):
method test_kvcache_prefill (line 132) | def test_kvcache_prefill(self):
method test_kvcache_single_token (line 155) | def test_kvcache_single_token(self):
method test_kvcache_single_token_sliding_window (line 182) | def test_kvcache_single_token_sliding_window(self):
method test_backward_gradients_match (line 215) | def test_backward_gradients_match(self):
class TestSDPAOnly (line 254) | class TestSDPAOnly:
method test_basic_forward (line 260) | def test_basic_forward(self):
method test_backward (line 274) | def test_backward(self):
method test_kvcache (line 292) | def test_kvcache(self):
class TestOverrideMechanism (line 340) | class TestOverrideMechanism:
method test_override_fa3 (line 344) | def test_override_fa3(self):
method test_override_sdpa (line 350) | def test_override_sdpa(self):
method test_override_auto (line 356) | def test_override_auto(self):
FILE: tests/test_engine.py
class MockConfig (line 16) | class MockConfig:
class MockModel (line 25) | class MockModel:
method __init__ (line 31) | def __init__(self, vocab_size=262): # 256 bytes + 6 special tokens
method get_device (line 36) | def get_device(self):
method forward (line 39) | def forward(self, ids, kv_cache=None):
class ByteTokenizer (line 50) | class ByteTokenizer:
method __init__ (line 55) | def __init__(self):
method encode_special (line 67) | def encode_special(self, s):
method get_bos_token_id (line 70) | def get_bos_token_id(self):
method encode (line 73) | def encode(self, s, prepend=None):
method decode (line 79) | def decode(self, tokens):
function test_kv_cache_basic (line 84) | def test_kv_cache_basic():
function test_kv_cache_prefill (line 124) | def test_kv_cache_prefill():
function test_multi_sample_first_token_diversity (line 158) | def test_multi_sample_first_token_diversity():
function test_seed_reproducibility (line 201) | def test_seed_reproducibility():
function test_temperature_zero_determinism (line 214) | def test_temperature_zero_determinism():
function test_max_tokens_respected (line 226) | def test_max_tokens_respected():
function test_num_samples_count (line 238) | def test_num_samples_count():
function test_different_seeds_introduce_variation_when_temperature_nonzero (line 249) | def test_different_seeds_introduce_variation_when_temperature_nonzero():
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (651K chars).
[
{
"path": ".claude/skills/read-arxiv-paper/SKILL.md",
"chars": 1982,
"preview": "---\nname: read-arxiv-paper\ndescription: Use this skill when when asked to read an arxiv paper given an arxiv URL\n---\n\nYo"
},
{
"path": ".gitignore",
"chars": 109,
"preview": ".venv/\n__pycache__/\n*.pyc\ndev-ignore/\nreport.md\neval_bundle/\n\n# Secrets\n.env\n\n# Local setup\nCLAUDE.md\nwandb/\n"
},
{
"path": ".python-version",
"chars": 5,
"preview": "3.10\n"
},
{
"path": "LICENSE",
"chars": 1072,
"preview": "MIT License\n\nCopyright (c) 2025 Andrej Karpathy\n\nPermission is hereby granted, free of charge, to any person obtaining a"
},
{
"path": "README.md",
"chars": 15742,
"preview": "# nanochat\n\n\n\n\nnanochat is the simplest exp"
},
{
"path": "dev/LEADERBOARD.md",
"chars": 12789,
"preview": "# Leaderboard\n\nDocs on participating in the \"Time-to-GPT-2\" leaderboard of nanochat.\n\nThe primary metric we care about i"
},
{
"path": "dev/LOG.md",
"chars": 60383,
"preview": "# Experiment Log\n\nA running summary documenting some experiments and findings. Started ~Jan 7 2026.\n\n---\n\n## 2026-03-04:"
},
{
"path": "dev/estimate_gpt3_core.ipynb",
"chars": 73033,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Estimating CORE Metric for GPT-3 "
},
{
"path": "dev/gen_synthetic_data.py",
"chars": 18663,
"preview": "\"\"\"\nSynthetic data generation for teaching nanochat about its identity and capabilities.\n\nThis script uses the OpenRoute"
},
{
"path": "dev/generate_logo.html",
"chars": 1127,
"preview": "<!DOCTYPE html>\n<html>\n<body style=\"margin:0; display:flex; justify-content:center; align-items:center; height:100vh; ba"
},
{
"path": "dev/repackage_data_reference.py",
"chars": 4421,
"preview": "\"\"\"\nRepackage a given dataset into simple parquet shards:\n\n- each shard is ~100MB in size (after zstd compression)\n- par"
},
{
"path": "dev/scaling_analysis.ipynb",
"chars": 15167,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"markdown\",\n \"metadata\": {},\n \"source\": [\n \"# Scaling Laws Analysis\\n\",\n \"\\n"
},
{
"path": "nanochat/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "nanochat/checkpoint_manager.py",
"chars": 8729,
"preview": "\"\"\"\nUtilities for saving and loading model/optim/state checkpoints.\n\"\"\"\nimport os\nimport re\nimport glob\nimport json\nimpo"
},
{
"path": "nanochat/common.py",
"chars": 10995,
"preview": "\"\"\"\nCommon utilities for nanochat.\n\"\"\"\n\nimport os\nimport re\nimport logging\nimport urllib.request\nimport torch\nimport tor"
},
{
"path": "nanochat/core_eval.py",
"chars": 11572,
"preview": "\"\"\"\nFunctions for evaluating the CORE metric, as described in the DCLM paper.\nhttps://arxiv.org/abs/2406.11794\n\nTODOs:\n-"
},
{
"path": "nanochat/dataloader.py",
"chars": 7450,
"preview": "\"\"\"\nDistributed dataloaders for pretraining.\n\nBOS-aligned bestfit:\n - Every row starts with BOS token\n - Documents p"
},
{
"path": "nanochat/dataset.py",
"chars": 7014,
"preview": "\"\"\"\nThe base/pretraining dataset is a set of parquet files.\nThis file contains utilities for:\n- iterating over the parqu"
},
{
"path": "nanochat/engine.py",
"chars": 16020,
"preview": "\"\"\"\nEngine for efficient inference of our models.\n\nEverything works around token sequences:\n- The user can send token se"
},
{
"path": "nanochat/execution.py",
"chars": 10307,
"preview": "\"\"\"\nSandboxed execution utilities for running Python code that comes out of an LLM.\nAdapted from OpenAI HumanEval code:\n"
},
{
"path": "nanochat/flash_attention.py",
"chars": 6827,
"preview": "\"\"\"\nUnified Flash Attention interface with automatic FA3/SDPA switching.\n\nExports `flash_attn` module that matches the F"
},
{
"path": "nanochat/fp8.py",
"chars": 11992,
"preview": "\"\"\"Minimal FP8 training for nanochat — tensorwise dynamic scaling only.\n\nDrop-in replacement for torchao's Float8Linear "
},
{
"path": "nanochat/gpt.py",
"chars": 26604,
"preview": "\"\"\"\nGPT model (rewrite, a lot simpler)\nNotable features:\n- rotary embeddings (and no positional embeddings)\n- QK norm\n- "
},
{
"path": "nanochat/loss_eval.py",
"chars": 3124,
"preview": "\"\"\"\nA number of functions that help with evaluating a base model.\n\"\"\"\nimport math\nimport torch\nimport torch.distributed "
},
{
"path": "nanochat/optim.py",
"chars": 25625,
"preview": "\"\"\"\nA nice and efficient mixed AdamW/Muon Combined Optimizer.\nUsually the embeddings and scalars go into AdamW, and the "
},
{
"path": "nanochat/report.py",
"chars": 15610,
"preview": "\"\"\"\nUtilities for generating training report cards. More messy code than usual, will fix.\n\"\"\"\n\nimport os\nimport re\nimpor"
},
{
"path": "nanochat/tokenizer.py",
"chars": 18360,
"preview": "\"\"\"\nBPE Tokenizer in the style of GPT-4.\n\nTwo implementations are available:\n1) HuggingFace Tokenizer that can do both t"
},
{
"path": "nanochat/ui.html",
"chars": 19110,
"preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width"
},
{
"path": "pyproject.toml",
"chars": 1438,
"preview": "[project]\nname = \"nanochat\"\nversion = \"0.1.0\"\ndescription = \"the minimal full-stack ChatGPT clone\"\nreadme = \"README.md\"\n"
},
{
"path": "runs/miniseries.sh",
"chars": 4000,
"preview": "#!/bin/bash\n\n# See speedrun.sh for more comments\n# Usage: ./miniseries.sh [series_name]\n# Example: ./miniseries.sh jan11"
},
{
"path": "runs/runcpu.sh",
"chars": 2364,
"preview": "#!/bin/bash\n\n# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)\n# This scrip"
},
{
"path": "runs/scaling_laws.sh",
"chars": 4868,
"preview": "#!/bin/bash\n\nLABEL=\"jan26\"\n\nFLOPS_BUDGETS=(\n 1e18\n 2.15e18\n 4.64e18\n 1e19\n)\nDEPTHS=(8 10 12 14 16 18 20)\n\nNP"
},
{
"path": "runs/speedrun.sh",
"chars": 4850,
"preview": "#!/bin/bash\n\n# This script is configured to train your own GPT-2 grade LLM (pretraining + finetuning)\n# It is designed t"
},
{
"path": "scripts/base_eval.py",
"chars": 13559,
"preview": "\"\"\"\nUnified evaluation script for base models.\n\nSupports three evaluation modes (comma-separated):\n --eval core : CO"
},
{
"path": "scripts/base_train.py",
"chars": 33221,
"preview": "\"\"\"\nTrain model. From root directory of the project, run as:\n\npython -m scripts.base_train\n\nor distributed as:\n\ntorchrun"
},
{
"path": "scripts/chat_cli.py",
"chars": 3918,
"preview": "\"\"\"\nNew and upgraded chat mode because a lot of the code has changed since the last one.\n\nIntended to be run single GPU "
},
{
"path": "scripts/chat_eval.py",
"chars": 11974,
"preview": "\"\"\"\nEvaluate the Chat model.\nAll the generic code lives here, and all the evaluation-specific\ncode lives in nanochat dir"
},
{
"path": "scripts/chat_rl.py",
"chars": 17007,
"preview": "\"\"\"\nReinforcement learning on GSM8K via \"GRPO\".\n\nI put GRPO in quotes because we actually end up with something a lot\nsi"
},
{
"path": "scripts/chat_sft.py",
"chars": 25604,
"preview": "\"\"\"\nSupervised fine-tuning (SFT) the model.\nRun as:\n\npython -m scripts.chat_sft\n\nOr torchrun for training:\n\ntorchrun --s"
},
{
"path": "scripts/chat_web.py",
"chars": 15619,
"preview": "#!/usr/bin/env python3\n\"\"\"\nUnified web chat server - serves both UI and API from a single FastAPI instance.\n\nUses data p"
},
{
"path": "scripts/tok_eval.py",
"chars": 11737,
"preview": "\"\"\"\nEvaluate compression ratio of the tokenizer.\n\"\"\"\n\nfrom nanochat.tokenizer import get_tokenizer, RustBPETokenizer\nfro"
},
{
"path": "scripts/tok_train.py",
"chars": 4304,
"preview": "\"\"\"\nTrain a tokenizer using our own BPE Tokenizer library.\nIn the style of GPT-4 tokenizer.\n\"\"\"\nimport os\nimport time\nim"
},
{
"path": "tasks/arc.py",
"chars": 2171,
"preview": "\"\"\"\nThe ARC dataset from Allen AI.\nhttps://huggingface.co/datasets/allenai/ai2_arc\n\"\"\"\n\nfrom datasets import load_datase"
},
{
"path": "tasks/common.py",
"chars": 5531,
"preview": "\"\"\"\nBase class for all Tasks.\nA Task is basically a dataset of conversations, together with some\nmetadata and often also"
},
{
"path": "tasks/customjson.py",
"chars": 2939,
"preview": "\"\"\"\nCustomJSON task for loading conversations from JSONL files.\nEach line in the JSONL file should be a JSON array of me"
},
{
"path": "tasks/gsm8k.py",
"chars": 4873,
"preview": "\"\"\"\nGSM8K evaluation.\nhttps://huggingface.co/datasets/openai/gsm8k\n\nExample problem instance:\n\nQuestion:\nWeng earns $12 "
},
{
"path": "tasks/humaneval.py",
"chars": 3419,
"preview": "\"\"\"\nEvaluate the Chat model on HumanEval dataset.\nBtw this dataset is a misnomer and has nothing to do with humans.\nIt i"
},
{
"path": "tasks/mmlu.py",
"chars": 3934,
"preview": "\"\"\"\nThe MMLU dataset.\nhttps://huggingface.co/datasets/cais/mmlu\n\"\"\"\n\nfrom datasets import load_dataset\nfrom tasks.common"
},
{
"path": "tasks/smoltalk.py",
"chars": 2077,
"preview": "\"\"\"\nSmolTalk by HuggingFace. Good \"general\" conversational dataset.\nhttps://huggingface.co/datasets/HuggingFaceTB/smol-s"
},
{
"path": "tasks/spellingbee.py",
"chars": 12416,
"preview": "\"\"\"\nTask intended to make nanochat better in spelling and counting, for example:\n\n\"How many r are in strawberry?\" -> 3\n\n"
},
{
"path": "tests/test_attention_fallback.py",
"chars": 16125,
"preview": "\"\"\"\nTest Flash Attention unified interface - verify FA3 and SDPA produce identical results.\n\nRun: python -m pytest tests"
},
{
"path": "tests/test_engine.py",
"chars": 9239,
"preview": "\"\"\"\nTest Engine class. Example run:\n\npython -m pytest tests/test_engine.py -v\n\"\"\"\n\nimport torch\nfrom nanochat.engine imp"
}
]
About this extraction
This page contains the full source code of the karpathy/nanochat GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (606.5 KB), approximately 162.9k tokens, and a symbol index with 343 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.