[
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\ncontact_links: []\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/eval_request.md",
    "content": "---\nname: Evaluation request\nabout: Propose a new benchmark or diagnostic to add\ntitle: \"[Eval] \"\nlabels: [\"evaluation\", \"needs-triage\"]\nassignees: []\n---\n\n## Motivation\nWhy is this evaluation important for HOPE/TITAN reproduction?\n\n## Task details\n- Dataset / benchmark:\n- Metric(s):\n- Expected runtime / hardware:\n\n## Environment target\n- OS:\n- Python:\n- Torch:\n- Preferred backend (`cpu` / `cuda` / `mps` / `rocm`):\n\n## Implementation sketch\nOutline scripts/flags needed (e.g., extend `scripts/eval/zeroshot.py`).\n\n## Acceptance criteria\nDescribe what needs to be captured (JSON fields, plots, etc.).\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/faithfulness_gap.md",
    "content": "---\nname: Faithfulness gap\nabout: Report deviations vs. the Nested Learning / HOPE specs\ntitle: \"[Faithfulness] \"\nlabels: [\"faithfulness\", \"needs-triage\"]\nassignees: []\n---\n\n## Summary\nDescribe the suspected deviation (cite paper section/equation).\n\n## Evidence\n- Config(s) / checkpoints affected\n- Logs / screenshots / metrics\n- Steps to reproduce\n\n## Environment\n- OS:\n- Python:\n- Torch:\n- Backend (`cpu` / `cuda` / `mps` / `rocm`):\n- GPU/accelerator model (if any):\n\nIf using ROCm: this project currently treats ROCm support as best-effort. Include HIP/ROCm version and exact torch build.\n\n## Expected behavior\nWhat should happen according to the paper?\n\n## Additional context\nAdd any extra notes, e.g., suggested fix or related PRs.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/perf_regression.md",
    "content": "---\nname: Performance regression\nabout: Report a training / eval performance drop vs. baseline\ntitle: \"[Perf] \"\nlabels: [\"performance\", \"needs-triage\"]\nassignees: []\n---\n\n## Summary\nDescribe the regression and the baseline you’re comparing against.\n\n## Baseline\n- Config / checkpoint:\n- Metrics (loss / ppl / eval scores):\n\n## Repro steps\nExact commands with overrides, plus hardware details.\n\n## Environment\n- OS:\n- Python:\n- Torch:\n- Backend (`cpu` / `cuda` / `mps` / `rocm`):\n- GPU/accelerator model (if any):\n\nIf using ROCm: this project currently treats ROCm support as best-effort. Include HIP/ROCm version and exact torch build.\n\n## Logs / artifacts\nAttach relevant logs, W&B links, or JSON eval files.\n\n## Suspected cause\nOptional theory or related commits/PRs.\n"
  },
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: CI\n\non:\n  push:\n    branches: [\"main\"]\n  pull_request:\n    branches: [\"main\"]\n\njobs:\n  lint-and-test:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Sync dependencies\n        run: uv sync --all-extras --dev\n\n      - name: Ruff\n        run: uv run ruff check .\n\n      - name: Mypy\n        run: uv run mypy src\n\n      - name: Verify docs path references\n        run: uv run python scripts/checks/verify_docs_refs.py\n\n      - name: Verify README critical commands\n        run: bash scripts/checks/check_readme_commands.sh\n\n      - name: Guard tracked file sizes / artifact extensions\n        run: bash scripts/checks/check_git_tracked_sizes.sh\n\n      - name: Verify scripts/data help exits cleanly\n        run: bash scripts/checks/check_data_script_help.sh\n\n      - name: Pytest\n        run: uv run pytest\n\n  cross-platform-smoke:\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [ubuntu-latest, macos-latest, windows-latest]\n        python-version: [\"3.10\", \"3.12\"]\n    runs-on: ${{ matrix.os }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Sync dependencies\n        run: uv sync --dev\n\n      - name: CLI help + doctor + smoke\n        run: |\n          uv run nl --help\n          uv run nl doctor --json\n          uv run nl smoke --config-name pilot_smoke --device cpu --batch-size 1 --seq-len 8\n          uv run python -m nested_learning --help\n\n  wheel-install-smoke:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Build wheel\n        run: uv build\n\n      - name: Install wheel in isolated venv\n        run: |\n          python -m venv /tmp/wheel-smoke\n          /tmp/wheel-smoke/bin/python -m pip install --upgrade pip\n          /tmp/wheel-smoke/bin/python -m pip install dist/*.whl\n\n      - name: Verify wheel entrypoints outside repo configs\n        run: |\n          /tmp/wheel-smoke/bin/python -m nested_learning --help\n          /tmp/wheel-smoke/bin/python -m nested_learning doctor --json\n          /tmp/wheel-smoke/bin/python - <<'PY'\n          import subprocess\n          import sys\n          import tempfile\n\n          tmp = tempfile.mkdtemp(prefix=\"nl-wheel-smoke-\")\n          cmd = [\n              sys.executable,\n              \"-m\",\n              \"nested_learning\",\n              \"smoke\",\n              \"--config-name\",\n              \"pilot_smoke\",\n              \"--device\",\n              \"cpu\",\n              \"--batch-size\",\n              \"1\",\n              \"--seq-len\",\n              \"8\",\n          ]\n          subprocess.run(cmd, cwd=tmp, check=True)\n          PY\n\n  cpu-ddp-smoke:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Sync dependencies\n        run: uv sync --all-extras --dev\n\n      - name: CPU DDP smoke (gloo backend)\n        run: bash scripts/run_cpu_ddp_smoke.sh\n\n  passkey-smoke:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Sync dependencies\n        run: uv sync --all-extras --dev\n\n      - name: Run synthetic passkey memorization test\n        run: bash scripts/tests/run_passkey_smoke.sh\n\n  fidelity-subset:\n    runs-on: ubuntu-latest\n    timeout-minutes: 20\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Sync dependencies\n        run: uv sync --all-extras --dev\n\n      - name: Run fidelity subset + compliance report\n        run: bash scripts/checks/run_fidelity_ci_subset.sh\n"
  },
  {
    "path": ".github/workflows/packages.yml",
    "content": "name: Packages\n\non:\n  push:\n    tags:\n      - \"v*\"\n  workflow_dispatch:\n\npermissions:\n  contents: read\n  packages: write\n\nenv:\n  REGISTRY: ghcr.io\n  IMAGE_NAME: ${{ github.repository_owner }}/nested-learning-dist\n\njobs:\n  publish-ghcr:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Build source and wheel distributions\n        run: uv build\n\n      - name: Generate checksums\n        run: |\n          cd dist\n          sha256sum * > SHA256SUMS.txt\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to GHCR\n        uses: docker/login-action@v3\n        with:\n          registry: ${{ env.REGISTRY }}\n          username: ${{ github.actor }}\n          password: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Extract OCI metadata\n        id: meta\n        uses: docker/metadata-action@v5\n        with:\n          images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}\n          tags: |\n            type=raw,value=${{ github.ref_name }},enable=${{ github.event_name == 'push' }}\n            type=raw,value=latest,enable=${{ github.event_name == 'push' && !contains(github.ref_name, 'rc') }}\n            type=raw,value=edge,enable=${{ github.event_name == 'workflow_dispatch' }}\n            type=sha,format=short,enable=${{ github.event_name == 'workflow_dispatch' }}\n\n      - name: Build and publish GHCR package image\n        uses: docker/build-push-action@v6\n        with:\n          context: .\n          file: ./docker/Dockerfile.dist\n          push: true\n          tags: ${{ steps.meta.outputs.tags }}\n          labels: ${{ steps.meta.outputs.labels }}\n"
  },
  {
    "path": ".github/workflows/release.yml",
    "content": "name: Release\n\non:\n  push:\n    tags:\n      - \"v*\"\n\npermissions:\n  contents: write\n  id-token: write\n\njobs:\n  build:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Build source and wheel distributions\n        run: uv build\n\n      - name: Twine check\n        run: uvx twine check dist/*\n\n      - name: Upload dist artifacts\n        uses: actions/upload-artifact@v4\n        with:\n          name: dist\n          path: dist/*\n\n  publish-testpypi:\n    if: contains(github.ref_name, 'rc')\n    needs: build\n    runs-on: ubuntu-latest\n    environment:\n      name: testpypi\n      url: https://test.pypi.org/p/nested-learning\n    steps:\n      - name: Download dist artifacts\n        uses: actions/download-artifact@v4\n        with:\n          name: dist\n          path: dist\n\n      - name: Publish to TestPyPI via Trusted Publishing\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          repository-url: https://test.pypi.org/legacy/\n          packages-dir: dist/\n\n  publish-pypi:\n    if: ${{ !contains(github.ref_name, 'rc') }}\n    needs: build\n    runs-on: ubuntu-latest\n    environment:\n      name: pypi\n      url: https://pypi.org/p/nested-learning\n    steps:\n      - name: Download dist artifacts\n        uses: actions/download-artifact@v4\n        with:\n          name: dist\n          path: dist\n\n      - name: Publish to PyPI via Trusted Publishing\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          packages-dir: dist/\n\n  publish-github-release:\n    needs:\n      - build\n      - publish-testpypi\n      - publish-pypi\n    if: |\n      always() &&\n      needs.build.result == 'success' &&\n      (needs.publish-testpypi.result == 'success' || needs.publish-testpypi.result == 'skipped') &&\n      (needs.publish-pypi.result == 'success' || needs.publish-pypi.result == 'skipped')\n    runs-on: ubuntu-latest\n    steps:\n      - name: Download dist artifacts\n        uses: actions/download-artifact@v4\n        with:\n          name: dist\n          path: dist\n\n      - name: Generate checksums\n        run: |\n          cd dist\n          sha256sum * > SHA256SUMS.txt\n\n      - name: Build release preamble\n        run: |\n          cat > release_preamble.md <<EOF\n          Package release for \\`${GITHUB_REF_NAME}\\`.\n\n          Install:\n          \\`\\`\\`bash\n          pip install nested-learning==${GITHUB_REF_NAME#v}\n          \\`\\`\\`\n\n          Included assets:\n          - source distribution (`.tar.gz`)\n          - wheel (`.whl`)\n          - \\`SHA256SUMS.txt\\` checksums\n\n          For compatibility/support details:\n          - https://github.com/${GITHUB_REPOSITORY}/blob/main/docs/COMPATIBILITY_MATRIX.md\n          - https://github.com/${GITHUB_REPOSITORY}/blob/main/docs/VERSIONING_POLICY.md\n          EOF\n\n      - name: Publish GitHub Release\n        uses: softprops/action-gh-release@v2\n        with:\n          prerelease: ${{ contains(github.ref_name, 'rc') }}\n          generate_release_notes: true\n          body_path: release_preamble.md\n          files: |\n            dist/*\n"
  },
  {
    "path": ".github/workflows/security.yml",
    "content": "name: Security\n\non:\n  push:\n    branches: [\"main\"]\n  pull_request:\n    branches: [\"main\"]\n  schedule:\n    - cron: \"0 6 * * 1\"\n\njobs:\n  dependency-audit:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Set up Python\n        uses: actions/setup-python@v5\n        with:\n          python-version: \"3.12\"\n\n      - name: Set up uv\n        uses: astral-sh/setup-uv@v3\n        with:\n          version: \"0.9.8\"\n\n      - name: Export requirements\n        run: uv export --all-extras --dev --format requirements-txt --output-file /tmp/requirements.txt\n\n      - name: pip-audit\n        run: uvx pip-audit -r /tmp/requirements.txt\n        continue-on-error: true\n\n  secret-scan:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Gitleaks scan\n        uses: gitleaks/gitleaks-action@v2\n        env:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\n"
  },
  {
    "path": ".gitignore",
    "content": "# Environment / tooling\n.venv/\n__pycache__/\n*.pyc\n.pytest_cache/\n.ruff_cache/\n.mypy_cache/\n\n# Local artifacts\nlogs/\nartifacts/\n/data/\noutputs/\ncheckpoints/\n*.pt\ntrain.log\ntrain_dist.log\nref_repos/\nconfigs/_tmp*\ngit.env\ndocs/POSTS.md\ndocs/EX_*.md\ndocs/CHECK_2_PLANNING_MODEL_REQUEST.md\ndocs/CHECK_2_PLANNING_MODEL_RESPONSE.md\ndocs/planner_check2_attachments.zip\ndocs/tmp/\ndocs_tmp/\nwandb/\neval/*_ci.json\n\n# Local paper scans / scratch references (keep tracked references separate)\ngoogle_papers/*_arXiv_v1.pdf\ngoogle_papers/*_arXiv_v1/\ngoogle_papers/Nested_Learning_Full_Paper.pdf\ngoogle_papers/Nested_Learning_Full_Paper/\n\n# Editors\n.DS_Store\n.idea/\n.vscode/\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\n\nAll notable changes to this project will be documented here. The format loosely follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) and uses semantic versioning once tagged releases begin.\n\n## [Unreleased]\n### Added\n- Optional attention KV-cache path for continuous streaming inference (`init_attention_cache`, `attention_cache`, `return_attention_cache`) across HOPE/TITAN/Transformer blocks.\n- Boundary-target online chunking mode (`train.online_boundary_targets`) and optional training-time attention-cache carry (`train.online_carry_attention_cache`) for stronger chunk-boundary semantics.\n- Evaluation streaming-state utilities (`src/nested_learning/eval_state.py`) plus continual-eval controls (`--eval-state-mode`, `--eval-use-fast-state`, `--eval-use-attention-cache`).\n- Compliance report automation (`scripts/checks/compliance_report.py`) with CI subset + mechanism smoke integration.\n- Flash/SDPA-backed self-attention path with safe fallbacks, unlocking PyTorch 2.9 SDPA kernels by default.\n- Hydra toggles for bf16 autocast (`train.mixed_precision.*`), `torch.compile` (`train.compile.*`), and fused optimizers.\n- Muon + AdamW hybrid optimizer option exposed via `optim.type=muon`, routing ≥2D matrices through `torch.optim.Muon`.\n- Test-time memorization flags (`--memorize*`) documented in README + `docs/guide.md`, matching TITAN eval behavior.\n- Automation helpers: `scripts/run_e2e_smoke.sh` documented in Quickstart, plus new `scripts/run_cpu_ddp_smoke.sh` for CPU-only DDP/gloo smoke coverage.\n- Streaming contract doc (`docs/STREAMING_CONTRACT.md`) defining sequence/segment/chunk semantics and fast-state lifecycle.\n- Cadence verification utility (`scripts/checks/verify_update_cadence.py`) with synthetic tests and release-checklist integration.\n- Fidelity CI subset runner (`scripts/checks/run_fidelity_ci_subset.sh`) and mechanism-auditing smoke runner (`scripts/run_mechanism_audit_smoke.sh`).\n- Progress/status docs for P7 execution (`docs/PLAN_PROGRESS_P7.md`, `docs/IMPLEMENTATION_STATUS.md`).\n- Bug-report reproducibility checklist (`docs/BUG_REPORT_CHECKLIST.md`).\n- Boundary-state training-loop regression coverage (`tests/test_boundary_state_training_loop.py`) plus eval-loader/metadata roundtrip coverage (`tests/test_checkpoint_metadata_and_eval_loaders.py`).\n- `scripts/checks/check_data_script_help.sh` to guarantee `scripts/data/* --help` exits cleanly; wired into CI.\n- Markdown anchor verification in `scripts/checks/verify_docs_refs.py` with dedicated unit coverage.\n- Tag release automation now creates GitHub Release entries with attached wheel/sdist artifacts plus `SHA256SUMS.txt`.\n- Added GHCR package publishing workflow (`.github/workflows/packages.yml`) so the Packages tab contains a versioned `nested-learning-dist` OCI bundle.\n\n### Changed\n- README / compliance / streaming docs now reflect boundary-target mode, optional KV-cache carry, and explicit scope boundaries.\n- CPU DDP smoke now includes strict-mode fail-fast verification.\n- Repository license metadata now matches the shipped Apache-2.0 text; badges updated accordingly.\n- README and guide refreshed with performance knobs, optimizer guidance, and memorization instructions so release consumers have a single source of truth.\n- Release checklist tracks the new CPU DDP smoke script to keep packaging instructions aligned with available tooling.\n- Training loop strict-mode guardrails: `train.strict_streaming_contract` now fail-fasts on known semantics violations (DDP feature downgrades, shared-batch fast-state, non paper-defined variant in strict mode).\n- CMS telemetry now includes cadence metrics (`updates_applied`, `tokens_flushed`, `pending_tokens`, `gate_hits`) to make update-frequency behavior auditable.\n- Paper-auditing preset now explicitly enables strict streaming contract checks.\n- `configs/pilot_paper_faithful.yaml` now explicitly sets `train.online_updates=true` and tests verify no implicit algorithm-mode fallback.\n- Boundary-state mode now emits an explicit startup warning code (`experimental_boundary_state_mode`) and validates cache/chunk constraints early.\n- Checkpoint metadata now records algorithm/online flags (`algorithm_mode`, `online_updates`, `online_boundary_targets`, `online_carry_attention_cache`, `use_fast_state`), and release manifest includes those flags.\n- Data split fallback policy is deterministic across data scripts (`train -> validation -> test -> first available`) with explicit available-splits logging.\n\n### Upcoming\n- GitHub Actions workflow covering `ruff`, `mypy`, and `pytest`.\n- End-to-end release dry-run ahead of the `v0.1.0` tag.\n\n## [0.1.0] - 2025-11-09\n### Added\n- PyTorch **2.9.0** / torchvision **0.24.0** environment managed via `uv` with reproducible `pyproject.toml` + `uv.lock`.\n- HOPE block implementation (attention → TITAN memory → CMS + deep optimizers) with configurable level clocks and self-modifier wiring.\n- Hydrated Hydra config tree for pilot, mid, target, and CPU-only smoke runs plus DDP/FSDP/DeepSpeed entrypoints.\n- Data tooling: tokenizer trainer, corpus filtering, mixture processing, and `scripts/data/run_sample.sh` shortcut emitting stats under `data/mixtures/`.\n- Evaluation suite: zero-shot benchmark CLI (PIQA/HellaSwag/WinoGrande/ARC/BoolQ/SIQA), Needle-in-a-Haystack generator, continual-learning forgetting analyzer.\n- Sample artifacts (`artifacts/examples/pilot_dummy.pt`, `logs/pilot_smoke.json`, `logs/mid_smoke.json`) for reproducing eval commands without lengthy training.\n- Documentation set (`docs/stage1_plan.md`, `docs/stage2_plan.md`, `docs/data_pipeline.md`, `docs/guide.md`) outlining architecture, scaling strategy, and onboarding.\n\n### Changed\n- README rewritten with badges, quickstart commands, and references to the new guide + release checklist.\n- Logging defaults clarified (`logging.backend=json|wandb`), with instructions for saving structured metrics under `logs/`.\n\n### Known gaps\n- Release automation and CI are tracked in `docs/release_plan.md`.\n- Scaling guidance for >100 B token corpora pending additional storage + GPU availability.\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Nested Learning Reproduction\n\n![CI](https://github.com/kmccleary3301/nested_learning/actions/workflows/ci.yml/badge.svg)\n![Security](https://github.com/kmccleary3301/nested_learning/actions/workflows/security.yml/badge.svg)\n![Python](https://img.shields.io/badge/python-3.10%20to%203.12-blue)\n![PyTorch](https://img.shields.io/badge/pytorch-2.9.0-red)\n![License](https://img.shields.io/badge/license-Apache--2.0-green)\n![Status](https://img.shields.io/badge/tests-smoke--ready-lightgrey)\n\nMechanism-level reproduction of Google's Nested Learning (HOPE) architecture (HOPE blocks, CMS, and Self‑Modifying TITANs), matching the quality bar set by lucidrains' TITAN reference while remaining fully open-source and `uv` managed.\n\nFaithfulness scope (high level):\n- ✅ HOPE / CMS / Self‑Modifying Titans update rules + wiring (mechanism-level)\n- ✅ Tensor-level invariants covered by unit tests (teach-signal, δℓ, CMS chunking, causality)\n- ✅ Boundary-target online chunking + optional attention-cache carry path are implemented\n- ⚠️ Stable default uses stop-grad online writes; an experimental single-process boundary-state mode supports differentiable write paths\n- ⚠️ Multi‑GPU mechanism-auditing online updates are not supported in this repo (DDP disables some features)\n\nPaper reference pin:\n- Source: `google_papers/Nested_Learning_Full_Paper/Nested_Learning_Full_Paper.md`\n- SHA-256: `7524af0724ac8e3bad9163bf0e79c85b490a26bc30b92d96b0bdf17a27f9febc`\n\n## Quickstart\n```bash\nuv python install 3.12\nuv sync --all-extras\nuv run nl doctor --json > logs/runtime_doctor.json\nuv run bash scripts/data/run_sample.sh\nuv run nl smoke --config-name pilot_smoke --device cpu\nuv run bash scripts/run_smoke.sh pilot  # CPU-friendly HOPE block smoke test\nuv run bash scripts/run_e2e_smoke.sh    # sync + sample data + smoke train + zeroshot eval\nuv run bash scripts/run_mechanism_audit_smoke.sh\nuv run python scripts/eval/zeroshot.py \\\n  --config configs/hope/pilot.yaml \\\n  --checkpoint artifacts/examples/pilot_dummy.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --tasks piqa --max-samples 32 --device cpu\n```\n\n## Requirements\n- Python 3.10-3.12\n- PyTorch 2.9.x+ (golden environment in this repo uses 2.9.x)\n- `uv` (recommended for development) or `pip` for package-style usage\n\n## Compatibility\n- Support tiers and OS/runtime matrix: `docs/COMPATIBILITY_MATRIX.md`\n- Versioning/stability policy: `docs/VERSIONING_POLICY.md`\n- Golden repro environment: Python 3.12 + `uv lock` + PyTorch 2.9.x\n\nmacOS / Apple Silicon expectations:\n- Mac users can run install + CLI + eval/smoke workflows.\n- `train.device=mps` is supported for small/local runs.\n- Linux + CUDA remains the only Tier 1 full-training path in this repo.\n- Cross-backend numerical parity (CUDA vs MPS) is not guaranteed.\n- If MPS is unavailable, device selection falls back to CPU (`nl doctor --json` shows this clearly).\n\n## Installation (pip-first)\n1. Create and activate a virtual environment.\n2. Install Torch first (CPU/CUDA wheel selection is backend-specific).\n3. Install this project.\n\nCPU example:\n```bash\npython -m venv .venv\nsource .venv/bin/activate\npython -m pip install --upgrade pip\npython -m pip install \"torch>=2.9,<3\" --index-url https://download.pytorch.org/whl/cpu\npython -m pip install -e .\n```\n\nCUDA example (adjust index URL to your CUDA runtime):\n```bash\npython -m venv .venv\nsource .venv/bin/activate\npython -m pip install --upgrade pip\npython -m pip install \"torch>=2.9,<3\" --index-url https://download.pytorch.org/whl/cu128\npython -m pip install -e .\n```\n\n## Setup (uv dev workflow)\n```bash\nuv python install 3.12\nuv sync --all-extras\n```\n\nDeveloper checks:\n- `uv run ruff check .`\n- `uv run mypy src`\n- `uv run pytest`\n- `uv run bash scripts/checks/run_fidelity_ci_subset.sh`\n- `uv run python scripts/checks/compliance_report.py --config configs/pilot.yaml --output eval/compliance_report.json`\n\n## CLI\nThe package ships with `nl` for portable workflows across local/dev/prod environments.\n\n```bash\n# runtime compatibility snapshot\nuv run nl doctor --json\n\n# architecture/config smoke on chosen device\nuv run nl smoke --config-name pilot_smoke --device cpu --batch-size 1 --seq-len 8\n\n# static fidelity checks for a config\nuv run nl audit --config-name pilot_paper_faithful\n\n# train with Hydra overrides\nuv run nl train --config-name pilot --override train.device=cuda:1 --override train.steps=100\n```\n\n`python -m nested_learning ...` is also supported.\n\n## First 30 Minutes\nUse this path for a fast first success on CPU:\n\n```bash\nuv sync --all-extras\nuv run bash scripts/data/run_sample.sh\nuv run bash scripts/run_smoke.sh pilot\nuv run bash scripts/run_mechanism_audit_smoke.sh\n```\n\nThis confirms:\n- data/tokenizer pipeline is operational,\n- model/training loop runs end-to-end,\n- cadence checks pass for a mechanism-auditing smoke run.\n\n## Data Pipeline\n1. **Tokenizer training**\n   ```bash\n   uv run python scripts/data/train_tokenizer.py \\\n     --manifest configs/data/refinedweb_mixture.yaml \\\n     --vocab-size 32000 \\\n     --output-dir artifacts/tokenizer/refinedweb_mix \\\n     --log-file data/mixtures/refinedweb_mix_tokenizer.json\n   ```\n2. **Corpus filtering + sharding**\n   ```bash\n   uv run python scripts/data/process_mixture.py \\\n     configs/data/refinedweb_mixture_filtered.yaml \\\n     --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n     --log-file data/mixtures/refinedweb_mix_filtered_shards.json\n   ```\n3. **Sample pipeline** (downloads/licensed datasets, filters, shards, records stats)\n   ```bash\n   uv run bash scripts/data/run_sample.sh\n   ```\n4. **Full pipeline** (set env vars like `RW_LIMIT`, `WIKI_LIMIT`, etc. to scale ingestion)\n  ```bash\n  uv run bash scripts/data/run_full.sh  # default ~50k docs per corpus; increase limits as needed\n  ```\n\n### Data Troubleshooting\n- If `scripts/data/run_sample.sh` cannot find `artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model`, rerun:\n  ```bash\n  uv run bash scripts/data/run_sample.sh\n  ```\n  The script auto-trains the tokenizer when missing.\n- If `scripts/data/run_full.sh` fails with `Bad split: train. Available splits: ['test']`, use split fallback:\n  ```bash\n  FALLBACK_SPLIT=test uv run bash scripts/data/run_full.sh\n  ```\n  You can also override per-corpus splits (for example `RW_SPLIT=test`).\n\n## Training\n- Single GPU / CPU:\n  ```bash\n  uv run nl train --config-name pilot_smoke\n  ```\n- Apple Silicon (MPS, if available):\n  ```bash\n  uv run nl train --config-name pilot_smoke --override train.device=mps\n  ```\n  Use this path for smoke and small local runs; long/full-scale paper-regime runs are not a supported Mac target in this repository.\n- Script-based entrypoint (legacy-compatible):\n  ```bash\n  uv run python train.py --config-name pilot_smoke\n  ```\n- DDP (torchrun):\n  ```bash\n  torchrun --nproc_per_node=2 train_dist.py --config-name mid\n  ```\n- CPU-only DDP smoke (verifies `gloo` backend and deterministic seeding):\n  ```bash\n  uv run bash scripts/run_cpu_ddp_smoke.sh\n  ```\n- FSDP (see `docs/FSDP_SCALING_GUIDE.md` for VRAM/batch sizing):\n  ```bash\n  # 760M run\n  torchrun --nproc_per_node=2 train_fsdp.py --config-name hope/mid_fsdp\n  # 1.3B run\n  torchrun --nproc_per_node=2 train_fsdp.py --config-name hope/target_fsdp\n  ```\n- DeepSpeed (requires `deepspeed` installed separately):\n  ```bash\n  deepspeed --num_gpus=2 train_deepspeed.py --config-name target \\\n    deepspeed.config=configs/deepspeed/zero3.json\n  ```\n\n### Mechanism-auditing presets (HOPE / Nested Learning)\n\nUse the mechanism-auditing preset configs (single GPU):\n\n```bash\nuv run python train.py --config-name pilot_paper_faithful\n# HOPE self-mod variant:\nuv run python train.py --config-name pilot_selfmod_paper_faithful\n```\n\nNotes:\n- These presets set `data.batch_size=1` to avoid cross-sample fast-memory sharing.\n- Online chunking supports one-token overlap **or** explicit boundary-target mode (`train.online_boundary_targets=true`).\n- Optional attention-state carry across chunks is available in training via `train.online_carry_attention_cache=true`.\n- The exact sequence/segment/chunk/buffer semantics are documented in `docs/STREAMING_CONTRACT.md`.\n\nOverrides:\n- `optim.type=m3` (paper optimizer option)\n- `train.steps=...` / `train.device=...`\n\nSee `docs/PAPER_COMPLIANCE.md` for full fidelity notes.\nSee `docs/STREAMING_CONTRACT.md` for the precise streaming/update contract used by this repo.\n\n## Scope Boundaries (Current)\n- This repo targets mechanism-auditing fidelity, not full paper-scale results parity.\n- Boundary-state gradient-through-write exists as an experimental constrained path; it is not yet treated as production/full-scale paper reproduction.\n- Distributed mechanism-auditing path for boundary-target + attention-cache carry is not implemented.\n\n### Pilot (3 B tokens) workflow\n1. Ensure TMUX session:\n   ```bash\n   tmux new -s pilot_train\n   ```\n2. Launch the long run on `cuda:1` (≈52 h wall clock):\n   ```bash\n   set -a && source git.env && set +a\n   export UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy\n   uv run python train.py --config-name pilot \\\n     logging.enabled=true logging.backend=wandb \\\n     logging.project=nested-learning logging.run_name=pilot-main-$(date +%Y%m%d%H%M%S) \\\n     train.device=cuda:1\n   ```\n3. Checkpoints appear in `artifacts/checkpoints/pilot/step_*.pt` every 1 000 steps; the accompanying W&B run captures full telemetry.\n4. Copy the final checkpoint, config, logs, and eval JSON/CSV into `artifacts/pilot_release/` for distribution.\n\n## Logging\nSet `logging.enabled=true` in Hydra configs (or override via CLI) to send metrics to W&B (default). For local JSON logs, use `logging.backend=json logging.path=logs/run.json`. Sample outputs reside in `logs/` and `artifacts/examples/`.\n\n## Evaluation\n- Zero-shot:\n  ```bash\n  uv run python scripts/eval/zeroshot.py \\\n  --config configs/hope/mid.yaml \\\n  --checkpoint checkpoints/mid/step_000100.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --tasks all --max-samples 200 --device cuda:0\n  ```\n  Use `uv run python scripts/eval/zeroshot.py --list-tasks` to display the full benchmark roster (PIQA, HellaSwag, WinoGrande, ARC-E/C, BoolQ, SIQA, CommonsenseQA, OpenBookQA). See `docs/zeroshot_eval.md` for details.\n- Needle-in-a-Haystack:\n  ```bash\n  uv run python scripts/eval/niah.py \\\n    --config configs/hope/mid.yaml \\\n    --checkpoint checkpoints/mid/step_000100.pt \\\n    --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n    --context-lengths 2048 4096 8192 --samples-per-length 20\n  ```\n- Continual-learning forgetting:\n  ```bash\n  uv run python scripts/eval/continual.py \\\n    --config configs/hope/mid.yaml \\\n    --checkpoints checkpoints/mid/step_000050.pt checkpoints/mid/step_000100.pt \\\n    --segments-yaml configs/data/continual_segments_sample.yaml \\\n    --batch-size 4 --max-batches 10 --memorize --memorize-steps 2\n  ```\n  Plot forgetting curves via `uv run python scripts/eval/plot_forgetting.py --continual-json eval/continual_mid.json`.\n- Long-context diagnostics:\n  ```bash\n  uv run python scripts/eval/passkey.py --config configs/hope/pilot.yaml --checkpoint artifacts/checkpoints/pilot/step_230000.pt \\\n    --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model --samples 64 --memorize\n\n  uv run python scripts/eval/pg19_perplexity.py --config configs/hope/pilot.yaml --checkpoint artifacts/checkpoints/pilot/step_230000.pt \\\n    --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model --max-samples 64\n  ```\n\nEvaluation summaries are written to `eval/` alongside per-task JSON metrics.\n\n### Test-time memorization toggles\nEvery evaluator supports TITAN-style memorization so you can reproduce test-time adaptation:\n```bash\nuv run python scripts/eval/zeroshot.py \\\n  ... \\\n  --memorize \\\n  --memorize-steps 2 \\\n  --memorize-use-correct-answer \\\n  --memorize-no-reset  # optional: retain updates across samples\n  --memorize-paths titan,cms_fast \\\n  --memorize-surprise-threshold 0.01\n```\n- `--memorize` turns on the learner with one LMS step per example by default.\n- `--memorize-steps` controls the number of adaptation passes per prompt.\n- `--memorize-use-correct-answer` injects ground-truth text during memorization for ablations.\n- `--memorize-no-reset` carries memories across samples; omit it to reset every question.\n- `--memorize-paths` restricts which levels receive teach-signal updates (`titan`, `cms_fast`, or `all`).\n- `--memorize-surprise-threshold` gates updates on average teach-signal norm, matching the paper’s surprise trigger.\n\nMemorization metrics (baseline vs adaptive) are emitted alongside task accuracy for easy comparisons.\n\n## Architecture variants\nSelect the paper-defined variant via `model.block_variant` in Hydra configs:\n- `hope_attention` (paper HOPE-Attention): `Attention → CMS` (paper-defined).\n- `hope_selfmod` (paper HOPE scaffold): `Self-modifying Titans (Eqs. 83–93; Eq. 91 residual MLP memories) → CMS` with (by default) **fixed q** and **local conv window=4**, plus chunked updates via `model.self_mod_chunk_size` (others) and `model.self_mod_chunk_size_memory` (M_memory). See `docs/PAPER_COMPLIANCE.md` for the “differentiable read / update-pass writes” semantics.\n- `hope_hybrid` (legacy): `Attention + TitanMemory + CMS` (exploratory; not paper-defined).\n- `transformer` (baseline): `Attention → MLP` (no TITAN/CMS learning updates; useful for Phase 2 comparisons).\n\nSelf-modifying Titans knobs (ablation-friendly, paper-aligned):\n- `model.self_mod_objective` (`l2` vs `dot`), `model.self_mod_use_rank1_precond` (DGD-like preconditioner), `model.self_mod_use_alpha` (weight-decay/retention gate), `model.self_mod_stopgrad_vhat`, `model.self_mod_momentum`, `model.self_mod_adaptive_q`, `model.self_mod_local_conv_window`.\n\n## Fast state (Nested Learning semantics)\nIn-context updates can run against a per-context fast state so meta parameters never change:\n- `HOPEModel.init_fast_state()` / `TitanOnlyModel.init_fast_state()` returns a `ModelFastState`.\n- `MemorizeConfig.use_fast_state=true` (default) requires passing `fast_state` into `memorize_tokens()` / `memorize_sequence()`; evaluation scripts handle this automatically.\n- Training can also run update passes against a per-batch fast state via `train.use_fast_state=true` (meta+delta fast state: meta params are learnable; online updates write deltas only). If `data.batch_size>1`, CMS/TITAN fast state is shared across the batch; use `data.batch_size=1` for strict per-context semantics. See `docs/PAPER_COMPLIANCE.md`.\n\n## Releases\nBefore tagging or announcing a new checkpoint, work through:\n- `docs/release_checklist.md` (model/eval artifact release bundle)\n- `docs/PACKAGE_RELEASE_CHECKLIST.md` (package/GitHub/PyPI release flow)\n- `docs/PYPI_TRUSTED_PUBLISHING.md` (one-time OIDC setup for TestPyPI/PyPI)\n\nTag pushes (`v*`) automatically publish:\n- PyPI/TestPyPI package artifacts (via Trusted Publishing), and\n- a GitHub Release entry with wheel, sdist, and `SHA256SUMS.txt` in the Releases tab.\n- a GitHub Packages (GHCR) OCI bundle (`nested-learning-dist`) containing `dist/*`.\n\nGitHub Packages note:\n- The repo publishes an OCI artifact bundle to GHCR (shown under the Packages tab), not a Python package registry endpoint.\n- Python installs should still use PyPI (`pip install nested-learning`).\n\nExample (pull/extract dist artifacts from GHCR):\n```bash\ndocker pull ghcr.io/kmccleary3301/nested-learning-dist:latest\ncid=$(docker create ghcr.io/kmccleary3301/nested-learning-dist:latest)\ndocker cp \"$cid:/dist\" ./dist_from_ghcr\ndocker rm \"$cid\"\n```\n\nFor versioning semantics and breaking-change expectations, see `docs/VERSIONING_POLICY.md`.\n\nFor reproducibility bug reports, use `docs/BUG_REPORT_CHECKLIST.md`.\n\n## Performance & optimizer options\n- **Mixed precision:** enable bf16 autocast via `train.mixed_precision.enabled=true train.mixed_precision.dtype=bf16` (already enabled in pilot/mid/target configs).\n- **`torch.compile`:** accelerate attention/core loops by toggling `train.compile.enable=true train.compile.mode=max-autotune`; failure falls back to eager unless `train.compile.strict=true`.\n- **Muon hybrid (default):** all HOPE configs now set `optim.type=muon`, routing ≥2D tensors through PyTorch 2.9's Muon optimizer while embeddings/norms stay on AdamW. Training logs emit `optim.muon_param_elems` / `optim.adamw_param_elems` so you can confirm the split.\n- **Fused AdamW fallback:** override with `optim.type=adamw optim.fused=auto` if Muon is unavailable or if you want to compare against the AdamW ablation in `reports/ablations.md`.\n- **Surprise gating:** set `model.surprise_threshold=<float>` to gate all inner updates. By default the surprise metric is the average L2 norm of the (scaled/clipped) teach signal (`model.surprise_metric=l2`); you can also use `loss` or `logit_entropy` for ablations. Evaluation CLIs expose `--memorize-surprise-threshold` for ad-hoc gating.\n\nAll Hydra knobs can be overridden from the CLI or composed via config groups (`configs/hope/*.yaml`). Use these flags in tandem with `scripts/run_e2e_smoke.sh` (automation) or `scripts/run_cpu_ddp_smoke.sh` (CPU-only determinism check) to validate releases quickly.\n\n## Documentation & References\n- `docs/IMPLEMENTATION_STATUS.md` – current mechanism-level status matrix.\n- `docs/PAPER_COMPLIANCE.md` – equation-to-code fidelity notes and explicit boundaries.\n- `docs/STREAMING_CONTRACT.md` – exact sequence/segment/chunk/update semantics.\n- `docs/release_checklist.md` – release readiness checklist.\n- `docs/data_pipeline.md` – large-scale sharding/tokenizer workflow.\n- `docs/scaling_guidance.md` – roadmap for expanding data + compute footprints.\n- `docs/stage2_plan.md` – Stage 2 architecture + experiment roadmap.\n- `docs/PHASE_2_PLAN.md` – detailed Phase 2 execution plan.\n- `docs/stage2_progress.md` – progress tracker for the latest faithfulness remediation sprint.\n- `docs/experiments_report.md` – draft paper covering completed experiments.\n- `docs/future_directions.md` – prioritized roadmap after the initial release.\n- `reports/stage2_smoke.md` – exact commands/artifacts for the release-ready smoke workflow.\n- `docs/FSDP_SCALING_GUIDE.md` – dual-RTX 6000 Ada instructions for the mid/target FSDP configs.\n- `google_papers/` – PDFs/markdown of Nested Learning & TITAN papers.\n- `CHANGELOG.md` – user-facing changes per release.\n\n## Contributing\n1. Run formatting/tests (`uv run ruff check .`, `uv run pytest`).\n2. Document new configs or scripts in the relevant docs under `docs/` and update `CHANGELOG.md`.\n3. Open a PR referencing the relevant NL/TITAN spec sections and tests.\n"
  },
  {
    "path": "TODO.md",
    "content": "# Project TODOs\n\n## Planner Finalization – P0 Foundation\n- [x] Add first-class package CLI (`nl`) with `doctor`, `smoke`, `train`, and `audit` commands.\n- [x] Support module entrypoint (`python -m nested_learning`).\n- [x] Register CLI script in `pyproject.toml` for pip/uv installs.\n- [x] Implement runtime capability detection and JSON doctor output.\n- [x] Add cross-platform smoke tests for CLI/config composition.\n- [x] Validate with lint + mypy + full pytest.\n\n## Planner Finalization – P1 Distribution/CI\n- [x] Relax package compatibility ranges (`python>=3.10`, `torch>=2.9,<3`) while keeping lockfile golden env.\n- [x] Split optional dependencies into extras (`gpu`, `logging`, `viz`) for lighter base installs.\n- [x] Add compatibility/support-tier documentation (`docs/COMPATIBILITY_MATRIX.md`).\n- [x] Add versioning/stability policy (`docs/VERSIONING_POLICY.md`).\n- [x] Add package release checklist (`docs/PACKAGE_RELEASE_CHECKLIST.md`).\n- [x] Expand CI with cross-platform smoke and wheel-install smoke lanes.\n- [x] Add release automation workflow (`.github/workflows/release.yml`) for tag-based TestPyPI/PyPI publish.\n- [x] Update README to pip-first install + compatibility/versioning links + CLI usage.\n\n## Stage 2 – Results Reproduction\n- [ ] **Data Engineering**\n  - [ ] Acquire RefinedWeb + supplement corpora under `data/raw/`.\n  - [x] Implement filtering/dedup scripts (language ID, length bounds).\n  - [x] Run `scripts/data/train_tokenizer.py` on combined corpus and store tokenizer artifacts.\n  - [x] Shard each corpus component with `scripts/data/process_mixture.py`; log mixture stats.\n  - [x] Automate `sample` and `full` pipelines via `scripts/data/run_sample.sh` / `scripts/data/run_full.sh`.\n- [ ] **Infrastructure & Configs**\n  - [x] Build Hydra config tree (`configs/hope/`) for pilot/mid/target, including optimizer + level schedules.\n  - [x] Integrate logging (W&B/MLflow) hooks into training loop and configs.\n  - [x] Provide DeepSpeed + FSDP launcher scripts with resume support.\n  - [x] Add CI workflow (`.github/workflows/ci.yml`) for lint/type/tests via `uv`.\n- [ ] **Scaling Training**\n  - [x] Run pilot (160 M, 3 B tokens) to validate pipeline + self-mod updates. *(Step 230 k packaged 13 Nov; resume after TITAN baseline catches up.)*\n  - [ ] Scale to 760 M / 30 B tokens; capture checkpoints + metrics. *(100-step mid run stable; longer runs waiting on teach-scale tuning + compute.)*\n  - [ ] Execute 1.3 B / 100 B training with long-context curriculum.\n- [ ] **Evaluation Harness**\n  - [x] Implement `scripts/eval/zeroshot.py` scaffolding (PIQA baseline).\n  - [x] Extend zero-shot harness to cover PIQA/HellaSwag/WinoGrande/ARC-E/C/BoolQ/SIQA/CommonsenseQA/OpenBookQA and document usage.\n  - [x] Build NIAH long-context scaffolding script (`scripts/eval/niah.py`).\n  - [x] Add continual-learning scripts measuring forgetting over streaming domains.\n  - [x] Capture Stage 2 eval packs (zeroshot/NIAH/continual) from pilot checkpoints once stable (step 230 k release).\n- [ ] **Ablations & Analysis**\n  - [x] Run teach-scale sweep (0.05/0.10/0.15) on pilot checkpoints. *(0.05 & 0.15 short + 25 k long runs logged; see `logs/pilot-teach05-20251114010549.json` and `logs/pilot-teach15-long-20251114185448.json`.)*\n  - [x] Run self-modifier off/on comparison at pilot scale.\n  - [ ] Test CMS depth variations and optimizer variants.\n  - [ ] Compare attention backbones (full vs. sliding vs. DeltaNet).\n- [ ] **Baseline Monitoring**\n  - [x] Finish TITAN long run (25 k steps, `cuda:0`, TMPDIR `/mnt/drive_4/tmp_titan`) and mirror HOPE packaging/eval workflow.\n- [ ] **Documentation & Release**\n  - [ ] Maintain experiment logs under `reports/`.\n  - [ ] Publish data pipeline instructions + provenance for each corpus.\n  - [ ] Summarize final metrics vs. baselines in Stage 2 report.\n\n## Immediate Sprint Focus (Nov 15)\n- [x] Design CMS sparse-chunk ablation config that stays within 49 GB (dim 384, seq 1024, batch 2, update periods 8/32/128/512).\n- [x] Run CMS sparse-chunk experiment, package checkpoint (`artifacts/checkpoints/pilot_cms_sparse/step_005000.pt`), and produce evals (`eval/*_pilot_cms_sparse_step5000.json`).\n- [x] Launch optimizer ablation comparing Muon hybrid vs fused AdamW on pilot-scale smoke (5–10 k steps) and archive eval metrics.\n- [x] Roll the new CMS + optimizer findings into `reports/ablations.md`, `docs/stage2_progress.md`, and outline the resulting Stage 2 training plan updates.\n\n## Planner Follow-up (P2)\n- [x] Manifest validation report (`scripts/data/validate_mixture.py`) + token overlap stats.\n- [x] Tokenizer coverage JSON via `scripts/data/check_tokenizer_coverage.py` + regression guard (`scripts/checks/tokenizer_coverage_guard.py`).\n- [x] Extend eval suite with passkey, PG‑19, and continual forgetting plots (see `scripts/eval/run_pilot_suite.sh` + `reports/plots/` output).\n- [x] Generate long-context/continual eval artifacts for pilot & TITAN checkpoints (`eval/passkey_*`, `eval/pg19_*`, `eval/continual_*`).\n- [x] Fill checkpoint reports (`reports/checkpoints/pilot_step230000.md`, `.../titan_step25000.md`, `.../pilot_teach05_long.md`, CMS variants, optimizer ablations, self-mod off).\n- [x] Run the same reporting workflow for future checkpoints (teach15 long, CMS sparse/no chunk, optimizer ablations) before publishing.\n\n## Planner Follow-up (P1)\n- [x] Make Muon the default outer optimizer (pilot/mid/target configs), log Muon vs AdamW param counts, and confirm bf16/SDPA/compile flags in training logs.\n- [x] Finalize FSDP/ZeRO configs for 760 M / 1.3 B (with grad checkpointing + VRAM notes) and document usage.\n- [x] Implement atomic checkpoint sidecars (SHA256, RNG state, tokenizer hash) plus a strict `scripts/checkpoint/verify.py`.\n- [x] Extend CI with CPU DDP determinism smoke + synthetic passkey memorization test.\n\n## Stage 2 – Execution Sprint (Nov 17)\n- [x] Relaunch HOPE pilot run on `cuda:1` (Muon + surprise gating) and produce fresh checkpoints/logs.\n  - Status (Jan 9): relaunch stopped at `artifacts/checkpoints/pilot_relaunch/step_477000.pt` and verified via `scripts/checkpoint/verify.py`.\n- [x] Package the new pilot checkpoint via `scripts/package_pilot_release.sh` and rerun the full eval suite (zeroshot/NIAH/continual/passkey/PG19) with memorize path/threshold metadata.\n  - Done: `reports/checkpoints/pilot_relaunch_step477000.md` + `eval/*_pilot.json` and refreshed `artifacts/pilot_release/`.\n- [x] Restart TITAN long baseline, mirror the eval suite, and record surprise gating stats.\n  - Status (Jan 9): packaged + evaluated `artifacts/checkpoints/mid_titan_long/step_032000.pt` (see `reports/checkpoints/titan_long_step32000.md` and `eval/*_titan.json`).\n- [ ] Run the mid-scale FSDP config (`configs/hope/mid_fsdp.yaml`), monitor VRAM, and archive checkpoints/logs.\n  - Status (Jan 10): 2×GPU FSDP smoke runs (synthetic) complete, including update pass and checkpoint saving (FSDP ranks now all participate in FULL_STATE_DICT gathering).\n- [x] Update `reports/checkpoints/` + `reports/ablations.md` with the new HOPE/TITAN results (include memorize paths/surprise thresholds).\n- [x] Refresh `docs/stage2_progress.md`, `docs/experiments_report.md`, and `docs/stage2_plan.md` with the latest execution status and next scaling steps.\n"
  },
  {
    "path": "configs/ablations/cms_sparse.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  dim: 384\n  num_layers: 8\n  heads: 6\n  titan_level:\n    name: titan\n    update_period: 8\n    optimizer_key: titan_opt\n  cms_hidden_multiplier: 2\n  cms_levels:\n    - name: cms_fast\n      update_period: 8\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 32\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 128\n      optimizer_key: cms_opt\n    - name: cms_ultra\n      update_period: 512\n      optimizer_key: cms_opt\n\ndata:\n  seq_len: 1024\n  batch_size: 2\n  num_workers: 2\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  device: \"cuda:1\"\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_cms_sparse\n    save_interval: 1000\n  log_interval: 25\n\nlogging:\n  path: logs/pilot_cms_sparse_metrics.json\n  run_name: pilot-cms-sparse\n"
  },
  {
    "path": "configs/ablations/selfmod_chunked_8_64.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  self_mod_chunk_size: 8\n  self_mod_chunk_size_memory: 64\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  device: \"cuda:1\"\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod_chunked_8_64\n    save_interval: 1000\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_selfmod_chunked_8_64_metrics.json\n  run_name: pilot-selfmod-chunked-8-64\n"
  },
  {
    "path": "configs/ablations/selfmod_momentum_off.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  self_mod_momentum: 0.0\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  device: \"cuda:1\"\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod_momentum_off\n    save_interval: 1000\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_selfmod_momentum_off_metrics.json\n  run_name: pilot-selfmod-momentum-off\n"
  },
  {
    "path": "configs/ablations/selfmod_momentum_on.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  self_mod_momentum: 0.9\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  device: \"cuda:1\"\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod_momentum_on\n    save_interval: 1000\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_selfmod_momentum_on_metrics.json\n  run_name: pilot-selfmod-momentum-on\n"
  },
  {
    "path": "configs/ablations/selfmod_no_alpha.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  self_mod_use_alpha: false\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  device: \"cuda:1\"\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod_no_alpha\n    save_interval: 1000\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_selfmod_no_alpha_metrics.json\n  run_name: pilot-selfmod-no-alpha\n"
  },
  {
    "path": "configs/ablations/selfmod_no_cms.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  cms_levels: []\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  device: \"cuda:1\"\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod_no_cms\n    save_interval: 1000\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_selfmod_no_cms_metrics.json\n  run_name: pilot-selfmod-no-cms\n"
  },
  {
    "path": "configs/ablations/selfmod_rank1_precond_off.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  self_mod_use_rank1_precond: false\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  device: \"cuda:1\"\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod_rank1_off\n    save_interval: 1000\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_selfmod_rank1_off_metrics.json\n  run_name: pilot-selfmod-rank1-off\n"
  },
  {
    "path": "configs/data/continual_segments_sample.yaml",
    "content": "segments:\n  - name: refinedweb_2018\n    shards_dir: data/shards/refinedweb_sample\n  - name: wikipedia_sample\n    shards_dir: data/shards/wikipedia_sample\n  - name: c4_sample\n    shards_dir: data/shards/c4_sample\n  - name: redpajama_sample\n    shards_dir: data/shards/redpajama_sample\n"
  },
  {
    "path": "configs/data/fineweb_edu_longdoc_filtered_sample.yaml",
    "content": "name: fineweb_edu_longdoc_filtered_sample\ntokenizer_output_dir: artifacts/tokenizer/fineweb_edu_longdoc\ndatasets:\n  - name: fineweb_edu_longdoc\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/fineweb_edu_longdoc_en_sample.txt\n    sample_limit: 5000\n    seq_len: 4096\n    sequences_per_shard: 1024\n    output_dir: data/shards/fineweb_edu_longdoc_sample\n    max_records: null\n\n"
  },
  {
    "path": "configs/data/fineweb_edu_mixture_full.yaml",
    "content": "name: fineweb_edu_full\ntokenizer_output_dir: artifacts/tokenizer/fineweb_edu\ndatasets:\n  - name: fineweb_edu\n    dataset: HuggingFaceFW/fineweb-edu\n    subset: sample-100BT\n    split: train\n    text_column: text\n    sample_limit: 100000\n    seq_len: 4096\n    sequences_per_shard: 1024\n    output_dir: data/shards/fineweb_edu_full\n    max_records: null\n\n"
  },
  {
    "path": "configs/data/fineweb_edu_mixture_sample.yaml",
    "content": "name: fineweb_edu_sample\ntokenizer_output_dir: artifacts/tokenizer/fineweb_edu\ndatasets:\n  - name: fineweb_edu\n    dataset: HuggingFaceFW/fineweb-edu\n    subset: sample-10BT\n    split: train\n    text_column: text\n    sample_limit: 5000\n    seq_len: 2048\n    sequences_per_shard: 1024\n    output_dir: data/shards/fineweb_edu_sample\n    max_records: 10000\n\n"
  },
  {
    "path": "configs/data/refinedweb_mixture.yaml",
    "content": "name: refinedweb_mix_v1\ntokenizer_output_dir: artifacts/tokenizer/refinedweb_mix\ndatasets:\n  - name: refinedweb\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/refinedweb_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 2048\n    output_dir: data/shards/refinedweb\n    max_records: null\n  - name: books\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/wikipedia_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 2048\n    output_dir: data/shards/wikipedia\n    max_records: null\n  - name: c4\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/c4_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 2048\n    output_dir: data/shards/c4\n    max_records: null\n  - name: redpajama\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/redpajama_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 2048\n    output_dir: data/shards/redpajama\n    max_records: null\n  - name: code\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/code_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 2048\n    output_dir: data/shards/code\n    max_records: null\n"
  },
  {
    "path": "configs/data/refinedweb_mixture_filtered.yaml",
    "content": "name: refinedweb_mix_filtered\ntokenizer_output_dir: artifacts/tokenizer/refinedweb_mix\ndatasets:\n  - name: refinedweb\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/refinedweb_en_sample.txt\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/refinedweb_filtered\n    max_records: null\n  - name: wikipedia\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/wikipedia_en_sample.txt\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/wikipedia_filtered\n    max_records: null\n  - name: c4\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/c4_en_sample.txt\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/c4_filtered\n    max_records: null\n  - name: redpajama\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/redpajama_en_sample.txt\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/redpajama_filtered\n    max_records: null\n  - name: code\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/code_en_sample.txt\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/code_filtered\n    max_records: null\n"
  },
  {
    "path": "configs/data/refinedweb_mixture_full.yaml",
    "content": "name: refinedweb_mix_full\ntokenizer_output_dir: artifacts/tokenizer/refinedweb_mix\ndatasets:\n  - name: refinedweb\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/refinedweb_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 1024\n    output_dir: data/shards/refinedweb_full\n    max_records: null\n  - name: wikipedia\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/wikipedia_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 1024\n    output_dir: data/shards/wikipedia_full\n    max_records: null\n  - name: c4\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/c4_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 1024\n    output_dir: data/shards/c4_full\n    max_records: null\n  - name: redpajama\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/redpajama_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 1024\n    output_dir: data/shards/redpajama_full\n    max_records: null\n  - name: code\n    dataset: text\n    split: train\n    text_column: text\n    data_files: data/filtered/code_en_full.txt\n    seq_len: 2048\n    sequences_per_shard: 1024\n    output_dir: data/shards/code_full\n    max_records: null\n"
  },
  {
    "path": "configs/data/refinedweb_mixture_sample.yaml",
    "content": "name: refinedweb_mix_sample\ntokenizer_output_dir: artifacts/tokenizer/refinedweb_mix\ndatasets:\n  - name: refinedweb\n    dataset: HuggingFaceFW/fineweb\n    subset: sample-10BT\n    split: train\n    text_column: text\n    sample_limit: 5000\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/refinedweb_sample\n    max_records: 10000\n  - name: books\n    dataset: wikimedia/wikipedia\n    subset: 20231101.en\n    split: train\n    text_column: text\n    sample_limit: 2000\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/wikipedia_sample\n    max_records: 5000\n  - name: c4\n    dataset: allenai/c4\n    subset: en\n    split: train\n    text_column: text\n    sample_limit: 2000\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/c4_sample\n    max_records: 4000\n  - name: redpajama\n    dataset: cerebras/SlimPajama-627B\n    split: train\n    text_column: text\n    sample_limit: 2000\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/redpajama_sample\n    max_records: 4000\n  - name: code\n    dataset: codeparrot/codeparrot-clean-train\n    split: train\n    text_column: content\n    sample_limit: 2000\n    seq_len: 512\n    sequences_per_shard: 512\n    output_dir: data/shards/code_sample\n    max_records: 4000\n"
  },
  {
    "path": "configs/deepspeed/zero3.json",
    "content": "{\n  \"bf16\": {\n    \"enabled\": true\n  },\n  \"train_batch_size\": 64,\n  \"gradient_accumulation_steps\": 1,\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"reduce_bucket_size\": 50000000,\n    \"stage3_param_persistence_threshold\": 100000,\n    \"stage3_prefetch_bucket_size\": 50000000\n  },\n  \"optimizer\": {\n    \"type\": \"AdamW\",\n    \"params\": {\n      \"lr\": 0.0002,\n      \"betas\": [\n        0.9,\n        0.95\n      ],\n      \"eps\": 1e-08,\n      \"weight_decay\": 0.01\n    }\n  }\n}\n"
  },
  {
    "path": "configs/hope/mid.yaml",
    "content": "defaults:\n  - _self_\n\nhydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  vocab_size: 32000\n  dim: 1024\n  num_layers: 24\n  heads: 16\n  surprise_threshold: null\n  freeze_backbone: false\n  titan_level:\n    name: titan\n    update_period: 16\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 32\n      optimizer_key: cms_opt\n    - name: cms_ultra\n      update_period: 128\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 8.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n    cms_opt:\n      type: deep_momentum\n      lr: 4.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n\ndata:\n  source: mixture\n  batch_size: 16\n  num_workers: 4\n  mixture:\n    samples_per_epoch: 8192\n    seed: 42\n    sources:\n      - name: refinedweb\n        shards_dir: data/shards/refinedweb_full\n        weight: 0.4\n      - name: wikipedia\n        shards_dir: data/shards/wikipedia_full\n        weight: 0.2\n      - name: c4\n        shards_dir: data/shards/c4_full\n        weight: 0.15\n      - name: redpajama\n        shards_dir: data/shards/redpajama_full\n        weight: 0.15\n      - name: code\n        shards_dir: data/shards/code_full\n        weight: 0.1\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 100\n  log_interval: 10\n  device: \"cuda:1\"\n  seed: 808\n  deterministic: false\n  step_offset: 0\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: true\n    mode: max-autotune\n  fsdp:\n    auto_wrap_min_params: 2000000\n    cpu_offload: false\n  checkpoint:\n    enable: true\n    dir: checkpoints/mid\n    save_interval: 50\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: muon\n  lr: 2.0e-4\n  weight_decay: 0.02\n  momentum: 0.95\n  betas:\n    - 0.9\n    - 0.999\n\nlogging:\n  enabled: false\n  backend: wandb\n  project: nested-learning\n  run_name: mid-${now:%Y%m%d%H%M%S}\n  path: logs/mid_metrics.json\n"
  },
  {
    "path": "configs/hope/mid_fsdp.yaml",
    "content": "defaults:\n  - mid\n  - _self_\n\nmodel:\n  gradient_checkpointing: true\n\ndata:\n  batch_size: 8  # per-rank micro-batch for 2× RTX 6000 Ada\n  num_workers: 6\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 250000\n  log_interval: 20\n  device: \"cuda\"\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: false\n  fsdp:\n    auto_wrap_min_params: 2000000\n    cpu_offload: false\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/mid_fsdp\n    save_interval: 1000\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: muon\n  lr: 2.0e-4\n  weight_decay: 0.01\n\nlogging:\n  enabled: true\n  backend: wandb\n  project: nested-learning\n  run_name: hope-mid-fsdp-${now:%Y%m%d%H%M%S}\n  path: logs/mid_fsdp_metrics.json\n"
  },
  {
    "path": "configs/hope/pilot.yaml",
    "content": "defaults:\n  - /pilot\n"
  },
  {
    "path": "configs/hope/pilot_attention.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_attention\n  qk_l2_norm: true\n  local_conv_window: 4\n\n"
  },
  {
    "path": "configs/hope/pilot_selfmod.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  # Chunk update cadence (paper §8.2): other memories update more often than M_memory.\n  self_mod_chunk_size: 8\n  self_mod_chunk_size_memory: 64\n\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod\n\nlogging:\n  run_name: pilot-selfmod\n  path: logs/pilot_selfmod_metrics.json\n"
  },
  {
    "path": "configs/hope/pilot_transformer.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  block_variant: transformer\n  qk_l2_norm: true\n  local_conv_window: 4\n\n"
  },
  {
    "path": "configs/hope/target.yaml",
    "content": "defaults:\n  - _self_\n\nhydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  vocab_size: 32000\n  dim: 1536\n  num_layers: 32\n  heads: 24\n  surprise_threshold: null\n  freeze_backbone: false\n  titan_level:\n    name: titan\n    update_period: 32\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_fast_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_mid_opt\n    - name: cms_slow\n      update_period: 32\n      optimizer_key: cms_slow_opt\n    - name: cms_ultra\n      update_period: 128\n      optimizer_key: cms_slow_opt\n    - name: cms_anchor\n      update_period: 512\n      optimizer_key: cms_anchor_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 6.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n    cms_fast_opt:\n      type: deep_momentum\n      lr: 3.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n    cms_mid_opt:\n      type: deep_momentum\n      lr: 2.5e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n    cms_slow_opt:\n      type: deep_momentum\n      lr: 2.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n    cms_anchor_opt:\n      type: deep_momentum\n      lr: 1.5e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n\ndata:\n  source: mixture\n  batch_size: 32\n  num_workers: 8\n  mixture:\n    samples_per_epoch: 32768\n    seed: 123\n    sources:\n      - name: refinedweb\n        shards_dir: data/shards/refinedweb_filtered\n        weight: 0.35\n      - name: wikipedia\n        shards_dir: data/shards/wikipedia_filtered\n        weight: 0.2\n      - name: c4\n        shards_dir: data/shards/c4_filtered\n        weight: 0.15\n      - name: redpajama\n        shards_dir: data/shards/redpajama_filtered\n        weight: 0.2\n      - name: code\n        shards_dir: data/shards/code_filtered\n        weight: 0.1\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 200\n  log_interval: 10\n  device: \"cuda:1\"\n  seed: 9001\n  deterministic: false\n  step_offset: 0\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: true\n    mode: max-autotune\n  fsdp:\n    auto_wrap_min_params: 2000000\n    cpu_offload: false\n  checkpoint:\n    enable: true\n    dir: checkpoints/target\n    save_interval: 100\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: muon\n  lr: 1.5e-4\n  weight_decay: 0.02\n  momentum: 0.95\n  betas:\n    - 0.9\n    - 0.999\n\nlogging:\n  enabled: false\n  backend: wandb\n  project: nested-learning\n  run_name: target-${now:%Y%m%d%H%M%S}\n  path: logs/target_metrics.json\n\ndeepspeed:\n  config: configs/deepspeed/zero3.json\n"
  },
  {
    "path": "configs/hope/target_fsdp.yaml",
    "content": "defaults:\n  - target\n  - _self_\n\nmodel:\n  gradient_checkpointing: true\n\ndata:\n  batch_size: 4  # per-rank micro-batch\n  num_workers: 8\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 300000\n  log_interval: 20\n  device: \"cuda\"\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: false\n  fsdp:\n    auto_wrap_min_params: 2500000\n    cpu_offload: false\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/target_fsdp\n    save_interval: 1000\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: muon\n  lr: 1.5e-4\n  weight_decay: 0.01\n\nlogging:\n  enabled: true\n  backend: wandb\n  project: nested-learning\n  run_name: hope-target-fsdp-${now:%Y%m%d%H%M%S}\n  path: logs/target_fsdp_metrics.json\n"
  },
  {
    "path": "configs/mid_smoke.yaml",
    "content": "hydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  vocab_size: 32000\n  dim: 256\n  num_layers: 4\n  heads: 8\n  titan_level:\n    name: titan\n    update_period: 16\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 16\n      optimizer_key: cms_opt\n    - name: cms_ultra\n      update_period: 64\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 8.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n    cms_opt:\n      type: deep_momentum\n      lr: 4.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n\ndata:\n  source: mixture\n  batch_size: 4\n  num_workers: 0\n  mixture:\n    samples_per_epoch: 128\n    seed: 0\n    sources:\n      - name: refinedweb\n        shards_dir: data/shards/refinedweb_filtered\n        weight: 0.4\n      - name: wikipedia\n        shards_dir: data/shards/wikipedia_filtered\n        weight: 0.2\n      - name: c4\n        shards_dir: data/shards/c4_filtered\n        weight: 0.15\n      - name: redpajama\n        shards_dir: data/shards/redpajama_filtered\n        weight: 0.15\n      - name: code\n        shards_dir: data/shards/code_filtered\n        weight: 0.1\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 10\n  log_interval: 1\n  device: \"cpu\"\n  seed: 2024\n  deterministic: true\n  mixed_precision:\n    enabled: false\n    dtype: bf16\n  compile:\n    enable: false\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/mid_smoke\n    save_interval: 10\n    save_last: true\n\noptim:\n  type: adamw\n  lr: 2.0e-4\n  fused: false\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/mid_smoke.json\n"
  },
  {
    "path": "configs/mid_stage2.yaml",
    "content": "hydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  vocab_size: 32000\n  dim: 768\n  num_layers: 18\n  heads: 12\n  teach_scale: 0.05\n  teach_clip: 5.0\n  teach_schedule:\n    warmup_steps: 20\n    decay_start: 80\n    decay_duration: 40\n  titan_level:\n    name: titan\n    update_period: 16\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 32\n      optimizer_key: cms_opt\n    - name: cms_ultra\n      update_period: 128\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 8.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n    cms_opt:\n      type: deep_momentum\n      lr: 4.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n\ndata:\n  source: mixture\n  batch_size: 8\n  num_workers: 2\n  mixture:\n    samples_per_epoch: 1024\n    seed: 42\n    sources:\n      - name: refinedweb\n        shards_dir: data/shards/refinedweb_full\n        weight: 0.4\n      - name: wikipedia\n        shards_dir: data/shards/wikipedia_full\n        weight: 0.2\n      - name: c4\n        shards_dir: data/shards/c4_full\n        weight: 0.15\n      - name: redpajama\n        shards_dir: data/shards/redpajama_full\n        weight: 0.15\n      - name: code\n        shards_dir: data/shards/code_full\n        weight: 0.1\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 100\n  log_interval: 10\n  device: \"cuda\"\n  seed: 3401\n  deterministic: false\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: true\n    mode: max-autotune\n  fsdp:\n    auto_wrap_min_params: 2000000\n    cpu_offload: false\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/mid_stage2\n    save_interval: 100\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: adamw\n  lr: 3.0e-5\n  fused: auto\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/mid_stage2.json\n"
  },
  {
    "path": "configs/mid_stage2_smoke.yaml",
    "content": "hydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  vocab_size: 32000\n  dim: 512\n  num_layers: 12\n  heads: 8\n  teach_scale: 0.2\n  teach_clip: 2.0\n  titan_level:\n    name: titan\n    update_period: 16\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 16\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 6.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n    cms_opt:\n      type: deep_momentum\n      lr: 3.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n\ndata:\n  source: mixture\n  batch_size: 8\n  num_workers: 2\n  mixture:\n    samples_per_epoch: 512\n    seed: 0\n    sources:\n      - name: refinedweb\n        shards_dir: data/shards/refinedweb_filtered\n        weight: 0.4\n      - name: wikipedia\n        shards_dir: data/shards/wikipedia_filtered\n        weight: 0.2\n      - name: c4\n        shards_dir: data/shards/c4_filtered\n        weight: 0.15\n      - name: redpajama\n        shards_dir: data/shards/redpajama_filtered\n        weight: 0.15\n      - name: code\n        shards_dir: data/shards/code_filtered\n        weight: 0.1\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 60\n  log_interval: 5\n  device: \"cuda\"\n  seed: 777\n  deterministic: false\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: false\n  fsdp:\n    auto_wrap_min_params: 1500000\n    cpu_offload: false\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/mid_stage2_smoke\n    save_interval: 60\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: adamw\n  lr: 1.0e-4\n  fused: auto\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/mid_stage2_smoke.json\n"
  },
  {
    "path": "configs/mid_titan_baseline.yaml",
    "content": "hydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  type: titan\n  vocab_size: 32000\n  dim: 768\n  num_layers: 18\n  heads: 12\n  surprise_threshold: 0.02\n  freeze_backbone: false\n  titan_level:\n    name: titan\n    update_period: 16\n    optimizer_key: titan_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 8.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n  teach_scale: 0.10\n  teach_clip: 4.0\n  teach_schedule:\n    warmup_steps: 60\n    decay_start: 140\n    decay_duration: 80\n\ndata:\n  source: mixture\n  batch_size: 4\n  num_workers: 2\n  mixture:\n    samples_per_epoch: 1024\n    seed: 42\n    sources:\n      - name: refinedweb\n        shards_dir: data/shards/refinedweb_full\n        weight: 0.4\n      - name: wikipedia\n        shards_dir: data/shards/wikipedia_full\n        weight: 0.2\n      - name: c4\n        shards_dir: data/shards/c4_full\n        weight: 0.15\n      - name: redpajama\n        shards_dir: data/shards/redpajama_full\n        weight: 0.15\n      - name: code\n        shards_dir: data/shards/code_full\n        weight: 0.1\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 220\n  log_interval: 20\n  device: \"cuda:1\"\n  seed: 451\n  deterministic: false\n  step_offset: 0\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: false\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/mid_titan_baseline\n    save_interval: 100\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: adamw\n  lr: 1.0e-5\n  fused: auto\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/mid_titan_baseline.json\n  run_name: mid_titan_baseline\n"
  },
  {
    "path": "configs/pilot.yaml",
    "content": "defaults:\n  - _self_\n\nhydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  vocab_size: 32000\n  dim: 512\n  num_layers: 12\n  heads: 8\n  teach_scale: 0.10\n  teach_clip: 5.0\n  surprise_threshold: 0.02\n  freeze_backbone: false\n  self_mod_lr: 0.001\n  teach_schedule:\n    warmup_steps: 2000\n    decay_start: 120000\n    decay_duration: 20000\n  titan_level:\n    name: titan\n    update_period: 8\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 32\n      optimizer_key: cms_opt\n    - name: cms_ultra\n      update_period: 128\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 6.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        # Best-effort paper mapping: rank-1 context projection preconditioner.\n        variant: nl_l2_precond\n    cms_opt:\n      type: deep_momentum\n      lr: 3.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        # Best-effort paper mapping: rank-1 context projection preconditioner.\n        variant: nl_l2_precond\n\ndata:\n  source: mixture\n  seq_len: 2048\n  batch_size: 6\n  num_workers: 4\n  mixture:\n    samples_per_epoch: 65536\n    seed: 1337\n    sources:\n      - name: refinedweb\n        shards_dir: data/shards/refinedweb_filtered\n        weight: 0.4\n      - name: wikipedia\n        shards_dir: data/shards/wikipedia_filtered\n        weight: 0.2\n      - name: c4\n        shards_dir: data/shards/c4_filtered\n        weight: 0.15\n      - name: redpajama\n        shards_dir: data/shards/redpajama_filtered\n        weight: 0.15\n      - name: code\n        shards_dir: data/shards/code_filtered\n        weight: 0.1\n\ntrain:\n  algorithm_mode: two_pass_stopgrad_updates\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 246667\n  log_interval: 50\n  device: \"cuda:1\"\n  seed: 1337\n  deterministic: false\n  step_offset: 0\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: false\n    mode: max-autotune\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/pilot\n    save_interval: 1000\n    save_last: true\n    resume_path: null\n    resume_tag: null\n\noptim:\n  type: muon\n  lr: 2.5e-4\n  weight_decay: 0.02\n  momentum: 0.95\n  betas:\n    - 0.9\n    - 0.999\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_metrics.json\n  project: nested-learning\n  run_name: pilot-main\n"
  },
  {
    "path": "configs/pilot_paper_faithful.yaml",
    "content": "defaults:\n  - /pilot\n  - _self_\n\nmodel:\n  # Explicit paper-defined variant (avoid inheriting repo default `hope_hybrid`).\n  block_variant: hope_attention\n  # Paper-faithful: treat \"surprise\" as the (scaled) teach signal itself, without threshold gating.\n  surprise_threshold: null\n  # Paper updates on the last (possibly partial) chunk; enable flush for non-multiple seq lengths.\n  cms_flush_partial_at_end: true\n  # Paper: q is non-adaptive and uses a fixed projection.\n  self_mod_adaptive_q: false\n  # Paper: local causal conv in the HOPE self-mod module.\n  self_mod_local_conv_window: 4\n\ndata:\n  # Paper-faithful semantics: CMS/TITAN fast state is per-context; this repo currently treats\n  # each *batch* as a single shared context when batch_size>1.\n  batch_size: 1\n\ntrain:\n  algorithm_mode: two_pass_stopgrad_updates\n  # Keep this explicit (instead of inherited) so paper-faithful behavior is visible in one file.\n  online_updates: true\n  # Paper: re-initialize fast memories per context (sequence).\n  use_fast_state: true\n  strict_streaming_contract: true\n  # Use explicit boundary-token supervision (no overlap approximation).\n  online_boundary_targets: true\n  # Carry attention state across chunks during online updates.\n  online_carry_attention_cache: true\n  # Fail fast if DDP would silently disable paper-critical features.\n  fail_if_paper_faithful_disabled: true\n\noptim:\n  # Ensure meta-learning updates include memory module initial states (paper §8.2).\n  param_policy: all\n\nlogging:\n  run_name: pilot-paper-faithful\n  path: logs/pilot_paper_faithful_metrics.json\n"
  },
  {
    "path": "configs/pilot_selfmod_paper_faithful.yaml",
    "content": "defaults:\n  - /pilot_paper_faithful\n  - _self_\n\nmodel:\n  block_variant: hope_selfmod\n  # Chunk update cadence (paper §8.2): other memories update more often than M_memory.\n  self_mod_chunk_size: 8\n  self_mod_chunk_size_memory: 64\n  self_mod_use_skip: false\n\ntrain:\n  checkpoint:\n    dir: artifacts/checkpoints/pilot_selfmod_paper_faithful\n\nlogging:\n  run_name: pilot-selfmod-paper-faithful\n  path: logs/pilot_selfmod_paper_faithful_metrics.json\n"
  },
  {
    "path": "configs/pilot_smoke.yaml",
    "content": "hydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\n\nmodel:\n  vocab_size: 32000\n  dim: 128\n  num_layers: 2\n  heads: 4\n  titan_level:\n    name: titan\n    update_period: 8\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 16\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 1.0e-3\n      params:\n        beta: 0.9\n        beta2: 0.999\n    cms_opt:\n      type: deep_momentum\n      lr: 5.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n\ndata:\n  source: synthetic\n  vocab_size: 32000\n  seq_len: 64\n  dataset_size: 1024\n  batch_size: 4\n  num_workers: 0\n\ntrain:\n  strict_streaming_contract: false\n  online_updates: true\n  online_chunk_size: 0\n  online_boundary_targets: false\n  online_carry_attention_cache: false\n  per_layer_teach_signal: true\n  steps: 10\n  log_interval: 1\n  device: \"cpu\"\n  seed: 1234\n  deterministic: true\n  mixed_precision:\n    enabled: false\n    dtype: bf16\n  compile:\n    enable: false\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/pilot_smoke\n    save_interval: 10\n    save_last: true\n\noptim:\n  type: adamw\n  lr: 3.0e-4\n  fused: false\n\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_smoke.json\n"
  },
  {
    "path": "configs/resolved/cms_sparse_eval.yaml",
    "content": "hydra:\n  run:\n    dir: .\n  output_subdir: null\n  job:\n    chdir: false\nmodel:\n  vocab_size: 32000\n  dim: 384\n  num_layers: 8\n  heads: 6\n  teach_scale: 0.1\n  teach_clip: 5.0\n  self_mod_lr: 0.001\n  teach_schedule:\n    warmup_steps: 2000\n    decay_start: 120000\n    decay_duration: 20000\n  titan_level:\n    name: titan\n    update_period: 8\n    optimizer_key: titan_opt\n  cms_levels:\n  - name: cms_fast\n    update_period: 8\n    optimizer_key: cms_opt\n  - name: cms_mid\n    update_period: 32\n    optimizer_key: cms_opt\n  - name: cms_slow\n    update_period: 128\n    optimizer_key: cms_opt\n  - name: cms_ultra\n    update_period: 512\n    optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 0.0006\n      params:\n        beta: 0.9\n        beta2: 0.999\n    cms_opt:\n      type: deep_momentum\n      lr: 0.0003\n      params:\n        beta: 0.9\n        beta2: 0.999\n  cms_hidden_multiplier: 2\ndata:\n  source: mixture\n  seq_len: 1024\n  batch_size: 2\n  num_workers: 2\n  mixture:\n    samples_per_epoch: 65536\n    seed: 1337\n    sources:\n    - name: refinedweb\n      shards_dir: data/shards/refinedweb_filtered\n      weight: 0.4\n    - name: wikipedia\n      shards_dir: data/shards/wikipedia_filtered\n      weight: 0.2\n    - name: c4\n      shards_dir: data/shards/c4_filtered\n      weight: 0.15\n    - name: redpajama\n      shards_dir: data/shards/redpajama_filtered\n      weight: 0.15\n    - name: code\n      shards_dir: data/shards/code_filtered\n      weight: 0.1\ntrain:\n  online_updates: true\n  online_chunk_size: 0\n  per_layer_teach_signal: true\n  steps: 5000\n  log_interval: 25\n  device: cuda:1\n  seed: 1337\n  deterministic: false\n  mixed_precision:\n    enabled: true\n    dtype: bf16\n  compile:\n    enable: false\n    mode: max-autotune\n  checkpoint:\n    enable: true\n    dir: artifacts/checkpoints/pilot_cms_sparse\n    save_interval: 1000\n    save_last: true\n    resume_path: null\n    resume_tag: null\noptim:\n  type: adamw\n  lr: 0.00025\n  fused: auto\nlogging:\n  enabled: true\n  backend: json\n  path: logs/pilot_cms_sparse_metrics.json\n  project: nested-learning\n  run_name: pilot-cms-sparse\n"
  },
  {
    "path": "configs/resolved/phase2_pilot_attention_eval.yaml",
    "content": "model:\n  vocab_size: 32000\n  dim: 512\n  num_layers: 12\n  heads: 8\n  teach_scale: 0.10\n  teach_clip: 5.0\n  surprise_threshold: 0.02\n  freeze_backbone: false\n  qk_l2_norm: true\n  local_conv_window: 4\n  block_variant: hope_attention\n  teach_schedule:\n    warmup_steps: 2000\n    decay_start: 120000\n    decay_duration: 20000\n  titan_level:\n    name: titan\n    update_period: 8\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 32\n      optimizer_key: cms_opt\n    - name: cms_ultra\n      update_period: 128\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 6.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n    cms_opt:\n      type: deep_momentum\n      lr: 3.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n\n"
  },
  {
    "path": "configs/resolved/phase2_pilot_transformer_eval.yaml",
    "content": "model:\n  vocab_size: 32000\n  dim: 512\n  num_layers: 12\n  heads: 8\n  teach_scale: 0.10\n  teach_clip: 5.0\n  surprise_threshold: 0.02\n  freeze_backbone: false\n  qk_l2_norm: true\n  local_conv_window: 4\n  block_variant: transformer\n  teach_schedule:\n    warmup_steps: 2000\n    decay_start: 120000\n    decay_duration: 20000\n  titan_level:\n    name: titan\n    update_period: 8\n    optimizer_key: titan_opt\n  cms_levels:\n    - name: cms_fast\n      update_period: 1\n      optimizer_key: cms_opt\n    - name: cms_mid\n      update_period: 4\n      optimizer_key: cms_opt\n    - name: cms_slow\n      update_period: 32\n      optimizer_key: cms_opt\n    - name: cms_ultra\n      update_period: 128\n      optimizer_key: cms_opt\n  optimizers:\n    titan_opt:\n      type: deep_momentum\n      lr: 6.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n    cms_opt:\n      type: deep_momentum\n      lr: 3.0e-4\n      params:\n        beta: 0.9\n        beta2: 0.999\n        variant: nl_l2_precond\n\n"
  },
  {
    "path": "docker/Dockerfile.dist",
    "content": "FROM scratch\n\nLABEL org.opencontainers.image.title=\"nested-learning-dist\"\nLABEL org.opencontainers.image.description=\"OCI bundle containing nested-learning wheel/sdist/checksums from CI release builds.\"\n\n# Built artifacts generated in workflow before image build.\nCOPY dist/ /dist/\n"
  },
  {
    "path": "docs/BUG_REPORT_CHECKLIST.md",
    "content": "# Bug Report Checklist\n\nUse this checklist when filing reproducibility or correctness issues.\n\n## Required Context\n\n- Commit SHA (`git rev-parse --short HEAD`)\n- Exact command line used\n- Config name and CLI overrides\n- Device/runtime details (`python --version`, `uv --version`, `nvidia-smi` if CUDA)\n\n## Required Artifacts\n\n- JSON training log path (if training path involved)\n- Full traceback/error output\n- Minimal failing input or dataset pointer\n- If streaming/cadence related: include `scripts/checks/verify_update_cadence.py` output\n- Include `scripts/checks/compliance_report.py` output (or note why unavailable)\n\n## Fast Reproduction Path\n\n1. Run `uv run bash scripts/checks/run_fidelity_ci_subset.sh`.\n2. Run `uv run bash scripts/run_mechanism_audit_smoke.sh`.\n3. Attach outputs and note which step failed.\n\n## Streaming/Cadence Issues\n\n- Specify `train.strict_streaming_contract` value.\n- Specify `train.online_updates`, `train.online_chunk_size`, `train.online_boundary_targets`, `train.online_carry_attention_cache`, `model.cms_flush_partial_at_end`.\n- Include expected vs observed update counts per level.\n"
  },
  {
    "path": "docs/COMPATIBILITY_MATRIX.md",
    "content": "# Compatibility Matrix\n\nThis document defines the support contract for runtime/backends.\n\n## Support Tiers\n\n- **Tier 1 (Supported):** CI-tested on every PR; regressions treated as bugs.\n- **Tier 2 (Supported with caveats):** tested periodically/partially; backend caveats apply.\n- **Tier 3 (Best-effort):** community-supported; no guaranteed CI lane.\n- **Unsupported:** intentionally out of scope; fail-fast when correctness is at risk.\n\n## Matrix\n\n| OS | Python | CPU | CUDA (NVIDIA) | MPS | ROCm |\n|---|---|---|---|---|---|\n| Linux x86_64 | 3.10-3.12 | Tier 1 (import/eval/smoke) | Tier 1 (import/eval/smoke/full training) | Unsupported | Tier 3 |\n| macOS Apple Silicon | 3.10-3.12 | Tier 2 (import/eval/smoke) | Unsupported | Tier 2 (import/eval), Tier 3 (smoke) | Unsupported |\n| macOS Intel | 3.10-3.12 | Tier 2 (import/eval), Tier 3 (smoke) | Unsupported | Unsupported | Unsupported |\n| Windows | 3.10-3.12 | Tier 2 (import/eval), Tier 3 (smoke) | Tier 3 (user-managed) | Unsupported | Unsupported |\n\nNotes:\n- CPU full-scale training is not a supported target.\n- Strict paper-faithful online-update semantics in distributed settings remain constrained by design.\n- Numerical parity across backend families (CUDA/MPS/ROCm) is not guaranteed.\n\n## Apple Silicon (MPS) practical expectations\n\nOn macOS Apple Silicon, this repo is intended to support:\n- install/import,\n- CLI diagnostics (`nl doctor`),\n- smoke/eval workflows,\n- small local runs with `train.device=mps`.\n\nThis repo does not currently treat macOS/MPS as a full paper-scale training target.\nFor full-size training and published artifact reproduction, prefer Linux + CUDA Tier 1 environments.\n\n## Runtime Degradation Policy\n\nAt runtime, unsupported performance features should degrade gracefully:\n- if flash/mem-efficient SDPA is unavailable, use math SDPA;\n- if `torch.compile` is unavailable/disabled, continue without compile;\n- if requested mixed precision is unsupported on the backend, degrade to fp32 and log it.\n\nUse `nl doctor --json` to capture capability snapshots in machine-readable form.\n\n## Golden Environment\n\nFor reproducibility of this repository’s published artifacts, prefer:\n- Python 3.12\n- PyTorch 2.9.x\n- `uv lock` / `uv sync --all-extras --dev`\n\nThe package metadata allows broader install ranges for portability, while the lockfile remains the canonical dev/test environment.\n"
  },
  {
    "path": "docs/FSDP_SCALING_GUIDE.md",
    "content": "# FSDP/ZeRO Scaling Guide (RTX 6000 Ada Dual-GPU Rig)\n\nThis note captures the configuration we will use for the Stage 2 mid (≈760 M) and target (≈1.3 B) HOPE models when running on the dual RTX 6000 Ada workstation (2× 48 GB). It accompanies the new Hydra configs `configs/hope/mid_fsdp.yaml` and `configs/hope/target_fsdp.yaml`.\n\n## Hardware & Software Assumptions\n- 2× NVIDIA RTX 6000 Ada (48 GB each)\n- CUDA 12.4, PyTorch 2.9, `uv` environment\n- NCCL backend, FSDP via `torch.distributed.fsdp`\n- Checkpoints stored under `artifacts/checkpoints/{mid_fsdp,target_fsdp}`\n\n## Config summary\n\n| Model | Params | Config | Per-rank micro-batch | Global batch (nranks=2) | Expected VRAM | Notes |\n|-------|--------|--------|----------------------|-------------------------|---------------|-------|\n| HOPE mid | ~760 M (dim 1024, 24L) | `configs/hope/mid_fsdp.yaml` | 8 sequences × 2048 tokens | 16×2048 tokens | 43–45 GB | bf16 activations, Muon outer optimizer, NL inner optimizer, gradient checkpointing, FSDP auto-wrap ≥2 M params |\n| HOPE target | ~1.3 B (dim 1536, 32L) | `configs/hope/target_fsdp.yaml` | 4 sequences × 2048 tokens | 8×2048 tokens | 46–48 GB | Slightly smaller per-rank batch to stay under 48 GB; Muon + checkpointing identical to mid config |\n\nBoth configs default to:\n- `optim.type = muon` (outer optimizer) with `nl_l2_precond` inner updates already wired through model lvl optimizers.\n- bf16 autocast (`train.mixed_precision.enabled = true, dtype = bf16`).\n- Gradient checkpointing via `model.gradient_checkpointing = true` (saves ~3 GB per rank).\n- `train.compile.enable = false` (Torch.compile can be toggled on after validation).\n- FSDP auto-wrap policy set via `train.fsdp.auto_wrap_min_params`.\n\n## Launch commands\n\n```bash\n# Mid model, 2 GPUs\nUV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy \\\nuv run torchrun --nproc_per_node=2 train_fsdp.py \\\n  --config-name hope/mid_fsdp logging.run_name=mid-fsdp-${USER}\n\n# Target model, 2 GPUs\nUV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy \\\nuv run torchrun --nproc_per_node=2 train_fsdp.py \\\n  --config-name hope/target_fsdp logging.run_name=target-fsdp-${USER}\n```\n\nTo resume from a checkpoint, set `train.checkpoint.resume_path` (path to `step_xxxxxx.pt`). State dicts use FSDP’s full-state sharding with CPU offload for rank 0.\n\n## ZeRO / DeepSpeed note\n\nFor multi-node runs or larger batch sizes, leverage `train_deepspeed.py` with `configs/deepspeed/zero3.json`. The per-model configs above can be reused by passing `--config-name hope/mid_fsdp` together with `DEEPSPEED_CONFIG=configs/deepspeed/zero3.json`.\n\n## Logging & Monitoring\n\n- JSON metrics live at `logs/mid_fsdp_metrics.json` or `logs/target_fsdp_metrics.json`.\n- W&B logging is enabled by default (`project = nested-learning`).\n- Additional telemetry (teach-signal norms, projector stats, CMS chunk samples) already flows through the model update metrics; ensure your W&B dashboard visualizes:\n  - `layer*.titan.titan.grad_norm`\n  - `layer*.titan.titan.ctx_norm` / `proj_norm`\n  - `layer*.cms.cms_fast.chunk_samples`, etc.\n\n## Checklist before starting a long run\n1. `uv run pytest` (ensure faithfulness tests pass).\n2. `nvidia-smi` — GPUs idle and temps normal.\n3. Confirm dataset shards (`data/shards/*_full/`) available locally.\n4. W&B credentials set (`WANDB_API_KEY`).\n5. For target config, consider setting `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` to reduce fragmentation.\n\nThis guide should let collaborators pick up the FSDP configs immediately without reverse-engineering the Hydra hierarchy.\n"
  },
  {
    "path": "docs/IMPLEMENTATION_STATUS.md",
    "content": "# Implementation Status (Source of Truth)\n\nThis table is the canonical mechanism-status map for this repo.\n\n| Mechanism | Status | Evidence |\n|---|---|---|\n| Teach-signal alignment (LM head tied to embeddings) | Implemented | `src/nested_learning/model.py`, `src/nested_learning/training.py`, `tests/test_teach_signal.py`, `tests/test_tied_weight_guard.py` |\n| Per-layer local teach signals (`δ_l`) | Implemented (single-process path) | `src/nested_learning/model.py`, `src/nested_learning/training.py`, `tests/test_teach_signal.py` |\n| CMS chunk accumulation + cross-call cadence | Implemented | `src/nested_learning/hope/block.py`, `tests/test_cms.py`, `tests/test_cms_cross_call.py`, `tests/test_model_streaming_cadence.py` |\n| CMS finalize/partial flush semantics | Implemented | `src/nested_learning/hope/block.py`, `tests/test_cms_flush_partial.py`, `docs/STREAMING_CONTRACT.md` |\n| Online chunking (overlap mode) | Implemented | `src/nested_learning/training.py`, `tests/test_online_chunking.py` |\n| Online chunking (boundary-target mode) | Implemented | `src/nested_learning/training.py`, `configs/pilot_paper_faithful.yaml`, `tests/test_online_chunking.py` |\n| Optional attention-cache carry across chunk calls | Implemented (single-process path) | `src/nested_learning/backbones.py`, `src/nested_learning/model.py`, `tests/test_attention_cache.py`, `tests/test_eval_state.py` |\n| Strict runtime guardrails | Implemented | `src/nested_learning/training.py`, `tests/test_strict_streaming_contract.py`, `tests/test_distributed_fail_fast.py`, `tests/test_fast_state_batch_semantics.py` |\n| Training algorithm mode banner/validation | Implemented (`two_pass_stopgrad_updates`, `boundary_state_grad_through_write`) | `src/nested_learning/training.py`, `tests/test_strict_streaming_contract.py` |\n| Boundary-state gradient-through-write algorithm mode | Experimental (single-process constrained path) | `src/nested_learning/training.py`, `tests/test_boundary_state_mode.py`, `tests/test_algorithm_mode_grad.py`, `docs/PAPER_COMPLIANCE.md` |\n| Online-updates fast-state invariant (`online_updates && !use_fast_state`) | Implemented (warn/error guard) | `src/nested_learning/training.py`, `tests/test_strict_streaming_contract.py` |\n| Inner optimizer mapping (`nl_l2_precond`) | Implemented (best-effort mapping) | `src/nested_learning/optim/deep.py`, `tests/test_optim.py`, `docs/PAPER_COMPLIANCE.md` |\n| Surprise-gated update flow | Implemented | `src/nested_learning/model.py`, `src/nested_learning/hope/block.py`, `src/nested_learning/titan/model.py` |\n| Test-time memorization path in eval harnesses | Implemented | `src/nested_learning/memorize.py`, `scripts/eval/*.py`, `tests/test_memorization.py` |\n| Compliance automation report | Implemented | `scripts/checks/compliance_report.py`, `scripts/checks/run_fidelity_ci_subset.sh`, `scripts/run_mechanism_audit_smoke.sh` |\n| Doc-to-code reference guard (anti-overclaim drift) | Implemented | `scripts/checks/verify_docs_refs.py`, `.github/workflows/ci.yml`, `tests/test_verify_docs_refs.py` |\n| Portable package/CLI entrypoints (`nl`, `python -m nested_learning`) | Implemented | `src/nested_learning/cli.py`, `src/nested_learning/__main__.py`, `tests/test_cli_tooling.py`, `pyproject.toml` |\n| Cross-platform smoke + wheel install CI gates | Implemented | `.github/workflows/ci.yml` (`cross-platform-smoke`, `wheel-install-smoke`) |\n| Package release automation (tag -> TestPyPI/PyPI) | Implemented | `.github/workflows/release.yml`, `docs/PACKAGE_RELEASE_CHECKLIST.md` |\n| Full boundary-state gradient-through-write algorithm from paper | Partially implemented (experimental) | Constrained single-process mode exists; not yet treated as production/full-scale parity (`docs/PAPER_COMPLIANCE.md`) |\n| Distributed mechanism-auditing parity for online/per-layer/boundary-cache path | Deferred | DDP strict fail-fast + documented limits (`src/nested_learning/training.py`, `scripts/run_cpu_ddp_smoke.sh`) |\n| Paper-scale training/eval reproduction | Deferred | Explicitly out of sprint scope (`docs/PAPER_COMPLIANCE.md`) |\n\n## Validation Entrypoints\n\n- Fidelity subset: `scripts/checks/run_fidelity_ci_subset.sh`\n- Mechanism-auditing smoke: `scripts/run_mechanism_audit_smoke.sh`\n- Full tests: `uv run pytest`\n"
  },
  {
    "path": "docs/P4_REMEDIATION_PLAN.md",
    "content": "# P4 Remediation Plan — Status & Tracking (Paper-Faithful HOPE/Nested Learning)\n\nThis file started as an execution checklist for the P4 “paper faithfulness” sprint. It is now maintained as a **status page** so contributors can quickly see what’s implemented, what is verified by tests, and what follow‑ups remain.\n\nFor the canonical “how to run paper‑faithful mode” guide, see `docs/PAPER_COMPLIANCE.md`.\n\n## Status (core remediation)\n\n**P0/P1 core faithfulness items:** complete.\n\nImplemented behaviors (with pointers):\n- **Self‑modifying TITAN path always-on** during the inner/update pass (does not require an external teach signal).  \n  Code: `src/nested_learning/hope/block.py`  \n  Test: `tests/test_selfmod_online.py`\n- **CMS update semantics** use per‑token δ targets and **sum‑over‑chunk** accumulation (no chunk‑mean broadcast).  \n  Code: `src/nested_learning/hope/block.py` (`_chunk_loss`, `_CmsBuffer`, `_pop_buffer_chunk`)  \n  Test: `tests/test_cms.py`\n- **Online CMS read‑after‑write** behavior (later tokens can see updated CMS weights when using the online training path).  \n  Code: `src/nested_learning/hope/block.py` (`_cms_forward_online`) + `src/nested_learning/training.py` (`train.online_updates`)  \n  Test: `tests/test_cms.py` (`test_cms_online_updates_affect_later_tokens`)\n- **Per‑layer local error signals (δℓ)** computed via autograd and routed into each block.  \n  Code: `src/nested_learning/model.py` (`forward_with_block_outputs`, `teach_signals`) + `src/nested_learning/training.py` (`_compute_layer_teach_signals`)  \n  Test: `tests/test_teach_signal.py`\n- **Paper optimizer option (M3)** implemented and selectable via `optim.type=m3`.  \n  Code: `src/nested_learning/optim/m3.py`  \n  Test: `tests/test_m3.py`\n\nDocs/telemetry added:\n- Paper‑faithful run flags + code mapping: `docs/PAPER_COMPLIANCE.md`\n- README “paper‑faithful mode” snippet: `README.md`\n- Per‑layer update telemetry (e.g. `layerX.cms.*`) emitted via `HOPEModel._gather_block_stats()`.\n\n## Remaining follow-ups (optional hardening, not required for “implemented correctly”)\n\nThese are improvements that strengthen the validation story or reduce ambiguity, but they are not required to claim the core mechanism is implemented:\n\n- [ ] Add an explicit unit test that demonstrates **per‑token δ vs chunk‑mean broadcast** leads to different update directions (sanity test on a toy CMS).\n- [ ] Add a “two chunks vs one chunk” regression test to lock in chunk boundary semantics in `train.online_updates` mode.\n- [ ] Expose `cms_online_updates` / `cms_chunk_reduction` / `selfmod_online_updates` as Hydra config toggles (currently paper‑faithful defaults live in the HOPE block configs).\n- [ ] Port the `train.py` online/per‑layer δℓ path to multi‑GPU (FSDP or custom DDP) so paper‑faithful mode scales beyond single GPU.\n\n"
  },
  {
    "path": "docs/PACKAGE_RELEASE_CHECKLIST.md",
    "content": "# Package Release Checklist (PyPI/GitHub)\n\nUse this checklist for package distribution releases (separate from checkpoint/artifact releases).\n\n## Pre-Release (RC)\n\n- [ ] `uv run ruff check .`\n- [ ] `uv run mypy src`\n- [ ] `uv run pytest -q`\n- [ ] `uv build`\n- [ ] `uvx twine check dist/*`\n- [ ] wheel install smoke works outside repo tree:\n  - [ ] `python -m venv /tmp/nl-wheel`\n  - [ ] `pip install dist/*.whl`\n  - [ ] `python -m nested_learning --help`\n  - [ ] `python -m nested_learning doctor --json`\n  - [ ] `python -m nested_learning smoke --config-name pilot_smoke --device cpu --batch-size 1 --seq-len 8`\n- [ ] `CHANGELOG.md` updated with:\n  - [ ] release highlights\n  - [ ] breaking changes (or explicit “none”)\n- [ ] `README.md` reflects current compatibility tiers and install guidance.\n- [ ] Trusted Publishing configured per `docs/PYPI_TRUSTED_PUBLISHING.md`.\n- [ ] Tag created for RC (`vX.Y.ZrcN`) and TestPyPI publish succeeds.\n\n## Final Release\n\n- [ ] Re-run validation checks listed above.\n- [ ] Promote release notes for `vX.Y.Z`.\n- [ ] PyPI publish workflow succeeds via Trusted Publishing (OIDC).\n- [ ] GitHub Release workflow creates/updates the tag release entry in the Releases tab.\n- [ ] Release assets attached by automation:\n  - [ ] wheel (`.whl`)\n  - [ ] source tarball (`.tar.gz`)\n  - [ ] `SHA256SUMS.txt`\n- [ ] GitHub Packages (GHCR) workflow succeeds and publishes `nested-learning-dist` for the tag.\n- [ ] Release notes include migration notes (if any) and links to compatibility/versioning docs.\n\n## Post-Release\n\n- [ ] Confirm install from PyPI in a clean environment.\n- [ ] Confirm `nl doctor` and `nl smoke` on at least one non-maintainer machine or CI lane.\n- [ ] Open follow-up issues for deferred release items.\n"
  },
  {
    "path": "docs/PAPER_COMPLIANCE.md",
    "content": "# Paper Compliance / Fidelity Guide (Nested Learning / HOPE)\n\nThis doc explains the **fidelity‑critical behaviors** (what the paper relies on) and how they map to this repo’s code, flags, and tests.\n\nIt is deliberately **mechanism‑focused**: you can use it to answer “did we implement the architecture/update rules correctly?” without requiring full‑scale training reproduction.\n\nFor exact chunk/segment/buffer semantics, see `docs/STREAMING_CONTRACT.md`.\n\n## Paper Reference Pin\n\nAll compliance/equation references in this repo are pinned to:\n\n- Source: `google_papers/Nested_Learning_Full_Paper/Nested_Learning_Full_Paper.md`\n- SHA-256: `7524af0724ac8e3bad9163bf0e79c85b490a26bc30b92d96b0bdf17a27f9febc`\n\n## Scope\n\n**In scope**\n- HOPE blocks (attention + CMS + TITAN/self‑mod paths) and the *nested/online* update mechanism.\n- Correct teach‑signal alignment (LM head vs embedding), per‑layer local error signals (δℓ), and chunk‑accumulated CMS updates.\n- A paper‑style optimizer option (M3) alongside practical defaults.\n\n**Out of scope (today)**\n- Full bi‑level meta‑learning experiments over explicit task episodes (outer objective over tasks + inner adaptation per task).\n- Results parity at the original paper’s compute scale.\n\n## Semantic contract (important)\n\nThis repo focuses on **mechanism-level fidelity** (update rules + dataflow) with explicit tests.\n\n- **Differentiable reads:** the forward pass used to compute the outer LM loss is standard autograd.\n- **Stop‑grad writes:** online memory updates are applied in an explicit update pass (typically under `torch.no_grad()`), so we do **not** backprop through online writes.\n- **Algorithm mode:** `train.algorithm_mode=two_pass_stopgrad_updates` is the stable default.  \n  `train.algorithm_mode=boundary_state_grad_through_write` is available as an **experimental single-process mechanism path** with strict runtime constraints (`online_updates=true`, `per_layer_teach_signal=true`, `use_fast_state=true`, non-DDP). It is not yet treated as full paper-training reproduction.\n- **Boundary-target mode:** we support explicit boundary-token supervision (`train.online_boundary_targets=true`) and optional attention-state carry across chunks (`train.online_carry_attention_cache=true`) for stronger streaming equivalence, while keeping stop-grad write semantics.\n- **Fast-state guardrail:** `train.online_updates=true` with `train.use_fast_state=false` now emits a warning in non-strict mode and raises in strict/paper-faithful mode.\n- **Meta initialization (fast-state mode):** when `train.use_fast_state=true`, meta parameters are not mutated by online updates, but the *read-path* meta parameters still receive outer gradients:\n  - CMS/TITAN fast state uses **meta+delta** (forward uses `meta + delta`; updates write deltas only).\n  - HOPE‑SelfMod uses a detached per‑context state, but the read path uses a **straight‑through meta gradient** link so the meta initialization remains trainable.\n\n## Quick start: mechanism-auditing presets (single GPU)\n\nThe highest-fidelity execution path in this repo is **single‑GPU** `train.py`, because it supports both:\n1) **per‑layer δℓ** teach signals and  \n2) **online chunked training** where later tokens’ loss/gradients can see earlier memory updates.\n\nMinimal smoke:\n\n```bash\nuv run python train.py --config-name pilot_paper_faithful train.steps=5\n```\n\nNote: these presets set `data.batch_size=1` to avoid cross-sample memory sharing\nwhen `train.use_fast_state=true`.\n\nOptional: select the paper optimizer variant for the *outer* step:\n\n```bash\nuv run python train.py --config-name pilot_paper_faithful train.steps=5 optim.type=m3\n```\n\nMechanism-auditing HOPE self-mod variant:\n\n```bash\nuv run python train.py --config-name pilot_selfmod_paper_faithful train.steps=5\n```\n\nBoundary-state experimental smoke (single process only):\n\n```bash\nuv run python train.py --config-name pilot_paper_faithful \\\n  train.algorithm_mode=boundary_state_grad_through_write \\\n  train.steps=5\n```\n\nBoundary-state mode tradeoffs:\n- Keeps cross-chunk write paths differentiable, which increases activation retention and memory pressure.\n- Usually runs slower than `two_pass_stopgrad_updates` due to larger autograd graphs.\n- Intended for mechanism probing and diagnostics, not for long production runs in this repo yet.\n\n## Mechanism-Auditing vs Practical Mode (Matrix)\n\nThis repo supports both mechanism-auditing presets (for correctness checks) and practical defaults (for running pilots quickly).\n\n| Mechanism | Paper intent | This repo (single GPU) | Notes / Tests |\n|---|---|---|---|\n| Teach‑signal alignment | δ uses LM head weights | `compute_teach_signal()` matches autograd grad | `tests/test_teach_signal.py` |\n| Per‑layer δℓ | block‑local error signals | `train.per_layer_teach_signal=true` | `tests/test_teach_signal.py` |\n| Online chunked training | later tokens can “see” earlier inner updates | `train.online_updates=true` with either overlap mode or explicit boundary-target mode + end-of-sequence finalize | `src/nested_learning/training.py`, `tests/test_online_chunking.py` |\n| CMS chunk accumulation | sum over token deltas per chunk | `cms_chunk_reduction=\"sum\"` default | `tests/test_cms.py`, `tests/test_cms_delta_rule.py` |\n| CMS partial-chunk flush | update on final partial chunk | `model.cms_flush_partial_at_end=true` | `tests/test_cms_flush_partial.py` |\n| CMS cadence across chunked calls | `update_period` accumulation must survive multiple update-pass calls | fast-state CMS buffers persist until `finalize_updates=true` | `tests/test_cms_cross_call.py` |\n| CMS LayerNorm | paper is architecture-light; norm is optional | `model.cms_use_layernorm=true` (default) | `tests/test_cms.py` |\n| HOPE‑SelfMod local conv | local conv window=4 (paper HOPE module) | `SelfModifyingTitansConfig.local_conv_window=4` default (causal depthwise) | `tests/test_selfmod_local_conv.py` |\n| HOPE‑SelfMod fixed q | paper: `q_t = x_t W_q` non‑adaptive | `SelfModifyingTitansConfig.adaptive_q=false` default | `tests/test_selfmod_adaptive_q.py` |\n| HOPE‑SelfMod Eq. (91) skip | no projection skip term (`w_skip`) | `model.self_mod_use_skip=false` (mechanism-auditing presets) | `tests/test_residual_mlp_memory.py` |\n| HOPE‑SelfMod read/write separation | differentiable read; stopgrad through writes | forward uses differentiable read; updates occur only in explicit update pass | `tests/test_selfmod_grad_flow.py`, `tests/test_hope_selfmod_update_pass.py` |\n| Fast‑state isolation | per‑context inner updates without mutating meta params, while read‑path meta init remains learnable | `train.use_fast_state=true` | CMS/TITAN use **meta+delta**; HOPE‑SelfMod read path uses straight‑through meta gradients. Meta params remain unchanged during updates and still receive outer grads (`tests/test_hope_selfmod_fast_state_meta_unchanged.py`, `tests/test_fast_state_meta_grads.py`, `tests/test_fast_state_selfmod_meta_grads.py`, `tests/test_fast_state_forward_equivalence.py`, `tests/test_fast_state_batch_semantics.py`) |\n| Surprise metric | paper “surprise” trigger | `model.surprise_metric=l2` (default); also `loss`, `logit_entropy` | `tests/test_surprise_metric.py`, `tests/test_faithfulness_harness.py` |\n| Outer optimizer | M3 option exists | `optim.type=m3` | `tests/test_m3.py` |\n| Outer param policy | include memory initial states in meta-update | `optim.param_policy=all` | `tests/test_optimizer_param_policy.py` |\n| DDP fail-fast | avoid silent paper-divergent fallbacks | `train.fail_if_paper_faithful_disabled=true` | `tests/test_distributed_fail_fast.py` |\n| Multi‑GPU | (not required by paper) | DDP disables `online_updates` + `per_layer_teach_signal`; FSDP uses offline updates | documented below |\n\nSurprise-gating note: for `model.surprise_metric=l2`, the current implementation applies a\nchunk-level gate from mean teach-signal norm, then applies token-level masking inside TITAN/CMS\nupdates. This behavior is intentionally tested (`tests/test_surprise_metric.py`).\n\n## Claims Boundary (What We Claim vs What We Do Not)\n\n| Claim category | Status | Notes |\n|---|---|---|\n| CMS/TITAN/self-mod mechanism wiring | Implemented | Unit tests cover teach-signal, chunking, cadence primitives, and update-path invariants. |\n| Mechanism-auditing single-GPU path | Implemented | Uses per-layer teach signals + explicit stop-grad update pass. |\n| Full paper boundary-state gradient training through online writes | Partially implemented (experimental) | `train.algorithm_mode=boundary_state_grad_through_write` enables a constrained single-process differentiable write path; still not treated as production/full-scale reproduction. |\n| Cross-chunk attention-state continuity (KV cache) | Partially implemented | Optional cache-carry path is available in model APIs and training boundary-target mode; distributed faithful path remains deferred. |\n| Full paper-scale result reproduction | Not implemented | Compute/data scale parity is intentionally deferred. |\n\n## Implementation Fidelity vs Result Fidelity\n\n- **Implementation fidelity (this repo target):** architecture/update-path correctness, teach-signal alignment, cadence, chunking semantics, and guardrails.\n- **Result fidelity (deferred):** matching full-paper training scale, data budget, and final benchmark curves.\n- This repo treats implementation fidelity as complete only when mechanism checks/tests pass; result parity is explicitly a separate track.\n\n## Scale Statement (Current vs Paper)\n\n- Current mechanism-auditing and pilot runs are intentionally below the full paper scale.\n- This repo does **not** claim paper-scale result reproduction at current compute/data settings.\n- Maintainer stance: prioritize faithful implementation and auditable behavior first; scale-up remains optional contributor work.\n\n## Paper-Faithful Configs (Usage + Caveats)\n\n| Config | Purpose | Default Algorithm Mode | Caveats |\n|---|---|---|---|\n| `configs/pilot_paper_faithful.yaml` | HOPE-attention mechanism-auditing baseline | `two_pass_stopgrad_updates` | Single-process intended; sets `data.batch_size=1`, `strict_streaming_contract=true`, boundary-target + cache-carry enabled |\n| `configs/pilot_selfmod_paper_faithful.yaml` | HOPE self-mod mechanism-auditing baseline | `two_pass_stopgrad_updates` | Same constraints as above; self-mod paper knobs forced (`self_mod_use_skip=false`, fixed `q`, local conv) |\n\nBoundary-state experimental override:\n- `train.algorithm_mode=boundary_state_grad_through_write`\n- Requires: `online_updates=true`, `per_layer_teach_signal=true`, `use_fast_state=true`, single-process (non-DDP).\n\n## Equation / Mechanism Code Pointers (file:line)\n\n| Paper mechanism | Code pointer |\n|---|---|\n| Teach-signal proxy `dL/dh` via LM head weights | `src/nested_learning/training.py:225` |\n| Per-layer teach signals (`δℓ`) from block outputs | `src/nested_learning/training.py:295` |\n| Online chunk iterators (overlap / boundary-target) | `src/nested_learning/training.py:352`, `src/nested_learning/training.py:369` |\n| Algorithm-mode constraints (including boundary-state experimental mode) | `src/nested_learning/training.py:606` |\n| Online cache/chunk constraint guards | `src/nested_learning/training.py:650` |\n| Online chunked train loop + update pass wiring | `src/nested_learning/training.py:685` |\n| Run-feature telemetry (algorithm + online flags) | `src/nested_learning/training.py:1418` |\n| Checkpoint metadata with algorithm/online flags | `src/nested_learning/training.py:1492` |\n| Tied embedding / LM head weight contract | `src/nested_learning/model.py:156` |\n| Block output capture for δℓ path | `src/nested_learning/model.py:317` |\n| Fast-state init + attention-cache init | `src/nested_learning/model.py:531`, `src/nested_learning/model.py:578` |\n| CMS chunk accumulation + cadence telemetry | `src/nested_learning/hope/block.py:297`, `src/nested_learning/hope/block.py:341`, `src/nested_learning/hope/block.py:365` |\n| CMS partial flush on final chunk | `src/nested_learning/hope/block.py:342`, `src/nested_learning/hope/block.py:941`, `src/nested_learning/hope/block.py:1493` |\n| Surprise gating threshold logic | `src/nested_learning/hope/block.py:567`, `src/nested_learning/hope/block.py:1676` |\n| Differentiable inner-update path toggle | `src/nested_learning/optim/manager.py:109`, `src/nested_learning/optim/manager.py:125` |\n| Test-time memorization with path/threshold controls | `src/nested_learning/memorize.py:169`, `src/nested_learning/memorize.py:292`, `src/nested_learning/memorize.py:366` |\n\n## Reproducibility Protocol (Mechanism Track)\n\n1. Environment:\n   - `uv sync --all-extras --dev`\n   - PyTorch `2.9.0`\n2. Determinism:\n   - set `train.seed=<int>`\n   - set `train.deterministic=true` for deterministic smoke runs\n3. Minimal mechanism run:\n   - `uv run python train.py --config-name pilot_paper_faithful train.steps=5`\n4. Optional boundary-state mechanism probe:\n   - `uv run python train.py --config-name pilot_paper_faithful train.algorithm_mode=boundary_state_grad_through_write train.steps=5`\n5. Validation gates:\n   - `uv run ruff check .`\n   - `uv run mypy src`\n   - `bash scripts/checks/run_fidelity_ci_subset.sh`\n   - `uv run pytest -q`\n\n## Community-Reported Remediation Map\n\n- Data split fallback robustness: `docs/data_pipeline.md` + `scripts/data/{train_tokenizer,shard_corpus,filter_corpus}.py`\n- Missing tokenizer/help ergonomics: `scripts/data/run_sample.sh`, `scripts/checks/check_data_script_help.sh`, CI workflow\n- Boundary-state mode guardrails + visibility: `src/nested_learning/training.py` + `tests/test_strict_streaming_contract.py` + `tests/test_boundary_state_training_loop.py`\n- Packaging metadata completeness: `src/nested_learning/training.py` + `scripts/package_pilot_release.sh` + `tests/test_package_release_script.py`\n\n## Acceptance Checklist (Mechanism Fidelity)\n\n- [x] Teach signal uses LM head weights with tied embedding head.\n- [x] Per-layer teach signals (`δℓ`) are available and tested.\n- [x] Online chunked updates support overlap + boundary-target semantics.\n- [x] CMS chunk accumulation/cadence is audited with machine-readable reports.\n- [x] Surprise gating behavior is tested (loss/entropy/l2 paths).\n- [x] Test-time memorization path controls (`paths`, `surprise_threshold`) are implemented and tested.\n- [x] Algorithm mode + online flags are emitted in run telemetry and checkpoint metadata.\n- [x] Data scripts have deterministic split fallback and CI help-smoke coverage.\n- [x] Security/release gates block large/binary artifact leakage.\n- [ ] Full paper-scale result reproduction (explicitly out of current scope).\n\n## Concepts → implementation mapping\n\n### 1) Outer parameters vs inner (“fast”) procedure\n\nIn this codebase:\n- **Outer update** = the standard optimizer step (`optimizer.step()`) on the model parameters after backprop.\n- **Inner update** = memory/fast updates applied *outside* the gradient graph using teach signals (δ), e.g. CMS updates and self‑modifying TITAN updates.\n\nWhere:\n- Outer loop: `src/nested_learning/training.py` (`run_training_loop`)\n- Inner update calls: inside the training loop after backward:\n  - `base_model(tokens, teach_signal=...)` or `base_model(tokens, teach_signals=[...])`\n- The update logic lives in the block implementations:\n  - `src/nested_learning/hope/block.py`\n\n### 2) “Levels” and update frequencies\n\nLevels are represented explicitly as `LevelSpec` entries with independent `update_period`s.\n\nWhere:\n- Specs: `src/nested_learning/levels.py`\n- Config surface (Hydra): `model.titan_level` and `model.cms_levels` in `configs/*.yaml`\n- Enforcement:\n  - Online CMS buffering + update‑period gating in `src/nested_learning/hope/block.py`\n  - Level optimizer tick/step orchestration in `src/nested_learning/optim/manager.py`\n\n### 3) Teach signal alignment (LM head gradient proxy)\n\nThe global teach signal is an approximation to **dL/dh**, where `h` is the hidden state **before** the LM head. This approximation must align to the LM head weights.\n\nIn this repo, `h` is explicitly the **post-LayerNorm hidden** (the exact input to `lm_head`), and tests pin this contract.\n\nWhere:\n- Weight tying is explicit: `src/nested_learning/model.py` (`self.lm_head.weight = self.embed.weight`)\n- Teach signal implementation: `src/nested_learning/training.py` (`compute_teach_signal`)\n- Unit coverage: `tests/test_teach_signal.py`\n\n### 4) Per‑layer local error signals (δℓ)\n\nWhen enabled, we compute a teach signal **per block output** (δℓ) via autograd and route it into each block’s update path.\n\nWhere:\n- Block output capture: `src/nested_learning/model.py` (`forward_with_block_outputs`)\n- δℓ computation: `src/nested_learning/training.py` (`_compute_layer_teach_signals`)\n- Routing to blocks: `src/nested_learning/model.py` (`teach_signals=[...]`)\n- Unit coverage: `tests/test_teach_signal.py` (shape + matching expectations)\n\nFlag:\n- `train.per_layer_teach_signal=true`\n\n### 5) Chunked online training (read‑after‑write for *loss*, not just updates)\n\nThis is the core “gradient propagation across frequencies” concern:\n\nIf you compute the LM loss on a full sequence **once**, and only apply memory updates after the backward pass, then later tokens’ loss does not reflect earlier inner updates.\n\nTo make later tokens “see” earlier inner updates during training, we support an **online chunked training** mode:\n- Split the sequence into chunks.\n- For each chunk:\n  1) forward → loss  \n  2) `loss.backward()` **accumulating** gradients across chunks (we do not zero grads per chunk)  \n  3) apply inner updates in `torch.no_grad()`  \n  4) proceed to the next chunk with updated memory\n- At the end, we do a single outer `optimizer.step()`.\n- Chunking supports **one-token overlap** mode and **explicit boundary-target** mode.\n- In fast-state mode, CMS accumulation buffers persist across calls and are finalized (optional partial flush + reset) only when `finalize_updates=true` for the sequence end.\n\nWhere:\n- `src/nested_learning/training.py` (search for `online_updates`)\n\nFlags:\n- `train.online_updates=true`\n- `train.online_chunk_size=0` (auto‑infer a chunk size from the minimum CMS update period)\n\n### 6) CMS update semantics (per‑token δ + sum‑over‑chunk accumulation)\n\nCMS updates are applied using:\n- **per‑token δ targets** (no chunk‑mean broadcast), and\n- **sum‑over‑chunk reduction** for the CMS update loss (rather than mean), which preserves the “accumulate over C tokens” semantics.\n\nWe implement the CMS local objective via a **gradient-shaping construction**:\n- `_chunk_loss()` chooses a target `t = stopgrad(prediction − δ)` so that `∂loss/∂prediction ∝ δ` under the chosen mask and reduction.\n- This matches the paper’s δ-based local learning rule while letting us implement the update via standard autograd.\n- Verified by `tests/test_cms_delta_rule.py`.\n\nWhere:\n- Chunk loss reduction: `src/nested_learning/hope/block.py` (`_chunk_loss`, `cms_chunk_reduction=\"sum\"`)\n- Online buffering by update_period and “pop exactly C tokens”: `src/nested_learning/hope/block.py` (`_CmsBuffer`, `_pop_buffer_chunk`, `_cms_forward_online`)\n- Unit coverage:\n  - `tests/test_cms.py` (online updates affect later tokens; update_period gating)\n\nNotes:\n- In the Hydra configs, CMS chunk reduction / online toggles are mechanism-auditing defaults inside the HOPE block configs. They are not currently exposed as top-level YAML keys; changing them requires a small code change.\n- `model.cms_flush_partial_at_end` is exposed because it affects correctness when sequence lengths are not exact multiples of update periods.\n\n### 7) Self‑modifying TITAN path (always‑on)\n\nSelf‑modifying TITAN updates run in the update pass; they do not require the teach signal to be nonzero, but they **do** require an explicit update call (i.e., passing `teach_signal`/`teach_signals` to trigger the update pass).\n\nWhere:\n- `src/nested_learning/hope/block.py` (self‑mod update path)\n- Unit coverage: `tests/test_selfmod_online.py`\n\n### 8) Outer optimizer options (including paper M3)\n\nDefault outer optimizer in configs is practical and reproducible (`optim.type=muon` hybrid with AdamW fallback for 1D params).\n\nPaper option:\n- `optim.type=m3` selects the M3 optimizer (multi‑scale momentum + Newton‑Schulz orthogonalization).\n\nWhere:\n- `src/nested_learning/optim/m3.py`\n- Unit coverage: `tests/test_m3.py`\n\n### 8.1 `nl_l2_precond` mapping assumptions (best-effort)\n\nThe inner deep optimizer variant `nl_l2_precond` is implemented as a rank-1 projection-style preconditioner:\n\n- Context vector `x_t`: repo uses the provided level context (typically mean hidden state over batch/sequence for that update event).\n- Projector: update is projected orthogonal to context via `g - (g·u)u` where `u = x_t / ||x_t||`.\n- This is a best-effort mechanism mapping, not a formal proof of exact paper-equation equivalence under all normalizations/objective variants.\n\nCode + tests:\n- `src/nested_learning/optim/deep.py` (`_nl_precondition`)\n- `tests/test_optim.py`\n\n## Distributed training caveats (important)\n\nMechanism-auditing mode is currently focused on `train.py` (single‑GPU).\n\n- **DDP (`train_dist.py`)**: calls the shared training loop, but explicitly disables:\n  - per‑layer teach signals (`train.per_layer_teach_signal`)\n  - online chunked training (`train.online_updates`)\n  because these require capturing block outputs and applying sequential inner updates in a way that is not yet DDP‑safe in this repo.\n  - If you want to avoid silent fallback behavior, set `train.fail_if_paper_faithful_disabled=true` to raise instead of disabling.\n\n- **FSDP (`train_fsdp.py`)**: currently uses a simpler “offline” update pass with a global teach signal after each outer step. It does not yet implement per‑layer δℓ or online chunked training.\n\nIf you need mechanism-auditing semantics at multi-GPU scale, the next engineering task is to port the `train.py` online/per-layer flow to FSDP (or a custom DDP scheme) while keeping correctness tests.\n\n## Verification checklist (fast)\n\nRun the fidelity tests:\n\n```bash\nuv run python scripts/checks/verify_docs_refs.py\n\nuv run pytest \\\n  tests/test_teach_signal.py \\\n  tests/test_cms.py \\\n  tests/test_cms_cross_call.py \\\n  tests/test_cms_flush_partial.py \\\n  tests/test_online_chunking.py \\\n  tests/test_attention_cache.py \\\n  tests/test_eval_state.py \\\n  tests/test_selfmod_online.py \\\n  tests/test_m3.py \\\n  tests/test_residual_mlp_memory.py \\\n  tests/test_selfmod_local_conv.py \\\n  tests/test_selfmod_adaptive_q.py \\\n  tests/test_selfmod_grad_flow.py \\\n  tests/test_hope_selfmod_update_pass.py \\\n  tests/test_cms_delta_rule.py \\\n  tests/test_selfmod_dgd_linear.py \\\n  tests/test_optimizer_param_policy.py \\\n  tests/test_distributed_fail_fast.py \\\n  tests/test_strict_streaming_contract.py \\\n  tests/test_verify_docs_refs.py\n```\n\nConfirm you’re running with the intended features:\n- startup `run_features` should include `train.algorithm_mode=two_pass_stopgrad_updates` and `train.backprop_through_online_writes=false`.\n- training logs include `teach_signal_norm` and per‑layer update telemetry (e.g. `layer0.cms.cms_fast.grad_norm`) when an update pass runs.\n- streaming semantics match `docs/STREAMING_CONTRACT.md` for the selected config mode.\n\n## Known gaps / intentionally deferred work\n\n- Full task‑episode meta‑learning evaluation loops are not implemented.\n- Multi‑GPU mechanism-auditing training (online + per-layer δℓ) is not yet implemented.\n- Full distributed mechanism-auditing path with boundary-target + attention-cache carry remains deferred.\n- Large‑scale results reproduction is not a requirement for claiming mechanism fidelity in this repo.\n"
  },
  {
    "path": "docs/PHASE2_LONG_CONTEXT_COMPARISON.md",
    "content": "# Phase 2 – HOPE-Attention vs Transformer (Long-Context Sanity)\n\nThis repo includes a lightweight Phase‑2 sanity check that compares **HOPE-Attention** (Attention → CMS) against a **baseline Transformer** on synthetic long‑context retrieval prompts.\n\nThe goal is not to claim paper‑level results (that requires large‑scale training), but to provide a **reproducible, implementation-level signal** that:\n\n- HOPE-Attention’s fast-state memorization path can **improve the margin/logprob** of the correct answer on long contexts.\n- The baseline Transformer **cannot**, because it has no in‑context update path.\n\n## What to run\n\nThis uses resolved, eval-friendly configs (no Hydra composition required):\n- `configs/resolved/phase2_pilot_attention_eval.yaml`\n- `configs/resolved/phase2_pilot_transformer_eval.yaml`\n\nAnd uses the init checkpoints generated under `artifacts/checkpoints/phase2_init/` (gitignored).\n\nRun (GPU recommended):\n\n```bash\nUV_LINK_MODE=copy UV_CACHE_DIR=/tmp/uv-cache \\\nuv run python scripts/eval/compare_variants.py \\\n  --a-config configs/resolved/phase2_pilot_attention_eval.yaml \\\n  --a-checkpoint artifacts/checkpoints/phase2_init/hope_attention_step000000.pt \\\n  --b-config configs/resolved/phase2_pilot_transformer_eval.yaml \\\n  --b-checkpoint artifacts/checkpoints/phase2_init/transformer_step000000.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --device cuda:1 \\\n  --output eval/phase2_compare_smoke_lastlayer_metrics.json \\\n  --seed 0 \\\n  --smoke \\\n  --memorize \\\n  --memorize-use-correct-answer \\\n  --memorize-layers last \\\n  --memorize-paths cms_fast\n```\n\n## What to look at\n\nOpen `eval/phase2_compare_smoke_lastlayer_metrics.json` and compare:\n\n- **HOPE-Attention (A)**:\n  - `a.passkey.mean_margin_delta` > 0\n  - `a.niah.niah_256_mean_margin_delta` > 0\n- **Transformer (B)**:\n  - corresponding `*_mean_margin_delta` fields are exactly `0.0`\n\nThis demonstrates a concrete Phase‑2 differentiator at pilot scale: **test‑time learning updates move the model in a direction that improves long‑context answer margins**, and the baseline cannot.\n\n"
  },
  {
    "path": "docs/PHASE_2_PLAN.md",
    "content": "# Phase 2 Plan – Execution & Results Packaging\n\n## Immediate Remediation Tasks (from EX_PHASE_1_CRITIQUE)\n\nBefore resuming large-scale runs, we must land the following **P0 faithfulness fixes** plus high-priority engineering upgrades. Each item lists the concrete code touchpoints, validation criteria, and downstream dependencies.\n\n### 1. Tie LM head weights + correct teach signal\n- **Scope**: `src/nested_learning/model.py`, `src/nested_learning/titan/model.py`, `src/nested_learning/training.py`, unit tests under `tests/`.\n- **Actions**:\n  1. Tie `lm_head.weight` to `embed.weight` for HOPE + TITAN models.\n  2. Update `compute_teach_signal` to:\n     - Use `model.lm_head.weight.detach()` instead of embeddings.\n     - Shift logits/targets to align with CE loss (`logits[:, :-1]` vs `tokens[:, 1:]`).\n     - Pad the teacher signal to maintain sequence length.\n  3. Add `tests/test_teach_signal.py` performing a finite-difference gradient check.\n- **Acceptance**: Unit test passing; manual verification on pilot smoke run logs (teach-signal norms logged).\n\n### 2. Implement CMS chunk accumulation (Eq. 31)\n- **Scope**: `src/nested_learning/cms.py` (or equivalent), `src/nested_learning/levels.py`, new telemetry structs, tests.\n- **Actions**:\n  1. Add per-level ring buffers sized to `update_period`.\n  2. Accumulate gradients/error proxies each step; only trigger optimizer update when buffer is full, then clear.\n  3. Emit `UpdateEvent` metrics (count, L2 norm) per level.\n  4. Unit test verifying exactly one update per `update_period` ticks.\n- **Acceptance**: Tests pass; pilot smoke shows stepped CMS updates in logs.\n\n### 3. Add L2-regression inner update (Eq. 27–29)\n- **Scope**: `src/nested_learning/optim/deep_momentum.py`, model forward hooks to pass `x_t`, tests.\n- **Actions**:\n  1. Introduce `variant=\"nl_l2_precond\"` that computes the rank-1 projector from input activations.\n  2. Route the relevant activations into the optimizer context.\n  3. Config flag in `configs/hope/*.yaml` to enable this variant.\n  4. Toy test: optimization reduces regression objective.\n- **Acceptance**: Unit test + pilot smoke run with `variant` enabled (log preconditioner statistics).\n\n### 4. Enable test-time memorization\n- **Scope**: `scripts/eval/zeroshot.py`, `scripts/eval/niah.py`, `scripts/eval/continual.py`, model eval hooks.\n- **Actions**:\n  1. Add flags (`--memorize`, `--memorize-steps`, `--memory-lr`, `--surprise-threshold`).\n  2. Implement TITAN memory updates (and optional CMS fast level) when `memorize=True`.\n  3. Add synthetic integration test ensuring memorization improves accuracy on a constructed needle task.\n- **Acceptance**: Tests pass; eval scripts produce separate `*_memorize.json` outputs with metrics > baseline on synthetic task.\n\n### 5. PyTorch performance upgrades\n- **Scope**: `src/nested_learning/*.py` (attention, training loop), optim factory.\n- **Actions**:\n  1. Replace `nn.MultiheadAttention` with manual QKV + `torch.nn.functional.scaled_dot_product_attention`, enabling FlashAttention where supported.\n  2. Wrap training step in `torch.autocast(device_type, dtype=torch.bfloat16)`; add config switch.\n  3. Add `torch.compile` (guarded) to model init.\n  4. Use fused AdamW (`fused=True`) for outer optimizer.\n- **Acceptance**: Pilot smoke runtime improves or stays stable; fallback path works on CPU.\n\n### 6. Muon integration\n- **Scope**: `src/nested_learning/optim/factory.py`, configs.\n- **Actions**:\n  1. Detect availability of `torch.optim.Muon`.\n  2. Split param groups: matrices → Muon, embeddings/biases/LayerNorm → AdamW.\n  3. Config knob `optim.outer.type = mixed_muon_adamw`.\n  4. Benchmark vs AdamW and log results.\n- **Acceptance**: Pilot smoke runs succeed with Muon; documentation updated.\n\n### 7. Seeding & backend robustness\n- **Scope**: training entrypoints (`train*.py`), `nested_learning/training.py`.\n- **Actions**:\n  1. Add `--seed` (Hydra config) and set Python/NumPy/Torch seeds + DataLoader worker init.\n  2. Auto-select DDP backend (`nccl` for CUDA, `gloo` otherwise); expose override.\n  3. Add CPU DDP smoke job in CI.\n- **Acceptance**: Seed reproducibility test (two runs same seed → identical loss trace); CI job green.\n\n### 8. Documentation & licensing polish\n- **Scope**: `pyproject.toml`, README, release docs.\n- **Actions**:\n  1. Align license declaration with `LICENSE` (Apache-2.0).\n  2. Ensure all referenced scripts are shipped; add `scripts/run_e2e_smoke.sh`.\n  3. Update README with memorization instructions and Muon requirements.\n- **Acceptance**: Lint job confirms license metadata; README diff reviewed.\n\nThese items are **blocking** for Stage 2 long runs. Only after P0 checklist completion do we resume the training/eval roadmap below.\n\n## 1. Training Runs\n1. **Pilot (160M / 3B tokens)**\n   - Objective: confirm stability, log teach-scale findings, generate base checkpoints for eval harnesses.\n   - Actions: run `configs/hope/pilot.yaml` with the full shard mixture; log to W&B and artifacts/.\n2. **Mid-scale (760M / 30B tokens)**\n   - Objective: produce the headline zero-shot/NIAH results.\n   - Actions: run `configs/hope/mid.yaml` (FSDP or DeepSpeed), capture checkpoints every ~50k steps.\n3. **Target (1.3B / 100B tokens)**\n   - Objective: long-context + continual-learning showcase.\n   - Actions: integrate 8k context curriculum, run with DeepSpeed ZeRO-3, checkpoint frequently.\n\n## 2. Evaluation Campaign\n1. **Zero-shot pack** – Use `scripts/eval/zeroshot.py --tasks all` on pilot/mid/target checkpoints; store JSON in `eval/zeroshot_*.json` and plot aggregated table in `docs/experiments_report.md`.\n2. **NIAH curves** – Run `scripts/eval/niah.py` (2048→512k) for each major checkpoint and plot accuracy vs. context length.\n3. **Continual-learning** – Run `scripts/eval/continual.py` across chronological segments; generate forgetting plots and correlate with level clocks.\n\n## 3. Baseline Comparisons\n- Reproduce lighter TITAN/Transformer baselines (reuse refs or simple adaptations) to evaluate on the same data/eval tasks.\n- Log results alongside HOPE for direct comparison in `reports/ablations.md` and W&B dashboards.\n\n## 4. Ablations\n1. Self-modifier on/off.\n2. CMS depth variations (1 vs. 3 vs. 5 levels).\n3. Deep optimizer variants per level.\n4. Attention swap (full vs. sliding-window/DeltaNet).\nRecord commands + metrics in `reports/ablations.md`.\n\n## 5. Documentation & Release\n1. Update `docs/experiments_report.md` with tables/plots.\n2. Record stability tricks and teach-scale notes in `docs/stability_journal.md`.\n3. Prepare a blog/paper draft summarizing architecture, training setup, and results.\n4. Tag a release (`v0.2-stage2-prep`) with checkpoints, configs, eval JSONs.\n\n## 6. Outreach & Community\n- Share follow-up results posts (link to W&B dashboards, zero-shot tables, long-context plots).\n- Invite collaborators for continual-learning and scaling experiments via README/Issues/Discussions.\n\n## 7. Tracking\n- Keep `TODO.md` updated per milestone.\n- Use W&B projects for each run (pilot/mid/target) and link them in `docs/stage2_progress.md`.\n"
  },
  {
    "path": "docs/PYPI_TRUSTED_PUBLISHING.md",
    "content": "# PyPI Trusted Publishing Setup\n\nThis repository ships `.github/workflows/release.yml` for OIDC-based publishing.\nUse this checklist once per repository to activate it.\nThe same workflow also publishes a GitHub Release entry (Releases tab) with wheel/sdist/checksum assets for each tag.\n\n## 1) Configure TestPyPI trusted publisher\n\nIn TestPyPI project settings (`nested-learning`):\n- Publisher: **GitHub**\n- Owner: `kmccleary3301`\n- Repository: `nested_learning`\n- Workflow name: `release.yml`\n- Environment: `testpypi`\n\n## 2) Configure PyPI trusted publisher\n\nIn PyPI project settings (`nested-learning`):\n- Publisher: **GitHub**\n- Owner: `kmccleary3301`\n- Repository: `nested_learning`\n- Workflow name: `release.yml`\n- Environment: `pypi`\n\n## 3) Validate release tags\n\n- RC tags (publish to TestPyPI): `vX.Y.ZrcN`\n- Stable tags (publish to PyPI): `vX.Y.Z`\n\nExample:\n```bash\ngit tag v0.2.0rc1\ngit push origin v0.2.0rc1\n```\n\n## 4) Verify workflow permissions\n\n`release.yml` requires:\n- `id-token: write` (for OIDC)\n- `contents: write`\n\nNo long-lived PyPI API tokens are required.\n\n## 5) Recommended first dry-run\n\n1. Create RC tag and publish to TestPyPI.\n2. Create clean virtualenv.\n3. Install package from TestPyPI.\n4. Run:\n   - `python -m nested_learning --help`\n   - `python -m nested_learning doctor --json`\n   - `python -m nested_learning smoke --config-name pilot_smoke --device cpu --batch-size 1 --seq-len 8`\n\n## 6) Verify GitHub release assets\n\nAfter the tag workflow completes, confirm the Releases tab entry for that tag contains:\n- `nested_learning-<version>-py3-none-any.whl`\n- `nested_learning-<version>.tar.gz`\n- `SHA256SUMS.txt`\n\n## 7) Verify GitHub Packages tab (GHCR)\n\nThe repository also ships `.github/workflows/packages.yml`, which publishes:\n- `ghcr.io/<owner>/nested-learning-dist:<tag>`\n\nThis is an OCI artifact bundle for distribution files (`dist/*`) and appears in the GitHub Packages tab.\nUse PyPI for normal `pip install` workflows.\n"
  },
  {
    "path": "docs/STREAMING_CONTRACT.md",
    "content": "# Streaming Contract (Mechanism-Auditing Mode)\n\nThis document defines the exact streaming semantics used by the single-GPU mechanism-auditing path.\n\n## Terms\n\n- `sequence`: one tokenized training example of length `T` used for next-token LM loss.\n- `segment`: one externally provided slice of a longer document (used in eval/inference workflows).\n- `chunk`: one training-time online slice used inside `train.online_updates=true`.\n- `batch`: the `B` sequences processed together by the dataloader.\n- `fast-state context`: per-context mutable memory state (CMS/TITAN/self-mod) used for online updates.\n\n## State Scope and Lifetime\n\n- Base model parameters (meta params): persistent across all steps.\n- Fast-state: initialized per batch in training when `train.use_fast_state=true`.\n- With `data.batch_size>1`, fast-state is currently shared across examples in the same batch.\n- In mechanism-auditing presets we set `data.batch_size=1` to preserve per-context semantics.\n\n## CMS Buffer Lifecycle\n\nFor each CMS level `level_name`:\n\n1. `initialize`: create empty buffer with `inputs`, `teach`, `active`, `count=0`.\n2. `accumulate`: append current tokens and increment `count`.\n3. `boundary update`: while `count >= update_period`, pop exactly `update_period` tokens and apply one update.\n4. `finalize`:\n   - if `cms_flush_partial_at_end=true`, flush remaining partial tokens once.\n   - clear buffer contents and reset count to zero.\n5. `reset`: equivalent to finalize + clear, used at sequence end.\n\n## `finalize_updates` Contract\n\n- `finalize_updates=false`:\n  - accumulate/update only full `update_period` boundaries.\n  - do not partial-flush.\n  - keep pending tokens for the next chunk call.\n- `finalize_updates=true`:\n  - apply normal boundary updates.\n  - optional partial flush (`cms_flush_partial_at_end=true`).\n  - clear per-level CMS buffers after finalize.\n\nTraining uses `finalize_updates=true` only on the last chunk of the sequence.\n\n## Chunk-Boundary Objective Semantics\n\nTwo training modes are supported:\n\n1. **Overlap mode (default)**: one-token overlap between neighboring chunks.\n2. **Boundary-target mode**: no overlap; each chunk receives explicit `next_tokens` boundary targets.\n\nExample for tokens `[t0 t1 t2 t3 t4]` and `chunk_size=2`:\n\n- Overlap mode:\n  - chunk 1: `[t0 t1]` contributes pair `t0->t1`\n  - chunk 2: `[t1 t2 t3]` contributes pairs `t1->t2`, `t2->t3`\n  - chunk 3: `[t3 t4]` contributes pair `t3->t4`\n- Boundary-target mode:\n  - chunk 1: `[t0 t1]` + boundary target `t2`\n  - chunk 2: `[t2 t3]` + boundary target `t4`\n  - chunk 3: `[t4]` (no boundary target)\n\nTotal supervised pairs remain `T-1`.\n\nBoundary-target mode is enabled with:\n- `train.online_boundary_targets=true`\n- `train.online_carry_attention_cache=true` is the canonical paper-faithful setting for\n  transformer-backed chunked runs in this repo.\n\n## Segment Semantics for Long Documents\n\n- A segment is external input partitioning, not the same as training chunking.\n- Optional attention-state carry is available via model attention cache APIs:\n  - `model.init_attention_cache()`\n  - `model(..., attention_cache=..., return_attention_cache=True)`\n- Training can carry attention state across chunk calls when:\n  - `train.online_boundary_targets=true`\n  - `train.online_carry_attention_cache=true`\n- Fast-memory updates can persist across steps when the caller reuses fast-state.\n\n## Strict Mode\n\nSet `train.strict_streaming_contract=true` to fail fast on known semantics violations:\n\n- distributed training with unsupported paper-auditing features,\n- fast-state with `data.batch_size>1`,\n- `train.online_updates=true` with `train.use_fast_state=false`,\n- non paper-defined variant under strict paper-auditing expectations,\n- invalid boundary/carry combinations for online chunking.\n\n## Cadence Verification Example\n\nAfter a run that emits JSON metrics, validate a CMS level cadence:\n\n```bash\nuv run python scripts/checks/verify_update_cadence.py \\\n  --log-path logs/mechanism_audit_smoke.json \\\n  --metric-prefix layer0.cms.cms_mid \\\n  --total-tokens 8 \\\n  --update-period 4 \\\n  --output reports/cadence_mechanism_audit_smoke.json\n```\n\nExpected report keys:\n- `ok`\n- `metric_prefix`\n- `expected`\n- `observed`\n- `checks`\n"
  },
  {
    "path": "docs/VERSIONING_POLICY.md",
    "content": "# Versioning and Stability Policy\n\nThis repository follows SemVer-style versioning with explicit 0.x constraints.\n\n## Current Phase: 0.x\n\nBefore `1.0.0`, stability guarantees are intentionally limited:\n- `0.x.y` patch releases should be non-breaking for normal workflows.\n- `0.X.0` minor releases may include breaking changes to config schema, defaults, CLI behavior, or checkpoint metadata.\n\n## Public Surface\n\nStable-ish surfaces (prioritized for compatibility):\n- `nl` CLI commands and flags\n- Hydra config schema for primary shipped configs\n- checkpoint sidecar metadata fields used by verification tooling\n\nExplicitly unstable surfaces:\n- internal Python module APIs\n- experimental mechanism paths and ablation-only options\n- ad hoc scripts under `scripts/` unless documented as stable entrypoints\n\n## Breaking Change Handling\n\nWhen a release introduces breakage:\n1. call it out in `CHANGELOG.md`,\n2. include migration notes,\n3. keep old behavior behind compatibility flags where reasonable for at least one minor cycle.\n\n## Golden Environment vs Supported Ranges\n\n- Golden reproduction environment: lockfile-based (`uv lock`, Python 3.12, PyTorch 2.9.x).\n- Package metadata supports broader compatibility ranges for portability.\n- If range installs diverge from golden behavior, prefer golden env for paper-faithful runs.\n\n"
  },
  {
    "path": "docs/compute_plan.md",
    "content": "# Compute Reservation Plan (Stage 2)\n\n## Hardware\n- Cluster: 2× nodes with dual NVIDIA RTX 6000 Ada (48 GB VRAM) + 64-core CPU + 512 GB RAM.\n- Scheduler: Slurm (partition `gpu-a6000`), 2 nodes available concurrently.\n\n## Reservations\n| Phase | Resources | Duration | Window | Purpose |\n|-------|-----------|----------|--------|---------|\n| Pilot run | 1 node (2× A6000) | 3 days | Week 1 (Mon–Wed) | 160 M param sanity run, tokenizer validation |\n| Ablations | 1 node | 2 days | Week 1 (Thu–Fri) | Self-modifier/CMS toggles at pilot scale |\n| Mid-scale | 2 nodes | 10 days | Weeks 2–3 | 760 M training to 30 B tokens + evals |\n| Mid evals | 1 node | 2 days | Week 3 (end) | Zero-shot + NIAH scripts on mid checkpoint |\n| Target warmup | 2 nodes | 3 days | Week 4 (start) | 1.3 B config dry run (short token budget) |\n| Target full run | 2 nodes | 14 days | Weeks 4–6 | 1.3 B / 100 B tokens |\n| Final evals | 1 node | 3 days | Week 6 | Long-context + continual learning |\n\n## Actions\n1. Submit Slurm reservations (`scripts/compute/create_reservations.sh`) for the windows above; tag jobs with `NL-Stage2`.\n2. Pre-stage datasets/token shards on node-local NVMe before each run to avoid network bottlenecks.\n3. Enable checkpoint mirroring to shared storage every 12 hours for resilience.\n4. Maintain utilization log in `reports/compute_usage.md` (to be created after first run).\n"
  },
  {
    "path": "docs/continual_classification_eval.md",
    "content": "# Continual Classification Evaluation (CLINC / Banking77 / DBpedia14)\n\nThe Nested Learning paper highlights **class-incremental continual learning** in the text classification\ndomain (CLINC, Banking77, DBpedia). This repo provides a lightweight, implementation-first harness that\ntreats classification as **generative label selection**:\n\n- Prompt: `Text: ... \\nLabel:`\n- Score each candidate label by log-probability of the label string\n- Optionally apply HOPE/TITAN/CMS **test-time memorization** after each example (fast-state by default)\n\n## Script\n\nUse `scripts/eval/continual_classification.py`.\n\n### Smoke run (CPU)\n\n```bash\nuv run python scripts/eval/continual_classification.py \\\n  --config configs/pilot_smoke.yaml \\\n  --checkpoint artifacts/checkpoints/pilot_smoke/step_000010.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --dataset clinc \\\n  --smoke \\\n  --device cpu \\\n  --output eval/continual_cls_smoke.json\n```\n\n### Memorization-enabled run\n\n```bash\nuv run python scripts/eval/continual_classification.py \\\n  --config configs/hope/pilot_attention.yaml \\\n  --checkpoint artifacts/checkpoints/pilot/step_230000.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --dataset banking77 \\\n  --task-size 10 --train-per-label 25 --eval-per-label 25 \\\n  --memorize --memorize-steps 1 \\\n  --memorize-paths titan,cms_fast \\\n  --memorize-surprise-threshold 0.02 \\\n  --device cuda:0 \\\n  --output eval/continual_cls_banking77.json\n```\n\nNotes:\n- `--task-size` controls class increments (how many labels per task).\n- `--memorize-no-reset` (default) keeps the fast-state across examples/tasks, matching a continual setting.\n- For “pure baseline” continual evaluation, omit `--memorize`.\n\n### Offline / local JSONL\n\nIf you don’t want to rely on HuggingFace downloads, supply a JSONL file:\n\n```bash\nuv run python scripts/eval/continual_classification.py \\\n  --config configs/pilot_smoke.yaml \\\n  --checkpoint artifacts/checkpoints/pilot_smoke/step_000010.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --local-jsonl data/local_continual_fixture.jsonl \\\n  --task-size 3 --train-per-label 2 --eval-per-label 2 \\\n  --smoke --device cpu \\\n  --output eval/continual_cls_local.json\n```\n\nEach line must be: `{\"text\": \"...\", \"label\": \"...\"}`.\n\n## Output\n\nThe JSON contains:\n- `task_accuracy_matrix[i][j]`: accuracy on task `i` evaluated after finishing task `j`\n- `avg_accuracy_final`: average accuracy after the last task\n- `avg_forgetting`: average (`max_acc_i - final_acc_i`) across tasks\n\nThis harness is intentionally lightweight so the community can refine the exact protocol to match the\npaper’s class-incremental schedules and reporting conventions.\n\n## Plotting\n\n```bash\nuv run python scripts/eval/plot_continual_classification.py \\\n  --continual-json eval/continual_cls_banking77.json \\\n  --output reports/plots/continual_cls_banking77.png\n```\n"
  },
  {
    "path": "docs/continual_eval.md",
    "content": "# Continual-Learning Evaluation Guide\n\nUse `scripts/eval/continual.py` to quantify forgetting across streaming segments. Supply:\n\n- `--config`: Hydra config for the HOPE model.\n- `--checkpoints`: ordered list of checkpoint paths (chronological training steps).\n- `--segments-yaml`: YAML describing segment names + shard directories (see `configs/data/continual_segments_sample.yaml`).\n- `--batch-size`, `--max-batches`: evaluation throughput controls (0 = entire shard).\n- `--eval-state-mode`: `reset_per_sample` (default) or `carry_across_samples`.\n- `--eval-use-fast-state` / `--eval-use-attention-cache`: enable inference-time streaming state carry semantics.\n\nExample:\n```bash\nuv run python scripts/eval/continual.py \\\n  --config configs/hope/mid.yaml \\\n  --checkpoints checkpoints/mid/step_000050.pt checkpoints/mid/step_000100.pt \\\n  --segments-yaml configs/data/continual_segments_sample.yaml \\\n  --batch-size 4 --max-batches 20 \\\n  --eval-state-mode carry_across_samples \\\n  --eval-use-attention-cache \\\n  --memorize --memorize-steps 2 \\\n  --memorize-paths titan,cms_fast \\\n  --memorize-surprise-threshold 0.02 \\\n  --output eval/continual_mid.json\n```\n\nWith memorization enabled the output includes baseline vs. memorize cross-entropy, Titan/CMS update stats per segment, the active memory paths, and the surprise threshold used. Adjust `--memorize-paths` (comma-separated) to restrict which levels update (e.g., `titan` only, or `titan,cms_fast`) and `--memorize-surprise-threshold` to replicate the paper’s surprise gating.\n\nNote: memorization uses a per-context fast state by default, so evaluation does not mutate checkpoint weights.\n\nTo visualize forgetting curves:\n\n```bash\nuv run python scripts/eval/plot_forgetting.py \\\n  --continual-json eval/continual_mid.json \\\n  --segment refinedweb_2018 \\\n  --output reports/plots/continual_mid_refinedweb.png\n```\n\nThe plot overlays baseline vs. memorize loss across checkpoints for the chosen segment. For full-scale runs, replace the sample YAML with the production segment list (e.g., chronological Wikipedia shards, MAWI sequences, etc.) and archive both the JSON and plot in your checkpoint report.\n"
  },
  {
    "path": "docs/data_pipeline.md",
    "content": "# Data Pipeline (Stage 2)\n\nThis document explains how to generate tokenizer artifacts and token shards for Stage 2 training.\n\n## Prerequisites\n- Ensure the `uv` environment is synced (`uv sync --all-extras`).\n- Large storage mounted at `data/raw/` and `data/shards/`.\n- HF datasets cache configured with valid credentials if accessing gated sets.\n\n## Dataset acquisition & licensing\nThe Stage 2 mixture mimics RefinedWeb + supplements. Download each source into `data/raw/<source>/` and document provenance before filtering.\n\n| Source | License / Terms | Acquisition Command(s) | Notes |\n|--------|-----------------|------------------------|-------|\n| RefinedWeb / FineWeb proxy | CC BY 4.0 (FineWeb) | `uv run python scripts/data/shard_corpus.py --dataset HuggingFaceFW/fineweb --subset sample-10BT --split train --output data/raw/refinedweb.ndjsonl --limit 20000000` | Keep a copy of the HF dataset card; respect scraping policies. |\n| FineWeb-Edu | CC BY 4.0 (FineWeb) | Use `HuggingFaceFW/fineweb-edu` (e.g., `subset=sample-10BT`) via `scripts/data/filter_corpus.py` + `scripts/data/process_mixture.py`. | Paper-aligned option; prefer long-doc filtering if matching the paper’s setup. |\n| Wikipedia 2023-12 dump | CC BY-SA 3.0 | Download `https://huggingface.co/datasets/wikipedia/20220301.en` via HF CLI or mirror the XML dump. | Use HF `datasets load_dataset` inside the filtering script to avoid storing raw XML. |\n| C4 (en) | ODC-By | `uv run python scripts/data/shard_corpus.py --dataset allenai/c4 --subset en --split train --output data/raw/c4_en.ndjsonl --limit 8000000` | Heavy dataset; ensure disk quota before streaming. |\n| RedPajama CC subset | CC BY | Use `togethercomputer/RedPajama-Data-1T-Sample` or the CC subset tarballs. | Store gzipped JSONL files under `data/raw/redpajama/*.jsonl.gz`. |\n| Code (Stack/Python mix) | Mostly MIT/Apache | Pull from `bigcode/starcoderdata` shards or permissively licensed repos. | Preserve LICENSE metadata per shard (`data/raw/code/LICENSES.md`). |\n\nEvery corpus contribution is tracked in `data/manifest/refinedweb_full_manifest.json`. Regenerate or edit this manifest whenever the mixture changes so downstream runs can validate shard presence and licensing.\n\nTo verify the manifest against local shards:\n\n```bash\nuv run python scripts/data/validate_mixture.py \\\n  --manifest data/manifest/refinedweb_full_manifest.json \\\n  --output data/mixtures/refinedweb_mix_manifest_report.json\n```\n\nAll raw pulls should include a short README describing the source URL, date retrieved, and any filters applied. Update `docs/data_pipeline.md` whenever the mix changes so downstream users know which corpora are safe to redistribute.\n\n## 1. Train tokenizer (multi-corpus manifest)\n\n```bash\nuv run python scripts/data/train_tokenizer.py \\\n  --manifest configs/data/refinedweb_mixture.yaml \\\n  --vocab-size 32000 \\\n  --output-dir artifacts/tokenizer/refinedweb_mix \\\n  --log-file data/mixtures/refinedweb_mix_tokenizer.json\n```\n\nThe manifest pulls small samples from FineWeb (RefinedWeb proxy), Wikimedia/Wikipedia, AllenAI C4, SlimPajama, and codeparrot code datasets. Outputs live in `artifacts/tokenizer/refinedweb_mix/`.\n\n### Sample pipeline note (hard vocab limit)\nWhen training a tokenizer on **tiny local samples**, SentencePiece can fail if it cannot reach the requested `--vocab-size` (default `hard_vocab_limit=true`).\n\nFor the repo’s sample pipeline (`scripts/data/run_sample.sh`), we disable this check:\n\n```bash\nuv run python scripts/data/train_tokenizer.py ... --no-hard-vocab-limit\n```\n\nFor “paper-faithful” runs, prefer training on a sufficiently large corpus and keep the default `--hard-vocab-limit`.\n\n### Tokenizer checksum\nRecord the checksum of every published tokenizer so collaborators can verify integrity before launching runs.\n\n```bash\nuv run python scripts/data/check_tokenizer.py \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --expected-sha256 f8871517ca968839bf6b9595a6e7891e6b8c6a70fd4df788696bce35be62d6c2 \\\n  --metadata-json artifacts/tokenizer/refinedweb_mix/checksum.json\n```\n\nThe command prints the SHA-256 digest and writes a JSON record (optional). Keep the expected hash in this doc so CI/scripts can assert integrity. Update the hash whenever the tokenizer is retrained.\n\n### Coverage sanity check\nBefore publishing a tokenizer, capture coverage metrics on a representative sample:\n\n```bash\nuv run python scripts/data/check_tokenizer_coverage.py \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --sample-file data/filtered/refinedweb_en_sample.txt \\\n  --max-lines 5000 \\\n  --output data/mixtures/refinedweb_mix_tokenizer_coverage.json\n```\n\nThe script reports tokens/word, proportion of single-token words, and a histogram of piece lengths. Add the JSON to your release bundle so collaborators can verify coverage.\n\n#### Automated regression guard\nAdd a regression check to CI or pre-release automation to ensure coverage does not drift:\n\n```bash\nuv run python scripts/checks/tokenizer_coverage_guard.py \\\n  --baseline data/mixtures/refinedweb_mix_tokenizer_coverage.json \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --sample-file data/filtered/refinedweb_en_sample.txt \\\n  --max-lines 5000 \\\n  --output data/mixtures/refinedweb_mix_tokenizer_coverage_latest.json\n```\n\nThe guard fails if `avg_tokens_per_word` increases by more than `0.05` or if the single/two-token coverage drops by more than `2 %`. Adjust tolerances via CLI flags if a new tokenizer intentionally changes segmentation. Include the generated JSON in release bundles alongside the manifest validation report.\n\n## 2. Shard mixture components\n\n```bash\nuv run python scripts/data/process_mixture.py \\\n  configs/data/refinedweb_mixture_filtered.yaml \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --log-file data/mixtures/refinedweb_mix_filtered_shards.json\n```\n\nThis iterates over each dataset entry (either streamed from HF or the filtered local files), tokenizes at sequence length 2048, and writes NumPy shards to `data/shards/<dataset>`. Stats (records, sequences, shards, total tokens) are recorded in `data/mixtures/refinedweb_mix_shards_full.json`.\n\n## 3. Legacy pilot data\n- `data/shards/tinystories_train/` retains 1,718 shards for unit tests and smoke runs.\n\n## Troubleshooting Matrix\n\n| Symptom | Likely cause | Deterministic fix |\n|---|---|---|\n| `run_sample.sh` cannot find `artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model` | Tokenizer has not been trained yet | Re-run `uv run bash scripts/data/run_sample.sh`; it auto-trains tokenizer when missing |\n| `Bad split: train. Available splits: ['test']` | Dataset exposes a non-`train` split | Use fallback (`FALLBACK_SPLIT=test uv run bash scripts/data/run_full.sh`) or set per-source split env vars like `RW_SPLIT=test` |\n| `Bad split` in tokenizer/shard/filter scripts | Requested split absent in source dataset | Built-in fallback now resolves in order `train -> validation -> test -> first available` and logs available splits |\n| SentencePiece fails to hit requested vocab size on tiny corpora | `hard_vocab_limit=true` with too little data | Use `--no-hard-vocab-limit` for sample runs; keep hard limit for large production corpora |\n| Tokenizer coverage regresses between runs | Different corpus sample or tokenizer settings | Run `scripts/data/check_tokenizer_coverage.py` and `scripts/checks/tokenizer_coverage_guard.py` against baseline JSON |\n\n## 4. Filtering & deduplication\nBefore sharding full-scale corpora, run language filtering + dedup to keep only high-quality English segments:\n\n```bash\nuv run python scripts/data/filter_corpus.py \\\n  --dataset HuggingFaceFW/fineweb \\\n  --subset sample-10BT \\\n  --split train \\\n  --text-column text \\\n  --output-path data/filtered/fineweb_en.txt \\\n  --min-chars 200 \\\n  --max-chars 8000 \\\n  --lang-threshold 0.85\n```\n\nAdjust dataset/subset arguments per manifest entry. The script enforces language probabilities via `langdetect`, performs length screening, and deduplicates using a rolling hash window. Point `scripts/data/process_mixture.py` to these filtered files (or custom dataset definitions) for large-scale processing.\n\n## 4.1 FineWeb-Edu manifests (paper-aligned)\n\nThis repo includes two manifest recipes for FineWeb-Edu:\n- `configs/data/fineweb_edu_mixture_sample.yaml` (subset `sample-10BT`, bounded `max_records`)\n- `configs/data/fineweb_edu_mixture_full.yaml` (subset `sample-100BT`, `seq_len=4096`)\n\nTokenizer training:\n```bash\nuv run python scripts/data/train_tokenizer.py \\\n  --manifest configs/data/fineweb_edu_mixture_sample.yaml \\\n  --vocab-size 32000 \\\n  --output-dir artifacts/tokenizer/fineweb_edu \\\n  --log-file data/mixtures/fineweb_edu_tokenizer_samples.json\n```\n\nSharding:\n```bash\nuv run python scripts/data/process_mixture.py \\\n  configs/data/fineweb_edu_mixture_sample.yaml \\\n  --tokenizer-path artifacts/tokenizer/fineweb_edu/spm_32000_unigram.model \\\n  --log-file data/mixtures/fineweb_edu_sample_shards.json\n```\n\nIf you want to more closely mimic “long document” regimes, filter first (higher `min_chars` / `max_chars`)\nand then switch the manifest entry to `dataset: text` + `data_files: <filtered_file>`. The tokenizer and\nsharding scripts accept `data_files` and will enforce the requested split.\n\n### 4.1.1 FineWeb-Edu long-doc filtered sample (turnkey)\n\nFor a concrete, paper-aligned “long document” recipe, use:\n- `configs/data/fineweb_edu_longdoc_filtered_sample.yaml`\n\nStep 1 — create a filtered long-doc file (example settings; tune `min_chars`/`max_chars` to match your needs):\n\n```bash\nuv run python scripts/data/filter_corpus.py \\\n  --dataset HuggingFaceFW/fineweb-edu \\\n  --subset sample-10BT \\\n  --split train \\\n  --text-column text \\\n  --target-lang en \\\n  --lang-threshold 0.85 \\\n  --min-chars 2000 \\\n  --max-chars 20000 \\\n  --limit 5000 \\\n  --output-path data/filtered/fineweb_edu_longdoc_en_sample.txt \\\n  --force-exit\n```\n\nStep 2 — train a tokenizer on that filtered file:\n\n```bash\nuv run python scripts/data/train_tokenizer.py \\\n  --manifest configs/data/fineweb_edu_longdoc_filtered_sample.yaml \\\n  --vocab-size 32000 \\\n  --output-dir artifacts/tokenizer/fineweb_edu_longdoc \\\n  --log-file data/mixtures/fineweb_edu_longdoc_tokenizer_samples.json\n```\n\nStep 3 — shard into tokenized `.npy` shards:\n\n```bash\nuv run python scripts/data/process_mixture.py \\\n  configs/data/fineweb_edu_longdoc_filtered_sample.yaml \\\n  --tokenizer-path artifacts/tokenizer/fineweb_edu_longdoc/spm_32000_unigram.model \\\n  --log-file data/mixtures/fineweb_edu_longdoc_sample_shards.json\n```\n\nAll outputs (`data/filtered/`, `data/shards/`, `artifacts/tokenizer/`) are gitignored.\n\n## 5. Artifacts & stats\n- Tokenizer samples: `data/mixtures/refinedweb_mix_tokenizer.json`\n- Shard stats (pilot stream): `data/mixtures/refinedweb_mix_shards.json`\n- Shard stats (filtered sample run): `data/mixtures/refinedweb_mix_filtered_shards.json`\n- Shard stats (full filtered run, seq_len=2048): `data/mixtures/refinedweb_mix_shards_full.json`\n- Latest corpus verification log: `logs/data_inventory_2025-11-10.md` (matches `data/mixtures/refinedweb_mix_full_shards.json` with `verified_at_utc` timestamp).\n- Tokenizer model: `artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model`\n- Continual-learning sample segments: `configs/data/continual_segments_sample.yaml`\n\n## 6. Next steps\n- Integrate the full shards into the training configs (see `configs/hope/mid.yaml`, `configs/hope/target.yaml`).\n- Automate periodic re-generation (e.g., weekly) if new data arrives.\n- Version mixture manifests and stats under `configs/data/` as recipes evolve.\n"
  },
  {
    "path": "docs/env_matrix.md",
    "content": "# Environment Matrix – Stage 2\n\nThis document captures the exact runtime state used for the Stage 2 sprint so collaborators can reproduce the setup without guesswork.\n\n## 1. Runtime Summary\n\n| Component | Version | Notes / Verification |\n|-----------|---------|----------------------|\n| OS | Ubuntu 22.04 LTS (kernel 6.x) | `cat /etc/os-release` (see host) |\n| Python | 3.12.2 (conda-forge build) | `uv run python -V` |\n| uv | 0.9.8 | `uv --version` |\n| PyTorch | 2.9.0+cu128 | `uv run python -c \"import torch; print(torch.__version__)\"` |\n| torchvision | 0.24.0+cu128 | `uv run python -c \"import torchvision; print(torchvision.__version__)\"` |\n| torchaudio | 2.9.0+cu128 | `uv run python -c \"import torchaudio; print(torchaudio.__version__)\"` |\n| CUDA runtime | 12.8 (PyTorch wheels) | `uv run python -c \"import torch; print(torch.version.cuda)\"` |\n| NVIDIA driver | 550.90.07 | `nvidia-smi --query-gpu=name,driver_version --format=csv` |\n| GPUs | 2 × NVIDIA RTX 6000 Ada (49 GB) | Prefer `cuda:1` for single-GPU jobs |\n\n## 2. uv / Dependency Management\n- `pyproject.toml` + `uv.lock` pin all Python dependencies.\n- Sync command: `uv sync --all-extras`.\n- When installing torch 2.9 manually: `uv pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cu128`.\n- Cache guidance: set `UV_CACHE_DIR=/tmp/uv-cache` if default path lacks space.\n\n## 3. GPU Usage Notes\n- Default to `cuda:1` for long single-GPU training/eval to avoid interfering with tmux sessions pinned to GPU0.\n- Distributed jobs use `torchrun --nproc_per_node=2` with both GPUs.\n- Driver 550.90.07 + CUDA 12.4 runtime confirmed compatible with PyTorch 2.9.0/cu128 wheel; no additional toolkit install needed.\n- Enable `NCCL_IB_DISABLE=1` if networking errors appear (not observed yet).\n\n## 4. Verification Checklist\nRun the following snippet after provisioning a new machine to confirm parity:\n\n```bash\nuv --version\nuv run python -V\nuv run python - <<'PY'\nimport torch, torchvision, torchaudio\nprint('torch', torch.__version__, 'cuda', torch.version.cuda)\nprint('torchvision', torchvision.__version__)\nprint('torchaudio', torchaudio.__version__)\nprint('device0', torch.cuda.get_device_name(0))\nPY\nnvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader\n```\n\nRecord the outputs in `logs/env_checks/<date>.txt` before running large jobs.\n\n## 5. Known Good Combinations\n| Stack | Status | Notes |\n|-------|--------|-------|\n| torch 2.9.0 + torchvision 0.24.0 + CUDA 12.8 | ✅ | Current default; supports FlashAttention and Muon optimizers. |\n| torch 2.9.0 + torchvision 0.23.x | ❌ | Version mismatch; torchvision 0.23 expects torch 2.8. |\n| torch 2.5.0 + torchvision 0.20.0 | ✅ legacy | Use only if targeting older runs (no muon support). |\n\n## 6. Process\n1. Clone repo → `git clone https://github.com/kmccleary3301/nested_learning.git`.\n2. `cd nested_learning && uv sync --all-extras`.\n3. Verify versions via checklist above.\n4. Export `WANDB_API_KEY` from `git.env` (sourced manually) before training.\n5. Launch jobs via `uv run ...` to guarantee the locked environment.\n\nKeeping this matrix current prevents silent drifts when PyTorch or CUDA releases change. Update it whenever the `uv.lock` or driver stack changes.\n"
  },
  {
    "path": "docs/experiments_report.md",
    "content": "# Experiments Report – Nested Learning Reproduction\n\n_Draft covering work completed through 9 Nov 2025. This document is meant to accompany the initial public release so contributors understand what has been reproduced and what remains._\n\n---\n\n## 1. Overview\n- **Goal:** Reproduce key aspects of Google's Nested Learning (HOPE) architecture using public tooling (`uv`, PyTorch 2.9.0) and release a community-ready codebase.\n- **Hardware:** Dual RTX 6000 Ada (49 GB each). All long-running experiments in this report use a single GPU (`cuda:1`) to accommodate other projects on the host.\n- **Data:** Filtered RefinedWeb mixture (FineWeb, Wikipedia, C4, SlimPajama, CodeParrot). Sample pipeline (`scripts/data/run_sample.sh`) for smoke tests; full pipeline (`scripts/data/run_full.sh`) for larger runs. Tokenizer: SentencePiece unigram 32k.\n\n---\n\n## 2. Experimental Setup\n| Component | Details |\n|-----------|---------|\n| Framework | PyTorch 2.9.0 (LTS), CUDA 12.4 |\n| Dependency Mgmt | `uv` with `pyproject.toml` + `uv.lock` |\n| Logging | JSON logs under `logs/` (W&B optional but disabled for release) |\n| Training Driver | `train.py` (single GPU), `train_dist.py` (torchrun) |\n| Evaluation | `scripts/eval/zeroshot.py`, `scripts/eval/niah.py`, `scripts/eval/continual.py` |\n| Teach Signal | Outer teach signal derived from logits residual; scale/clip adjustable per config with runtime scheduling |\n\n### Key Configurations\n1. **HOPE Mid (single GPU)**\n   - Config: `configs/mid_stage2.yaml`\n   - Dim = 768, 18 layers, 12 heads, TITAN-level + CMS levels (fast/mid/slow/ultra)\n   - Teach schedule: warmup 60 steps, decay start 140, duration 80 (for 220-step run)\n   - Gradient clipping applied inside TITAN and CMS blocks\n\n2. **TITAN Baseline**\n   - Config: `configs/mid_titan_baseline.yaml` (`model.type=titan`)\n   - Same backbone (attention + TITAN memory) but no CMS/self-mod update path\n   - Teach schedule mirrors HOPE run to enable apples-to-apples comparison\n\n---\n\n## 3. Experiments\n\n### 3.1 Data Pipeline Validation\n| Command | Purpose |\n|---------|---------|\n| `uv run bash scripts/data/run_sample.sh` | Smoke-friendly filtering + sharding (RefinedWeb/Wiki/C4/SlimPajama/Code) |\n| `RW_LIMIT=20000 ... uv run bash scripts/data/run_full.sh` | Full pipeline (run in tmux `data_full`) to produce `_full` shards |\n| `uv run python scripts/data/process_mixture.py configs/data/refinedweb_mixture_full.yaml ...` | Re-sharding with SentencePiece tokenizer |\n\nArtifacts: `data/filtered/*_full.txt`, `data/shards/*_full`, stats in `data/mixtures/refinedweb_mix_full_shards.json`.\n\n- Manifest validation: `data/manifest/refinedweb_full_manifest.json` lists every corpus (shard dir, license, download URL). Running `uv run python scripts/data/validate_mixture.py --manifest ...` produces overlap and size stats (`data/mixtures/refinedweb_mix_manifest_report.json`) so we can spot missing/duplicate shards before training.\n- Tokenizer coverage: `scripts/data/check_tokenizer_coverage.py` now emits coverage JSON (`data/mixtures/refinedweb_mix_tokenizer_coverage.json`). On the filtered RefinedWeb sample the 32k unigram tokenizer averages 1.34 tokens/word with ~77% single-token words, confirming adequate coverage before scaling runs.\n\n### 3.2 HOPE vs TITAN (single GPU, 220 steps)\nAll runs below use batch size 4, optimizer LR 1e‑5, teach_scale 0.10, teach_clip 4.0, runtime schedule (warmup 60, decay 140→220). Commands launched via tmux to keep the CLI free.\n\n| Model | Checkpoint | PIQA (128) | Winogrande (128) | Notes |\n|-------|------------|------------|------------------|-------|\n| HOPE | `artifacts/checkpoints/mid_stage2_ts10_single220_schedD/step_000220.pt` | 0.469 | 0.594 | Loss drops from 10.55 → 8.55; NIAH still ~0 |\n| TITAN | `artifacts/checkpoints/mid_titan_baseline/step_000200.pt` | 0.469 | 0.594 | Loss similar; continuous memory absent |\n\nNIAH results (`eval/niah_mid_stage2_ts10_single220_schedD.json`, `eval/niah_mid_titan_baseline.json`) remain near random at 2k/4k tokens for both models. Continual-learning logs are finite but noisy (short runs). A longer training window is needed to expose the advantages cited in the paper (e.g., HOPE surpassing TITAN on long-context recall).\n\n### 3.3 Teach-Scale Sweep (short runs)\n| teach_scale | Configuration | Checkpoint | Final loss (step 40) |\n|-------------|---------------|------------|----------------------|\n| 0.05 | `logs/mid_stage2_single_ts05.json` | `artifacts/checkpoints/mid_stage2_single_ts05/step_000040.pt` | 9.81 |\n| 0.10 | `logs/mid_stage2_single_ts10.json` | `artifacts/checkpoints/mid_stage2_single_ts10/step_000040.pt` | 9.77 |\n| 0.20 | `logs/mid_stage2_single_ts20.json` | `artifacts/checkpoints/mid_stage2_single_ts20/step_000040.pt` | 9.76 |\n\nEven at 0.20, residual clipping kept the run stable, indicating headroom for larger teach scales once the data window grows.\n\n### 3.4 Dual-GPU Smoke (HOPE)\n| Command | Output |\n|---------|--------|\n| `uv run torchrun --nproc_per_node=2 train_dist.py --config-name mid_stage2_smoke` | `artifacts/checkpoints/mid_stage2_smoke/step_000060.pt`, `logs/mid_stage2_smoke.json` |\n| `uv run python scripts/eval/zeroshot.py ...` | `eval/zeroshot_mid_stage2_smoke.json` |\n| `uv run python scripts/eval/niah.py ...` | `eval/niah_mid_stage2_smoke.json` |\n| `uv run python scripts/eval/continual.py ...` | `eval/continual_mid_stage2_smoke.json` |\n\nThese runs validate the distributed training/eval path and are the recommended “smoke” workflows for contributors.\n\n### 3.5 Test-Time Memorization Harness\nHOPE/TITAN models now support TITAN-style test-time learning via shared CLI flags:\n\n```\nuv run python scripts/eval/zeroshot.py \\\n  --config configs/mid_stage2_smoke.yaml \\\n  --checkpoint artifacts/checkpoints/mid_stage2_smoke/step_000060.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --tasks piqa \\\n  --max-samples 32 \\\n  --output eval/zeroshot_mid_stage2_smoke_piqa_mem.json \\\n  --device cuda:1 \\\n  --memorize \\\n  --memorize-steps 2 \\\n  --memorize-use-correct-answer\n```\n\nNIAH and continual harnesses expose analogous options (`--memorize`, `--memorize-steps`, `--memorize-no-reset`, `--memorize-use-correct-answer`). The memorization loop replays the prompt (optionally augmented with the correct answer) through the teach-signal pathway before each eval query, letting us probe TITAN-style “learning at test time”.\n\nPilot PIQA example (32-sample subset, single GPU):\n\n| Mode | Command / Output | Accuracy |\n|------|------------------|----------|\n| Baseline | `eval/zeroshot_mid_stage2_smoke_piqa_baseline.json` | 0.5625 |\n| Memorize (prompt + answer, 2 steps) | `eval/zeroshot_mid_stage2_smoke_piqa_mem.json` | 0.5625 |\n\nAt this scale, memorization neither helps nor hurts, but the infrastructure is in place to replicate the substantial gains reported in HOPE/TITAN once longer contexts and richer checkpoints are available.\n\n### 3.6 Long-context diagnostics (pilot step 230k)\n- **Passkey retrieval (`eval/passkey_pilot_step230000.json`):** 64 prompts with 256 filler sentences each. Accuracy baseline vs memorize is flat at 0.484 while Titan updates average ~2.13 (CMS-fast disabled). This confirms the harness works but also shows we need longer training to see the passkey delta reported in the paper.\n- **PG-19 perplexity (`eval/pg19_pilot_step230000.json`):** Streaming PG-19 excerpts truncated to 2048 tokens yield PPL ≈ 2.5k for both baseline and memorize settings (4 samples). The script is part of the pilot suite so future checkpoints can report comparable long-form perplexities out-of-the-box.\n\n### 3.7 Continual forgetting plots\n`scripts/eval/continual.py` now records both baseline and memorize CE per segment. Running it on checkpoints `[5k, 10k, 230k]` and passing the JSON into `scripts/eval/plot_forgetting.py` produces `reports/plots/continual_pilot_refinedweb.png`, which shows continual CE dropping from ~48 at step 5k to ~8 at step 230k on the RefinedWeb segment (memorization on). These plots will accompany every checkpoint report going forward.\n\n### 3.8 Pilot (3 B tokens) – 230 k-step snapshot\n- **Config:** `configs/pilot.yaml` (dim 512, 12 layers, TITAN + CMS fast/mid/slow/ultra, teach_schedule warmup 2k → decay 120k→140k). Train batch = 6, seq_len = 2048, Muon optimizer, bf16 autocast + SDPA + `torch.compile`.\n- **Run status:** The HOPE pilot reached step 246 667 (≈3.0 B tokens). We package the step 230 000 checkpoint as the release artifact because it predates the LR cooldown and logged stable eval metrics.\n- **Metrics (memorization enabled, 256-sample cap per task):**\n\n  | Eval | HOPE (step 230k) | TITAN (step 9k, reference) |\n  |------|------------------|----------------------------|\n  | PIQA | **0.496** | 0.492 |\n  | HellaSwag | 0.297 | – |\n  | Winogrande | 0.473 | – |\n  | ARC-E / ARC-C | 0.285 / 0.234 | – |\n  | BoolQ | 0.367 | – |\n  | SIQA | 0.316 | – |\n  | CommonSenseQA | 0.180 | – |\n  | OpenBookQA | 0.113 | – |\n  | NIAH (2 k → 65 k) | 0.625 / 0.50 / 0.375 / 0.50 / 0.75 / 0.50 | 0.50 @ 2–8 k |\n  | Continual CE (RefinedWeb/Wiki/C4/RP) | 8.06 / 7.79 / 7.68 / 7.95 | 12–14 |\n\n- **Packaging:** `artifacts/pilot_release/` mirrors the 230 k checkpoint (`checkpoint.pt`), config snapshot, pilot logs, metadata with the 3 B-token goal, and eval JSONs (legacy step 22 k + new step 230 k). TITAN short-run metrics remain bundled.\n- **Next:** With both HOPE (step 230 k) and TITAN (step 25 k) packaged, the immediate tasks are (1) run the queued ablations (teach-scale, CMS chunking, optimizer swaps) on the HOPE checkpoint tree, and (2) extend evaluation coverage to larger configs before resuming the HOPE long run past 246 k steps.\n\n- **TITAN baseline (25 k steps):** The long run on `configs/mid_titan_baseline.yaml` wrapped at step 25 000 (`artifacts/checkpoints/mid_titan_baseline/step_025000.pt`, W&B `titan-long-20251113192738`). Fresh evals (memorization on, 256 max samples) show:\n\n  | Eval | TITAN (step 25k) |\n  |------|------------------|\n  | PIQA / HellaSwag / Winogrande | 0.484 / 0.293 / 0.480 |\n  | ARC-E / ARC-C / BoolQ / SIQA | 0.281 / 0.250 / 0.398 / 0.293 |\n  | CSQA / OpenBookQA | 0.188 / 0.145 |\n  | NIAH (2 k → 65 k) | 0.50 / 0.625 / 0.125 / 0.75 / 0.50 / 0.125 |\n  | Continual CE (RefinedWeb/Wiki/C4/RP) | 8.36 / 8.12 / 7.85 / 8.11 |\n\n  Outputs live in `eval/zeroshot_titan_step25000.json`, `eval/niah_titan_step25000.json`, `eval/continual_titan_step25000.json` (also copied into `artifacts/pilot_release/` alongside `titan_step_025000.pt`). These numbers now provide the matched baseline for HOPE step 230 k comparisons and upcoming ablations.\n\n\n---\n\n## 4. Observations & Lessons Learned\n1. **NaNs past 80 steps:** Early runs blew up after 80 steps once teach_scale exceeded 0.05. Introducing runtime scaling + residual clipping inside TITAN/CMS eliminated the NaNs and allowed 220-step runs on a single GPU.\n2. **Batch-size constraints:** With only one GPU, we reduced per-GPU batch to 4 to stay within 49 GB VRAM. DDP runs will need gradient checkpointing or FSDP to scale further.\n3. **NIAH is data hungry:** Every HOPE/TITAN run so far shows near-random recall at 2k/4k tokens; longer contexts and more tokens are required to differentiate architectures.\n4. **Teach signal scheduling:** A linear warmup (60 steps) followed by linear decay (start 140) kept the 220-step run stable. Future runs should explore cosine or per-level schedules.\n\n---\n\n## 5. Limitations\n- Current comparisons cover only the 160 M-scale HOPE/TITAN pair; larger configs (760 M / 1.3 B) remain untrained.\n- Scaling beyond the pilot is still blocked on additional compute + stability sweeps for teach_scale, CMS depth, and optimizer variants.\n- DDP/TITAN runs still rely on JSON logging; integration with structured logging (e.g., W&B) is deferred to future contributors.\n- Pipeline uses filtered RefinedWeb proxies; exact data parity with Google’s internal corpora is not guaranteed.\n\n---\n\n## 6. Next Steps\n1. **Longer Runs:** Extend both HOPE and TITAN baselines to millions of tokens using FSDP/DeepSpeed (target ≥760 M parameter config).\n2. **Eval Coverage:** Integrate full RAFT/ARC suite plus additional long-context datasets (Needle-in-a-Haystack 32k, PassKey tasks).\n3. **HPO:** Once stable runs exist, sweep teach_scale/clip, CMS depth, and self-mod learning rates to quantify HOPE vs TITAN gains.\n4. **Automation:** Add CI for data sampling + dual-GPU smoke to catch regressions, and consider nightly tmux scripts for longer training jobs.\n\n### 3.5 HOPE Pilot Relaunch (toward step 250 k, surprise-gated)\n\n- **Config:** `configs/pilot.yaml` with Muon outer optimizer, `nl_l2_precond` inner variant, `teach_scale=0.10`, `surprise_threshold=0.02`.\n- **Checkpoint:** `artifacts/checkpoints/pilot_relaunch/step_477000.pt` (verified via `scripts/checkpoint/verify.py`; sidecars stored alongside the checkpoint).\n- **Eval suite:** `eval/zeroshot_pilot.json`, `eval/niah_pilot.json`, `eval/continual_pilot.json`, `eval/passkey_pilot.json`, `eval/pg19_pilot.json`.\n- **Report:** `reports/checkpoints/pilot_relaunch_step477000.md`.\n- **Note:** with `surprise_threshold=0.02` the memorize harness recorded 0 update events on these short prompts, so memorization deltas are ≈0 (expected for this gated configuration).\n\n### 3.6 TITAN Long Baseline Relaunch (toward step 25 k)\n\n- **Config:** `configs/mid_titan_baseline.yaml`, `teach_scale=0.10`, `surprise_threshold=0.02`.\n- **Checkpoint:** `artifacts/checkpoints/mid_titan_long/step_032000.pt` (verified via `scripts/checkpoint/verify.py`; sidecars stored alongside the checkpoint).\n- **Eval suite:** `eval/zeroshot_titan.json`, `eval/niah_titan.json`, `eval/continual_titan.json`, `eval/passkey_titan.json`, `eval/pg19_titan.json`.\n- **Report:** `reports/checkpoints/titan_long_step32000.md`.\n- **Note:** with `surprise_threshold=0.02` the memorize harness recorded 0 update events on these short prompts, so memorization deltas are ≈0 (expected for this gated configuration).\n\n---\n\n## 7. References\n- `docs/stage2_progress.md` – running log of all Stage 2 work.\n- `docs/stability_journal.md` – chronological notes on NaN fixes, teach-scale tuning, tmux jobs.\n- `reports/stage2_smoke.md` – command cheat sheet for reproducing the smoke runs referenced here.\n\nThis report will be updated as we push beyond short runs and start reproducing the full metrics from Google's Nested Learning paper.\n"
  },
  {
    "path": "docs/future_directions.md",
    "content": "# Future Directions – Nested Learning Reproduction\n\nThis roadmap outlines high-impact areas for contributors once the initial public release is out. Items are organized by theme and roughly prioritized.\n\n---\n\n## 1. Scaling the Architecture\n1. **Longer Runs (≥3B tokens):** Use FSDP or DeepSpeed ZeRO to train the 760 M config on the filtered `_full` shards. Target at least 3B tokens so HOPE’s long-context advantages can emerge.\n2. **Target Config (1.3 B / 100 B tokens):** Prepare configs and launcher scripts for multi-node environments (Slurm, Kubernetes). Emphasize reproducible manifests and resume logic.\n3. **Context Expansion:** Integrate FlashAttention2 or block-sparse attention to push context lengths beyond 32k tokens. Update `scripts/eval/niah.py` accordingly.\n\n## 2. Evaluation & Analysis\n1. **Full Benchmark Suite:** Extend `scripts/eval/zeroshot.py` to include ARC-E/C, BoolQ, SIQA by default with standard prompts. Automate results aggregation into Markdown tables.\n2. **Long-Context Benchmarks:** Add Passkey, PG19, and retrieval tasks besides Needle-in-a-Haystack.\n3. **Continual Learning:** Create larger segment manifests (e.g., Wikipedia by year) and compute forgetting metrics across dozens of checkpoints.\n\n## 3. Optimization & HPO\n1. **Teach-Scale Scheduling:** Explore cosine or per-level schedules; integrate gradient clipping hyperparameters through Hydra sweeps.\n2. **Optimizer Variants:** Try Muon/DeepMomentum for TITAN/CMS updates. Compare against simple SGD/Adam baselines.\n3. **Automated Sweeps:** Wire up lightweight HPO (Ray Tune, Ax) for pilot configs to test teach_scale, clip, and CMS depth combinations.\n\n## 4. Data & Tooling\n1. **Dataset Expansion:** Add book/video/code corpora, ensure licensing compliance, and document provenance.\n2. **Tokenizer Experiments:** Evaluate alternative vocab sizes or SentencePiece BPE to see if certain domains benefit.\n3. **CI Enhancements:** Add GPU-aware smoke tests (e.g., GitHub self-hosted runner) to catch regressions in dual-GPU workflows.\n\n## 5. Documentation & Community\n1. **Release Notes:** Publish structured release notes with each tagged version (capabilities, limitations, roadmap).\n2. **Contributor Guides:** Document coding standards, logging conventions, and how to submit new configs/evals.\n3. **Experiment Tracking:** Encourage use of the `docs/experiments_report.md` template for all major runs to keep the public record up to date.\n\n---\n\nContributors are welcome to pick any of these items (or propose new ones) via GitHub issues or pull requests. Please cross-reference this file so efforts stay coordinated.*** End Patch\n"
  },
  {
    "path": "docs/phase2_comparison.md",
    "content": "# Phase 2 – HOPE-Attention vs Transformer Baseline\n\nPhase 2 is “implementation-complete” when we can compare the **paper-defined HOPE-Attention** variant\n(`Attention → CMS`) against a **standard Transformer** baseline (`Attention → MLP`) using the same\ntokenizer, context lengths, and evaluation harness.\n\nThis does **not** require paper-scale training; it’s intended for correctness/ergonomics and\nCPU-friendly smoke checks.\n\n## 0) Prerequisites\n\n- A SentencePiece tokenizer at `artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model`.\n  - If missing, run `uv run bash scripts/data/run_sample.sh` (see `docs/guide.md`).\n\n## 1) Smoke checkpoints (CPU)\n\nTrain two tiny smoke checkpoints from the same base config:\n\n```bash\n# HOPE-Attention smoke (paper-defined variant)\nuv run python train.py --config-name pilot_smoke \\\n  model.block_variant=hope_attention \\\n  model.qk_l2_norm=true model.local_conv_window=4 \\\n  train.checkpoint.dir=artifacts/checkpoints/pilot_smoke_attention \\\n  logging.path=logs/pilot_smoke_attention.json\n\n# Transformer baseline smoke\nuv run python train.py --config-name pilot_smoke \\\n  model.block_variant=transformer \\\n  model.qk_l2_norm=true model.local_conv_window=4 \\\n  train.checkpoint.dir=artifacts/checkpoints/pilot_smoke_transformer \\\n  logging.path=logs/pilot_smoke_transformer.json\n```\n\n## 2) Long-context comparison (CPU)\n\nUse the comparison runner (writes a single JSON with both results):\n\n```bash\nuv run python scripts/eval/compare_variants.py \\\n  --a-config configs/pilot_smoke.yaml \\\n  --a-checkpoint artifacts/checkpoints/pilot_smoke_attention/step_000010.pt \\\n  --b-config configs/pilot_smoke.yaml \\\n  --b-checkpoint artifacts/checkpoints/pilot_smoke_transformer/step_000010.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --device cpu \\\n  --smoke \\\n  --output eval/phase2_compare_smoke.json\n```\n\nFor larger GPU-backed pilots, use the dedicated Hydra configs:\n- `configs/hope/pilot_attention.yaml`\n- `configs/hope/pilot_transformer.yaml`\n\nand rerun the comparison script on the resulting checkpoints.\n\n## 3) Adaptation sanity check (no training)\n\nThis repo also includes a deterministic unit-level smoke that demonstrates **in-context adaptation**\nexists for `hope_attention` (via CMS fast-state updates) and is absent for `transformer`:\n\n```bash\nuv run pytest -q tests/test_phase2_memorization_delta.py\n```\n\nFor a standalone JSON output (no tokenizer/checkpoints required):\n\n```bash\nuv run python scripts/eval/phase2_memorization_delta_smoke.py --device cpu\n```\n"
  },
  {
    "path": "docs/release_checklist.md",
    "content": "# Release Checklist (Stage 2)\n\nUse this list before tagging/publishing any checkpoint bundle.\n\n## Faithfulness & Tests\n- [ ] `uv run pytest` (especially `tests/test_teach_signal.py`, `tests/test_cms.py`, `tests/test_optim.py`, `tests/test_memorization.py`)\n- [ ] Inner optimizer variant (`nl_l2_precond`) enabled in configs\n- [ ] Teach-signal log shows finite norms across recent steps\n- [ ] CMS chunk telemetry confirms expected update cadence\n- [ ] Run `uv run python scripts/checks/verify_update_cadence.py --log-path <json_log> --metric-prefix <layerN.cms.level> --total-tokens <T> --update-period <C> [--flush-partial] --output reports/cadence_<run>.json`\n- [ ] Run `uv run python scripts/checks/compliance_report.py --config <config.yaml> --cadence-report reports/cadence_<run>.json --output reports/compliance_<run>.json`\n- [ ] Run `uv run python scripts/checks/verify_docs_refs.py` (prevents doc/code reference drift)\n- [ ] `bash scripts/run_cpu_ddp_smoke.sh` (CPU DDP determinism)\n- [ ] `bash scripts/tests/run_passkey_smoke.sh` (synthetic memorization)\n\n## Artifacts\n- [ ] Checkpoint `.pt` + `.yaml` + `.meta.json` (with tokenizer hash) in `artifacts/...`\n- [ ] Checkpoint `.meta.json` includes `algorithm_mode` + online flags (`online_updates`, `online_boundary_targets`, `online_carry_attention_cache`, `use_fast_state`)\n- [ ] Tokenizer model + checksum JSON included\n- [ ] Eval JSON/CSV (zero-shot, NIAH, continual) appended to `eval/`\n- [ ] Checkpoint report filled from `docs/templates/checkpoint_report.md`\n- [ ] Long-context extras (passkey, PG-19) + forgetting plots saved (`eval/passkey_*.json`, `eval/pg19_*.json`, `reports/plots/*.png`)\n- [ ] Run `uv run python scripts/checkpoint/verify.py --checkpoint <path>` on every artifact\n\n## Data & Provenance\n- [ ] `data/manifest/refinedweb_full_manifest.json` updated if mixture changed\n- [ ] `scripts/data/validate_mixture.py --manifest ...` report archived\n- [ ] Tokenizer coverage JSON generated via `scripts/data/check_tokenizer_coverage.py`\n- [ ] Coverage guard run (`scripts/checks/tokenizer_coverage_guard.py`) and JSON attached\n\n## Logging & Monitoring\n- [ ] W&B run link recorded in report\n- [ ] Local JSON logs copied to `logs/`\n- [ ] Memorizations stats (surprise counts, Titan/CMS updates) summarized\n\n## Distribution\n- [ ] README references any new scripts/configs\n- [ ] Issue templates / release notes updated if new features shipped\n- [ ] (Optional) Outreach draft prepared in maintainer notes\n- [ ] Release manifest records algorithm mode + online flags for packaged checkpoints\n- [ ] Run `bash scripts/checks/check_git_tracked_sizes.sh` before push/tag (prevents large binaries/artifact extensions from being tracked)\n\nCheck these boxes before pushing tags or announcing new checkpoints so collaborators can reproduce results confidently.\n"
  },
  {
    "path": "docs/scaling_guidance.md",
    "content": "# Scaling Guidance – Nested Learning Reproduction\n\nThis document describes how to extend the current smoke-tested Nested Learning (HOPE) stack to larger datasets, hardware targets, and experiment scopes without changing the core codebase.\n\n---\n\n## 1. Hardware Tiers\n| Tier | GPUs | VRAM | Usage |\n|------|------|------|-------|\n| **Dev / Smoke** | CPU or 1× RTX 6000 Ada | 0–48 GB | Pipeline validation, unit/integration tests |\n| **Pilot** | 2× RTX 6000 Ada | 48 GB each | `configs/hope/pilot.yaml`, seq len ≤2K, ≤5 B tokens |\n| **Mid** | 4–8× RTX 6000 Ada (or single H200) | 48–80 GB | `configs/hope/mid.yaml`, seq len 4K, 30 B tokens |\n| **Target** | Future dual H200 | 141 GB each | `configs/hope/target.yaml`, seq len 8K+, 100 B tokens |\n\nRecommendations:\n- Prefer `cuda:1` on dual-GPU workstations to keep `cuda:0` free for interactive workloads.\n- Enable `torch.set_float32_matmul_precision(\"high\")` on Ampere+/Ada to benefit from BF16 kernels.\n- For >2 GPUs, switch to `train_dist.py` (DDP) or `train_fsdp.py` with `train.fsdp.auto_wrap_min_params` tuned to 2 M for HOPE blocks.\n\n---\n\n## 2. Storage & Data Layout\n| Corpus Slice | Sample (current) | Full target | Disk (approx) |\n|--------------|------------------|-------------|---------------|\n| RefinedWeb / FineWeb proxy | 2k docs → 4 shards | 4 B docs | 1.2 TB |\n| Wikipedia EN | 1k docs | Full dump | 70 GB |\n| C4 EN | 1k docs | 400 M docs | 300 GB |\n| SlimPajama | 1k docs | 600 B tokens | 450 GB |\n| CodeParrot clean | 1k files | 50 B tokens | 200 GB |\n\nScaling procedure:\n1. **Raw ingestion:** Stage compressed corpora under `data/raw/` (ensure ≥3 TB free for target runs).\n2. **Filtering:** Drive `scripts/data/run_full.sh` with env vars (e.g., `RW_LIMIT=1000000 WIKI_LIMIT=250000`) to control per-corpus document counts. The script wraps `filter_corpus.py` + `process_mixture.py` + tokenizer training for the `configs/data/refinedweb_mixture_full.yaml` manifest. Keep `--force-exit` to make failures loud.\n3. **Tokenizer retrain:** Rerun `scripts/data/train_tokenizer.py` with combined manifest once filtered corpora exceed ~100 M tokens to avoid domain skew. Store models under `artifacts/tokenizer/<mixture_name>/`.\n4. **Sharding:** Update `configs/data/refinedweb_mixture_filtered.yaml` with new `max_records` and `sequence_length` values (e.g., 2048 for pilot, 4096 for mid, 8192 for target). Run `scripts/data/process_mixture.py` pointing to new tokenizer and filtered text.\n5. **Stats:** Version log files in `data/mixtures/` (`*_shards.json`, `*_tokenizer.json`) for each scale; include total tokens, sequences, and shards to keep reproducibility audit trail.\n\n---\n\n## 3. Training Scale-Up\n1. **Pilot (≤160 M params):** Use `train.py --config-name hope/pilot` on dual RTX 6000. Set `train.device=cuda:1`, `train.checkpoint.enable=true`, `train.checkpoint.save_interval=500`. Expect ~3 mins per 100 steps on synthetic data.\n2. **Mid (≈760 M params):** Launch via `torchrun --nproc_per_node=2 train_dist.py --config-name hope/mid train.device=cuda --train.steps=5000`. Feed mixture shards from `data/shards/*_filtered`. Enable gradient checkpointing if activations exceed memory (see `model.gradient_checkpointing` flag in config).\n3. **Target (≥1.3 B params):** Prefer DeepSpeed or FSDP. Example:\n   ```bash\n   torchrun --nproc_per_node=2 train_fsdp.py --config-name hope/target \\\n     train.fsdp.auto_wrap_min_params=4000000 \\\n     train.checkpoint.enable=true train.checkpoint.dir=checkpoints/target\n   ```\n   When H200 cluster arrives, increase `train.data.seq_len` to 8192 and switch attention backend to FlashAttention (set `model.attn.impl=flash` in config).\n4. **Logging:** Route to W&B for long runs (`logging.backend=wandb`). For on-prem, keep JSON logs under `logs/<run>.json` and ship to artifact storage after each training window.\n5. **Optimizer tuning:** Stage 2 experiments call for Muon/DeepMomentum variants. Define new entries under `model.optimizers` keyed by level names; adjust `lr`, `beta` per level clock frequency.\n\n---\n\n## 4. Evaluation Expectations\n| Stage | Checkpoints | Eval commands |\n|-------|-------------|---------------|\n| Smoke | `artifacts/checkpoints/pilot_smoke/step_000010.pt` | `scripts/eval/zeroshot.py` with `--max-samples 32`, NIAH `--samples-per-length 3`, continual sample segments |\n| Pilot | Every 1k steps | `--tasks all`, `niah` with `context-lengths 4k 8k`, continual segments covering refinedweb/wikipedia/code |\n| Mid/Target | Every 1k steps (rotated) | Full tasks + ARC-C, BoolQ, SIQA entire validation, NIAH up to 32k tokens, continual across ≥5 segments |\n\nArchive outputs under `eval/<run>/<timestamp>.json` to make comparisons easy. Use the provided `eval/zeroshot_full_smoke.json`, `eval/niah_smoke.json`, and `eval/continual_smoke.json` as formatting references.\n\n---\n\n## 5. Roadmap to H200 Cluster\n1. **Data:** Mirror filtered corpora to shared storage; ensure tokenizer + shard manifests are versioned with commit SHAs.\n2. **Compute:** Port launcher scripts to Slurm or Kubernetes. Provide templates under `scripts/infra/` (TODO) with environment exports (`MASTER_ADDR`, `MASTER_PORT`).\n3. **Long-context:** Integrate block-sparse or state-space attention when sequence lengths exceed 32k tokens. Keep `context_lengths` in config and Hydra overrides to toggle.\n4. **Reliability:** Add resumption guidelines in `docs/release_plan.md` (FSDP resume path). Store checkpoints in object storage with lifecycle policies.\n\nThis document should be updated whenever new corpora, hardware, or launch scripts are introduced so contributors can quickly understand how to move beyond the smoke-tested baseline.\n"
  },
  {
    "path": "docs/spec_interfaces.md",
    "content": "# Interface Notes for Nested Learning Modules\n\n## LevelClock / LevelSpec (`nested_learning.levels`)\n- `LevelSpec`: name, update_period, warmup, jitter, optimizer binding.\n- `LevelClock`: tracks global step, exposes `should_update(name)` and `record_update(name)`; keeps timeline for logging.\n\n## AssocMemory (`nested_learning.assoc_memory`)\n- Abstracts retrieval (`forward`) and writeback (`update`). Concrete memories (TITAN, CMS) implement this along with optional `reset_state` from `SupportsReset` protocol.\n\n## CMS (`nested_learning.cms` – forthcoming)\n- Chain of MLPs with per-level clocks; includes `forward` for retrieval/composition and `maybe_update` for gated parameter updates.\n\n## TITAN Memory (`nested_learning.titan.memory` – forthcoming)\n- Learnable long-term memory approximating lucidrains implementation; provides `score_surprise` + `update` to support self-modifier pathways.\n\n## SelfModifier & Deep Optimizers (`nested_learning.hope.self_mod`, `nested_learning.optim.deep`) \n- SelfModifier: neural updater that emits parameter deltas based on (key, value, error).\n- Deep optimizers: generalize momentum/Adam with pluggable associative memories.\n\n## HOPE Block/Model (`nested_learning.hope.block`, `nested_learning.model`)\n- Composition of attention backbone, TITAN retrieval, CMS consolidation, and self-mod updates, assembled into full autoregressive model.\n\nThese notes ensure consistency across workstreams and act as inline documentation while implementing the remaining modules.\n"
  },
  {
    "path": "docs/sprint_next_plan.md",
    "content": "# Sprint Plan – Stage 2 Pilot & Results Sprint\n\n**Window:** Nov 10 – Nov 17, 2025 (7 days)  \n**Goal:** Produce reproducible pilot-scale HOPE checkpoints + evaluation packs, validate the data/infra path for mid-scale runs, and capture documentation that unlocks Stage 2 scaling.  \n**Success Criteria:**\n1. Pilot (≈160 M params, 3 B tokens) runs end-to-end with checkpoints + W&B logs.\n2. Zero-shot, NIAH, and continual-learning eval JSONs produced for the pilot checkpoint and compared to TITAN baseline.\n3. Data provenance + environment setup documented so collaborators can rerun without context.\n4. Reports updated with pilot metrics, open issues, and risk mitigations.\n\n## Constraints & Resources\n- **Hardware:** dual RTX 6000 Ada (48 GB) → default to `cuda:1` for single-GPU jobs; DDP uses both.\n- **Framework:** PyTorch 2.9 + torchvision 0.24 via `uv`.\n- **Data:** RefinedWeb/FineWeb proxy mix (`data/shards/*_full`) already filtered; tokenizer at `artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model`.\n- **Tracking:** W&B project `hope-stage2`; JSON logs under `logs/`.\n- **Blocking risks:** insufficient storage for 100 B tokens, instability when teach_scale > 0.1, and long eval runtimes (>30 min) that must run inside tmux.\n\n## Workstreams & Detailed TODOs\n\n### P0 – Faithfulness & Critique Remediation (must finish before P1+)\n| ID | Task | Subtasks | Acceptance |\n|----|------|----------|------------|\n| F1 | **Weight tying + teach-signal fix** | (a) Tie `lm_head.weight` to `embed.weight` in HOPE + TITAN models. (b) Rewrite `compute_teach_signal` to time-shift targets and use the head weight. (c) Emit teach-signal norms in logs. (d) Add `tests/test_teach_signal.py` verifying finite-difference gradient vs analytic teacher. | CI test passes; pilot smoke log contains `teach_signal_norm` with finite values. |\n| F2 | **CMS chunk accumulation** | (a) Introduce per-level buffers sized to `update_period`. (b) Trigger optimizer updates only when buffer fills, then clear. (c) Add telemetry (count, L2 magnitude). (d) Unit test that `update_period=3` produces exactly one update every three ticks. | Unit test green; logs show stepped CMS updates aligned with periods. |\n| F3 | **L2 regression inner rule (Eq. 27–29)** | (a) Add `variant=\"nl_l2_precond\"` to deep optimizer. (b) Plumb activations into optimizer context. (c) Config flag + defaults in `configs/hope/*.yaml`. (d) Toy regression test demonstrating objective decrease. | Test passes; pilot smoke with `variant` enabled logs preconditioner stats. |\n| F4 | **Test-time memorization path** | (a) Add CLI flags to eval scripts. (b) Implement Titan memory updates during eval when `memorize=True`. (c) Optional CMS fast-level updates. (d) Synthetic integration test verifying improved accuracy on memorization-enabled run. | `scripts/tests/test_memorization_eval.py` green; eval JSONs contain `_memorize` variants. |\n| F5 | **PyTorch perf upgrades (SDPA, autocast, compile, fused AdamW)** | (a) Replace attention with manual QKV + SDPA (Flash support). (b) Wrap train step in `torch.autocast(..., dtype=torch.bfloat16)` with config switch. (c) Add guarded `torch.compile`. (d) Enable fused AdamW. | Pilot smoke runtime comparison recorded; fallback path works on CPU (smoke test). |\n| F6 | **Muon outer optimizer option** | (a) Detect `torch.optim.Muon`. (b) Split param groups (matrices→Muon, others→AdamW). (c) Config knob `optim.outer.type`. (d) Document trade-offs + log metrics. | Pilot smoke completes with `optim.outer.type=muon_mix`; README/env docs updated. |\n| F7 | **Seeding + backend robustness** | (a) Hydra-level `seed` field; set Python/NumPy/Torch seeds. (b) DataLoader worker init + manual seed for rng. (c) Auto-select DDP backend; allow override. (d) Add CPU DDP CI job. | Two identical runs (same seed) produce identical log traces; CI includes CPU DDP smoke. |\n| F8 | **License & packaging polish** | (a) Align `pyproject.toml` license with Apache-2.0. (b) Ensure referenced scripts ship (`scripts/run_e2e_smoke.sh`). (c) README update for memorization, Muon, env. | Lint job verifies license metadata; README instructions reviewed/approved. |\n\n### 1. Data & Environment Readiness\n| ID | Task | Details | Owner | Status | Deliverable |\n|----|------|---------|-------|--------|-------------|\n| D1 | Corpus inventory | Verify presence + integrity of `data/shards/*_full` (RefinedWeb, Wikipedia, C4, RedPajama, Code) and update stats in `data/mixtures/refinedweb_mix_full_shards.json`. | KM | Not Started | Updated JSON + short log |\n| D2 | Provenance doc | Extend `docs/data_pipeline.md` with acquisition commands, licensing notes, and shard counts per corpus. | KM | Not Started | PR-ready doc |\n| D3 | Tokenizer lock | Record checksum + training command for `artifacts/tokenizer/...32k.model` in `docs/data_pipeline.md`; script to assert checksum before runs. | KM | Not Started | Script `scripts/data/check_tokenizer.py` + doc snippet |\n| D4 | Env matrix | Confirm `uv.lock` captures torch 2.9 stack; add `docs/env_matrix.md` describing GPU driver, CUDA runtime, `uv` commands, and fallback instructions. | KM | Not Started | Env doc + verified `uv pip list` diff |\n\n### 2. Pilot Training Execution\n| ID | Task | Details | Owner | Status | Deliverable |\n|----|------|---------|-------|--------|-------------|\n| P1 | Upgrade `configs/hope/pilot.yaml` | Move from synthetic toy config to 160 M spec (layers=12, dim=768, seq=2048, batch ladder). Include teach schedule, optimizer, data mixture hook. | KM | Not Started | Updated YAML + changelog |\n| P2 | Dry-run smoke | `uv run python train.py --config-name hope/pilot --train.steps=50 --train.device=cuda:1` using sample shards to confirm stability/logging. | KM | Not Started | Log `logs/pilot_smoke.json` |\n| P3 | Full pilot launch | tmux-managed torchrun/DeepSpeed to 3 B tokens (≈150k steps). Enable checkpoints every 1k steps in `artifacts/checkpoints/pilot`. | KM | Not Started | Checkpoints + `logs/pilot_full.json` + W&B link |\n| P4 | Monitoring hooks | Ensure gradient/teach-scale stats captured (level update magnitudes, CMS norms). Implement metrics in `src/training/callbacks.py`. | KM | Not Started | Callback code + metrics in log |\n| P5 | Artifact packaging | Copy best checkpoint, config snapshot, log, and eval metadata to `artifacts/pilot_release/`. | KM | Not Started | Structured folder for sharing |\n\n### 3. Evaluation & Baselines\n| ID | Task | Details | Owner | Status | Deliverable |\n|----|------|---------|-------|--------|-------------|\n| E1 | Zero-shot sweep | Run `scripts/eval/zeroshot.py` on pilot + TITAN baseline checkpoints (PIQA, HellaSwag, WinoGrande, ARC-E/C, BoolQ, SIQA, OpenBookQA). Store JSON under `eval/zeroshot_pilot_v1.json`. | KM | Not Started | Eval JSON + summary table |\n| E2 | NIAH curve | Expand `scripts/eval/niah.py` to 2k→64k contexts, add CLI for seeds/batch. Plot accuracy vs. length and save PNG + CSV in `reports/plots/niah_pilot.*`. | KM | Not Started | CSV + plot |\n| E3 | Continual-learning bench | Finalize `scripts/eval/continual.py` to iterate through `configs/data/continual_segments_full.yaml`, log forgetting metrics, and compare to TITAN baseline. | KM | In Progress (per scaffolding) | JSON `eval/continual_pilot.json` + diff vs. baseline |\n| E4 | Baseline rerun | Reproduce TITAN-only run matching pilot data/time (use `configs/mid_titan_baseline.yaml` adjusted for 160 M). Document differences and store checkpoints. | KM | Not Started | Baseline checkpoint + eval JSON |\n\n### 4. Documentation & Reporting\n| ID | Task | Details | Owner | Status | Deliverable |\n|----|------|---------|-------|--------|-------------|\n| R1 | Update `docs/experiments_report.md` | Add pilot training summary, metrics tables, and comparison vs. TITAN. Include open issues + next actions. | KM | Not Started | Updated report |\n| R2 | `docs/stage2_progress.md` refresh | Append sprint log entries (date, what ran, pointers to artifacts). | KM | Not Started | Progress section |\n| R3 | `reports/ablations.md` | Stub section for pilot-scale ablations (self-modifier on/off, CMS depth). Outline command templates even if runs pending. | KM | Not Started | Markdown updates |\n| R4 | Release checklist | Update `docs/release_checklist.md` (or add if missing) with pilot deliverables, git tags, and artifact verification steps. | KM | Not Started | Checklist doc |\n\n### 5. Outreach & Coordination\n| ID | Task | Details | Owner | Status | Deliverable |\n|----|------|---------|-------|--------|-------------|\n| O1 | Issues roadmap | Open GitHub issues for P1–P5, E1–E4, R1–R4 to invite contributions; include artifact links. | KM | Not Started | Issue list |\n| O2 | README updates | Highlight pilot deliverables, add “How to reproduce pilot run” section with commands + data requirements. | KM | Not Started | README diff |\n| O3 | Community sync | Draft short update in `docs/POSTS.md` for Discord/Twitter after pilot results land (link to dashboards). | KM | Not Started | Draft |\n\n## Execution Order & Dependencies\n1. **D1–D4** unblock everything; complete before P1.\n2. **P1 → P2 → P3** sequential; P4 instrumentation can merge during P2/P3 once metrics tested.\n3. **P3 completion** gates E1–E3; E4 can run in parallel using existing baseline config.\n4. **R1–R4** depend on eval outputs; draft skeletons early to keep pace.\n5. Outreach tasks (O1–O3) happen once initial pilot artifacts exist.\n\n## Tracking & Reporting\n- Update `TODO.md` and this sprint doc daily with status (☐/△/✓).\n- Maintain tmux session names (`pilot_full`, `pilot_eval`, `pilot_baseline`) and log paths.\n- Push commits frequently; tag `v0.2.0-pilot` when criteria met.\n- Capture blockers in `docs/stage2_progress.md` so future shifts have context.\n\n## Risks & Mitigations\n| Risk | Mitigation |\n|------|------------|\n| Pilot instability when teach_scale > 0.05 | Implement gradient clipping per level, adaptive schedules, and fallback to 0.05 if divergence occurs. |\n| Data storage pressure during 3 B-token run | Stream shards using lazy loader; clean intermediate caches under `tmp/` after runs. |\n| Eval runtime (NIAH up to 64k) | Batch contexts, reuse cached passkeys, and run in tmux `pilot_eval`. |\n| Artifact drift | Snapshot `uv.lock`, configs, and tokenizer hash into `artifacts/pilot_release/metadata.json`. |\n\n---\n\nThis sprint plan is self-contained; executing the tasks above will deliver a fully documented pilot-scale reproduction plus the infrastructure needed for mid-scale Stage 2 runs.\n"
  },
  {
    "path": "docs/stage2_plan.md",
    "content": "# Stage 2 Plan – Nested Learning (HOPE) Results Reproduction\n\nThis document details Stage 2 goals: reproduce the key experimental results from Google’s Nested Learning (HOPE) paper/blog using the Stage 1 codebase. It is self-contained and assumes Stage 1 deliverables (architecture, training harness, tests, `uv` environment) are ready.\n\n**Status update (Jan 2026):** pilot relaunch and baseline checkpoints have been packaged and evaluated (HOPE `artifacts/checkpoints/pilot_relaunch/step_477000.pt`, TITAN `artifacts/checkpoints/mid_titan_long/step_032000.pt`). See `reports/checkpoints/pilot_relaunch_step477000.md`, `reports/checkpoints/titan_long_step32000.md`, and the latest eval JSONs under `eval/`.\n\n---\n\n## 1. Objectives\n- Train HOPE models at multiple scales (pilot 160 M, target 760 M and 1.3 B parameters) using public corpora approximating the paper’s data mix.\n- Reproduce headline metrics: perplexity on pretraining validation, zero-shot scores on reasoning benchmarks, long-context recall, and continual-learning forgetting curves.\n- Provide complete experiment artefacts: configs, logs, checkpoints, evaluation scripts, and analysis notebooks.\n\n---\n\n## 2. Scope\n\n### 2.1 In scope\n- Data pipeline build-out (tokenization, sharding, streaming) for ≥100 B tokens.\n- Distributed training scripts (FSDP/DeepSpeed) with logging + checkpointing.\n- Evaluation harnesses for LM, QA/reasoning, long-context (NIAH variants), and continual-learning tasks.\n- Ablations mirroring the paper (self-modifier toggles, CMS depth, optimizer variants, attention replacements).\n\n### 2.2 Out of scope\n- Non-HOPE architectures beyond the paper’s baselines (Transformer, TITAN, SAMBA, DeltaNet) except insofar as comparisons require reimplementation.\n- Deployment/serving.\n- Hyper-parameter sweeps beyond reproducing reported configs.\n\n---\n\n## 3. Data & Tokenization\n\n| Component | Choice | Notes |\n|-----------|--------|-------|\n| Tokenizer | SentencePiece unigram 32k (shared with Stage 1) | Train on combined corpus below; manifest in `configs/data/refinedweb_mixture.yaml` |\n| Base corpus | RefinedWeb / FineWeb proxy (≈600 B tokens) | Deduplicated, filtered for quality |\n| Supplements | Books3 (if license permits), Stack/Code subset, Wikipedia, C4, RedPajama CC | Provide balanced mixture to mimic broad-domain data |\n| Reasoning eval data | PIQA, HellaSwag, WinoGrande, ARC-E/C, SIQA, BoolQ | Use HF datasets; no training on eval splits |\n| Long-context eval | Needles-in-a-Haystack (Passkey/Number/Word), PG19, NarrativeQA | Scripts to synthesize passkey tasks to 512k tokens |\n| Continual tasks | Streaming Wikipedia by year + domain shift (news → code → conversations); synthetic permuted classes for stress-test | Track forgetting via accuracy drop on earlier segments |\n\nData pipeline tasks:\n1. Acquire corpora (cc-by or permissible) into `data/raw/`.\n2. Normalize & filter (language detection, length bounds, dedup).\n3. Train tokenizer, store at `artifacts/tokenizer/spm_unigram_32k.model`.\n4. Shard dataset into binary `.bin` or HF streaming format with 2048-token sequences + metadata.\n\n---\n\n## 4. Training Strategy\n\n### 4.1 Model scales\n| Name | Params | Layers | Dim | Heads | Sequence | Tokens |\n|------|--------|--------|-----|-------|----------|--------|\n| Pilot | 160 M | 12 | 512 | 8 | 2k | 3 B |\n| Mid | 760 M | 24 | 1024 | 16 | 4k | 30 B |\n| Target | 1.3 B | 32 | 1536 | 24 | 8k | 100 B |\n\nLevel schedules (example):\n- TITAN level: update every 8/16/32 steps for Pilot/Mid/Target.\n- CMS levels: {fast = 1, mid = 4, slow = 32, ultra = 128} update periods, gated by warmups.\n\n### 4.2 Optimizers\n- Outer weights: AdamW (β1=0.9, β2=0.95 for high LR stability), cosine decay, warmup 2k steps.\n- Inner memories: DeepMomentum variants (preconditioned momentum for TITAN, DMGD for CMS).\n- Gradient clipping: 1.0 outer, 0.3 inner deltas; update dropout 0.2 on self-mod outputs.\n\n### 4.3 Distributed setup\n- Framework: PyTorch FSDP (full-shard) or DeepSpeed ZeRO-3.\n- Precision: BF16 activations/weights, FP32 master weights; optional FlashAttention for context.\n- Checkpointing every 1k steps with partitioned state (model + optimizer + level clocks).\n- Logging via WandB/MLflow with structured metrics (loss, ppl, level update magnitudes, memory norms).\n\n### 4.4 Curriculum\n1. Pilot run on 3 B tokens to validate pipeline, run ablations quickly (<=12 GPU-days).\n2. Scale to 30 B tokens (760 M) once metrics stable; capture full eval suite.\n3. Final 1.3 B / 100 B run with refined hyper-params and longer contexts (8k); integrate long-context tasks in training via mixture-of-lengths.\n\n---\n\n## 5. Evaluation Plan\n\n### 5.1 Language Modeling\n- Validation perplexity on held-out RefinedWeb shards plus WikiText-103.\n- Log per-domain ppl to monitor forgetting when streaming.\n\n### 5.2 Zero-shot Benchmarks\n- Implement script `scripts/eval/zeroshot.py` pulling HF datasets (initial PIQA support, extend to others).\n- Metrics: accuracy for PIQA/HellaSwag/WinoGrande/ARC-E/C/SIQA/BoolQ; match table from paper.\n\n### 5.3 Long-context (NIAH)\n- Generate custom sequences via `scripts/eval/niah.py` (currently scaffolds synthetic pass-key prompts); extend to context lengths up to 512k tokens.\n- Evaluate recall accuracy vs. context length; compare HOPE vs. Transformer/TITAN baseline checkpoints.\n\n### 5.4 Continual Learning\n- Streaming tasks: sequential corpora (e.g., Wiki by year). After each segment, evaluate on all previous segments to compute average forgetting (Δ accuracy/perplexity). Use `scripts/eval/continual.py` with a segments YAML describing shard directories and checkpoint ordering.\n- Use HOPE’s level stats to correlate update frequency with forgetting reduction.\n\n### 5.5 Ablations\n1. Self-modifier disabled.\n2. CMS depth variations (k=1 vs. 3 vs. 4 levels).\n3. Deep optimizer variants per level.\n4. Attention backbone swap (full vs. sliding-window vs. DeltaNet).\n\nEach ablation run uses Pilot scale unless specified; record metrics in `reports/ablations.md`.\n\n---\n\n## 6. Deliverables & Acceptance Criteria\n| Deliverable | Criteria |\n|-------------|----------|\n| Data pipeline | Scripts in `scripts/data/`, tokenizer artefacts, documentation; reproducible shards |\n| Training configs | Hydra YAMLs under `configs/hope/{pilot,mid,target}.yaml`; include optimizer, level schedules |\n| Distributed training scripts | `train_dist.py`, launchers for FSDP/DeepSpeed with resume support |\n| Evaluation suite | CLI tools for LM, zero-shot, NIAH, continual forgetting; CI test on small checkpoints |\n| Reports | Markdown/Notebook summaries of metrics vs. baselines; highlight deviations |\n\n---\n\n## 7. Work Breakdown (Stage 2)\n1. **Data Engineering** – ingest/filter/pack corpora; train tokenizer; unit tests for sharding.\n2. **Infra & Configs** – Hydra config tree, logging integration, distributed launcher templates.\n3. **Scaling Training** – pilot → mid → target runs; monitor; adjust hyper-params.\n4. **Evaluation** – implement LM + zero-shot harness, NIAH generator, continual-learning scripts.\n5. **Ablations & Analysis** – run targeted toggles, plot results, compare to paper.\n6. **Documentation & Release** – write experiment logs, dataset README, reproduction checklists. Keep `docs/release_checklist.md` updated and treat it as the gate for tagging/publishing checkpoints.\n\nEach workstream tracked via `TODO.md` or issue tracker; dependencies: (1) before (3/4), etc.\n\n---\n\n## 8. Timeline (indicative)\n| Week | Milestone |\n|------|-----------|\n| 1 | Data pipeline + tokenizer complete; pilot configs ready |\n| 2 | Pilot training + ablations; evaluation harness validated |\n| 3–4 | Mid-scale (760 M) training + zero-shot/NIAH evals |\n| 5–6 | Target (1.3 B) training, long-context + continual learning results |\n| 7 | Ablations finalized, comparison vs. baselines, publish report |\n\n---\n\n## 9. Risks & Mitigations\n| Risk | Mitigation |\n|------|------------|\n| Dataset licensing/availability | Stick to permissive corpora; document provenance |\n| Compute instability at 100 B tokens | Use gradient checkpointing, monitor memory, schedule restarts |\n| Eval drift vs. paper | Match prompt templates from Eleuther harness; verify tokenization alignment |\n| Long-context efficiency | Integrate FlashAttention2 or block-sparse attention for >32k tokens |\n| Continual learning metrics noisy | Average over multiple seeds; use bootstrapped confidence intervals |\n\n---\n\n## 10. Exit Criteria\n- Matching (or within tolerance of) reported perplexity and zero-shot accuracy at 760 M and 1.3 B.\n- Demonstrated long-context recall advantage over Transformer baseline at ≥256k tokens.\n- Documented continual-learning improvements (reduced forgetting) with plots.\n- All scripts/configs reproducible via `uv run` workflows; README updated with instructions.\n"
  },
  {
    "path": "docs/stage2_progress.md",
    "content": "# Stage 2 Progress\n\nLast updated: `2026-02-24`\n\n## Sprint Status\n\n- **A-series (algorithm-mode + boundary-state fidelity):** Done\n- **B-series (docs/usability/data-script robustness):** Done\n- **C-series (cadence/compliance/mechanism tests):** Done\n- **D-series (security/release hygiene gates):** Done\n- **E-series (paper-compliance reconciliation + reproducibility):** Done\n- **F-series (final validation + reporting):** In progress (documentation/report closure)\n- **P0-series (packaging/CLI/runtime portability foundation):** Done (`nl doctor/smoke/train/audit`, `python -m nested_learning`)\n- **P1-series (distribution/CI/release scaffolding):** Done (compat matrix, pip-first README, cross-platform smoke CI, release workflow)\n\n## Done Criteria\n\n1. Boundary-state path is runnable, guarded, and explicitly marked experimental.\n2. Paper-faithful configs are explicit and test-covered.\n3. Checkpoint metadata and release manifest include algorithm + online flags.\n4. Data scripts have deterministic split fallback and `--help` smoke checks in CI.\n5. Docs reference checks validate both file paths and markdown anchors.\n6. Security gates block accidental binary/artifact tracking.\n7. Compliance reports are generated from current configs and included in `reports/`.\n\n## Validation Snapshot\n\n- `uv run ruff check .` -> pass\n- `uv run mypy src` -> pass\n- `bash scripts/checks/run_fidelity_ci_subset.sh` -> pass\n- `uv run pytest -q` -> pass\n\n## Generated Compliance Artifacts\n\n- `reports/compliance_summary_pilot_paper_faithful.json`\n- `reports/compliance_summary_pilot.json`\n- `reports/cadence_mechanism_audit_smoke.json`\n- `reports/compliance_mechanism_audit_smoke.json`\n- `reports/security_release_gate.md`\n"
  },
  {
    "path": "docs/templates/checkpoint_report.md",
    "content": "# Checkpoint Report Template\n\nCopy this template into `reports/checkpoints/<run>.md` (or similar) for every published checkpoint.\n\n## 1. Run Summary\n- **Model / Config:** (e.g., HOPE pilot, `configs/pilot.yaml`)\n- **Checkpoint path:** `artifacts/checkpoints/...`\n- **Hydra overrides:** `...`\n- **Tokens seen / steps:** (e.g., 3 B tokens / 230 k steps)\n- **Outer optimizer:** (Muon/AdamW + settings)\n- **Inner optimizer variant:** (`nl_l2_precond`, etc.)\n- **Teach schedule:** warmup/decay parameters\n\n## 2. Environment\n- Git commit SHA\n- PyTorch / CUDA / cuDNN versions\n- `uv.lock` hash\n- Tokenizer path + SHA256\n\n## 3. Training Metrics\n- Plot or table for loss/ppl vs step (include teach-signal norm)\n- Gradient norms (global + per-level if available)\n- Notable events (OOM retries, restarts)\n\n## 4. Memory-System Telemetry\n- Average `layer*.titan.*.grad_norm` and projector norms\n- CMS chunk stats (`chunk_samples`, updates per 1k tokens)\n- Surprise/memorization triggers (counts, thresholds)\n\n## 5. Evaluation\n- Zero-shot table (baseline vs memorize accuracy)\n- NIAH accuracies by context length + memorize deltas\n- Continual CE per segment (baseline vs memorize)\n- Additional diagnostics (LongBench, PG-19, etc.)\n\n## 6. Reproduction Commands\n```\n# train\nuv run python train.py --config-name ...\n# eval\nuv run python scripts/eval/zeroshot.py ...\n```\n\n## 7. Risks / Notes\n- Known deviations from the paper\n- TODOs before scaling this checkpoint (e.g., data quirks, missing ablations)\n\n---\n\nUse this structure to keep every release auditable and to make comparisons across HOPE/TITAN checkpoints straightforward.\n"
  },
  {
    "path": "docs/zeroshot_eval.md",
    "content": "# Zero-shot Evaluation Guide\n\nThe script `scripts/eval/zeroshot.py` evaluates HOPE checkpoints on \ncommon reasoning benchmarks. Tasks currently supported:\n\n- `piqa`\n- `hellaswag`\n- `winogrande`\n- `arc_easy`\n- `arc_challenge`\n- `boolq`\n- `siqa`\n- `commonsenseqa`\n- `openbookqa`\n- Synthetic LongBench-style passkey (`scripts/eval/passkey.py`)\n- PG-19 perplexity (`scripts/eval/pg19_perplexity.py`)\n\n## Usage\n\n```bash\nuv run python scripts/eval/zeroshot.py \\\n  --config configs/hope/mid.yaml \\\n  --checkpoint checkpoints/mid/checkpoint_best.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --tasks all \\\n  --max-samples 500 \\\n  --output eval/zeroshot_mid.json\n```\n\nSet `--tasks` to a comma-separated list (e.g., `piqa,hellaswag`) or `all`.\nUse `--list-tasks` to print available options.\n\nEach task logs accuracy and sample count into the JSON file. Adjust\n`--max-samples` (0 = evaluate entire validation set) based on runtime.\n\n`scripts/eval/zeroshot.py` now exposes `--eval-state-mode`, but only\n`reset_per_sample` is supported for multi-choice scoring in this implementation.\n\nFor reproducibility, record the checkpoint SHA, tokenizer version, \nand command invocation alongside the JSON results.\n\n## Long-context diagnostics\n\n### Needle-in-a-Haystack (NIAH)\n\nProbe retrieval across increasing prompt lengths:\n\n```bash\nuv run python scripts/eval/niah.py \\\n  --config configs/hope/pilot.yaml \\\n  --checkpoint artifacts/checkpoints/pilot/step_230000.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --context-lengths 2048 --context-lengths 4096 --context-lengths 8192 \\\n  --samples-per-length 50 \\\n  --memorize --memorize-steps 2 \\\n  --memorize-paths titan,cms_fast \\\n  --memorize-surprise-threshold 0.02 \\\n  --output eval/niah_pilot.json\n```\n\n### RULER-ish NIAH Suite\n\nRun multiple synthetic retrieval variants (single-needle, multi-needle, KV retrieval, and positioned needles):\n\n```bash\nuv run python scripts/eval/niah_suite.py \\\n  --config configs/hope/pilot.yaml \\\n  --checkpoint artifacts/checkpoints/pilot/step_230000.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --context-tokens 2048 --context-tokens 4096 \\\n  --samples-per-length 25 \\\n  --memorize --memorize-steps 2 \\\n  --memorize-paths titan,cms_fast \\\n  --memorize-surprise-threshold 0.02 \\\n  --output eval/niah_suite_pilot.json\n```\n\nPlot the suite as accuracy vs context length:\n\n```bash\nuv run python scripts/eval/plot_niah_suite.py \\\n  --niah-suite-json eval/niah_suite_pilot.json \\\n  --output reports/plots/niah_suite_pilot.png\n```\n\n### Passkey Retrieval\nGenerate synthetic passkey prompts to stress memorization at test time:\n\n```bash\nuv run python scripts/eval/passkey.py \\\n  --config configs/hope/pilot.yaml \\\n  --checkpoint artifacts/checkpoints/pilot/step_230000.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --samples 64 --filler-sentences 256 \\\n  --memorize --memorize-steps 2 \\\n  --memorize-paths titan,cms_fast \\\n  --memorize-surprise-threshold 0.02 \\\n  --output eval/passkey_pilot.json\n```\n\nThe JSON reports baseline vs. memorize accuracy, Titan/CMS update stats, the active memory paths, and the surprise threshold. Use `--memorize-paths` to restrict updates to `titan`, `cms_fast`, or any comma-separated combination, and `--memorize-surprise-threshold` to match the paper’s surprise-gated updates.\n\nFor passkey/NIAH scripts, `--eval-state-mode=reset_per_sample` is currently required to avoid branch contamination between answer candidates.\n\nNote: memorization runs against a per-example fast state by default (so checkpoint weights are not mutated). If you call `memorize_sequence()` programmatically with `MemorizeConfig.use_fast_state=true`, pass a `fast_state = model.init_fast_state()` object.\n\n### PG-19 Perplexity\n\n```bash\nuv run python scripts/eval/pg19_perplexity.py \\\n  --config configs/hope/pilot.yaml \\\n  --checkpoint artifacts/checkpoints/pilot/step_230000.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --max-samples 64 \\\n  --output eval/pg19_pilot.json\n```\n\nThis computes long-form perplexity (baseline vs. memorize) and records total tokens processed.\n"
  },
  {
    "path": "eval/continual_dummy.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/examples/pilot_dummy.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 10.550894003791585,\n      \"wikipedia_sample\": 10.583863976883562,\n      \"c4_sample\": 10.598039421783268,\n      \"redpajama_sample\": 10.557760518590998\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_mid_stage2.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2/step_000100.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000100\": {\n      \"refinedweb_segment\": {\n        \"loss\": 7.937718627690803,\n        \"ppl\": 2800.9631739004617\n      },\n      \"wikipedia_segment\": {\n        \"loss\": 8.102176148177593,\n        \"ppl\": 3301.645132319484\n      },\n      \"redpajama_segment\": {\n        \"loss\": 7.927422247660837,\n        \"ppl\": 2772.271357018084\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": 0.0,\n    \"wikipedia_segment\": 0.0,\n    \"redpajama_segment\": 0.0\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_smoke.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_smoke/step_000060.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000060\": {\n      \"refinedweb_segment\": {\n        \"loss\": 8.373536780669031,\n        \"ppl\": 4330.926544499951\n      },\n      \"wikipedia_segment\": {\n        \"loss\": 8.520759884112035,\n        \"ppl\": 5017.86530456549\n      },\n      \"redpajama_segment\": {\n        \"loss\": 8.357480258531066,\n        \"ppl\": 4261.942232700102\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": 0.0,\n    \"wikipedia_segment\": 0.0,\n    \"redpajama_segment\": 0.0\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_ts10.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_ts10/step_000080.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000080\": {\n      \"refinedweb_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"wikipedia_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"redpajama_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": NaN,\n    \"wikipedia_segment\": NaN,\n    \"redpajama_segment\": NaN\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_ts10_single120_clip.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_ts10_single120_clip/step_000120.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000120\": {\n      \"refinedweb_segment\": {\n        \"loss\": 10.065063906555773,\n        \"ppl\": 23507.243967214054\n      },\n      \"wikipedia_segment\": {\n        \"loss\": 10.100256753761007,\n        \"ppl\": 24349.26038607696\n      },\n      \"redpajama_segment\": {\n        \"loss\": 10.057008508133562,\n        \"ppl\": 23318.644393446848\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": 0.0,\n    \"wikipedia_segment\": 0.0,\n    \"redpajama_segment\": 0.0\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_ts10_single140_schedC.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_ts10_single140_schedC/step_000140.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000140\": {\n      \"refinedweb_segment\": {\n        \"loss\": 10.11361301369863,\n        \"ppl\": 24676.656967001243\n      },\n      \"wikipedia_segment\": {\n        \"loss\": 10.087023663007583,\n        \"ppl\": 24029.16699711637\n      },\n      \"redpajama_segment\": {\n        \"loss\": 10.103308081580234,\n        \"ppl\": 24423.671430180777\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": 0.0,\n    \"wikipedia_segment\": 0.0,\n    \"redpajama_segment\": 0.0\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_ts10_single220_schedD.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_ts10_single220_schedD/step_000220.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000220\": {\n      \"refinedweb_segment\": {\n        \"loss\": 10.070207008317025,\n        \"ppl\": 23628.45554963\n      },\n      \"wikipedia_segment\": {\n        \"loss\": 10.095827823049168,\n        \"ppl\": 24241.657657346874\n      },\n      \"redpajama_segment\": {\n        \"loss\": 10.052455834913161,\n        \"ppl\": 23212.72352009285\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": 0.0,\n    \"wikipedia_segment\": 0.0,\n    \"redpajama_segment\": 0.0\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_ts10_single80.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_ts10_single80/step_000080.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000080\": {\n      \"refinedweb_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"wikipedia_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"redpajama_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": NaN,\n    \"wikipedia_segment\": NaN,\n    \"redpajama_segment\": NaN\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_ts10_single80lr2e5.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_ts10_single80lr2e5/step_000080.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000080\": {\n      \"refinedweb_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"wikipedia_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"redpajama_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": NaN,\n    \"wikipedia_segment\": NaN,\n    \"redpajama_segment\": NaN\n  }\n}"
  },
  {
    "path": "eval/continual_mid_stage2_ts20.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_stage2_ts20/step_000080.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000080\": {\n      \"refinedweb_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"wikipedia_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      },\n      \"redpajama_segment\": {\n        \"loss\": NaN,\n        \"ppl\": NaN\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": NaN,\n    \"wikipedia_segment\": NaN,\n    \"redpajama_segment\": NaN\n  }\n}"
  },
  {
    "path": "eval/continual_mid_titan_baseline.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/mid_titan_baseline/step_000200.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000200\": {\n      \"refinedweb_segment\": {\n        \"loss\": 8.688015418603229,\n        \"ppl\": 5931.399209854872\n      },\n      \"wikipedia_segment\": {\n        \"loss\": 8.77053791202911,\n        \"ppl\": 6441.636566384941\n      },\n      \"redpajama_segment\": {\n        \"loss\": 8.648725595492905,\n        \"ppl\": 5702.874331994893\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": 0.0,\n    \"wikipedia_segment\": 0.0,\n    \"redpajama_segment\": 0.0\n  }\n}"
  },
  {
    "path": "eval/continual_pilot.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_relaunch/step_477000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 20.093504903834393,\n      \"wikipedia_sample\": 19.44820931690313,\n      \"c4_sample\": 19.59033499342588,\n      \"redpajama_sample\": 19.960534548526173\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 20.093504903834393,\n      \"wikipedia_sample\": 19.44820931690313,\n      \"c4_sample\": 19.59033499342588,\n      \"redpajama_sample\": 19.960534548526173\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": 0.0,\n      \"wikipedia_sample\": 0.0,\n      \"c4_sample\": 0.0,\n      \"redpajama_sample\": 0.0\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      }\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_cms_nochunk_step5000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_cms_nochunk/step_005000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 46.577801645058706,\n      \"wikipedia_sample\": 47.75575092114114,\n      \"c4_sample\": 49.66668201901908,\n      \"redpajama_sample\": 52.05659361240215\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_cms_sparse_step5000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_cms_sparse/step_005000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 25.5893353068432,\n      \"wikipedia_sample\": 25.045852667104942,\n      \"c4_sample\": 25.094903910836596,\n      \"redpajama_sample\": 25.40295529598826\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_multi.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot/step_005000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 44.24760044642857,\n      \"wikipedia_sample\": 45.375970064823875,\n      \"c4_sample\": 44.95872179244129,\n      \"redpajama_sample\": 44.70617126345401\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 44.23934954439824,\n      \"wikipedia_sample\": 45.36217817392368,\n      \"c4_sample\": 44.94942056017612,\n      \"redpajama_sample\": 44.70034246575342\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": -0.008250902030333407,\n      \"wikipedia_sample\": -0.013791890900193948,\n      \"c4_sample\": -0.009301232265165993,\n      \"redpajama_sample\": -0.005828797700587529\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 1.613788730129727,\n        \"cms_fast_updates\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 1.8479480545959177,\n        \"cms_fast_updates\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 1.4406822725594566,\n        \"cms_fast_updates\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 1.3893899804404555,\n        \"cms_fast_updates\": 0.0\n      }\n    }\n  },\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot/step_010000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 8.350521820725293,\n      \"wikipedia_sample\": 8.060312242004036,\n      \"c4_sample\": 8.085568613548801,\n      \"redpajama_sample\": 8.269488489557853\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 8.350520960738747,\n      \"wikipedia_sample\": 8.060311095355308,\n      \"c4_sample\": 8.085567896893346,\n      \"redpajama_sample\": 8.269487151801004\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": -8.599865459757439e-07,\n      \"wikipedia_sample\": -1.1466487279676585e-06,\n      \"c4_sample\": -7.166554549797866e-07,\n      \"redpajama_sample\": -1.3377568492956016e-06\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 0.3657227133185188,\n        \"cms_fast_updates\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 0.37036076905133086,\n        \"cms_fast_updates\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 0.3642253975777976,\n        \"cms_fast_updates\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 0.37168583213599504,\n        \"cms_fast_updates\": 0.0\n      }\n    }\n  },\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot/step_230000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 8.13594605208843,\n      \"wikipedia_sample\": 7.762453704057608,\n      \"c4_sample\": 7.744226671431629,\n      \"redpajama_sample\": 8.051934157518957\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 8.13594590875734,\n      \"wikipedia_sample\": 7.762453608503547,\n      \"c4_sample\": 7.744226528100538,\n      \"redpajama_sample\": 8.051934348627078\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": -1.433310909959573e-07,\n      \"wikipedia_sample\": -9.555406066397154e-08,\n      \"c4_sample\": -1.433310909959573e-07,\n      \"redpajama_sample\": 1.9110812132794308e-07\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 0.14358223369094958,\n        \"cms_fast_updates\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 0.15011811023973465,\n        \"cms_fast_updates\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 0.14206775680657913,\n        \"cms_fast_updates\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 0.14818353760242076,\n        \"cms_fast_updates\": 0.0\n      }\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_opt_adamw_step5000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot-opt-adamw-20251115173858/step_005000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 50.124265953705965,\n      \"wikipedia_sample\": 43.31982135212818,\n      \"c4_sample\": 39.34402806231654,\n      \"redpajama_sample\": 38.658693508439335\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_opt_muon_step5000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot-opt-muon-20251115180139/step_005000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 11.34836554779232,\n      \"wikipedia_sample\": 11.255398661096502,\n      \"c4_sample\": 11.203000970829256,\n      \"redpajama_sample\": 10.75289801132889\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_selfmod_off_step5000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_selfmod_off/step_005000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 45.68723512414383,\n      \"wikipedia_sample\": 44.921795116805285,\n      \"c4_sample\": 44.4194165851272,\n      \"redpajama_sample\": 45.51563474651419\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_step22000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/pilot_release/checkpoint.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 47.07892000978474,\n      \"wikipedia_sample\": 45.77716793052838,\n      \"c4_sample\": 46.112149889921724,\n      \"redpajama_sample\": 46.673732570939336\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_step230000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot/step_230000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 8.065531237577972,\n      \"wikipedia_sample\": 7.796267658390411,\n      \"c4_sample\": 7.68133831584057,\n      \"redpajama_sample\": 7.9526173451642\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 8.065531357020548,\n      \"wikipedia_sample\": 7.796267897275563,\n      \"c4_sample\": 7.68133826806354,\n      \"redpajama_sample\": 7.9526173451642\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": 1.1944257671814285e-07,\n      \"wikipedia_sample\": 2.3888515165992885e-07,\n      \"c4_sample\": -4.777703033198577e-08,\n      \"redpajama_sample\": 0.0\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 0.29348359665116774,\n        \"cms_fast_updates\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 0.29477991902535194,\n        \"cms_fast_updates\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 0.28394374698655866,\n        \"cms_fast_updates\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 0.2949711520357745,\n        \"cms_fast_updates\": 0.0\n      }\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_teach05_long_step25000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_teach05_long/step_025000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 52.060874816536206,\n      \"wikipedia_sample\": 49.42592037671233,\n      \"c4_sample\": 48.898553311521525,\n      \"redpajama_sample\": 50.883419459393345\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_teach05_step2000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_teach05/step_002000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 37.402111362524465,\n      \"wikipedia_sample\": 33.21865062377691,\n      \"c4_sample\": 35.86063688692515,\n      \"redpajama_sample\": 32.89820301033513\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_teach15_long_step25000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_teach15_long/step_025000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 7.913888830494741,\n      \"wikipedia_sample\": 7.628488153207559,\n      \"c4_sample\": 7.560239023705051,\n      \"redpajama_sample\": 7.7894536744358485\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_pilot_teach15_step2000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/pilot_teach15/step_002000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 69.37442438233855,\n      \"wikipedia_sample\": 66.62145991316046,\n      \"c4_sample\": 66.5390625,\n      \"redpajama_sample\": 68.55203797700587\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_smoke.json",
    "content": "{\n  \"checkpoints\": [\n    \"artifacts/checkpoints/pilot_smoke/step_000010.pt\"\n  ],\n  \"segments\": [\n    {\n      \"name\": \"refinedweb_segment\",\n      \"shards_dir\": \"data/shards/refinedweb_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"wikipedia_segment\",\n      \"shards_dir\": \"data/shards/wikipedia_filtered\",\n      \"max_batches\": 10\n    },\n    {\n      \"name\": \"redpajama_segment\",\n      \"shards_dir\": \"data/shards/redpajama_filtered\",\n      \"max_batches\": 10\n    }\n  ],\n  \"metrics\": {\n    \"step_000010\": {\n      \"refinedweb_segment\": {\n        \"loss\": 10.532247871055528,\n        \"ppl\": 37505.687647419305\n      },\n      \"wikipedia_segment\": {\n        \"loss\": 10.51791113090142,\n        \"ppl\": 36971.81449407471\n      },\n      \"redpajama_segment\": {\n        \"loss\": 10.54121800850049,\n        \"ppl\": 37843.63225938314\n      }\n    }\n  },\n  \"forgetting\": {\n    \"refinedweb_segment\": 0.0,\n    \"wikipedia_segment\": 0.0,\n    \"redpajama_segment\": 0.0\n  }\n}"
  },
  {
    "path": "eval/continual_titan.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/mid_titan_long/step_032000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 8.536108780653743,\n      \"wikipedia_sample\": 8.046991647619556,\n      \"c4_sample\": 7.78577950136069,\n      \"redpajama_sample\": 8.107161848512414\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 8.536108780653743,\n      \"wikipedia_sample\": 8.046991647619556,\n      \"c4_sample\": 7.78577950136069,\n      \"redpajama_sample\": 8.107161848512414\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": 0.0,\n      \"wikipedia_sample\": 0.0,\n      \"c4_sample\": 0.0,\n      \"redpajama_sample\": 0.0\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      }\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_titan_relaunch_step001000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/mid_titan_long/step_001000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 42.24322330601761,\n      \"wikipedia_sample\": 41.00887353228963,\n      \"c4_sample\": 39.92439552501223,\n      \"redpajama_sample\": 41.27286761558219\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 42.24322330601761,\n      \"wikipedia_sample\": 41.00887353228963,\n      \"c4_sample\": 39.92439552501223,\n      \"redpajama_sample\": 41.27286761558219\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": 0.0,\n      \"wikipedia_sample\": 0.0,\n      \"c4_sample\": 0.0,\n      \"redpajama_sample\": 0.0\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      }\n    }\n  }\n]"
  },
  {
    "path": "eval/continual_titan_step25000.json",
    "content": "[\n  {\n    \"checkpoint\": \"artifacts/checkpoints/mid_titan_baseline/step_025000.pt\",\n    \"segment_losses\": {\n      \"refinedweb_2018\": 8.374168058570817,\n      \"wikipedia_sample\": 8.117256561697346,\n      \"c4_sample\": 7.871031807080174,\n      \"redpajama_sample\": 8.099312727418665\n    },\n    \"segment_baseline_losses\": {\n      \"refinedweb_2018\": 8.363879824524522,\n      \"wikipedia_sample\": 8.11803621116683,\n      \"c4_sample\": 7.859244497041646,\n      \"redpajama_sample\": 8.098724401067148\n    },\n    \"segment_memorize_delta\": {\n      \"refinedweb_2018\": -0.01028823404629442,\n      \"wikipedia_sample\": 0.0007796494694840561,\n      \"c4_sample\": -0.011787310038527288,\n      \"redpajama_sample\": -0.0005883263515169546\n    },\n    \"memorize_stats\": {\n      \"refinedweb_2018\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      },\n      \"wikipedia_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      },\n      \"c4_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      },\n      \"redpajama_sample\": {\n        \"titan_mem_updates\": 0.0,\n        \"cms_fast_updates\": 0.0\n      }\n    }\n  }\n]"
  },
  {
    "path": "eval/niah_dummy.json",
    "content": "{\n  \"niah_2048\": 0.0,\n  \"niah_4096\": 0.0\n}"
  },
  {
    "path": "eval/niah_mid_stage2.json",
    "content": "{\n  \"niah_2048\": 0.3333333333333333,\n  \"niah_4096\": 0.3333333333333333\n}"
  },
  {
    "path": "eval/niah_mid_stage2_smoke.json",
    "content": "{\n  \"niah_2048\": 0.6,\n  \"niah_4096\": 0.6,\n  \"niah_8192\": 0.6\n}"
  },
  {
    "path": "eval/niah_mid_stage2_ts10.json",
    "content": "{\n  \"niah_2048\": 0.0\n}"
  },
  {
    "path": "eval/niah_mid_stage2_ts10_single120_clip.json",
    "content": "{\n  \"niah_2048\": 0.0\n}"
  },
  {
    "path": "eval/niah_mid_stage2_ts10_single140_schedC.json",
    "content": "{\n  \"niah_2048\": 0.4\n}"
  },
  {
    "path": "eval/niah_mid_stage2_ts10_single220_schedD.json",
    "content": "{\n  \"niah_2048\": 0.4,\n  \"niah_4096\": 0.8\n}"
  },
  {
    "path": "eval/niah_mid_stage2_ts10_single80.json",
    "content": "{\n  \"niah_2048\": 0.0\n}"
  },
  {
    "path": "eval/niah_mid_stage2_ts10_single80lr2e5.json",
    "content": "{\n  \"niah_2048\": 0.0\n}"
  },
  {
    "path": "eval/niah_mid_stage2_ts20.json",
    "content": "{\n  \"niah_2048\": 0.0\n}"
  },
  {
    "path": "eval/niah_mid_titan_baseline.json",
    "content": "{\n  \"niah_2048\": 0.6,\n  \"niah_4096\": 0.0\n}"
  },
  {
    "path": "eval/niah_pilot.json",
    "content": "{\n  \"niah_2048\": 0.625,\n  \"niah_2048_baseline_accuracy\": 0.625,\n  \"niah_2048_memorize_accuracy\": 0.625,\n  \"niah_2048_memorize_delta\": 0.0,\n  \"niah_4096\": 0.5,\n  \"niah_4096_baseline_accuracy\": 0.5,\n  \"niah_4096_memorize_accuracy\": 0.5,\n  \"niah_4096_memorize_delta\": 0.0,\n  \"niah_8192\": 0.625,\n  \"niah_8192_baseline_accuracy\": 0.625,\n  \"niah_8192_memorize_accuracy\": 0.625,\n  \"niah_8192_memorize_delta\": 0.0,\n  \"niah_16384\": 0.375,\n  \"niah_16384_baseline_accuracy\": 0.375,\n  \"niah_16384_memorize_accuracy\": 0.375,\n  \"niah_16384_memorize_delta\": 0.0,\n  \"niah_32768\": 0.375,\n  \"niah_32768_baseline_accuracy\": 0.375,\n  \"niah_32768_memorize_accuracy\": 0.375,\n  \"niah_32768_memorize_delta\": 0.0,\n  \"niah_65536\": 0.5,\n  \"niah_65536_baseline_accuracy\": 0.5,\n  \"niah_65536_memorize_accuracy\": 0.5,\n  \"niah_65536_memorize_delta\": 0.0,\n  \"niah_titan_mem_updates\": 0.0,\n  \"niah_titan_update_events\": 0.0,\n  \"niah_cms_fast_updates\": 0.0,\n  \"niah_cms_fast_update_events\": 0.0,\n  \"niah_cms_mid_updates\": 0.0,\n  \"niah_cms_mid_update_events\": 0.0,\n  \"niah_cms_slow_updates\": 0.0,\n  \"niah_cms_slow_update_events\": 0.0,\n  \"niah_cms_ultra_updates\": 0.0,\n  \"niah_cms_ultra_update_events\": 0.0,\n  \"niah_memorize_paths\": \"titan,cms_fast\",\n  \"niah_memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/niah_pilot_cms_nochunk_step5000.json",
    "content": "{\n  \"niah_2048\": 0.75,\n  \"niah_4096\": 0.25,\n  \"niah_8192\": 0.25,\n  \"niah_16384\": 0.25,\n  \"niah_32768\": 0.75,\n  \"niah_65536\": 0.5\n}"
  },
  {
    "path": "eval/niah_pilot_cms_sparse_step5000.json",
    "content": "{\n  \"niah_2048\": 0.75,\n  \"niah_4096\": 0.5,\n  \"niah_8192\": 0.625,\n  \"niah_16384\": 0.625,\n  \"niah_32768\": 0.5,\n  \"niah_65536\": 0.375\n}"
  },
  {
    "path": "eval/niah_pilot_opt_adamw_step5000.json",
    "content": "{\n  \"niah_2048\": 0.75,\n  \"niah_4096\": 1.0,\n  \"niah_8192\": 0.5,\n  \"niah_16384\": 0.75,\n  \"niah_32768\": 0.5,\n  \"niah_65536\": 0.25\n}"
  },
  {
    "path": "eval/niah_pilot_opt_muon_step5000.json",
    "content": "{\n  \"niah_2048\": 0.5,\n  \"niah_4096\": 0.5,\n  \"niah_8192\": 0.25,\n  \"niah_16384\": 0.75,\n  \"niah_32768\": 0.75,\n  \"niah_65536\": 0.75\n}"
  },
  {
    "path": "eval/niah_pilot_selfmod_off_step5000.json",
    "content": "{\n  \"niah_2048\": 0.75,\n  \"niah_4096\": 0.75,\n  \"niah_8192\": 0.5,\n  \"niah_16384\": 0.75,\n  \"niah_32768\": 0.25,\n  \"niah_65536\": 0.75\n}"
  },
  {
    "path": "eval/niah_pilot_step22000.json",
    "content": "{\n  \"niah_2048\": 1.0,\n  \"niah_4096\": 0.0,\n  \"niah_8192\": 0.0\n}"
  },
  {
    "path": "eval/niah_pilot_step230000.json",
    "content": "{\n  \"niah_2048\": 0.25,\n  \"niah_2048_baseline_accuracy\": 0.25,\n  \"niah_2048_memorize_accuracy\": 0.25,\n  \"niah_2048_memorize_delta\": 0.0,\n  \"niah_4096\": 0.625,\n  \"niah_4096_baseline_accuracy\": 0.625,\n  \"niah_4096_memorize_accuracy\": 0.625,\n  \"niah_4096_memorize_delta\": 0.0,\n  \"niah_8192\": 0.5,\n  \"niah_8192_baseline_accuracy\": 0.5,\n  \"niah_8192_memorize_accuracy\": 0.5,\n  \"niah_8192_memorize_delta\": 0.0,\n  \"niah_16384\": 0.625,\n  \"niah_16384_baseline_accuracy\": 0.625,\n  \"niah_16384_memorize_accuracy\": 0.625,\n  \"niah_16384_memorize_delta\": 0.0,\n  \"niah_32768\": 0.625,\n  \"niah_32768_baseline_accuracy\": 0.625,\n  \"niah_32768_memorize_accuracy\": 0.625,\n  \"niah_32768_memorize_delta\": 0.0,\n  \"niah_65536\": 0.375,\n  \"niah_65536_baseline_accuracy\": 0.375,\n  \"niah_65536_memorize_accuracy\": 0.375,\n  \"niah_65536_memorize_delta\": 0.0,\n  \"niah_titan_mem_updates\": 1.5477444703160508,\n  \"niah_cms_fast_updates\": 0.0\n}"
  },
  {
    "path": "eval/niah_pilot_teach05_long_step25000.json",
    "content": "{\n  \"niah_2048\": 0.25,\n  \"niah_4096\": 0.5,\n  \"niah_8192\": 0.375,\n  \"niah_16384\": 0.75,\n  \"niah_32768\": 0.75,\n  \"niah_65536\": 0.75\n}"
  },
  {
    "path": "eval/niah_pilot_teach05_step2000.json",
    "content": "{\n  \"niah_2048\": 0.5,\n  \"niah_4096\": 0.75,\n  \"niah_8192\": 1.0,\n  \"niah_16384\": 0.75,\n  \"niah_32768\": 0.25,\n  \"niah_65536\": 1.0\n}"
  },
  {
    "path": "eval/niah_pilot_teach15_long_step25000.json",
    "content": "{\n  \"niah_2048\": 0.75,\n  \"niah_4096\": 0.625,\n  \"niah_8192\": 0.375,\n  \"niah_16384\": 0.75,\n  \"niah_32768\": 0.5,\n  \"niah_65536\": 0.75\n}"
  },
  {
    "path": "eval/niah_pilot_teach15_step2000.json",
    "content": "{\n  \"niah_2048\": 0.75,\n  \"niah_4096\": 0.75,\n  \"niah_8192\": 0.75,\n  \"niah_16384\": 0.5,\n  \"niah_32768\": 0.25,\n  \"niah_65536\": 0.5\n}"
  },
  {
    "path": "eval/niah_smoke.json",
    "content": "{\n  \"niah_2048\": 0.6666666666666666,\n  \"niah_4096\": 0.0,\n  \"niah_8192\": 0.6666666666666666\n}"
  },
  {
    "path": "eval/niah_titan.json",
    "content": "{\n  \"niah_2048\": 0.375,\n  \"niah_2048_baseline_accuracy\": 0.375,\n  \"niah_2048_memorize_accuracy\": 0.375,\n  \"niah_2048_memorize_delta\": 0.0,\n  \"niah_4096\": 0.375,\n  \"niah_4096_baseline_accuracy\": 0.375,\n  \"niah_4096_memorize_accuracy\": 0.375,\n  \"niah_4096_memorize_delta\": 0.0,\n  \"niah_8192\": 0.5,\n  \"niah_8192_baseline_accuracy\": 0.5,\n  \"niah_8192_memorize_accuracy\": 0.5,\n  \"niah_8192_memorize_delta\": 0.0,\n  \"niah_16384\": 0.75,\n  \"niah_16384_baseline_accuracy\": 0.75,\n  \"niah_16384_memorize_accuracy\": 0.75,\n  \"niah_16384_memorize_delta\": 0.0,\n  \"niah_32768\": 0.625,\n  \"niah_32768_baseline_accuracy\": 0.625,\n  \"niah_32768_memorize_accuracy\": 0.625,\n  \"niah_32768_memorize_delta\": 0.0,\n  \"niah_65536\": 0.375,\n  \"niah_65536_baseline_accuracy\": 0.375,\n  \"niah_65536_memorize_accuracy\": 0.375,\n  \"niah_65536_memorize_delta\": 0.0,\n  \"niah_titan_mem_updates\": 0.0,\n  \"niah_titan_update_events\": 0.0,\n  \"niah_cms_fast_updates\": 0.0,\n  \"niah_cms_fast_update_events\": 0.0,\n  \"niah_cms_mid_updates\": 0.0,\n  \"niah_cms_mid_update_events\": 0.0,\n  \"niah_cms_slow_updates\": 0.0,\n  \"niah_cms_slow_update_events\": 0.0,\n  \"niah_cms_ultra_updates\": 0.0,\n  \"niah_cms_ultra_update_events\": 0.0,\n  \"niah_memorize_paths\": \"titan\",\n  \"niah_memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/niah_titan_relaunch_step001000.json",
    "content": "{\n  \"niah_2048\": 0.5,\n  \"niah_2048_baseline_accuracy\": 0.5,\n  \"niah_2048_memorize_accuracy\": 0.5,\n  \"niah_2048_memorize_delta\": 0.0,\n  \"niah_4096\": 0.75,\n  \"niah_4096_baseline_accuracy\": 0.75,\n  \"niah_4096_memorize_accuracy\": 0.75,\n  \"niah_4096_memorize_delta\": 0.0,\n  \"niah_8192\": 0.625,\n  \"niah_8192_baseline_accuracy\": 0.625,\n  \"niah_8192_memorize_accuracy\": 0.625,\n  \"niah_8192_memorize_delta\": 0.0,\n  \"niah_16384\": 0.625,\n  \"niah_16384_baseline_accuracy\": 0.625,\n  \"niah_16384_memorize_accuracy\": 0.625,\n  \"niah_16384_memorize_delta\": 0.0,\n  \"niah_32768\": 0.375,\n  \"niah_32768_baseline_accuracy\": 0.375,\n  \"niah_32768_memorize_accuracy\": 0.375,\n  \"niah_32768_memorize_delta\": 0.0,\n  \"niah_65536\": 0.375,\n  \"niah_65536_baseline_accuracy\": 0.375,\n  \"niah_65536_memorize_accuracy\": 0.375,\n  \"niah_65536_memorize_delta\": 0.0,\n  \"niah_titan_mem_updates\": 0.0,\n  \"niah_cms_fast_updates\": 0.0,\n  \"niah_memorize_paths\": \"titan\",\n  \"niah_memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/niah_titan_step25000.json",
    "content": "{\n  \"niah_2048\": 0.125,\n  \"niah_2048_baseline_accuracy\": 0.125,\n  \"niah_2048_memorize_accuracy\": 0.125,\n  \"niah_2048_memorize_delta\": 0.0,\n  \"niah_4096\": 0.5,\n  \"niah_4096_baseline_accuracy\": 0.5,\n  \"niah_4096_memorize_accuracy\": 0.5,\n  \"niah_4096_memorize_delta\": 0.0,\n  \"niah_8192\": 0.25,\n  \"niah_8192_baseline_accuracy\": 0.25,\n  \"niah_8192_memorize_accuracy\": 0.25,\n  \"niah_8192_memorize_delta\": 0.0,\n  \"niah_16384\": 0.5,\n  \"niah_16384_baseline_accuracy\": 0.5,\n  \"niah_16384_memorize_accuracy\": 0.5,\n  \"niah_16384_memorize_delta\": 0.0,\n  \"niah_32768\": 0.5,\n  \"niah_32768_baseline_accuracy\": 0.5,\n  \"niah_32768_memorize_accuracy\": 0.5,\n  \"niah_32768_memorize_delta\": 0.0,\n  \"niah_65536\": 0.625,\n  \"niah_65536_baseline_accuracy\": 0.625,\n  \"niah_65536_memorize_accuracy\": 0.625,\n  \"niah_65536_memorize_delta\": 0.0,\n  \"niah_titan_mem_updates\": 0.0,\n  \"niah_cms_fast_updates\": 0.0\n}"
  },
  {
    "path": "eval/passkey_pilot.json",
    "content": "{\n  \"samples\": 64,\n  \"filler_sentences\": 256,\n  \"accuracy_base\": 0.4375,\n  \"accuracy_memorize\": 0.4375,\n  \"accuracy_delta\": 0.0,\n  \"path_stats\": {\n    \"titan_mem_updates\": 0.0,\n    \"titan_update_events\": 0.0,\n    \"cms_fast_updates\": 0.0,\n    \"cms_fast_update_events\": 0.0,\n    \"cms_mid_updates\": 0.0,\n    \"cms_mid_update_events\": 0.0,\n    \"cms_slow_updates\": 0.0,\n    \"cms_slow_update_events\": 0.0,\n    \"cms_ultra_updates\": 0.0,\n    \"cms_ultra_update_events\": 0.0\n  },\n  \"memorize_paths\": \"titan,cms_fast\",\n  \"memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/passkey_pilot_step230000.json",
    "content": "{\n  \"samples\": 64,\n  \"filler_sentences\": 256,\n  \"accuracy_base\": 0.484375,\n  \"accuracy_memorize\": 0.484375,\n  \"accuracy_delta\": 0.0,\n  \"path_stats\": {\n    \"titan_mem_updates\": 2.1295860992622053,\n    \"cms_fast_updates\": 0.0\n  }\n}"
  },
  {
    "path": "eval/passkey_titan.json",
    "content": "{\n  \"samples\": 64,\n  \"filler_sentences\": 256,\n  \"accuracy_base\": 0.46875,\n  \"accuracy_memorize\": 0.46875,\n  \"accuracy_delta\": 0.0,\n  \"path_stats\": {\n    \"titan_mem_updates\": 0.0,\n    \"titan_update_events\": 0.0,\n    \"cms_fast_updates\": 0.0,\n    \"cms_fast_update_events\": 0.0,\n    \"cms_mid_updates\": 0.0,\n    \"cms_mid_update_events\": 0.0,\n    \"cms_slow_updates\": 0.0,\n    \"cms_slow_update_events\": 0.0,\n    \"cms_ultra_updates\": 0.0,\n    \"cms_ultra_update_events\": 0.0\n  },\n  \"memorize_paths\": \"titan\",\n  \"memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/passkey_titan_relaunch_step001000.json",
    "content": "{\n  \"samples\": 64,\n  \"filler_sentences\": 256,\n  \"accuracy_base\": 0.5,\n  \"accuracy_memorize\": 0.5,\n  \"accuracy_delta\": 0.0,\n  \"path_stats\": {\n    \"titan_mem_updates\": 0.0,\n    \"cms_fast_updates\": 0.0\n  },\n  \"memorize_paths\": \"titan\",\n  \"memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/passkey_titan_step25000.json",
    "content": "{\n  \"samples\": 64,\n  \"filler_sentences\": 256,\n  \"accuracy_base\": 0.53125,\n  \"accuracy_memorize\": 0.546875,\n  \"accuracy_delta\": 0.015625,\n  \"path_stats\": {\n    \"titan_mem_updates\": 0.0,\n    \"cms_fast_updates\": 0.0\n  }\n}"
  },
  {
    "path": "eval/pg19_pilot.json",
    "content": "{\n  \"samples\": 32,\n  \"tokens\": 65504,\n  \"ppl_base\": 285944896.0,\n  \"ppl_memorize\": 285944896.0,\n  \"ppl_delta\": 0.0,\n  \"memorize_paths\": \"titan,cms_fast\",\n  \"memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/pg19_pilot_step230000.json",
    "content": "{\n  \"samples\": 4,\n  \"tokens\": 8188,\n  \"ppl_base\": 2497.421875,\n  \"ppl_memorize\": 2497.421875,\n  \"ppl_delta\": 0.0\n}"
  },
  {
    "path": "eval/pg19_titan.json",
    "content": "{\n  \"samples\": 32,\n  \"tokens\": 65504,\n  \"ppl_base\": 2449.884765625,\n  \"ppl_memorize\": 2449.884765625,\n  \"ppl_delta\": 0.0,\n  \"memorize_paths\": \"titan\",\n  \"memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/pg19_titan_relaunch_step001000.json",
    "content": "{\n  \"samples\": 32,\n  \"tokens\": 65504,\n  \"ppl_base\": 2.931641597034496e+17,\n  \"ppl_memorize\": 2.931641597034496e+17,\n  \"ppl_delta\": 0.0,\n  \"memorize_paths\": \"titan\",\n  \"memorize_surprise_threshold\": 0.02\n}"
  },
  {
    "path": "eval/pg19_titan_step25000.json",
    "content": "{\n  \"samples\": 4,\n  \"tokens\": 8188,\n  \"ppl_base\": 3122.819580078125,\n  \"ppl_memorize\": 3233.149658203125,\n  \"ppl_delta\": -110.330078125\n}"
  },
  {
    "path": "eval/phase2_compare_smoke_lastlayer_metrics.json",
    "content": "{\n  \"seed\": 0,\n  \"device\": \"cuda:1\",\n  \"tokenizer_path\": \"artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model\",\n  \"a\": {\n    \"config\": \"configs/resolved/phase2_pilot_attention_eval.yaml\",\n    \"checkpoint\": \"artifacts/checkpoints/phase2_init/hope_attention_step000000.pt\",\n    \"passkey\": {\n      \"samples\": 8,\n      \"filler_sentences\": 20,\n      \"accuracy_base\": 0.5,\n      \"accuracy_memorize\": 0.5,\n      \"accuracy_delta\": 0.0,\n      \"mean_logprob_true_base\": -1294.908706665039,\n      \"mean_logprob_true_memorize\": -1294.8721160888672,\n      \"mean_logprob_true_delta\": 0.036590576171875,\n      \"mean_logprob_false_base\": -1297.3745727539062,\n      \"mean_logprob_false_memorize\": -1297.3665771484375,\n      \"mean_logprob_false_delta\": 0.00799560546875,\n      \"mean_margin_base\": 2.4658660888671875,\n      \"mean_margin_memorize\": 2.4944610595703125,\n      \"mean_margin_delta\": 0.028594970703125,\n      \"memorize_paths\": \"cms_fast\",\n      \"memorize_use_correct_answer\": true,\n      \"memorize_stats\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 36.15032583102584,\n        \"cms_fast_update_events\": 48.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      }\n    },\n    \"niah\": {\n      \"niah_256_baseline_accuracy\": 0.625,\n      \"niah_256_memorize_accuracy\": 0.625,\n      \"niah_256_memorize_delta\": 0.0,\n      \"niah_256_mean_logprob_true_base\": -991.9136581420898,\n      \"niah_256_mean_logprob_true_memorize\": -991.8013305664062,\n      \"niah_256_mean_logprob_true_delta\": 0.11232757568359375,\n      \"niah_256_mean_margin_base\": 62.82862854003906,\n      \"niah_256_mean_margin_memorize\": 62.92052459716797,\n      \"niah_256_mean_margin_delta\": 0.09189605712890625,\n      \"memorize_paths\": \"cms_fast\",\n      \"memorize_use_correct_answer\": true,\n      \"memorize_stats\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 97.94030156359076,\n        \"cms_fast_update_events\": 37.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      }\n    }\n  },\n  \"b\": {\n    \"config\": \"configs/resolved/phase2_pilot_transformer_eval.yaml\",\n    \"checkpoint\": \"artifacts/checkpoints/phase2_init/transformer_step000000.pt\",\n    \"passkey\": {\n      \"samples\": 8,\n      \"filler_sentences\": 20,\n      \"accuracy_base\": 0.625,\n      \"accuracy_memorize\": 0.625,\n      \"accuracy_delta\": 0.0,\n      \"mean_logprob_true_base\": -2081.151657104492,\n      \"mean_logprob_true_memorize\": -2081.151657104492,\n      \"mean_logprob_true_delta\": 0.0,\n      \"mean_logprob_false_base\": -2115.255126953125,\n      \"mean_logprob_false_memorize\": -2115.255126953125,\n      \"mean_logprob_false_delta\": 0.0,\n      \"mean_margin_base\": 34.10346984863281,\n      \"mean_margin_memorize\": 34.10346984863281,\n      \"mean_margin_delta\": 0.0,\n      \"memorize_paths\": \"cms_fast\",\n      \"memorize_use_correct_answer\": true,\n      \"memorize_stats\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      }\n    },\n    \"niah\": {\n      \"niah_256_baseline_accuracy\": 0.625,\n      \"niah_256_memorize_accuracy\": 0.625,\n      \"niah_256_memorize_delta\": 0.0,\n      \"niah_256_mean_logprob_true_base\": -1839.57421875,\n      \"niah_256_mean_logprob_true_memorize\": -1839.57421875,\n      \"niah_256_mean_logprob_true_delta\": 0.0,\n      \"niah_256_mean_margin_base\": 104.52041625976562,\n      \"niah_256_mean_margin_memorize\": 104.52041625976562,\n      \"niah_256_mean_margin_delta\": 0.0,\n      \"memorize_paths\": \"cms_fast\",\n      \"memorize_use_correct_answer\": true,\n      \"memorize_stats\": {\n        \"titan_mem_updates\": 0.0,\n        \"titan_update_events\": 0.0,\n        \"cms_fast_updates\": 0.0,\n        \"cms_fast_update_events\": 0.0,\n        \"cms_mid_updates\": 0.0,\n        \"cms_mid_update_events\": 0.0,\n        \"cms_slow_updates\": 0.0,\n        \"cms_slow_update_events\": 0.0,\n        \"cms_ultra_updates\": 0.0,\n        \"cms_ultra_update_events\": 0.0\n      }\n    }\n  },\n  \"memorize\": {\n    \"enabled\": true,\n    \"steps\": 1,\n    \"reset\": true,\n    \"use_correct_answer\": true,\n    \"paths\": \"cms_fast\",\n    \"surprise_threshold\": null\n  }\n}"
  },
  {
    "path": "eval/zeroshot_full_smoke.json",
    "content": "{\n  \"piqa_accuracy\": 0.625,\n  \"piqa_samples\": 32,\n  \"hellaswag_accuracy\": 0.0,\n  \"hellaswag_samples\": 32,\n  \"winogrande_accuracy\": 0.46875,\n  \"winogrande_samples\": 32,\n  \"arc_arc-easy_accuracy\": 0.21875,\n  \"arc_arc-easy_samples\": 32,\n  \"arc_arc-challenge_accuracy\": 0.28125,\n  \"arc_arc-challenge_samples\": 32,\n  \"boolq_accuracy\": 0.84375,\n  \"boolq_samples\": 32,\n  \"siqa_accuracy\": 0.25,\n  \"siqa_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2.json",
    "content": "{\n  \"piqa_accuracy\": 0.5,\n  \"piqa_samples\": 16,\n  \"hellaswag_accuracy\": 0.0,\n  \"hellaswag_samples\": 16,\n  \"winogrande_accuracy\": 0.625,\n  \"winogrande_samples\": 16\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_smoke.json",
    "content": "{\n  \"piqa_accuracy\": 0.484375,\n  \"piqa_samples\": 64,\n  \"hellaswag_accuracy\": 0.0,\n  \"hellaswag_samples\": 64,\n  \"winogrande_accuracy\": 0.484375,\n  \"winogrande_samples\": 64,\n  \"arc_arc-easy_accuracy\": 0.203125,\n  \"arc_arc-easy_samples\": 64,\n  \"arc_arc-challenge_accuracy\": 0.25,\n  \"arc_arc-challenge_samples\": 64,\n  \"boolq_accuracy\": 0.28125,\n  \"boolq_samples\": 64,\n  \"siqa_accuracy\": 0.3125,\n  \"siqa_samples\": 64\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_smoke_piqa_baseline.json",
    "content": "{\n  \"piqa_accuracy\": 0.5625,\n  \"piqa_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_smoke_piqa_mem.json",
    "content": "{\n  \"piqa_accuracy\": 0.5625,\n  \"piqa_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_ts10.json",
    "content": "{\n  \"piqa_accuracy\": 0.46875,\n  \"piqa_samples\": 32,\n  \"winogrande_accuracy\": 0.59375,\n  \"winogrande_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_ts10_single120_clip.json",
    "content": "{\n  \"piqa_accuracy\": 0.59375,\n  \"piqa_samples\": 32,\n  \"winogrande_accuracy\": 0.40625,\n  \"winogrande_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_ts10_single140_schedC.json",
    "content": "{\n  \"piqa_accuracy\": 0.546875,\n  \"piqa_samples\": 64,\n  \"winogrande_accuracy\": 0.484375,\n  \"winogrande_samples\": 64\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_ts10_single220_schedD.json",
    "content": "{\n  \"piqa_accuracy\": 0.5078125,\n  \"piqa_samples\": 128,\n  \"winogrande_accuracy\": 0.4921875,\n  \"winogrande_samples\": 128\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_ts10_single80.json",
    "content": "{\n  \"piqa_accuracy\": 0.46875,\n  \"piqa_samples\": 32,\n  \"winogrande_accuracy\": 0.59375,\n  \"winogrande_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_ts10_single80lr2e5.json",
    "content": "{\n  \"piqa_accuracy\": 0.46875,\n  \"piqa_samples\": 32,\n  \"winogrande_accuracy\": 0.59375,\n  \"winogrande_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_stage2_ts20.json",
    "content": "{\n  \"piqa_accuracy\": 0.46875,\n  \"piqa_samples\": 32,\n  \"winogrande_accuracy\": 0.59375,\n  \"winogrande_samples\": 32\n}"
  },
  {
    "path": "eval/zeroshot_mid_titan_baseline.json",
    "content": "{\n  \"piqa_accuracy\": 0.5078125,\n  \"piqa_samples\": 128,\n  \"winogrande_accuracy\": 0.4375,\n  \"winogrande_samples\": 128\n}"
  },
  {
    "path": "eval/zeroshot_pilot.json",
    "content": "{\n  \"piqa_accuracy\": 0.5390625,\n  \"piqa_samples\": 256,\n  \"piqa_baseline_accuracy\": 0.53125,\n  \"piqa_memorize_accuracy\": 0.5390625,\n  \"piqa_memorize_delta\": 0.0078125,\n  \"piqa_memorize_paths\": \"titan,cms_fast\",\n  \"piqa_memorize_surprise_threshold\": 0.02,\n  \"piqa_titan_mem_updates\": 1017722.1927137375,\n  \"piqa_titan_update_events\": 816.0,\n  \"piqa_cms_fast_updates\": 70510.01682426152,\n  \"piqa_cms_fast_update_events\": 15480.0,\n  \"piqa_cms_mid_updates\": 0.0,\n  \"piqa_cms_mid_update_events\": 0.0,\n  \"piqa_cms_slow_updates\": 0.0,\n  \"piqa_cms_slow_update_events\": 0.0,\n  \"piqa_cms_ultra_updates\": 0.0,\n  \"piqa_cms_ultra_update_events\": 0.0,\n  \"hellaswag_accuracy\": 0.2734375,\n  \"hellaswag_samples\": 256,\n  \"hellaswag_baseline_accuracy\": 0.27734375,\n  \"hellaswag_memorize_accuracy\": 0.2734375,\n  \"hellaswag_memorize_delta\": -0.00390625,\n  \"hellaswag_memorize_paths\": \"titan,cms_fast\",\n  \"hellaswag_memorize_surprise_threshold\": 0.02,\n  \"hellaswag_titan_mem_updates\": 29645.10869884491,\n  \"hellaswag_titan_update_events\": 24.0,\n  \"hellaswag_cms_fast_updates\": 2120.844047217164,\n  \"hellaswag_cms_fast_update_events\": 636.0,\n  \"hellaswag_cms_mid_updates\": 0.0,\n  \"hellaswag_cms_mid_update_events\": 0.0,\n  \"hellaswag_cms_slow_updates\": 0.0,\n  \"hellaswag_cms_slow_update_events\": 0.0,\n  \"hellaswag_cms_ultra_updates\": 0.0,\n  \"hellaswag_cms_ultra_update_events\": 0.0,\n  \"winogrande_accuracy\": 0.48046875,\n  \"winogrande_samples\": 256,\n  \"winogrande_baseline_accuracy\": 0.48046875,\n  \"winogrande_memorize_accuracy\": 0.48046875,\n  \"winogrande_memorize_delta\": 0.0,\n  \"winogrande_memorize_paths\": \"titan,cms_fast\",\n  \"winogrande_memorize_surprise_threshold\": 0.02,\n  \"winogrande_titan_mem_updates\": 0.0,\n  \"winogrande_titan_update_events\": 0.0,\n  \"winogrande_cms_fast_updates\": 0.0,\n  \"winogrande_cms_fast_update_events\": 0.0,\n  \"winogrande_cms_mid_updates\": 0.0,\n  \"winogrande_cms_mid_update_events\": 0.0,\n  \"winogrande_cms_slow_updates\": 0.0,\n  \"winogrande_cms_slow_update_events\": 0.0,\n  \"winogrande_cms_ultra_updates\": 0.0,\n  \"winogrande_cms_ultra_update_events\": 0.0,\n  \"arc_arc-easy_accuracy\": 0.3125,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-easy_baseline_accuracy\": 0.3046875,\n  \"arc_arc-easy_memorize_accuracy\": 0.3125,\n  \"arc_arc-easy_memorize_delta\": 0.0078125,\n  \"arc_arc-easy_memorize_paths\": \"titan,cms_fast\",\n  \"arc_arc-easy_memorize_surprise_threshold\": 0.02,\n  \"arc_arc-easy_titan_mem_updates\": 505374.0824314356,\n  \"arc_arc-easy_titan_update_events\": 408.0,\n  \"arc_arc-easy_cms_fast_updates\": 35437.47048758948,\n  \"arc_arc-easy_cms_fast_update_events\": 9072.0,\n  \"arc_arc-easy_cms_mid_updates\": 0.0,\n  \"arc_arc-easy_cms_mid_update_events\": 0.0,\n  \"arc_arc-easy_cms_slow_updates\": 0.0,\n  \"arc_arc-easy_cms_slow_update_events\": 0.0,\n  \"arc_arc-easy_cms_ultra_updates\": 0.0,\n  \"arc_arc-easy_cms_ultra_update_events\": 0.0,\n  \"arc_arc-challenge_accuracy\": 0.21875,\n  \"arc_arc-challenge_samples\": 256,\n  \"arc_arc-challenge_baseline_accuracy\": 0.22265625,\n  \"arc_arc-challenge_memorize_accuracy\": 0.21875,\n  \"arc_arc-challenge_memorize_delta\": -0.00390625,\n  \"arc_arc-challenge_memorize_paths\": \"titan,cms_fast\",\n  \"arc_arc-challenge_memorize_surprise_threshold\": 0.02,\n  \"arc_arc-challenge_titan_mem_updates\": 297412.9418039322,\n  \"arc_arc-challenge_titan_update_events\": 240.0,\n  \"arc_arc-challenge_cms_fast_updates\": 20917.3373906645,\n  \"arc_arc-challenge_cms_fast_update_events\": 5376.0,\n  \"arc_arc-challenge_cms_mid_updates\": 0.0,\n  \"arc_arc-challenge_cms_mid_update_events\": 0.0,\n  \"arc_arc-challenge_cms_slow_updates\": 0.0,\n  \"arc_arc-challenge_cms_slow_update_events\": 0.0,\n  \"arc_arc-challenge_cms_ultra_updates\": 0.0,\n  \"arc_arc-challenge_cms_ultra_update_events\": 0.0,\n  \"boolq_accuracy\": 0.6328125,\n  \"boolq_samples\": 256,\n  \"boolq_baseline_accuracy\": 0.6328125,\n  \"boolq_memorize_accuracy\": 0.6328125,\n  \"boolq_memorize_delta\": 0.0,\n  \"boolq_memorize_paths\": \"titan,cms_fast\",\n  \"boolq_memorize_surprise_threshold\": 0.02,\n  \"boolq_titan_mem_updates\": 0.0,\n  \"boolq_titan_update_events\": 0.0,\n  \"boolq_cms_fast_updates\": 0.0,\n  \"boolq_cms_fast_update_events\": 0.0,\n  \"boolq_cms_mid_updates\": 0.0,\n  \"boolq_cms_mid_update_events\": 0.0,\n  \"boolq_cms_slow_updates\": 0.0,\n  \"boolq_cms_slow_update_events\": 0.0,\n  \"boolq_cms_ultra_updates\": 0.0,\n  \"boolq_cms_ultra_update_events\": 0.0,\n  \"siqa_accuracy\": 0.30078125,\n  \"siqa_samples\": 256,\n  \"siqa_baseline_accuracy\": 0.30078125,\n  \"siqa_memorize_accuracy\": 0.30078125,\n  \"siqa_memorize_delta\": 0.0,\n  \"siqa_memorize_paths\": \"titan,cms_fast\",\n  \"siqa_memorize_surprise_threshold\": 0.02,\n  \"siqa_titan_mem_updates\": 0.0,\n  \"siqa_titan_update_events\": 0.0,\n  \"siqa_cms_fast_updates\": 0.0,\n  \"siqa_cms_fast_update_events\": 0.0,\n  \"siqa_cms_mid_updates\": 0.0,\n  \"siqa_cms_mid_update_events\": 0.0,\n  \"siqa_cms_slow_updates\": 0.0,\n  \"siqa_cms_slow_update_events\": 0.0,\n  \"siqa_cms_ultra_updates\": 0.0,\n  \"siqa_cms_ultra_update_events\": 0.0,\n  \"commonsenseqa_accuracy\": 0.1640625,\n  \"commonsenseqa_samples\": 256,\n  \"commonsenseqa_baseline_accuracy\": 0.15234375,\n  \"commonsenseqa_memorize_accuracy\": 0.1640625,\n  \"commonsenseqa_memorize_delta\": 0.01171875,\n  \"commonsenseqa_memorize_paths\": \"titan,cms_fast\",\n  \"commonsenseqa_memorize_surprise_threshold\": 0.02,\n  \"commonsenseqa_titan_mem_updates\": 1356566.724259615,\n  \"commonsenseqa_titan_update_events\": 1092.0,\n  \"commonsenseqa_cms_fast_updates\": 94493.39003942604,\n  \"commonsenseqa_cms_fast_update_events\": 23448.0,\n  \"commonsenseqa_cms_mid_updates\": 0.0,\n  \"commonsenseqa_cms_mid_update_events\": 0.0,\n  \"commonsenseqa_cms_slow_updates\": 0.0,\n  \"commonsenseqa_cms_slow_update_events\": 0.0,\n  \"commonsenseqa_cms_ultra_updates\": 0.0,\n  \"commonsenseqa_cms_ultra_update_events\": 0.0,\n  \"openbookqa_accuracy\": 0.15625,\n  \"openbookqa_samples\": 256,\n  \"openbookqa_baseline_accuracy\": 0.1484375,\n  \"openbookqa_memorize_accuracy\": 0.15625,\n  \"openbookqa_memorize_delta\": 0.0078125,\n  \"openbookqa_memorize_paths\": \"titan,cms_fast\",\n  \"openbookqa_memorize_surprise_threshold\": 0.02,\n  \"openbookqa_titan_mem_updates\": 2107297.096786499,\n  \"openbookqa_titan_update_events\": 1692.0,\n  \"openbookqa_cms_fast_updates\": 146285.0985297151,\n  \"openbookqa_cms_fast_update_events\": 33012.0,\n  \"openbookqa_cms_mid_updates\": 0.0,\n  \"openbookqa_cms_mid_update_events\": 0.0,\n  \"openbookqa_cms_slow_updates\": 0.0,\n  \"openbookqa_cms_slow_update_events\": 0.0,\n  \"openbookqa_cms_ultra_updates\": 0.0,\n  \"openbookqa_cms_ultra_update_events\": 0.0\n}"
  },
  {
    "path": "eval/zeroshot_pilot_cms_nochunk_step5000.json",
    "content": "{\n  \"piqa_accuracy\": 0.51953125,\n  \"piqa_samples\": 256,\n  \"hellaswag_accuracy\": 0.27734375,\n  \"hellaswag_samples\": 256,\n  \"winogrande_accuracy\": 0.47265625,\n  \"winogrande_samples\": 256,\n  \"arc_arc-easy_accuracy\": 0.3203125,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-challenge_accuracy\": 0.2421875,\n  \"arc_arc-challenge_samples\": 256,\n  \"boolq_accuracy\": 0.6328125,\n  \"boolq_samples\": 256,\n  \"siqa_accuracy\": 0.30078125,\n  \"siqa_samples\": 256,\n  \"commonsenseqa_accuracy\": 0.19140625,\n  \"commonsenseqa_samples\": 256,\n  \"openbookqa_accuracy\": 0.1484375,\n  \"openbookqa_samples\": 256\n}"
  },
  {
    "path": "eval/zeroshot_pilot_cms_sparse_step5000.json",
    "content": "{\n  \"piqa_accuracy\": 0.515625,\n  \"piqa_samples\": 256,\n  \"hellaswag_accuracy\": 0.2578125,\n  \"hellaswag_samples\": 256,\n  \"winogrande_accuracy\": 0.5,\n  \"winogrande_samples\": 256,\n  \"arc_arc-easy_accuracy\": 0.3046875,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-challenge_accuracy\": 0.234375,\n  \"arc_arc-challenge_samples\": 256,\n  \"boolq_accuracy\": 0.3671875,\n  \"boolq_samples\": 256,\n  \"siqa_accuracy\": 0.26171875,\n  \"siqa_samples\": 256,\n  \"commonsenseqa_accuracy\": 0.19140625,\n  \"commonsenseqa_samples\": 256,\n  \"openbookqa_accuracy\": 0.140625,\n  \"openbookqa_samples\": 256\n}"
  },
  {
    "path": "eval/zeroshot_pilot_debug.json",
    "content": "{\n  \"piqa_accuracy\": 0.59375,\n  \"piqa_samples\": 32,\n  \"piqa_baseline_accuracy\": 0.59375,\n  \"piqa_memorize_accuracy\": 0.59375,\n  \"piqa_memorize_delta\": 0.0,\n  \"piqa_memorize_paths\": \"all\",\n  \"piqa_titan_mem_updates\": 214.61288151843473,\n  \"piqa_cms_fast_updates\": 293.2037897591945\n}"
  },
  {
    "path": "eval/zeroshot_pilot_dummy_piqa.json",
    "content": "{\n  \"piqa_accuracy\": 0.0,\n  \"piqa_samples\": 2\n}"
  },
  {
    "path": "eval/zeroshot_pilot_opt_adamw_step5000.json",
    "content": "{\n  \"piqa_accuracy\": 0.55859375,\n  \"piqa_samples\": 256,\n  \"hellaswag_accuracy\": 0.2734375,\n  \"hellaswag_samples\": 256,\n  \"winogrande_accuracy\": 0.5,\n  \"winogrande_samples\": 256,\n  \"boolq_accuracy\": 0.3671875,\n  \"boolq_samples\": 256\n}"
  },
  {
    "path": "eval/zeroshot_pilot_opt_muon_step5000.json",
    "content": "{\n  \"piqa_accuracy\": 0.53125,\n  \"piqa_samples\": 256,\n  \"hellaswag_accuracy\": 0.3125,\n  \"hellaswag_samples\": 256,\n  \"winogrande_accuracy\": 0.484375,\n  \"winogrande_samples\": 256,\n  \"boolq_accuracy\": 0.5703125,\n  \"boolq_samples\": 256\n}"
  },
  {
    "path": "eval/zeroshot_pilot_selfmod_off_step5000.json",
    "content": "{\n  \"piqa_accuracy\": 0.515625,\n  \"piqa_samples\": 256,\n  \"hellaswag_accuracy\": 0.265625,\n  \"hellaswag_samples\": 256,\n  \"winogrande_accuracy\": 0.46484375,\n  \"winogrande_samples\": 256,\n  \"arc_arc-easy_accuracy\": 0.2890625,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-challenge_accuracy\": 0.20703125,\n  \"arc_arc-challenge_samples\": 256,\n  \"boolq_accuracy\": 0.6328125,\n  \"boolq_samples\": 256,\n  \"siqa_accuracy\": 0.33203125,\n  \"siqa_samples\": 256,\n  \"commonsenseqa_accuracy\": 0.1640625,\n  \"commonsenseqa_samples\": 256,\n  \"openbookqa_accuracy\": 0.1640625,\n  \"openbookqa_samples\": 256\n}"
  },
  {
    "path": "eval/zeroshot_pilot_step22000.json",
    "content": "{\n  \"piqa_accuracy\": 0.53125,\n  \"piqa_samples\": 128\n}"
  },
  {
    "path": "eval/zeroshot_pilot_step230000.json",
    "content": "{\n  \"piqa_accuracy\": 0.515625,\n  \"piqa_samples\": 256,\n  \"piqa_baseline_accuracy\": 0.51171875,\n  \"piqa_memorize_accuracy\": 0.515625,\n  \"piqa_memorize_delta\": 0.00390625,\n  \"piqa_titan_mem_updates\": 10.03807739145368,\n  \"piqa_cms_fast_updates\": 0.0,\n  \"hellaswag_accuracy\": 0.30078125,\n  \"hellaswag_samples\": 256,\n  \"hellaswag_baseline_accuracy\": 0.296875,\n  \"hellaswag_memorize_accuracy\": 0.30078125,\n  \"hellaswag_memorize_delta\": 0.00390625,\n  \"hellaswag_titan_mem_updates\": 9.644964346148374,\n  \"hellaswag_cms_fast_updates\": 0.0,\n  \"winogrande_accuracy\": 0.4921875,\n  \"winogrande_samples\": 256,\n  \"winogrande_baseline_accuracy\": 0.484375,\n  \"winogrande_memorize_accuracy\": 0.4921875,\n  \"winogrande_memorize_delta\": 0.0078125,\n  \"winogrande_titan_mem_updates\": 9.799855060760141,\n  \"winogrande_cms_fast_updates\": 0.0,\n  \"arc_arc-easy_accuracy\": 0.28515625,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-easy_baseline_accuracy\": 0.28515625,\n  \"arc_arc-easy_memorize_accuracy\": 0.28515625,\n  \"arc_arc-easy_memorize_delta\": 0.0,\n  \"arc_arc-easy_titan_mem_updates\": 9.155673045361393,\n  \"arc_arc-easy_cms_fast_updates\": 0.0,\n  \"arc_arc-challenge_accuracy\": 0.25,\n  \"arc_arc-challenge_samples\": 256,\n  \"arc_arc-challenge_baseline_accuracy\": 0.24609375,\n  \"arc_arc-challenge_memorize_accuracy\": 0.25,\n  \"arc_arc-challenge_memorize_delta\": 0.00390625,\n  \"arc_arc-challenge_titan_mem_updates\": 9.186319984716842,\n  \"arc_arc-challenge_cms_fast_updates\": 0.0,\n  \"boolq_accuracy\": 0.3671875,\n  \"boolq_samples\": 256,\n  \"boolq_baseline_accuracy\": 0.3671875,\n  \"boolq_memorize_accuracy\": 0.3671875,\n  \"boolq_memorize_delta\": 0.0,\n  \"boolq_titan_mem_updates\": 8.10761637164109,\n  \"boolq_cms_fast_updates\": 0.0,\n  \"siqa_accuracy\": 0.3125,\n  \"siqa_samples\": 256,\n  \"siqa_baseline_accuracy\": 0.31640625,\n  \"siqa_memorize_accuracy\": 0.3125,\n  \"siqa_memorize_delta\": -0.00390625,\n  \"siqa_titan_mem_updates\": 10.038975774096901,\n  \"siqa_cms_fast_updates\": 0.0,\n  \"commonsenseqa_accuracy\": 0.1875,\n  \"commonsenseqa_samples\": 256,\n  \"commonsenseqa_baseline_accuracy\": 0.19140625,\n  \"commonsenseqa_memorize_accuracy\": 0.1875,\n  \"commonsenseqa_memorize_delta\": -0.00390625,\n  \"commonsenseqa_titan_mem_updates\": 10.048177535929167,\n  \"commonsenseqa_cms_fast_updates\": 0.0,\n  \"openbookqa_accuracy\": 0.140625,\n  \"openbookqa_samples\": 256,\n  \"openbookqa_baseline_accuracy\": 0.140625,\n  \"openbookqa_memorize_accuracy\": 0.140625,\n  \"openbookqa_memorize_delta\": 0.0,\n  \"openbookqa_titan_mem_updates\": 10.512614026167608,\n  \"openbookqa_cms_fast_updates\": 0.0\n}"
  },
  {
    "path": "eval/zeroshot_pilot_teach05_long_step25000.json",
    "content": "{\n  \"piqa_accuracy\": 0.5078125,\n  \"piqa_samples\": 256,\n  \"hellaswag_accuracy\": 0.28515625,\n  \"hellaswag_samples\": 256,\n  \"winogrande_accuracy\": 0.4765625,\n  \"winogrande_samples\": 256,\n  \"arc_arc-easy_accuracy\": 0.3203125,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-challenge_accuracy\": 0.23828125,\n  \"arc_arc-challenge_samples\": 256,\n  \"boolq_accuracy\": 0.3671875,\n  \"boolq_samples\": 256,\n  \"siqa_accuracy\": 0.328125,\n  \"siqa_samples\": 256,\n  \"commonsenseqa_accuracy\": 0.19921875,\n  \"commonsenseqa_samples\": 256,\n  \"openbookqa_accuracy\": 0.14453125,\n  \"openbookqa_samples\": 256\n}"
  },
  {
    "path": "eval/zeroshot_pilot_teach05_step2000.json",
    "content": "{\n  \"piqa_accuracy\": 0.453125,\n  \"piqa_samples\": 128,\n  \"hellaswag_accuracy\": 0.2734375,\n  \"hellaswag_samples\": 128,\n  \"winogrande_accuracy\": 0.5078125,\n  \"winogrande_samples\": 128,\n  \"arc_arc-easy_accuracy\": 0.25,\n  \"arc_arc-easy_samples\": 128,\n  \"arc_arc-challenge_accuracy\": 0.2265625,\n  \"arc_arc-challenge_samples\": 128,\n  \"boolq_accuracy\": 0.6640625,\n  \"boolq_samples\": 128,\n  \"siqa_accuracy\": 0.2890625,\n  \"siqa_samples\": 128,\n  \"commonsenseqa_accuracy\": 0.1875,\n  \"commonsenseqa_samples\": 128,\n  \"openbookqa_accuracy\": 0.1796875,\n  \"openbookqa_samples\": 128\n}"
  },
  {
    "path": "eval/zeroshot_pilot_teach15_long_step25000.json",
    "content": "{\n  \"piqa_accuracy\": 0.49609375,\n  \"piqa_samples\": 256,\n  \"hellaswag_accuracy\": 0.3046875,\n  \"hellaswag_samples\": 256,\n  \"winogrande_accuracy\": 0.5,\n  \"winogrande_samples\": 256,\n  \"arc_arc-easy_accuracy\": 0.30078125,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-challenge_accuracy\": 0.23828125,\n  \"arc_arc-challenge_samples\": 256,\n  \"boolq_accuracy\": 0.3671875,\n  \"boolq_samples\": 256,\n  \"siqa_accuracy\": 0.31640625,\n  \"siqa_samples\": 256,\n  \"commonsenseqa_accuracy\": 0.17578125,\n  \"commonsenseqa_samples\": 256,\n  \"openbookqa_accuracy\": 0.125,\n  \"openbookqa_samples\": 256\n}"
  },
  {
    "path": "eval/zeroshot_pilot_teach15_step2000.json",
    "content": "{\n  \"piqa_accuracy\": 0.484375,\n  \"piqa_samples\": 128,\n  \"hellaswag_accuracy\": 0.2578125,\n  \"hellaswag_samples\": 128,\n  \"winogrande_accuracy\": 0.4609375,\n  \"winogrande_samples\": 128,\n  \"arc_arc-easy_accuracy\": 0.203125,\n  \"arc_arc-easy_samples\": 128,\n  \"arc_arc-challenge_accuracy\": 0.21875,\n  \"arc_arc-challenge_samples\": 128,\n  \"boolq_accuracy\": 0.3359375,\n  \"boolq_samples\": 128,\n  \"siqa_accuracy\": 0.34375,\n  \"siqa_samples\": 128,\n  \"commonsenseqa_accuracy\": 0.2109375,\n  \"commonsenseqa_samples\": 128,\n  \"openbookqa_accuracy\": 0.1484375,\n  \"openbookqa_samples\": 128\n}"
  },
  {
    "path": "eval/zeroshot_smoke.json",
    "content": "{\n  \"piqa_accuracy\": 0.5,\n  \"piqa_samples\": 16\n}"
  },
  {
    "path": "eval/zeroshot_titan.json",
    "content": "{\n  \"piqa_accuracy\": 0.484375,\n  \"piqa_samples\": 256,\n  \"piqa_baseline_accuracy\": 0.51171875,\n  \"piqa_memorize_accuracy\": 0.484375,\n  \"piqa_memorize_delta\": -0.02734375,\n  \"piqa_memorize_paths\": \"titan\",\n  \"piqa_memorize_surprise_threshold\": 0.02,\n  \"piqa_titan_mem_updates\": 0.0,\n  \"piqa_titan_update_events\": 0.0,\n  \"piqa_cms_fast_updates\": 0.0,\n  \"piqa_cms_fast_update_events\": 0.0,\n  \"piqa_cms_mid_updates\": 0.0,\n  \"piqa_cms_mid_update_events\": 0.0,\n  \"piqa_cms_slow_updates\": 0.0,\n  \"piqa_cms_slow_update_events\": 0.0,\n  \"piqa_cms_ultra_updates\": 0.0,\n  \"piqa_cms_ultra_update_events\": 0.0,\n  \"hellaswag_accuracy\": 0.26171875,\n  \"hellaswag_samples\": 256,\n  \"hellaswag_baseline_accuracy\": 0.265625,\n  \"hellaswag_memorize_accuracy\": 0.26171875,\n  \"hellaswag_memorize_delta\": -0.00390625,\n  \"hellaswag_memorize_paths\": \"titan\",\n  \"hellaswag_memorize_surprise_threshold\": 0.02,\n  \"hellaswag_titan_mem_updates\": 0.0,\n  \"hellaswag_titan_update_events\": 0.0,\n  \"hellaswag_cms_fast_updates\": 0.0,\n  \"hellaswag_cms_fast_update_events\": 0.0,\n  \"hellaswag_cms_mid_updates\": 0.0,\n  \"hellaswag_cms_mid_update_events\": 0.0,\n  \"hellaswag_cms_slow_updates\": 0.0,\n  \"hellaswag_cms_slow_update_events\": 0.0,\n  \"hellaswag_cms_ultra_updates\": 0.0,\n  \"hellaswag_cms_ultra_update_events\": 0.0,\n  \"winogrande_accuracy\": 0.47265625,\n  \"winogrande_samples\": 256,\n  \"winogrande_baseline_accuracy\": 0.48828125,\n  \"winogrande_memorize_accuracy\": 0.47265625,\n  \"winogrande_memorize_delta\": -0.015625,\n  \"winogrande_memorize_paths\": \"titan\",\n  \"winogrande_memorize_surprise_threshold\": 0.02,\n  \"winogrande_titan_mem_updates\": 0.0,\n  \"winogrande_titan_update_events\": 0.0,\n  \"winogrande_cms_fast_updates\": 0.0,\n  \"winogrande_cms_fast_update_events\": 0.0,\n  \"winogrande_cms_mid_updates\": 0.0,\n  \"winogrande_cms_mid_update_events\": 0.0,\n  \"winogrande_cms_slow_updates\": 0.0,\n  \"winogrande_cms_slow_update_events\": 0.0,\n  \"winogrande_cms_ultra_updates\": 0.0,\n  \"winogrande_cms_ultra_update_events\": 0.0,\n  \"arc_arc-easy_accuracy\": 0.2734375,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-easy_baseline_accuracy\": 0.29296875,\n  \"arc_arc-easy_memorize_accuracy\": 0.2734375,\n  \"arc_arc-easy_memorize_delta\": -0.01953125,\n  \"arc_arc-easy_memorize_paths\": \"titan\",\n  \"arc_arc-easy_memorize_surprise_threshold\": 0.02,\n  \"arc_arc-easy_titan_mem_updates\": 0.0,\n  \"arc_arc-easy_titan_update_events\": 0.0,\n  \"arc_arc-easy_cms_fast_updates\": 0.0,\n  \"arc_arc-easy_cms_fast_update_events\": 0.0,\n  \"arc_arc-easy_cms_mid_updates\": 0.0,\n  \"arc_arc-easy_cms_mid_update_events\": 0.0,\n  \"arc_arc-easy_cms_slow_updates\": 0.0,\n  \"arc_arc-easy_cms_slow_update_events\": 0.0,\n  \"arc_arc-easy_cms_ultra_updates\": 0.0,\n  \"arc_arc-easy_cms_ultra_update_events\": 0.0,\n  \"arc_arc-challenge_accuracy\": 0.21484375,\n  \"arc_arc-challenge_samples\": 256,\n  \"arc_arc-challenge_baseline_accuracy\": 0.2265625,\n  \"arc_arc-challenge_memorize_accuracy\": 0.21484375,\n  \"arc_arc-challenge_memorize_delta\": -0.01171875,\n  \"arc_arc-challenge_memorize_paths\": \"titan\",\n  \"arc_arc-challenge_memorize_surprise_threshold\": 0.02,\n  \"arc_arc-challenge_titan_mem_updates\": 0.0,\n  \"arc_arc-challenge_titan_update_events\": 0.0,\n  \"arc_arc-challenge_cms_fast_updates\": 0.0,\n  \"arc_arc-challenge_cms_fast_update_events\": 0.0,\n  \"arc_arc-challenge_cms_mid_updates\": 0.0,\n  \"arc_arc-challenge_cms_mid_update_events\": 0.0,\n  \"arc_arc-challenge_cms_slow_updates\": 0.0,\n  \"arc_arc-challenge_cms_slow_update_events\": 0.0,\n  \"arc_arc-challenge_cms_ultra_updates\": 0.0,\n  \"arc_arc-challenge_cms_ultra_update_events\": 0.0,\n  \"boolq_accuracy\": 0.390625,\n  \"boolq_samples\": 256,\n  \"boolq_baseline_accuracy\": 0.390625,\n  \"boolq_memorize_accuracy\": 0.390625,\n  \"boolq_memorize_delta\": 0.0,\n  \"boolq_memorize_paths\": \"titan\",\n  \"boolq_memorize_surprise_threshold\": 0.02,\n  \"boolq_titan_mem_updates\": 0.0,\n  \"boolq_titan_update_events\": 0.0,\n  \"boolq_cms_fast_updates\": 0.0,\n  \"boolq_cms_fast_update_events\": 0.0,\n  \"boolq_cms_mid_updates\": 0.0,\n  \"boolq_cms_mid_update_events\": 0.0,\n  \"boolq_cms_slow_updates\": 0.0,\n  \"boolq_cms_slow_update_events\": 0.0,\n  \"boolq_cms_ultra_updates\": 0.0,\n  \"boolq_cms_ultra_update_events\": 0.0,\n  \"siqa_accuracy\": 0.32421875,\n  \"siqa_samples\": 256,\n  \"siqa_baseline_accuracy\": 0.3125,\n  \"siqa_memorize_accuracy\": 0.32421875,\n  \"siqa_memorize_delta\": 0.01171875,\n  \"siqa_memorize_paths\": \"titan\",\n  \"siqa_memorize_surprise_threshold\": 0.02,\n  \"siqa_titan_mem_updates\": 0.0,\n  \"siqa_titan_update_events\": 0.0,\n  \"siqa_cms_fast_updates\": 0.0,\n  \"siqa_cms_fast_update_events\": 0.0,\n  \"siqa_cms_mid_updates\": 0.0,\n  \"siqa_cms_mid_update_events\": 0.0,\n  \"siqa_cms_slow_updates\": 0.0,\n  \"siqa_cms_slow_update_events\": 0.0,\n  \"siqa_cms_ultra_updates\": 0.0,\n  \"siqa_cms_ultra_update_events\": 0.0,\n  \"commonsenseqa_accuracy\": 0.1640625,\n  \"commonsenseqa_samples\": 256,\n  \"commonsenseqa_baseline_accuracy\": 0.1796875,\n  \"commonsenseqa_memorize_accuracy\": 0.1640625,\n  \"commonsenseqa_memorize_delta\": -0.015625,\n  \"commonsenseqa_memorize_paths\": \"titan\",\n  \"commonsenseqa_memorize_surprise_threshold\": 0.02,\n  \"commonsenseqa_titan_mem_updates\": 0.0,\n  \"commonsenseqa_titan_update_events\": 0.0,\n  \"commonsenseqa_cms_fast_updates\": 0.0,\n  \"commonsenseqa_cms_fast_update_events\": 0.0,\n  \"commonsenseqa_cms_mid_updates\": 0.0,\n  \"commonsenseqa_cms_mid_update_events\": 0.0,\n  \"commonsenseqa_cms_slow_updates\": 0.0,\n  \"commonsenseqa_cms_slow_update_events\": 0.0,\n  \"commonsenseqa_cms_ultra_updates\": 0.0,\n  \"commonsenseqa_cms_ultra_update_events\": 0.0,\n  \"openbookqa_accuracy\": 0.1171875,\n  \"openbookqa_samples\": 256,\n  \"openbookqa_baseline_accuracy\": 0.12890625,\n  \"openbookqa_memorize_accuracy\": 0.1171875,\n  \"openbookqa_memorize_delta\": -0.01171875,\n  \"openbookqa_memorize_paths\": \"titan\",\n  \"openbookqa_memorize_surprise_threshold\": 0.02,\n  \"openbookqa_titan_mem_updates\": 0.0,\n  \"openbookqa_titan_update_events\": 0.0,\n  \"openbookqa_cms_fast_updates\": 0.0,\n  \"openbookqa_cms_fast_update_events\": 0.0,\n  \"openbookqa_cms_mid_updates\": 0.0,\n  \"openbookqa_cms_mid_update_events\": 0.0,\n  \"openbookqa_cms_slow_updates\": 0.0,\n  \"openbookqa_cms_slow_update_events\": 0.0,\n  \"openbookqa_cms_ultra_updates\": 0.0,\n  \"openbookqa_cms_ultra_update_events\": 0.0\n}"
  },
  {
    "path": "eval/zeroshot_titan_relaunch_step001000.json",
    "content": "{\n  \"piqa_accuracy\": 0.5234375,\n  \"piqa_samples\": 256,\n  \"piqa_baseline_accuracy\": 0.5234375,\n  \"piqa_memorize_accuracy\": 0.5234375,\n  \"piqa_memorize_delta\": 0.0,\n  \"piqa_memorize_paths\": \"titan\",\n  \"piqa_memorize_surprise_threshold\": 0.02,\n  \"piqa_titan_mem_updates\": 0.0,\n  \"piqa_cms_fast_updates\": 0.0,\n  \"hellaswag_accuracy\": 0.2890625,\n  \"hellaswag_samples\": 256,\n  \"hellaswag_baseline_accuracy\": 0.2890625,\n  \"hellaswag_memorize_accuracy\": 0.2890625,\n  \"hellaswag_memorize_delta\": 0.0,\n  \"hellaswag_memorize_paths\": \"titan\",\n  \"hellaswag_memorize_surprise_threshold\": 0.02,\n  \"hellaswag_titan_mem_updates\": 0.0,\n  \"hellaswag_cms_fast_updates\": 0.0,\n  \"winogrande_accuracy\": 0.52734375,\n  \"winogrande_samples\": 256,\n  \"winogrande_baseline_accuracy\": 0.52734375,\n  \"winogrande_memorize_accuracy\": 0.52734375,\n  \"winogrande_memorize_delta\": 0.0,\n  \"winogrande_memorize_paths\": \"titan\",\n  \"winogrande_memorize_surprise_threshold\": 0.02,\n  \"winogrande_titan_mem_updates\": 0.0,\n  \"winogrande_cms_fast_updates\": 0.0,\n  \"arc_arc-easy_accuracy\": 0.26953125,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-easy_baseline_accuracy\": 0.26953125,\n  \"arc_arc-easy_memorize_accuracy\": 0.26953125,\n  \"arc_arc-easy_memorize_delta\": 0.0,\n  \"arc_arc-easy_memorize_paths\": \"titan\",\n  \"arc_arc-easy_memorize_surprise_threshold\": 0.02,\n  \"arc_arc-easy_titan_mem_updates\": 0.0,\n  \"arc_arc-easy_cms_fast_updates\": 0.0,\n  \"arc_arc-challenge_accuracy\": 0.21875,\n  \"arc_arc-challenge_samples\": 256,\n  \"arc_arc-challenge_baseline_accuracy\": 0.21875,\n  \"arc_arc-challenge_memorize_accuracy\": 0.21875,\n  \"arc_arc-challenge_memorize_delta\": 0.0,\n  \"arc_arc-challenge_memorize_paths\": \"titan\",\n  \"arc_arc-challenge_memorize_surprise_threshold\": 0.02,\n  \"arc_arc-challenge_titan_mem_updates\": 0.0,\n  \"arc_arc-challenge_cms_fast_updates\": 0.0,\n  \"boolq_accuracy\": 0.6328125,\n  \"boolq_samples\": 256,\n  \"boolq_baseline_accuracy\": 0.6328125,\n  \"boolq_memorize_accuracy\": 0.6328125,\n  \"boolq_memorize_delta\": 0.0,\n  \"boolq_memorize_paths\": \"titan\",\n  \"boolq_memorize_surprise_threshold\": 0.02,\n  \"boolq_titan_mem_updates\": 0.0,\n  \"boolq_cms_fast_updates\": 0.0,\n  \"siqa_accuracy\": 0.3125,\n  \"siqa_samples\": 256,\n  \"siqa_baseline_accuracy\": 0.3125,\n  \"siqa_memorize_accuracy\": 0.3125,\n  \"siqa_memorize_delta\": 0.0,\n  \"siqa_memorize_paths\": \"titan\",\n  \"siqa_memorize_surprise_threshold\": 0.02,\n  \"siqa_titan_mem_updates\": 0.0,\n  \"siqa_cms_fast_updates\": 0.0,\n  \"commonsenseqa_accuracy\": 0.203125,\n  \"commonsenseqa_samples\": 256,\n  \"commonsenseqa_baseline_accuracy\": 0.203125,\n  \"commonsenseqa_memorize_accuracy\": 0.203125,\n  \"commonsenseqa_memorize_delta\": 0.0,\n  \"commonsenseqa_memorize_paths\": \"titan\",\n  \"commonsenseqa_memorize_surprise_threshold\": 0.02,\n  \"commonsenseqa_titan_mem_updates\": 0.0,\n  \"commonsenseqa_cms_fast_updates\": 0.0,\n  \"openbookqa_accuracy\": 0.1640625,\n  \"openbookqa_samples\": 256,\n  \"openbookqa_baseline_accuracy\": 0.1640625,\n  \"openbookqa_memorize_accuracy\": 0.1640625,\n  \"openbookqa_memorize_delta\": 0.0,\n  \"openbookqa_memorize_paths\": \"titan\",\n  \"openbookqa_memorize_surprise_threshold\": 0.02,\n  \"openbookqa_titan_mem_updates\": 0.0,\n  \"openbookqa_cms_fast_updates\": 0.0\n}"
  },
  {
    "path": "eval/zeroshot_titan_step25000.json",
    "content": "{\n  \"piqa_accuracy\": 0.48828125,\n  \"piqa_samples\": 256,\n  \"piqa_baseline_accuracy\": 0.484375,\n  \"piqa_memorize_accuracy\": 0.48828125,\n  \"piqa_memorize_delta\": 0.00390625,\n  \"piqa_titan_mem_updates\": 0.0,\n  \"piqa_cms_fast_updates\": 0.0,\n  \"hellaswag_accuracy\": 0.296875,\n  \"hellaswag_samples\": 256,\n  \"hellaswag_baseline_accuracy\": 0.296875,\n  \"hellaswag_memorize_accuracy\": 0.296875,\n  \"hellaswag_memorize_delta\": 0.0,\n  \"hellaswag_titan_mem_updates\": 0.0,\n  \"hellaswag_cms_fast_updates\": 0.0,\n  \"winogrande_accuracy\": 0.47265625,\n  \"winogrande_samples\": 256,\n  \"winogrande_baseline_accuracy\": 0.48828125,\n  \"winogrande_memorize_accuracy\": 0.47265625,\n  \"winogrande_memorize_delta\": -0.015625,\n  \"winogrande_titan_mem_updates\": 0.0,\n  \"winogrande_cms_fast_updates\": 0.0,\n  \"arc_arc-easy_accuracy\": 0.2890625,\n  \"arc_arc-easy_samples\": 256,\n  \"arc_arc-easy_baseline_accuracy\": 0.28515625,\n  \"arc_arc-easy_memorize_accuracy\": 0.2890625,\n  \"arc_arc-easy_memorize_delta\": 0.00390625,\n  \"arc_arc-easy_titan_mem_updates\": 0.0,\n  \"arc_arc-easy_cms_fast_updates\": 0.0,\n  \"arc_arc-challenge_accuracy\": 0.25,\n  \"arc_arc-challenge_samples\": 256,\n  \"arc_arc-challenge_baseline_accuracy\": 0.2421875,\n  \"arc_arc-challenge_memorize_accuracy\": 0.25,\n  \"arc_arc-challenge_memorize_delta\": 0.0078125,\n  \"arc_arc-challenge_titan_mem_updates\": 0.0,\n  \"arc_arc-challenge_cms_fast_updates\": 0.0,\n  \"boolq_accuracy\": 0.40625,\n  \"boolq_samples\": 256,\n  \"boolq_baseline_accuracy\": 0.41015625,\n  \"boolq_memorize_accuracy\": 0.40625,\n  \"boolq_memorize_delta\": -0.00390625,\n  \"boolq_titan_mem_updates\": 0.0,\n  \"boolq_cms_fast_updates\": 0.0,\n  \"siqa_accuracy\": 0.2890625,\n  \"siqa_samples\": 256,\n  \"siqa_baseline_accuracy\": 0.296875,\n  \"siqa_memorize_accuracy\": 0.2890625,\n  \"siqa_memorize_delta\": -0.0078125,\n  \"siqa_titan_mem_updates\": 0.0,\n  \"siqa_cms_fast_updates\": 0.0,\n  \"commonsenseqa_accuracy\": 0.18359375,\n  \"commonsenseqa_samples\": 256,\n  \"commonsenseqa_baseline_accuracy\": 0.1953125,\n  \"commonsenseqa_memorize_accuracy\": 0.18359375,\n  \"commonsenseqa_memorize_delta\": -0.01171875,\n  \"commonsenseqa_titan_mem_updates\": 0.0,\n  \"commonsenseqa_cms_fast_updates\": 0.0,\n  \"openbookqa_accuracy\": 0.14453125,\n  \"openbookqa_samples\": 256,\n  \"openbookqa_baseline_accuracy\": 0.15234375,\n  \"openbookqa_memorize_accuracy\": 0.14453125,\n  \"openbookqa_memorize_delta\": -0.0078125,\n  \"openbookqa_titan_mem_updates\": 0.0,\n  \"openbookqa_cms_fast_updates\": 0.0\n}"
  },
  {
    "path": "google_papers/Nested_Learning/Nested_Learning.json",
    "content": "{\r\n  \"pages\": [\r\n    {\r\n      \"index\": 0,\r\n      \"markdown\": \"# Nested Learning: The Illusion of Deep Learning Architectures \\n\\nAli Behrouz<br>Google Research<br>USA<br>alibehrouz@google.com\\n\\nMeisam Razaviyayn<br>Google Research<br>USA<br>rezavyayn@google.com\\n\\nPeiling Zhong<br>Google Research<br>USA<br>peilinz@google.com<br>Vahab Mirrokni<br>Google Research<br>USA<br>mirrokni@google.com\\n\\n## Abstract\\n\\nOver the last decades, developing more powerful neural architectures and simultaneously designing optimization algorithms to effectively train them have been the core of research efforts to enhance the capability of machine learning models. Despite the recent progresses, particularly in developing Language Models (LMs), there are fundamental challenges and unanswered questions about how such models can continually learn/memorize, self-improved, and find \\\"effective solutions,\\\". In this paper, we present a new learning paradigm, called Nested Learning (NL), that coherently represents a model with a set of nested, multi-level, and/or parallel optimization problems, each of which with its own \\\"context flow\\\". NL reveals that existing deep learning methods learns from data through compressing their own context flow, and explain how in-context learning emerges in large models. NL suggests a path (a new dimension to deep learning) to design more expressive learning algorithms with more \\\"levels\\\", resulting in higher-order in-context learning abilities. In addition to its neuroscientifically plausible and mathematically white-box nature, we advocate for its importance by presenting three core contributions: (1) Deep Optimizers: Based on NL, we show that well-known gradient-based optimizers (e.g., Adam, SGD with Momentum, etc.) are in fact associative memory modules that aim to compress the gradients with gradient descent. Building on this insight, we present a set of more expressive optimizers with deep memory and/or more powerful learning rules; (2) Self-Modifying Titans: Taking advantage of NL's insights on learning algorithms, we present a novel sequence model that learns how to modify itself by learning its own update algorithm; and (3) Continuum Memory System: We present a new formulation for memory system that generalizes the traditional viewpoint of \\\"long-term/short-term memory\\\". Combining our self-modifying sequence model with the continuum memory system, we present a learning module, called HOPE, showing promising results in language modeling, continual learning, and long-context reasoning tasks.\\n\\n## 1 Introduction\\n\\nThis version of the paper has been extensively summarized to fit the page limit of NeurIPS camera ready, and some materials, experiments, discussions, and methods are moved to appendix, which might make some parts hard to follow or cause inconsistencies. To avoid such cases, please read our arXiv version instead [1].\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 1,\r\n      \"markdown\": \"![img-0.jpeg](img-0.jpeg)\\n\\nFigure 1: The uniform and reusable structure as well as multi time scale update in the brain are the key components to unlock the continual learning in humans. Nested Learning (NL) allows for multi time-scale update for each component of the brain, while showing that well-known architectures such as Transformers are in fact linear layers with different frequency updates.\\n\\nFor decades, AI research has focused on designing machine learning algorithms that learn from data [2–5] or experience [6–8]; often by optimizing an objective $\\\\mathcal{L}(\\\\boldsymbol{\\\\theta})$ over parameters $\\\\boldsymbol{\\\\theta} \\\\in \\\\Theta$ with gradient-based methods. While traditional machine learning techniques required careful engineering and domain expertise to design feature extractors, limiting their ability to directly process and learn from natural data [9], deep representation learning offered a fully automated alternative to discover the representations needed for the task. Thereafter, deep learning has been an inseparable part of the large-scale computational models with seminal success in chemistry and biology [10], games [11, 12], computer vision [13, 14], and multimodal and natural language understanding [15–17].\\n\\nStacking of multiple layers, as it is done in deep learning models, provides the models with larger capacity, better expressive power in representing complex features, and more internal computations (e.g., #FLOPS) [18–20], all of which are critical and desirable characteristics for static tasks that require in-distribution predictions over a previously fixed set. This deep design, however, is not a universal solution to all the challenges and cannot help the expressive power of the models in multiple aspects, for example: (i) The computational depth of deep models might not change with more layers [21, 22], leaving their ability to implement complex algorithms untouched compared to traditional shallow approaches [23]; (ii) The capacity of some class of parameters might show marginal improvement with increasing the depth/width of the model [24]; (iii) The training process might converge to a suboptimal solution, mainly due to the suboptimal choice of the optimizer or its hyperparameters; and (iv) The model's ability to fast adapt to a new task, continually learn, and/or generalize to out-of-distribution data might not changed with stacking more layers and requires more careful designs.\\n\\nThe core part of the efforts to overcome the above challenges and to enhance the capability of deep learning models concentrate on: (1) developing more expressive class of parameters (i.e., neural architectures) [13, 25–28]; (2) introducing objectives that can better model the tasks [29–32]; (3) designing more efficient/effective optimization algorithms to find better solutions or with more resilience to forgetting [33–36]; and (4) scaling the model size to enhance its expressivity, when the \\\"right\\\" choice of architecture, objective, and optimization algorithms are made [24, 37, 38]. Collectively, these advancements and new findings on scaling patterns of deep models have established the foundations upon which Large Language Models (LLMs) have been built.\\n\\nThe development of LLMs marks a pivotal milestone in deep learning research: a paradigm shift from task-specific models to more general-purpose systems with various emergent capabilities as a result of scaling the \\\"right\\\" architectures [38, 39]. Despite all their success and remarkable capabilities in diverse sets of tasks [15, 40, 41], LLMs are largely static after their initial deployment phase, meaning that they successfully perform tasks learned during pre- or post-training, but are unable to continually acquire new capabilities beyond their immediate context. The only adaptable component of LLMs is their *in-context learning* ability–a (known to be emergent) characteristic of LLMs that enables fast adaption to the context and so perform zero- or few-shot tasks [38]. Beyond in-context learning, recent efforts to overcome the static nature of LLMs either are computationally expensive, require external components, lack generalization, and/or might suffer from catastrophic forgetting [42–44], which has led researchers to question if there is a need to revisit how to design machine learning\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-0.jpeg\",\r\n          \"top_left_x\": 289,\r\n          \"top_left_y\": 191,\r\n          \"bottom_right_x\": 1409,\r\n          \"bottom_right_y\": 589,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 2,\r\n      \"markdown\": \"models and if a new learning paradigm beyond stacking of layers is required to unleash the capabilities of LLMs in continual setups.\\n\\nCurrent Models only Experience the Immediate Present. As an analogy and to better illustrate the static nature of LLMs, we use the example of anterograde amnesia-a neurological condition where a person cannot form new long-term memories after the onset of the disorder, while existing memories remain intact [45]. This condition limits the person's knowledge and experiences to a short window of present and long past-before the onset of the disorder-which results in continuously experiencing the immediate present as if it were always new. The memory processing system of current LLMs suffer from a similar pattern. Their knowledge is limited to either, the immediate context that fits into their context window, or the knowledge in MLP layers that stores long-past, before the onset of \\\"end of pre-training.\\\" This analogy, has motivated us to take inspiration from neurophysiology literature and how brain consolidate its short-term memories:\\n\\n# 1.1 Human Brain Perspective and Neurophysiological Motivation \\n\\nHuman brain is highly efficient and effective when it comes to continual learning (a.k.a. effective context management), which is often attributed to neuroplasticity-the brain's remarkable capacity to change itself in response to new experiences, memories, learning, and even damage [46, 47]. Recent studies support that the formation of Long-term memory involves at least two distinct but complementary consolidation processes [48-50]: (1) A rapid \\\"online\\\" consolidation (also known as synaptic consolidation) phase occurs immediately or soon after learning, even during wakefulness. This is when new and initially fragile memory traces are stabilized and begin transferring from short-term to long-term storage; (2) An \\\"offline\\\" consolidation (also known as systems consolidation) process repeats the replay of the recently encoded patterns-during sharp-wave ripples (SWRs) in the hippocampus, coordinated with cortical sleep spindles and slow oscillations-strengthens and reorganizes the memory and supports transfer to cortical sites [51-53].\\nComing back to the analogy of anterograde amnesia, evidence indicates that the condition can impact both stages, but especially the online consolidation phase, mainly due to the fact that hippocampus is the gateway for encoding new declarative memories, and so its damage means new information never will be stored in long-term memory. As mentioned above, the design of LLMs, and more specifically Transformer-based backbones, suffers from a similar condition after the pre-training phase. That is, the information provided in the context, never impacts the long-term memory parameters (e.g., feedforward layers), and so the model is not capable of acquiring new knowledge or skill, unless the information is still stored in the short-term memory (e.g., attention). To this end, although the second stage is equally, or even more, crucial for the consolidation of memories, and its absence can damage the process and might cause loss of memory [54, 55], in this work, we focus on the first stage: memory consolidation as an online process. We provide additional discussion on human brain perspective and its connection to NL in Appendix A.\\n\\nNotations. We let $x \\\\in \\\\mathbb{R}^{N \\\\times d_{h}}$ be the input, $\\\\mathcal{M}_{t}$ represent the state of memory/model $\\\\mathcal{M}$ at time $t$, $\\\\mathbf{K}$ be the keys, $\\\\mathbf{V}$ be the values, and $\\\\mathbf{Q}$ be the query matrices. We use bold lowercase letters with subscript $t$ to refer to the vector corresponds to the input $t$ (i.e., $\\\\mathbf{k}_{t}, \\\\mathbf{v}_{t}$, and $\\\\mathbf{q}_{t}$ ). We further refer to the distribution of any entities $f$ as $p(f)$. Through the paper, we use simple MLPs with $\\\\mathcal{L}_{\\\\mathcal{M}} \\\\geq 1$ layers and residual connection as the architecture of the memory module $\\\\mathcal{M}(\\\\cdot)$. When it is needed, we parameterized the memory module with $\\\\boldsymbol{\\\\theta}_{\\\\mathcal{M}} \\\\supseteq\\\\left\\\\{W_{1}, W_{2}, \\\\ldots, W_{\\\\mathcal{L}_{\\\\mathcal{M}}}\\\\right\\\\}$, which at least includes the parameters of linear layers in the MLP. We use superscript with parenthesis to refer to parameters in different levels of nested learning (different update frequency): i.e., $W^{(\\\\ell)}$.\\n\\n## 2 Nested Learning\\n\\nThis section discusses the motivations, formal definitions, and general high-level implications of Nested Learning (NL). We start with a formulation of associative memory and then by using step-by-step examples, we build the intuition behind architecture decomposition and its connection to modeling a neural network as an integrated system of optimization problems. We aim to first show how existing methods and concepts in deep learning fall under the NL paradigm and then we present new formulations that go beyond traditional methods and/or provide insights on how to improve existing algorithms and designs.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 3,\r\n      \"markdown\": \"![img-1.jpeg](img-1.jpeg)\\n\\nFigure 2: Nested Learning Paradigm that represent a machine learning model and its training procedure as a set of nested optimization problems. (Left) An example of Hybrid architecture. While deep learning perspective, as the flattened image of NL, does not provide insight about the depth of computation in the blocks, NL transparently represent all the inner gradient flows. (Right) A Neural Learning Module: A computational model that learns how to compress its own context flow. For example, the first level corresponds to the model's the most outer-loop training, often refer to as \\\"pre-training\\\" step.\\n\\n# 2.1 Associative Memory \\n\\nAssociative memory-the ability to form and retrieve connections between events-is a fundamental mental process and is an inseparable component of human learning [56]. Often in the literature, the concept of memorization and learning are used interchangeably; in neuropsychology literature, however, these two are clearly distinguished. More specifically, following neuropsychology literature [57], we build our terminology based on the following definition of memory and learning:\\n\\n## Learning vs. Memorization:\\n\\nMemory is a neural update caused by an input, and learning is the process for acquiring effective and useful memory.\\n\\nIn this work, our goal is to first show that all the elements of a computational sequence model, including optimizers and neural networks, are associative memory systems that compress their own context flow. Broadly speaking, associative memory is an operator that maps a set of keys to a set of values. We follow the general definition of associative memory by Behrouz et al. [58]:\\nDefinition 1 (Associative Memory). Given a set of keys $\\\\mathcal{K} \\\\subseteq \\\\mathbb{R}^{d_{k}}$ and values $\\\\mathcal{V} \\\\subseteq \\\\mathbb{R}^{d_{v}}$, associative memory is an operator $\\\\mathcal{M}: \\\\mathcal{K} \\\\rightarrow \\\\mathcal{V}$ that maps two sets of keys $\\\\mathcal{K}$ and values $\\\\mathcal{V}$. To learn such mapping from the data, an objective $\\\\hat{\\\\mathcal{L}}(\\\\cdot ; \\\\cdot)$ measures the quality of the mapping and $\\\\mathcal{M}$ can be defined as:\\n\\n$$\\n\\\\mathcal{M}^{*}=\\\\arg \\\\min _{\\\\mathcal{M}} \\\\quad \\\\hat{\\\\mathcal{L}}(\\\\mathcal{M}(\\\\mathcal{K}) ; \\\\mathcal{V})\\n$$\\n\\nWhile the operator itself is a memory and the mapping acts as a memorization process (i.e., memorizing the connections of events in the context), acquiring such effective operator based on the data, is a learning process. It is notable that, here, keys and values can be any arbitrary events that memory aims to map them and are not limited to tokens. Later in this section, we will discuss that given a context flow, keys and values might be tokens, gradients, sub-sequences, etc. Furthermore, while the term of associative memory is more common in neuroscience and neuropsychology literature, the above formulation is also closely related to data compression and low-dimensional representation. That is, one can interpret the optimization process in Equation 1 as the training process of a network $\\\\mathcal{M}(.)$ that aims to compress the mappings into its parameters and so represent them in a lower dimensional space.\\n\\nIn sequence modeling, where keys and values are input tokens (e.g., tokenized text), the choice of objective and the optimization process for solving Equation 1 can result in distinct sequence\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-1.jpeg\",\r\n          \"top_left_x\": 294,\r\n          \"top_left_y\": 191,\r\n          \"bottom_right_x\": 1409,\r\n          \"bottom_right_y\": 613,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 4,\r\n      \"markdown\": \"modeling architectures (see [59] and [58]) such as global/local softmax attention [27], or other modern recurrent models [28, 60, 61]. This simple formulation of sequence models provides us with better understanding of their internal process and also a tool to simply compare their modeling power based on their objective and optimization process. In the following, using step-by-step examples, we discuss how this formulation can be applied to all components of a neural architecture (including its optimization process in pre-training) and in fact, how a model is an integrated system of multi-level, nested, and or parallel memories, each of which with its own context flow.\\n\\nA Simple Example of MLP Training. We start with a simple example, in which we aim to train a 1-layer MLP (parameterized with $W$ ) for task $\\\\mathcal{T}$ and on dataset $\\\\mathcal{D}_{\\\\text {train }}=\\\\left\\\\{x_{1}, \\\\ldots, x_{\\\\left|\\\\mathcal{D}_{\\\\text {train }}\\\\right|}\\\\right\\\\}$ by optimizing the objective $\\\\mathcal{L}(\\\\cdot ; \\\\cdot)$ with gradient descent. In this case, the training process is equivalent to the following optimization problem:\\n\\n$$\\nW^{*}=\\\\arg \\\\min _{W} \\\\mathcal{L}\\\\left(W ; \\\\mathcal{D}_{\\\\text {train }}\\\\right)\\n$$\\n\\nwhose optimization by gradient descent results in a weight update rule equivalent to:\\n\\n$$\\n\\\\begin{aligned}\\nW_{t+1} & =W_{t}-\\\\eta_{t+1} \\\\nabla_{W_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right) \\\\\\\\\\n& =W_{t}-\\\\eta_{t+1} \\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right) \\\\otimes x_{t+1}, \\\\quad \\\\text { where } x_{t+1} \\\\sim \\\\mathcal{D}_{\\\\text {train }}\\n\\\\end{aligned}\\n$$\\n\\nwhere $y_{t+1}=W x_{t+1}$ is the output of the model for input $x_{t+1}$. Given this formulation, one can let $u_{t+1}=\\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)$ and reformulate the backpropagation process as the solution to an optimization problem on finding an optimal associative memory that maps input data points $\\\\mathcal{D}_{\\\\text {train }}=\\\\left\\\\{x_{t}\\\\right\\\\}_{t=1}^{\\\\left|\\\\mathcal{D}_{\\\\text {train }}\\\\right|}$ to their corresponding $u_{t+1}=\\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)$. That is, we let $\\\\mathcal{M}(\\\\cdot)=W_{t}$. parametrizes the memory, and use dot-product similarity to measure the quality of $W_{t}$ 's mapping between $x_{t+1}$ and $\\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)$ :\\n\\n$$\\n\\\\begin{aligned}\\nW_{t+1} & =\\\\arg \\\\min _{W}\\\\left\\\\langle W x_{t+1}, u_{t+1}\\\\right\\\\rangle+\\\\frac{1}{2 \\\\eta_{t+1}}\\\\left\\\\|W-W_{t}\\\\right\\\\|_{2}^{2} \\\\\\\\\\n& =\\\\arg \\\\min _{W}\\\\left\\\\langle W x_{t}, \\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)\\\\right\\\\rangle+\\\\frac{1}{2 \\\\eta_{t+1}}\\\\left\\\\|W-W_{t}\\\\right\\\\|_{2}^{2}\\n\\\\end{aligned}\\n$$\\n\\nIn the above formulation, $u_{t+1}=\\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)$ can be interpreted as a local surprise signal in representation space that quantifies the mismatch between the current output and the structure the objective $\\\\mathcal{L}(\\\\cdot ; \\\\cdot)$ enforces. Therefore, this formulation translates the training phase of the model as a process of acquiring effective memory that maps data samples to their Local Surprise Signal (LSS) in representation space-defined as the mismatch between the current output and the structure enforced by the objective $\\\\mathcal{L}(\\\\cdot ; \\\\cdot)$. Accordingly, in this example, our model has a single gradient flow over the data samples, which is only active over dataset $\\\\mathcal{D}_{\\\\text {train }}=\\\\left\\\\{x_{1}, \\\\ldots, x_{\\\\left|\\\\mathcal{D}_{\\\\text {train }}\\\\right|}\\\\right\\\\}$ and will be frozen for any other data samples afterwards (a.k.a inference or test time).\\n\\nNext, in the above example, we replace the gradient descent algorithm with its enhanced momentumbased variant, resulting in an update rule of:\\n\\n$$\\n\\\\begin{aligned}\\n& W_{t+1}=W_{t}-\\\\mathbf{m}_{t+1} \\\\\\\\\\n& \\\\mathbf{m}_{t+1}=\\\\mathbf{m}_{t}-\\\\eta_{t+1} \\\\nabla_{W_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)=\\\\mathbf{m}_{t}-\\\\eta_{t+1} \\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right) \\\\otimes x_{t+1}\\n\\\\end{aligned}\\n$$\\n\\nIn Equation 8, given the previous state of Equation 7 (at time $t$ ), the value of $\\\\nabla_{W_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)$ or similarly $\\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)$ are independent of recurrence in Equation 8 and so can be pre-computed beforehand. To this end, we let $u_{t+1}=\\\\nabla_{W_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)$, and so Equation 8 can be reformulated as:\\n\\n$$\\n\\\\begin{aligned}\\nW_{t+1} & =W_{t}-\\\\mathbf{m}_{t+1} \\\\\\\\\\n\\\\mathbf{m}_{t+1} & =\\\\arg \\\\min _{\\\\mathbf{m}}-\\\\left\\\\langle\\\\mathbf{m}, \\\\nabla_{W_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)\\\\right\\\\rangle+\\\\eta_{t+1}\\\\left\\\\|\\\\mathbf{m}-\\\\mathbf{m}_{t}\\\\right\\\\|_{2}^{2} \\\\\\\\\\n& =\\\\arg \\\\min _{\\\\mathbf{m}}-\\\\left\\\\langle\\\\mathbf{m} x_{t+1}, \\\\nabla_{y_{t+1}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t+1}\\\\right)\\\\right\\\\rangle+\\\\eta_{t+1}\\\\left\\\\|\\\\mathbf{m}-\\\\mathbf{m}_{t}\\\\right\\\\|_{2}^{2}\\n\\\\end{aligned}\\n$$\\n\\nwhere the optimization problem in Equation 10 is equivalent to on step of gradient descent with adaptive learning rate of $\\\\eta_{t+1}$. Given these formulation, one can interpret the momentum term as either: (1) a key-less associative memory that compress the gradients into its parameters, or (2) an associative memory that learns how to map data points to their corresponding LSS-value. Interestingly, this formulation reveals that gradient descent with momentum is indeed a two-level\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 5,\r\n      \"markdown\": \"optimization process, where the memory is optimized by simple gradient descent algorithm. This process is closely related to Fast Weight Programs (FWPs) [62], where the weight update process (i.e., Equation 9) is the slow network that its momentum weight is generated by a fast network (i.e., Equation 10).\\nConcluding the above examples, we observed that the training process of a 1-layer MLP with: (1) Gradient descent is a 1-level associative memory that learns how to map data points to their corresponding LSS-value; and (2) Gradient descent with momentum is a 2-level associative memory (or optimization process) that the inner-level learns to store gradient values into its parameters, and then the outer-level updates the slow weight (i.e., $W_{t}$ ) with the value of the inner-level memory. While these are the most simple examples with respect to both architecture and optimizer algorithms, one might ask if similar conclusion can be made in more complex setups.\\n\\nAn Example of Architectural Decomposition. In the next example, we replace the MLP module with a linear attention [60]. That is, we aim to train a 1-layer linear attention for task $\\\\mathcal{T}$ and on a sequence of $\\\\mathcal{D}_{\\\\text {train }}=\\\\left\\\\{x_{1}, \\\\ldots, x_{\\\\left|\\\\mathcal{D}_{\\\\text {train }}\\\\right|}\\\\right\\\\}$ by optimizing the objective $\\\\mathcal{L}$ with gradient descent. Recalling the unnormalized linear attention formulation:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\mathbf{k}_{t}=x_{t} W_{\\\\mathbf{k}}, \\\\quad \\\\mathbf{v}_{t}=x_{t} W_{\\\\mathbf{v}}, \\\\quad \\\\mathbf{q}_{t}=x_{t} W_{\\\\mathbf{q}} \\\\\\\\\\n& \\\\mathcal{M}_{t}=\\\\mathcal{M}_{t-1}+\\\\mathbf{v}_{t} \\\\mathbf{k}_{t}^{\\\\top} \\\\\\\\\\n& y_{t}=\\\\mathcal{M}_{t} \\\\mathbf{q}_{t}\\n\\\\end{aligned}\\n$$\\n\\nAs discussed in earlier studies [58, 59], the recurrence in Equation 13 can be reformulated as the optimization process of a matrix-valued associative memory $\\\\mathcal{M}_{t}(\\\\cdot)$, in which, it aims to compress the mappings of keys and values into its parameters. In more details, in Definition 1, if we let $\\\\tilde{\\\\mathcal{L}}\\\\left(\\\\mathcal{M}_{t-1} ; \\\\mathbf{k}_{t}, \\\\mathbf{v}_{t}\\\\right):=-\\\\left\\\\langle\\\\mathcal{M}_{t-1} \\\\mathbf{k}_{t}, \\\\mathbf{v}_{t}\\\\right\\\\rangle$ and aim to optimize the memory with gradient descent, the memory update rule is: (Note that $\\\\nabla \\\\tilde{\\\\mathcal{L}}\\\\left(\\\\mathcal{M}_{t-1} ; \\\\mathbf{k}_{t}, \\\\mathbf{v}_{t}\\\\right)=\\\\mathbf{v}_{t} \\\\mathbf{k}_{t}^{\\\\top}$ and we let learning rate $\\\\eta_{t}=1$ )\\n\\n$$\\n\\\\begin{aligned}\\n\\\\mathcal{M}_{t+1} & =\\\\arg \\\\min _{\\\\mathcal{M}}\\\\left\\\\langle\\\\mathcal{M} \\\\mathbf{k}_{t+1}, \\\\mathbf{v}_{t+1}\\\\right\\\\rangle+\\\\left\\\\|\\\\mathcal{M}-\\\\mathcal{M}_{t}\\\\right\\\\|_{2}^{2} \\\\quad \\\\text { with gradient descent } \\\\\\\\\\n\\\\Rightarrow \\\\mathcal{M}_{t+1} & =\\\\mathcal{M}_{t}-\\\\nabla \\\\tilde{\\\\mathcal{L}}\\\\left(\\\\mathcal{M}_{t} ; \\\\mathbf{k}_{t+1}, \\\\mathbf{v}_{t+1}\\\\right)=\\\\mathcal{M}_{t}+\\\\mathbf{v}_{t+1} \\\\mathbf{k}_{t+1}^{\\\\top}\\n\\\\end{aligned}\\n$$\\n\\nwhich is equivalent to the update rule of an unnormalized linear attention in Equation 13. Also, note that as we observed in the first example, training a linear layer with gradient descent is a 1-layer optimization problem of an associative memory (see Equation 3) and so the general training/updating process of projection layers (i.e., $W_{\\\\mathbf{k}}, W_{\\\\mathbf{v}}$, and $W_{\\\\mathbf{q}}$ ) is itself an optimization process of associative memory. Accordingly, this setup, i.e., training a linear attention with gradient descent, can be seen as a two-level optimization process, where the outer-loop (also known as training process) optimizes the projection layers with gradient descent, while the inner-loop optimizes the inner memory of $\\\\mathcal{M}_{t}$ with gradient descent.\\n\\nNote that, as discussed above, here, we have two associative memories, and so each of which has their own optimization process and gradient flow. That is, in the optimization of outer-level parameters of $W_{\\\\mathbf{k}}, W_{\\\\mathbf{v}}$, and $W_{\\\\mathbf{q}}$ there is no gradient with respect to parameter $\\\\mathcal{M}(\\\\cdot)$ and so there is no backpropagation through it. Similarly, in the inner-level, there is no backpropagation through projection layers and they are considered frozen. Furthermore, it is notable that in this example, the above formulation is also closely connected to FWPs perspective of linear attentions [63], where projections are considered slow weights, and memory update in Equation 13 is the fast weight update rule.\\n\\nArchitectural Decomposition with More Levels. In both above examples, we discussed simple cases, where they can be translated into 2-level optimization processes, which also coincides with their FWPs interpretations. In practice, however, we need to use more powerful optimization algorithms to train the model, and/or use more powerful recurrent update rule for memory. As a simple example, assume we use gradient descent with momentum to train a linear attention model. In the above examples, we show that how the linear attention component can be decomposed into two nested optimization problem. Similarly, here the model can be represented as a 2-level optimization problem, where (1) the inner level optimizes the memory to compress the context using gradient descent (see Equation 15), and (2) the outer level optimizes the projection layers with gradient descent with momentum. Interestingly, from the first example, we know that \\\"gradient descent with momentum\\\" algorithm itself is indeed a 2-level optimization problem where the momentum term itself is an associative memory that compress the past gradients into its parameters.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 6,\r\n      \"markdown\": \"# 2.2 Nested Optimization Problems \\n\\nIn the previous section, we provided examples to demonstrate how one can decompose a machine learning model into a set of nested or multi-level optimization problems. Next, we first aim to present a formal formulation for nested learning problems and then define Neural Learning Module-an integrated computational system that learns from data.\\n\\nAs we observed in the previous section, while we decomposed the model into a set of optimization process, it is still unclear if we can define a hierarchy (or order) over these problems, and uniquely represent the model in this format. Inspired by the hierarchy of brain waves that indicates the information processing frequency rate of each part (discussed in Section 1), we use the update rate of each optimization problem to order the components in multiple levels. To this end, we let the one update step over one data point to be the unit of time, and define the update frequency rate of each component as:\\nDefinition 2 (Update Frequency). For any component of $A$, which can be a parametric component (e.g., learnable weights or momentum term in gradient descent in momentum) or a non-parametric component (e.g., attention block), we define its frequency, denoted as $f_{A}$, as its number of updates per unit of time.\\n\\nGiven the above update frequency, we can order the components of a machine learning algorithm based on operator $(\\\\cdot \\\\succ \\\\cdot)$. We let $A$ to be faster than $B$ and denote $A \\\\succ B$ if: (1) $f_{A}>f_{B}$, or (2) $f_{A}=f_{B}$ but the computation of the $B$ 's state at time $t$ requires the computation of $A$ 's state at time $t$. In this definition, when $A \\\\nsucc B$ and $B \\\\nsucc A$, we let $A \\\\stackrel{t}{\\\\sim} B$, which indicates that $A$ and $B$ has the same frequency update, but their computation is independent of each other (Later, we provide an example of this cases in AdamW optimizer). Based on the above operator, we sort the components into an ordered set of \\\"levels\\\", where (1) components in the same level have the same frequency update, and (2) the higher the level is, the lower its frequency. Notably, based on the above definition, each component has its own optimization problem and so context. While we optimize the component's inner objective with gradient-based optimizers, the above statement is equivalent to having exclusive gradient flow for each component in the model. In general case, however, one can use non-parametric solution (as we later discuss about attention).\\n\\nNeural Learning Module. Given the above definition of nested learning problems, we define neural learning module as a new way of representation of machine learning models that shows the model as an interconnected system of components, each of which with its own gradient flow. Note that, orthogonal to deep learning, nested learning allows us to define neural learning models with more levels, resulting in more expressive architecture.\\n\\nNested learning allows computational models that are composed of multiple (multi-layer) levels to learn from and process data with different levels of abstraction and time-scales.\\n\\nNext, we study optimizers and well-known deep learning architectures from the nested learning perspective, and provide examples that how NL can help to enhance those components.\\n\\n### 2.3 Optimizers as Learning Modules\\n\\nIn this section, we start by understanding how well-known optimizers and their variants are special instances of nested learning. Recall the gradient descent method with momentum,\\n\\n$$\\n\\\\begin{aligned}\\n& W_{i+1}=W_{i}+\\\\mathbf{m}_{i+1} \\\\\\\\\\n& \\\\mathbf{m}_{i+1}=\\\\alpha_{i+1} \\\\mathbf{m}_{i}-\\\\eta_{t} \\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)\\n\\\\end{aligned}\\n$$\\n\\nwhere matrix (or vector) $\\\\mathbf{m}_{i}$ is the momentum at state $i$ and $\\\\alpha_{i}$ and $\\\\eta_{i}$ are adaptive learning and momentum rates, respectively. Assuming $\\\\alpha_{i+1}=1$, the momentum term can be viewed as the result of optimizing the following objective with gradient descent:\\n\\n$$\\n\\\\min _{\\\\mathbf{m}}\\\\left\\\\langle\\\\mathbf{m} \\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)^{\\\\top}, \\\\mathbf{I}\\\\right\\\\rangle\\n$$\\n\\nThis interpretation shows that momentum can indeed be viewed as a meta memory module that learns how to memorize gradients of the objective into its parameters. Building on this intuition, in\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 7,\r\n      \"markdown\": \"Section C. 4 we show that Adam with a small modification is the optimal associative memory for the models' gradients. Next, we show that how this perspective can result in designing more expressive optimizers:\\n\\nExtension: More Expressive Association. As discussed earlier, momentum is a value-less associative memory and so has limited expressive power. To address this issue, following the original definition of associative memory (i.e., mapping keys to values), we let value parameter $\\\\mathbf{v}_{i}=\\\\mathbf{P}_{i}$ and so the momentum aims to minimize:\\n\\n$$\\n\\\\min _{\\\\mathbf{m}}\\\\left\\\\langle\\\\mathbf{m} \\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)^{\\\\top}, \\\\mathbf{P}_{i}\\\\right\\\\rangle\\n$$\\n\\nusing gradient descent, resulting in the update rule:\\n\\n$$\\n\\\\begin{aligned}\\n& W_{i+1}=W_{i}+\\\\mathbf{m}_{i+1} \\\\\\\\\\n& \\\\mathbf{m}_{i+1}=\\\\alpha_{i+1} \\\\mathbf{m}_{i}-\\\\eta_{t} \\\\mathbf{P}_{i} \\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)\\n\\\\end{aligned}\\n$$\\n\\nThis formulation is equivalent to using preconditioning the momentum GD. In fact, preconditioning means that the momentum term is an associative memory that learns how to compress the mappings between $\\\\mathbf{P}_{i}$ and the gradient term $\\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)$. While any reasonable choice (e.g., random features) of preconditioning can improve the expressivity of the initial version of GD with momentum per se is a value-less memory (i.e., mapping all gradients to a single value), the above perspective gives more intuition about what preconditioning are more useful. That is, the momentum acts as a memory that aims to map gradients to their corresponding values, and so a function of gradients (e.g., information about Hessian) can provide the memory with a more meaningful mappings.\\n\\nExtension: More Expressive Objectives. As discussed by Behrouz et al. [58], optimizing an inner objective of dot-product similarity results in Hebbian-like update rule, which can cause the memory to be less effective. A natural extension of this internal objective is to use $\\\\ell_{2}(\\\\cdot)$ regression loss (for measuring the corresponding key-value mapping fitness) and minimize the loss function $\\\\|\\\\mathbf{m} \\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)^{\\\\top}-\\\\mathbf{P}_{i}\\\\|_{2}^{2}$, resulting in the update rule of:\\n\\n$$\\n\\\\begin{aligned}\\n& W_{i+1}=W_{i}+\\\\mathbf{m}_{i+1} \\\\\\\\\\n& \\\\mathbf{m}_{i+1}=\\\\left(\\\\alpha_{i+1} \\\\mathbf{I}-\\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)^{\\\\top} \\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)\\\\right) \\\\mathbf{m}_{i}-\\\\eta_{t} \\\\mathbf{P}_{i} \\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)\\n\\\\end{aligned}\\n$$\\n\\nThis update is based on delta-rule [64] and so it allows the memory (momentum) to better manage its limited capacity and better memorize the series of past gradients.\\n\\nExtension: More Expressive Memory. As discussed earlier, momentum can be viewed as a meta memory model that uses a linear layer (i.e., matrix-valued) to compress the past gradient values. Due to the linear nature of momentum, only linear functions of past gradients can be learned by its internal objective. To increase the learning capacity of this module, one alternative is to use alternative powerful persistent learning modules: i.e., replacing a linear matrix-valued memory for momentum with an MLP. Therefore, momentum as the a memory for the past gradients, has more capacity to capture the underlying dynamics of the gradients. To this end, we extend the formulation in Equation 17 as:\\n\\n$$\\nW_{i+1}=W_{i}+\\\\mathbf{m}_{i+1}\\\\left(\\\\mathbf{u}_{i}\\\\right), \\\\quad \\\\text { and } \\\\quad \\\\mathbf{m}_{i+1}=\\\\alpha_{i+1} \\\\mathbf{m}_{i}-\\\\eta_{t} \\\\nabla \\\\mathcal{L}^{(2)}\\\\left(\\\\mathbf{m}_{i} ; \\\\mathbf{u}_{i}, \\\\mathbf{I}\\\\right)\\n$$\\n\\nwhere $\\\\mathbf{u}_{i}=\\\\nabla \\\\mathcal{L}\\\\left(W_{i} ; x_{i}\\\\right)$ and $\\\\nabla \\\\mathcal{L}^{(2)}(\\\\cdot)$ is the internal objective of momentum (e.g., dot product similarity $\\\\left\\\\langle\\\\mathbf{m}\\\\left(\\\\mathbf{u}_{i}^{\\\\top}\\\\right), \\\\mathbf{1}\\\\right\\\\rangle$ ). We refer to this variant as Deep Momentum Gradient Descent (DMGD).\\n\\nExtension: None Linear Outputs. Building upon the above perspective, in which we see the momentum as a neural architecture, one common technique to enhance the representation power of momentum memory module is to use non-linearity on top of its output [28, 65]. That is, we re-formulate Equation 23 as:\\n\\n$$\\nW_{i+1}=W_{i}+\\\\sigma\\\\left(\\\\mathbf{m}_{i+1}\\\\left(\\\\mathbf{u}_{i}\\\\right)\\\\right), \\\\quad \\\\text { and } \\\\quad \\\\mathbf{m}_{i+1}=\\\\alpha_{i+1} \\\\mathbf{m}_{i}-\\\\eta_{t} \\\\nabla \\\\mathcal{L}^{(2)}\\\\left(\\\\mathbf{m}_{i} ; \\\\mathbf{u}_{i}, \\\\mathbf{I}\\\\right)\\n$$\\n\\nwhere $\\\\sigma(\\\\cdot)$ is an arbitrary non-linearity. As an example, we let $\\\\sigma(\\\\cdot)=$ Newton-Schulz $(\\\\cdot)$, where Newton-Schulz $(\\\\cdot)$ is the iterative Newton-Schulz method [66], and $\\\\mathbf{m}(\\\\cdot)$ be a linear layer; the resulted optimizer is equivalent to Muon optimizer [34].\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 8,\r\n      \"markdown\": \"Going Beyond Simple Backpropagation. As discussed earlier in Section 2.1, the pre-training process and backpropagation is a form of associative memory, where input data is mapped to the surprised caused by its predicted output $\\\\nabla_{y_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t}\\\\right)$ :\\n\\n$$\\nW_{t+1}=W_{t}-\\\\eta_{t+1} \\\\nabla_{W_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t}\\\\right)=W_{t}-\\\\eta_{t+1} \\\\nabla_{y_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t}\\\\right) \\\\otimes x_{t}, \\\\quad \\\\text { where } x_{t} \\\\sim \\\\mathcal{D}_{\\\\text {train }}\\n$$\\n\\nwhich from the associative memory perspective is equivalent to one step of gradient descent in optimization process of:\\n\\n$$\\n\\\\min _{W}\\\\left\\\\langle W x_{t}, \\\\nabla_{y_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t}\\\\right)\\\\right\\\\rangle\\n$$\\n\\nAs we discussed in Appendix C, the above formulation cause ignoring the dependencies of data samples like $x_{t}$. To extend it to a more powerful formulation where it also consider the dependencies of data points (which is extremely important when we use optimizer in the token space as they are not independent), we use $L_{2}$ regression objective with one step of gradient descent as follows:\\n\\n$$\\n\\\\min _{W}\\\\left\\\\|W x_{t}-\\\\nabla_{y_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t}\\\\right)\\\\right\\\\|_{2}^{2}\\n$$\\n\\nThis formulation results in a new variant of gradient descent, which can be simplified as follows:\\n\\n$$\\n\\\\begin{aligned}\\nW_{t+1} & =W_{t}\\\\left(\\\\mathbf{I}-x_{t} x_{t}^{\\\\top}\\\\right)-\\\\eta_{t+1} \\\\nabla_{W_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t}\\\\right) \\\\\\\\\\n& =W_{t}\\\\left(\\\\mathbf{I}-x_{t} x_{t}^{\\\\top}\\\\right)-\\\\eta_{t+1} \\\\nabla_{y_{t}} \\\\mathcal{L}\\\\left(W_{t} ; x_{t}\\\\right) \\\\otimes x_{t}, \\\\quad \\\\text { where } x_{t} \\\\sim \\\\mathcal{D}_{\\\\text {train }}\\n\\\\end{aligned}\\n$$\\n\\nLater, we use this optimizer as the internal optimizer of our HOPE architecture.\\n\\n# 3 HOPE: A Self-Referential Learning Module with Continuum Memory \\n\\nExisting architectural backbones consist of (1) a working memory module (e.g., attention), which is responsible to actively fuse the information across sequence length, and (2) a feed-forward layer (e.g., MLP) that fuse information across features and acts as the persistent memory or knowledge storage of pre-training phase. From the NL perspective, pre-training is the phase that the most outer level of the learning module is updated over its limited context flow. Accordingly, in the continual setup, such pre-training phase is also rarely updated over time, and so its corresponding knowledge storage needs to rarely be updated over time. Given this intuition, we extend the traditional view-point of long-term/short-term memory system and suggest a knowledge storage feed-forward for each level (frequency domain).\\n\\nGiven the definition of frequency, Continuum Memory System (CMS) is formalized as a chain of MLP blocks $\\\\operatorname{MLP}^{\\\\left(f_{1}\\\\right)}(\\\\cdot), \\\\ldots, \\\\operatorname{MLP}^{\\\\left(f_{k}\\\\right)}(\\\\cdot)$, each of which associated with a chunk size of $C^{(\\\\ell)}:=\\\\frac{\\\\max _{f} C^{(\\\\ell)}}{f_{1}}$ such that given input $x=\\\\left\\\\{x_{1}, \\\\ldots, x_{T}\\\\right\\\\}$ the output of the chain is calculated as (we disregard normalizations for the sake of clarity):\\n\\n$$\\ny_{t}=\\\\operatorname{MLP}^{\\\\left(f_{k}\\\\right)}\\\\left(\\\\operatorname{MLP}^{\\\\left(f_{k-1}\\\\right)}\\\\left(\\\\cdots \\\\operatorname{MLP}^{\\\\left(f_{1}\\\\right)}\\\\left(x_{t}\\\\right)\\\\right)\\\\right)\\n$$\\n\\nwhere the parameters of $\\\\ell$-th MLP block, i.e., $\\\\boldsymbol{\\\\theta}^{\\\\left(f_{\\\\ell}\\\\right)}$, are updated every $C^{(\\\\ell)}$ steps:\\n\\n$$\\n\\\\boldsymbol{\\\\theta}_{i+1}^{\\\\left(f_{\\\\ell}\\\\right)}=\\\\boldsymbol{\\\\theta}_{i}^{\\\\left(f_{\\\\ell}\\\\right)}-\\\\left\\\\{\\\\begin{array}{ll}\\n\\\\sum_{t=i-C^{(\\\\ell)}}^{i} \\\\eta_{t}^{(\\\\ell)} f\\\\left(\\\\boldsymbol{\\\\theta}_{t}^{\\\\left(f_{t}\\\\right)} ; x_{t}\\\\right) & \\\\text { if } i \\\\equiv 0\\\\left(\\\\bmod C^{(\\\\ell)}\\\\right) \\\\\\\\\\n0 & \\\\text { otherwise }\\n\\\\end{array}\\\\right.\\n$$\\n\\nIn Appendix B.1, we discuss different variants of this formulation, including fully nested MLP layers. Here $\\\\eta_{t}^{(\\\\ell)}$ are learning rates corresponds to $\\\\boldsymbol{\\\\theta}^{\\\\left(f_{t}\\\\right)}$, and $f(\\\\cdot)$ is the error component of an arbitrary optimizer (e.g., $\\\\nabla \\\\mathcal{L}\\\\left(\\\\boldsymbol{\\\\theta}_{t}^{\\\\left(f_{t}\\\\right)} ; x_{t}\\\\right)$ in gradient descent). The conventional Transformer block [27] is a special instance of this formulation, where $k=1$. It is notable that Equation 31 provides an important interpretation: parameters $\\\\boldsymbol{\\\\theta}_{t}^{\\\\left(f_{t}\\\\right)}$ are responsible for compressing their own context into the their parameters and so they are a representative of abstract knowledge of their context.\\n\\nHOPE. We further present a self-referential learning module based on Titans [28] and our variant of gradient descent in Section B.1. Combining this self-referential sequence model with continuum memory system results in HOPE architecture.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 9,\r\n      \"markdown\": \"![img-2.jpeg](img-2.jpeg)\\n\\nFigure 3: A comparison of Hope architectural backbone with Transformers (Normalization and potential data-dependent components are removed for the sake of clarity).\\n\\nTable 1: Performance of HOPE and baselines on language modeling and common-sense reasoning tasks. Hybrid models are marked with *.\\n\\n|  Model | Wiki. <br> ppt $\\\\downarrow$ | LMB. <br> ppt $\\\\downarrow$ | LMB. <br> acc $\\\\uparrow$ | PIQA <br> acc $\\\\uparrow$ | Hella. <br> acc_n $\\\\uparrow$ | Wino. <br> acc $\\\\uparrow$ | ARC-e <br> acc $\\\\uparrow$ | ARC-e <br> acc_n $\\\\uparrow$ | SIQA <br> acc $\\\\uparrow$ | BoolQ <br> acc $\\\\uparrow$ | Avg. <br> $\\\\uparrow$  |\\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\\n|  HOPE (ours) | 26.05 | 29.38 | 35.40 | 64.62 | 40.11 | 51.19 | 56.92 | 28.49 | 38.33 | 60.12 | 46.90  |\\n|  760M params / 30B tokens |  |  |  |  |  |  |  |  |  |  |   |\\n|  Transformer++ | 25.21 | 27.64 | 35.78 | 66.92 | 42.19 | 51.95 | 60.38 | 32.46 | 39.51 | 60.37 | 48.69  |\\n|  RetNet | 26.08 | 24.45 | 34.51 | 67.19 | 41.63 | 52.09 | 63.17 | 32.78 | 38.36 | 57.92 | 48.46  |\\n|  DeltaNet | 24.37 | 24.60 | 37.06 | 66.93 | 41.98 | 50.65 | 64.87 | 31.39 | 39.88 | 59.02 | 48.97  |\\n|  TTT | 24.17 | 23.51 | 34.74 | 67.25 | 43.92 | 50.99 | 64.53 | 33.81 | 40.16 | 59.58 | 47.32  |\\n|  Samba* | 20.63 | 22.71 | 39.72 | 69.19 | 47.35 | 52.01 | 66.92 | 33.20 | 38.98 | 61.24 | 51.08  |\\n|  Titans (LMM) | 20.04 | 21.96 | 37.40 | 69.28 | 48.46 | 52.27 | 66.31 | 35.84 | 40.13 | 62.76 | 51.56  |\\n|  HOPE (ours) | 20.53 | 20.47 | 39.02 | 70.13 | 49.21 | 52.70 | 66.89 | 36.05 | 40.71 | 63.29 | 52.26  |\\n|  1.3B params / 100B tokens |  |  |  |  |  |  |  |  |  |  |   |\\n|  Transformer++ | 18.53 | 18.32 | 42.60 | 70.02 | 50.23 | 53.51 | 68.83 | 35.10 | 40.66 | 57.09 | 52.25  |\\n|  RetNet | 19.08 | 17.27 | 40.52 | 70.07 | 49.16 | 54.14 | 67.34 | 33.78 | 40.78 | 60.39 | 52.02  |\\n|  DeltaNet | 17.71 | 16.88 | 42.46 | 70.72 | 50.93 | 53.35 | 68.47 | 35.66 | 40.22 | 55.29 | 52.14  |\\n|  Samba* | 16.13 | 13.29 | 44.94 | 70.94 | 53.42 | 55.56 | 68.81 | 36.17 | 39.96 | 62.11 | 54.00  |\\n|  Titans (LMM) | 15.60 | 11.41 | 49.14 | 73.09 | 56.31 | 59.81 | 72.43 | 40.82 | 42.05 | 60.97 | 56.82  |\\n|  HOPE (ours) | 15.11 | 11.63 | 50.01 | 73.29 | 56.84 | 60.19 | 72.30 | 41.24 | 42.52 | 61.46 | 57.23  |\\n\\n# 4 Experiments\\n\\nFor the sake of space, in the main paper, we report the results of the HOPE's evaluation on language modeling, and common-sense reasoning, tasks. However, we report an extensive set of results, including on experiments on optimizers, emergence of in-context learning, continual learning abilities of HOPE, ablation studies, long-context tasks, etc. in the appendix. Details about the experimental setups and other used datasets are in Appendix G\\n\\nLanguage Modeling and Common-sense Reasoning. We follow recent sequence modeling studies $[28,67,68]$ and report the results of HOPE and baselines with size of $340 \\\\mathrm{M}, 760 \\\\mathrm{M}$, and 1.3 B on language modeling and also commonsense reasoning downstream tasks. These results are reported in Table 1. HOPE demonstrate a very good perfomance across all the scales and benchmark tasks, outperforming both Transformers and recent modern recurrent neural networks, including Gated DeltaNet and Titans. Comparing HOPE to Titans and Gated DeltaNet, we can see that dynamically changing the key, value, and query projections based on the context as well a deep memory module can result in a model with lower perplexity and higher accuracy in benchmark results.\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-2.jpeg\",\r\n          \"top_left_x\": 294,\r\n          \"top_left_y\": 191,\r\n          \"bottom_right_x\": 1405,\r\n          \"bottom_right_y\": 613,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 10,\r\n      \"markdown\": \"# References \\n\\n[1] Ali Behrouz, Meisam Razaviyayn, Peilin Zhong, and Vahab Mirrokni. Nested learning: The illusion of deep learning architectures. arXiv preprint arXiv.\\n[2] Walter Pitts. The linear theory of neuron networks: The dynamic problem. The bulletin of mathematical biophysics, 5:23-31, 1943.\\n[3] Warren S McCulloch. The brain computing machine. Electrical Engineering, 68(6):492-497, 1949 .\\n[4] Warren S McCulloch and Walter Pitts. The statistical organization of nervous activity. Biometrics, 4(2):91-99, 1948.\\n[5] Arthur L Samuel. Some studies in machine learning using the game of checkers. IBM Journal of research and development, 3(3):210-229, 1959.\\n[6] David Silver and Richard S Sutton. Welcome to the era of experience. Google AI, 1, 2025.\\n[7] Richard S Sutton, Andrew G Barto, et al. Reinforcement learning: An introduction, volume 1. 1998.\\n[8] Jonathan H. Connell and Sridhar Mahadevan. Robot learning. Robotica, 17(2):229-235, 1999. doi: 10.1017/S0263574799271172.\\n[9] Yann LeCun, Yoshua Bengio, and Geoffrey Hinton. Deep learning. nature, 521(7553):436-444, 2015.\\n[10] John Jumper, Richard Evans, Alexander Pritzel, Tim Green, Michael Figurnov, Olaf Ronneberger, Kathryn Tunyasuvunakool, Russ Bates, Augustin Žídek, Anna Potapenko, et al. Highly accurate protein structure prediction with alphafold. nature, 596(7873):583-589, 2021.\\n[11] David Silver, Aja Huang, Chris J Maddison, Arthur Guez, Laurent Sifre, George Van Den Driessche, Julian Schrittwieser, Ioannis Antonoglou, Veda Panneershelvam, Marc Lanctot, et al. Mastering the game of go with deep neural networks and tree search. nature, 529(7587):484-489, 2016.\\n[12] David Silver, Thomas Hubert, Julian Schrittwieser, Ioannis Antonoglou, Matthew Lai, Arthur Guez, Marc Lanctot, Laurent Sifre, Dharshan Kumaran, Thore Graepel, et al. A general reinforcement learning algorithm that masters chess, shogi, and go through self-play. Science, 362(6419):1140-1144, 2018.\\n[13] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. Advances in neural information processing systems, 25, 2012.\\n[14] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=YicbFdNTTy.\\n[15] Gheorghe Comanici, Eric Bieber, Mike Schaekermann, Ice Pasupat, Noveen Sachdeva, Inderjit Dhillon, Marcel Blistein, Ori Ram, Dan Zhang, Evan Rosen, et al. Gemini 2.5: Pushing the frontier with advanced reasoning, multimodality, long context, and next generation agentic capabilities. arXiv preprint arXiv:2507.06261, 2025.\\n[16] Aixin Liu, Bei Feng, Bing Xue, Bingxuan Wang, Bochao Wu, Chengda Lu, Chenggang Zhao, Chengqi Deng, Chenyu Zhang, Chong Ruan, et al. Deepseek-v3 technical report. arXiv preprint arXiv:2412.19437, 2024.\\n[17] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 11,\r\n      \"markdown\": \"[18] Guido Montúfar, Razvan Pascanu, Kyunghyun Cho, and Yoshua Bengio. On the number of linear regions of deep neural networks. Advances in neural information processing systems, 27, 2014.\\n[19] Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, and Surya Ganguli. Exponential expressivity in deep neural networks through transient chaos. Advances in neural information processing systems, 29, 2016.\\n[20] Joel Hestness, Sharan Narang, Newsha Ardalani, Gregory Diamos, Heewoo Jun, Hassan Kianinejad, Md Mostofa Ali Patwary, Yang Yang, and Yanqi Zhou. Deep learning scaling is predictable, empirically. arXiv preprint arXiv:1712.00409, 2017.\\n[21] William Merrill, Ashish Sabharwal, and Noah A Smith. Saturated transformers are constantdepth threshold circuits. Transactions of the Association for Computational Linguistics, 10: 843-856, 2022.\\n[22] Clayton Sanford, Daniel Hsu, and Matus Telgarsky. Transformers, parallel computation, and logarithmic depth. In Forty-first International Conference on Machine Learning, 2024. URL https://openreview.net/forum?id=QCZabhKQhB.\\n[23] William Merrill, Jackson Petty, and Ashish Sabharwal. The illusion of state in state-space models. In Forty-first International Conference on Machine Learning, 2024. URL https : //openreview.net/forum?id=QZgo9JZpLq.\\n[24] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020.\\n[25] Juergen Schmidhuber and Sepp Hochreiter. Long short-term memory. Neural Computation MIT-Press, 1997.\\n[26] Kunihiko Fukushima. Neocognitron: A self-organizing neural network model for a mechanism of pattern recognition unaffected by shift in position. Biological cybernetics, 36(4):193-202, 1980.\\n[27] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ł ukasz Kaiser, and Illia Polosukhin. Attention is all you need. In I. Guyon, U. Von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. URL https://proceedings.neurips.cc/paper_files/paper/2017/file/ 3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.\\n[28] Ali Behrouz, Peilin Zhong, and Vahab Mirrokni. Titans: Learning to memorize at test time. arXiv preprint arXiv:2501.00663, 2024.\\n[29] David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. Learning representations by back-propagating errors. nature, 323(6088):533-536, 1986.\\n[30] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial networks. Communications of the ACM, 63(11):139-144, 2020.\\n[31] Shaden Alshammari, John Hershey, Axel Feldmann, William T Freeman, and Mark Hamilton. I-con: A unifying framework for representation learning. arXiv preprint arXiv:2504.16929, 2025.\\n[32] R Devon Hjelm, Alex Fedorov, Samuel Lavoie-Marchildon, Karan Grewal, Phil Bachman, Adam Trischler, and Yoshua Bengio. Learning deep representations by mutual information estimation and maximization. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=Bklr3j0cKX.\\n[33] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 12,\r\n      \"markdown\": \"[34] K Jordan, Y Jin, V Boza, Y Jiacheng, F Cecista, L Newhouse, and J Bernstein. Muon: An optimizer for hidden layers in neural networks, 2024b. URL https://kellerjordan. github. io/posts/muon, 2024.\\n[35] Vineet Gupta, Tomer Koren, and Yoram Singer. Shampoo: Preconditioned stochastic tensor optimization. In International Conference on Machine Learning, pages 1842-1850. PMLR, 2018.\\n[36] Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, and Sham M. Kakade. SOAP: Improving and stabilizing shampoo using adam for language modeling. In The Thirteenth International Conference on Learning Representations, 2025. URL https://openreview.net/forum?id=IDxZhXrpNf.\\n[37] Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas, Lisa Anne Hendricks, Johannes Welbl, Aidan Clark, et al. Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.\\n[38] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877-1901, 2020.\\n[39] Rylan Schaeffer, Brando Miranda, and Sanmi Koyejo. Are emergent abilities of large language models a mirage? Advances in neural information processing systems, 36:55565-55581, 2023.\\n[40] Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, and Caiming Xiong. Codegen: An open large language model for code with multi-turn program synthesis. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=iaYcJKpY2B.\\n[41] Wenhai Wang, Zhe Chen, Xiaokang Chen, Jiannan Wu, Xizhou Zhu, Gang Zeng, Ping Luo, Tong Lu, Jie Zhou, Yu Qiao, et al. Visionllm: Large language model is also an open-ended decoder for vision-centric tasks. Advances in Neural Information Processing Systems, 36: $61501-61513,2023$.\\n[42] Sabri Eyuboglu, Ryan Ehrlich, Simran Arora, Neel Guha, Dylan Zinsley, Emily Liu, Will Tennien, Atri Rudra, James Zou, Azalia Mirhoseini, et al. Cartridges: Lightweight and generalpurpose long context representations via self-study. arXiv preprint arXiv:2506.06266, 2025.\\n[43] hongzhou yu, Tianhao Cheng, Yingwen Wang, Wen He, Qing Wang, Ying Cheng, Yuejie Zhang, Rui Feng, and Xiaobo Zhang. FinemedLM-o1: Enhancing medical knowledge reasoning ability of LLM from supervised fine-tuning to test-time training. In Second Conference on Language Modeling, 2025. URL https://openreview.net/forum?id=7ZwuGZCopw.\\n[44] Ekin Akyürek, Mehul Damani, Adam Zweiger, Linlu Qiu, Han Guo, Jyothish Pari, Yoon Kim, and Jacob Andreas. The surprising effectiveness of test-time training for few-shot learning. In Forty-second International Conference on Machine Learning, 2024.\\n[45] William Beecher Scoville and Brenda Milner. Loss of recent memory after bilateral hippocampal lesions. Journal of neurology, neurosurgery, and psychiatry, 20(1):11, 1957.\\n[46] Alvaro Pascual-Leone, Amir Amedi, Felipe Fregni, and Lotfi B Merabet. The plastic human brain cortex. Annu. Rev. Neurosci., 28(1):377-401, 2005.\\n[47] Michael V Johnston. Plasticity in the developing brain: implications for rehabilitation. Developmental disabilities research reviews, 15(2):94-101, 2009.\\n[48] Akihiro Goto, Ayaka Bota, Ken Miya, Jingbo Wang, Suzune Tsukamoto, Xinzhi Jiang, Daichi Hirai, Masanori Murayama, Tomoki Matsuda, Thomas J. McHugh, Takeharu Nagai, and Yasunori Hayashi. Stepwise synaptic plasticity events drive the early phase of memory consolidation. Science, 374(6569):857-863, 2021. doi: 10.1126/science.abj9195. URL https://www.science.org/doi/abs/10.1126/science.abj9195.\\n[49] Uwe Frey and Richard GM Morris. Synaptic tagging and long-term potentiation. Nature, 385 (6616):533-536, 1997.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 13,\r\n      \"markdown\": \"[50] Wannan Yang, Chen Sun, Roman Huszár, Thomas Hainmueller, Kirill Kiselev, and György Buzsáki. Selection of experience for memory by hippocampal sharp wave ripples. Science, 383 (6690):1478-1483, 2024.\\n[51] Daoyun Ji and Matthew A Wilson. Coordinated memory replay in the visual cortex and hippocampus during sleep. Nature neuroscience, 10(1):100-107, 2007.\\n[52] Adrien Peyrache, Mehdi Khamassi, Karim Benchenane, Sidney I Wiener, and Francesco P Battaglia. Replay of rule-learning related neural patterns in the prefrontal cortex during sleep. Nature neuroscience, 12(7):919-926, 2009.\\n[53] David J Foster and Matthew A Wilson. Reverse replay of behavioural sequences in hippocampal place cells during the awake state. Nature, 440(7084):680-683, 2006.\\n[54] Sean PA Drummond, Gregory G Brown, J Christian Gillin, John L Stricker, Eric C Wong, and Richard B Buxton. Altered brain response to verbal learning following sleep deprivation. Nature, 403(6770):655-657, 2000.\\n[55] Seung-Schik Yoo, Peter T Hu, Ninad Gujar, Ferenc A Jolesz, and Matthew P Walker. A deficit in the ability to form new human memories without sleep. Nature neuroscience, 10(3):385-392, 2007.\\n[56] W Scott Terry. Learning and memory: Basic principles, processes, and procedures. Routledge, 2017.\\n[57] Hideyuki Okano, Tomoo Hirano, and Evan Balaban. Learning and memory. Proceedings of the National Academy of Sciences, 97(23):12403-12404, 2000.\\n[58] Ali Behrouz, Meisam Razaviyayn, Peilin Zhong, and Vahab Mirrokni. It's all connected: A journey through test-time memorization, attentional bias, retention, and online optimization. arXiv preprint arXiv:2504.13173, 2025.\\n[59] Bo Liu, Rui Wang, Lemeng Wu, Yihao Feng, Peter Stone, and Qiang Liu. Longhorn: State space models are amortized online learners. arXiv preprint arXiv:2407.14207, 2024.\\n[60] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pages 5156-5165. PMLR, 2020.\\n[61] Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. Retentive network: A successor to transformer for large language models. arXiv preprint arXiv:2307.08621, 2023.\\n[62] Juergen Schmidhuber. Learning to control fast-weight memories: An alternative to recurrent nets. accepted for publication in. Neural Computation, 1992.\\n[63] Imanol Schlag, Kazuki Irie, and Juergen Schmidhuber. Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning, pages 9355-9366. PMLR, 2021.\\n[64] DL Prados and SC Kak. Neural network capacity using delta rule. Electronics Letters, 25(3): 197-199, 1989.\\n[65] Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, et al. Learning to (learn at test time): Rnns with expressive hidden states. arXiv preprint arXiv:2407.04620, 2024.\\n[66] Nicholas J Higham. Functions of matrices: theory and computation. SIAM, 2008.\\n[67] Songlin Yang, Jan Kautz, and Ali Hatamizadeh. Gated delta networks: Improving mamba2 with delta rule. arXiv preprint arXiv:2412.06464, 2024.\\n[68] Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim. Parallelizing linear transformers with the delta rule over sequence length. Advances in Neural Information Processing Systems, 37:115491-115522, 2024.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 14,\r\n      \"markdown\": \"[69] Matteo Tiezzi, Michele Casoni, Alessandro Betti, Tommaso Guidi, Marco Gori, and Stefano Melacci. On the resurgence of recurrent models for long sequences: Survey and research opportunities in the transformer era. arXiv preprint arXiv:2402.08132, 2024.\\n[70] Bo Peng, Eric Alcaide, Quentin Gregory Anthony, Alon Albalak, Samuel Arcadinho, Stella Biderman, Huanqi Cao, Xin Cheng, Michael Nguyen Chung, Leon Derczynski, Xingjian Du, Matteo Grella, Kranthi Kiran GV, Xuzheng He, Haowen Hou, Przemyslaw Kazienko, Jan Kocon, Jiaming Kong, Bartłomiej Koptyra, Hayden Lau, Jiaju Lin, Krishna Sri Ipsit Mantri, Ferdinand Mom, Atsushi Saito, Guangyu Song, Xiangru Tang, Johan S. Wind, Stanisław Wozniak, Zhenyuan Zhang, Qinghua Zhou, Jian Zhu, and Rui-Jie Zhu. RWKV: Reinventing RNNs for the transformer era. In The 2023 Conference on Empirical Methods in Natural Language Processing, 2023. URL https://openreview.net/forum?id=7SaXczaBpG.\\n[71] Jimmy T.H. Smith, Andrew Warrington, and Scott Linderman. Simplified state space layers for sequence modeling. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=Ai8Hw3AXqks.\\n[72] Ramin Hasani, Mathias Lechner, Tsun-Hsuan Wang, Makram Chahine, Alexander Amini, and Daniela Rus. Liquid structural state-space models. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id= g40TKRKfS7R.\\n[73] Ali Behrouz, Michele Santacatterina, and Ramin Zabih. Mambamixer: Efficient selective state space models with dual token and channel selection. arXiv preprint arXiv:2403.19888, 2024.\\n[74] Bo Peng, Daniel Goldstein, Quentin Anthony, Alon Albalak, Eric Alcaide, Stella Biderman, Eugene Cheah, Xingjian Du, Teddy Ferdinan, Haowen Hou, et al. Eagle and finch: Rwkv with matrix-valued states and dynamic recurrence. arXiv preprint arXiv:2404.05892, 2024.\\n[75] Bo Peng, Ruichong Zhang, Daniel Goldstein, Eric Alcaide, Haowen Hou, Janna Lu, William Merrill, Guangyu Song, Kaifeng Tan, Saiteja Utpala, et al. Rwkv-7\\\" goose\\\" with expressive dynamic state evolution. arXiv preprint arXiv:2503.14456, 2025.\\n[76] Julien Siems, Timur Carstensen, Arber Zela, Frank Hutter, Massimiliano Pontil, and Riccardo Grazzi. Deltaproduct: Increasing the expressivity of deltanet through products of householders. arXiv preprint arXiv:2502.10297, 2025.\\n[77] John J Hopfield. Neural networks and physical systems with emergent collective computational abilities. Proceedings of the national academy of sciences, 79(8):2554-2558, 1982.\\n[78] Juergen Schmidhuber. Reducing the ratio between learning complexity and number of time varying variables in fully recurrent nets. In ICANN'93: Proceedings of the International Conference on Artificial Neural Networks Amsterdam, The Netherlands 13-16 September 1993 3, pages 460-463. Springer, 1993.\\n[79] Donald Olding Hebb. The organization of behavior: A neuropsychological theory. Psychology press, 2005 .\\n[80] Tsendsuren Munkhdalai and Hong Yu. Neural semantic encoders. In Proceedings of the conference. Association for Computational Linguistics. Meeting, volume 1, page 397. NIH Public Access, 2017.\\n[81] Tsendsuren Munkhdalai, Alessandro Sordoni, Tong Wang, and Adam Trischler. Metalearned neural memory. Advances in Neural Information Processing Systems, 32, 2019.\\n[82] Kazuki Irie, Imanol Schlag, Robert Csordas, and Juergen Schmidhuber. Going beyond linear transformers with recurrent fast weight programmers. Advances in neural information processing systems, 34:7703-7717, 2021.\\n[83] Ke Alexander Wang, Jiaxin Shi, and Emily B Fox. Test-time regression: a unifying framework for designing sequence models with associative memory. arXiv preprint arXiv:2501.12352, 2025 .\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 15,\r\n      \"markdown\": \"[84] Kazuki Irie, Robert Csordas, and Juergen Schmidhuber. The dual form of neural networks revisited: Connecting test time predictions to training patterns via spotlights of attention. In International Conference on Machine Learning, pages 9639-9659. PMLR, 2022.\\n[85] Kazuki Irie, Imanol Schlag, Róbert Csordás, and Juergen Schmidhuber. A modern selfreferential weight matrix that learns to modify itself. In International Conference on Machine Learning, pages 9660-9677. PMLR, 2022.\\n[86] Jongho Park, Jaeseung Park, Zheyang Xiong, Nayoung Lee, Jaewoong Cho, Samet Oymak, Kangwook Lee, and Dimitris Papailiopoulos. Can mamba learn how to learn? a comparative study on in-context learning tasks. In Forty-first International Conference on Machine Learning, 2024. URL https://openreview.net/forum?id=GbFluKMmtE.\\n[87] Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. In International Conference on Learning Representations, 2017. URL https:// openreview.net/forum?id=Byj72udxe.\\n[88] Denis Paperno, German Kruszewski, Angeliki Lazaridou, Ngoc Quan Pham, Raffaella Bernardi, Sandro Pezzelle, Marco Baroni, Gemma Boleda, and Raquel Fernandez. The LAMBADA dataset: Word prediction requiring a broad discourse context. In Katrin Erk and Noah A. Smith, editors, Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1525-1534, Berlin, Germany, August 2016. Association for Computational Linguistics. doi: 10.18653/v1/P16-1144. URL https://aclanthology.org/P16-1144/.\\n[89] Yonatan Bisk, Rowan Zellers, Jianfeng Gao, Yejin Choi, et al. Piqa: Reasoning about physical commonsense in natural language. In Proceedings of the AAAI conference on artificial intelligence, volume 34, pages 7432-7439, 2020.\\n[90] Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. HellaSwag: Can a machine really finish your sentence? In Anna Korhonen, David Traum, and Lluis Marquez, editors, Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 4791-4800, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1472. URL https://aclanthology.org/P19-1472/.\\n[91] Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. Winogrande: An adversarial winograd schema challenge at scale. Communications of the ACM, 64(9):99-106, 2021.\\n[92] Peter Clark, Isaac Cowhey, Oren Etzioni, Tushar Khot, Ashish Sabharwal, Carissa Schoenick, and Oyvind Tafjord. Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv preprint arXiv:1803.05457, 2018.\\n[93] Maarten Sap, Hannah Rashkin, Derek Chen, Ronan Le Bras, and Yejin Choi. Social IQa: Commonsense reasoning about social interactions. In Kentaro Inui, Jing Jiang, Vincent Ng, and Xiaojun Wan, editors, Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 4463-4473, Hong Kong, China, November 2019. Association for Computational Linguistics. doi: 10.18653/v1/D19-1454. URL https://aclanthology.org/D19-1454/.\\n[94] Christopher Clark, Kenton Lee, Ming-Wei Chang, Tom Kwiatkowski, Michael Collins, and Kristina Toutanova. BoolQ: Exploring the surprising difficulty of natural yes/no questions. In Jill Burstein, Christy Doran, and Thamar Solorio, editors, Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 2924-2936, Minneapolis, Minnesota, June 2019. Association for Computational Linguistics. doi: 10.18653/v1/N19-1300. URL https://aclanthology.org/N19-1300/.\\n[95] Michael Poli, Armin W Thomas, Eric Nguyen, Pragaash Ponnusamy, Björn Deiseroth, Kristian Kersting, Taiji Suzuki, Brian Hie, Stefano Ermon, Christopher Ré, et al. Mechanistic design and scaling of hybrid architectures. arXiv preprint arXiv:2403.17844, 2024.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 16,\r\n      \"markdown\": \"# NeurIPS Paper Checklist \\n\\n## 1. Claims\\n\\nQuestion: Do the main claims made in the abstract and introduction accurately reflect the paper's contributions and scope?\\nAnswer: [Yes]\\nJustification: The abstract and introduction both accurately state the main theoretical contributions of the paper and accurately discuss our experimental results.\\nGuidelines:\\n\\n- The answer NA means that the abstract and introduction do not include the claims made in the paper.\\n- The abstract and/or introduction should clearly state the claims made, including the contributions made in the paper and important assumptions and limitations. A No or NA answer to this question will not be perceived well by the reviewers.\\n- The claims made should match theoretical and experimental results, and reflect how much the results can be expected to generalize to other settings.\\n- It is fine to include aspirational goals as motivation as long as it is clear that these goals are not attained by the paper.\\n\\n\\n## 2. Limitations\\n\\nQuestion: Does the paper discuss the limitations of the work performed by the authors?\\nAnswer: [Yes]\\nJustification: We discuss the limitations of our method in the main text, including its scope, assumptions, and potential areas for future improvement.\\nGuidelines:\\n\\n- The answer NA means that the paper has no limitation while the answer No means that the paper has limitations, but those are not discussed in the paper.\\n- The authors are encouraged to create a separate \\\"Limitations\\\" section in their paper.\\n- The paper should point out any strong assumptions and how robust the results are to violations of these assumptions (e.g., independence assumptions, noiseless settings, model well-specification, asymptotic approximations only holding locally). The authors should reflect on how these assumptions might be violated in practice and what the implications would be.\\n- The authors should reflect on the scope of the claims made, e.g., if the approach was only tested on a few datasets or with a few runs. In general, empirical results often depend on implicit assumptions, which should be articulated.\\n- The authors should reflect on the factors that influence the performance of the approach. For example, a facial recognition algorithm may perform poorly when image resolution is low or images are taken in low lighting. Or a speech-to-text system might not be used reliably to provide closed captions for online lectures because it fails to handle technical jargon.\\n- The authors should discuss the computational efficiency of the proposed algorithms and how they scale with dataset size.\\n- If applicable, the authors should discuss possible limitations of their approach to address problems of privacy and fairness.\\n- While the authors might fear that complete honesty about limitations might be used by reviewers as grounds for rejection, a worse outcome might be that reviewers discover limitations that aren't acknowledged in the paper. The authors should use their best judgment and recognize that individual actions in favor of transparency play an important role in developing norms that preserve the integrity of the community. Reviewers will be specifically instructed to not penalize honesty concerning limitations.\\n\\n\\n## 3. Theory assumptions and proofs\\n\\nQuestion: For each theoretical result, does the paper provide the full set of assumptions and a complete (and correct) proof?\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 17,\r\n      \"markdown\": \"Answer: [Yes]\\nJustification: All assumptions are stated in the body of the theorems directly. As much of the proofs as possible is provided in the main paper text. We provided all detailed proofs of lemmas and theorems introducing this work.\\nGuidelines:\\n\\n- The answer NA means that the paper does not include theoretical results.\\n- All the theorems, formulas, and proofs in the paper should be numbered and crossreferenced.\\n- All assumptions should be clearly stated or referenced in the statement of any theorems.\\n- The proofs can either appear in the main paper or the supplemental material, but if they appear in the supplemental material, the authors are encouraged to provide a short proof sketch to provide intuition.\\n- Inversely, any informal proof provided in the core of the paper should be complemented by formal proofs provided in appendix or supplemental material.\\n- Theorems and Lemmas that the proof relies upon should be properly referenced.\\n\\n\\n# 4. Experimental result reproducibility \\n\\nQuestion: Does the paper fully disclose all the information needed to reproduce the main experimental results of the paper to the extent that it affects the main claims and/or conclusions of the paper (regardless of whether the code and data are provided or not)?\\nAnswer: [Yes]\\nJustification: This paper fully discloses all the information needed to reproduce the main experimental results.\\nGuidelines:\\n\\n- The answer NA means that the paper does not include experiments.\\n- If the paper includes experiments, a No answer to this question will not be perceived well by the reviewers: Making the paper reproducible is important, regardless of whether the code and data are provided or not.\\n- If the contribution is a dataset and/or model, the authors should describe the steps taken to make their results reproducible or verifiable.\\n- Depending on the contribution, reproducibility can be accomplished in various ways. For example, if the contribution is a novel architecture, describing the architecture fully might suffice, or if the contribution is a specific model and empirical evaluation, it may be necessary to either make it possible for others to replicate the model with the same dataset, or provide access to the model. In general. releasing code and data is often one good way to accomplish this, but reproducibility can also be provided via detailed instructions for how to replicate the results, access to a hosted model (e.g., in the case of a large language model), releasing of a model checkpoint, or other means that are appropriate to the research performed.\\n- While NeurIPS does not require releasing code, the conference does require all submissions to provide some reasonable avenue for reproducibility, which may depend on the nature of the contribution. For example\\n(a) If the contribution is primarily a new algorithm, the paper should make it clear how to reproduce that algorithm.\\n(b) If the contribution is primarily a new model architecture, the paper should describe the architecture clearly and fully.\\n(c) If the contribution is a new model (e.g., a large language model), then there should either be a way to access this model for reproducing the results or a way to reproduce the model (e.g., with an open-source dataset or instructions for how to construct the dataset).\\n(d) We recognize that reproducibility may be tricky in some cases, in which case authors are welcome to describe the particular way they provide for reproducibility. In the case of closed-source models, it may be that access to the model is limited in some way (e.g., to registered users), but it should be possible for other researchers to have some path to reproducing or verifying the results.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 18,\r\n      \"markdown\": \"# 5. Open access to data and code \\n\\nQuestion: Does the paper provide open access to the data and code, with sufficient instructions to faithfully reproduce the main experimental results, as described in supplemental material?\\nAnswer: [Yes]\\nJustification: We plan to provide data and code via Github after the paper is made publicly available online.\\nGuidelines:\\n\\n- The answer NA means that paper does not include experiments requiring code.\\n- Please see the NeurIPS code and data submission guidelines (https://nips.cc/ public/guides/CodeSubmissionPolicy) for more details.\\n- While we encourage the release of code and data, we understand that this might not be possible, so \\\"No\\\" is an acceptable answer. Papers cannot be rejected simply for not including code, unless this is central to the contribution (e.g., for a new open-source benchmark).\\n- The instructions should contain the exact command and environment needed to run to reproduce the results. See the NeurIPS code and data submission guidelines (https: //nips.cc/public/guides/CodeSubmissionPolicy) for more details.\\n- The authors should provide instructions on data access and preparation, including how to access the raw data, preprocessed data, intermediate data, and generated data, etc.\\n- The authors should provide scripts to reproduce all experimental results for the new proposed method and baselines. If only a subset of experiments are reproducible, they should state which ones are omitted from the script and why.\\n- At submission time, to preserve anonymity, the authors should release anonymized versions (if applicable).\\n- Providing as much information as possible in supplemental material (appended to the paper) is recommended, but including URLs to data and code is permitted.\\n\\n\\n## 6. Experimental setting/details\\n\\nQuestion: Does the paper specify all the training and test details (e.g., data splits, hyperparameters, how they were chosen, type of optimizer, etc.) necessary to understand the results?\\nAnswer:[Yes]\\nJustification: Detailed descriptions of the experiment and the associated code are provided in the supplementary materials.\\nGuidelines:\\n\\n- The answer NA means that the paper does not include experiments.\\n- The experimental setting should be presented in the core of the paper to a level of detail that is necessary to appreciate the results and make sense of them.\\n- The full details can be provided either with the code, in appendix, or as supplemental material.\\n\\n\\n## 7. Experiment statistical significance\\n\\nQuestion: Does the paper report error bars suitably and correctly defined or other appropriate information about the statistical significance of the experiments?\\nAnswer: [Yes]\\nJustification: To account for variability and support statistical validity, we report standard deviations computed over multiple runs.\\nGuidelines:\\n\\n- The answer NA means that the paper does not include experiments.\\n- The authors should answer \\\"Yes\\\" if the results are accompanied by error bars, confidence intervals, or statistical significance tests, at least for the experiments that support the main claims of the paper.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 19,\r\n      \"markdown\": \"- The factors of variability that the error bars are capturing should be clearly stated (for example, train/test split, initialization, random drawing of some parameter, or overall run with given experimental conditions).\\n- The method for calculating the error bars should be explained (closed form formula, call to a library function, bootstrap, etc.)\\n- The assumptions made should be given (e.g., Normally distributed errors).\\n- It should be clear whether the error bar is the standard deviation or the standard error of the mean.\\n- It is OK to report 1-sigma error bars, but one should state it. The authors should preferably report a 2-sigma error bar than state that they have a $96 \\\\%$ CI, if the hypothesis of Normality of errors is not verified.\\n- For asymmetric distributions, the authors should be careful not to show in tables or figures symmetric error bars that would yield results that are out of range (e.g. negative error rates).\\n- If error bars are reported in tables or plots, The authors should explain in the text how they were calculated and reference the corresponding figures or tables in the text.\\n\\n\\n# 8. Experiments compute resources \\n\\nQuestion: For each experiment, does the paper provide sufficient information on the computer resources (type of compute workers, memory, time of execution) needed to reproduce the experiments?\\nAnswer: [Yes]\\nJustification: To facilitate reproducibility, we report accelerator type, memory capacity, and runtime for key experiments.\\nGuidelines:\\n\\n- The answer NA means that the paper does not include experiments.\\n- The paper should indicate the type of compute workers CPU or GPU, internal cluster, or cloud provider, including relevant memory and storage.\\n- The paper should provide the amount of compute required for each of the individual experimental runs as well as estimate the total compute.\\n- The paper should disclose whether the full research project required more compute than the experiments reported in the paper (e.g., preliminary or failed experiments that didn't make it into the paper).\\n\\n\\n## 9. Code of ethics\\n\\nQuestion: Does the research conducted in the paper conform, in every respect, with the NeurIPS Code of Ethics https://neurips.cc/public/EthicsGuidelines?\\nAnswer: [Yes]\\nJustification: There are no violations of the NeurIPS Code of Ethics in this paper.\\nGuidelines:\\n\\n- The answer NA means that the authors have not reviewed the NeurIPS Code of Ethics.\\n- If the authors answer No, they should explain the special circumstances that require a deviation from the Code of Ethics.\\n- The authors should make sure to preserve anonymity (e.g., if there is a special consideration due to laws or regulations in their jurisdiction).\\n\\n\\n## 10. Broader impacts\\n\\nQuestion: Does the paper discuss both potential positive societal impacts and negative societal impacts of the work performed?\\nAnswer: [Yes]\\nJustification: We provided a discussion in the Appendix about the impact of the work.\\nGuidelines:\\n\\n- The answer NA means that there is no societal impact of the work performed.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 20,\r\n      \"markdown\": \"- If the authors answer NA or No, they should explain why their work has no societal impact or why the paper does not address societal impact.\\n- Examples of negative societal impacts include potential malicious or unintended uses (e.g., disinformation, generating fake profiles, surveillance), fairness considerations (e.g., deployment of technologies that could make decisions that unfairly impact specific groups), privacy considerations, and security considerations.\\n- The conference expects that many papers will be foundational research and not tied to particular applications, let alone deployments. However, if there is a direct path to any negative applications, the authors should point it out. For example, it is legitimate to point out that an improvement in the quality of generative models could be used to generate deepfakes for disinformation. On the other hand, it is not needed to point out that a generic algorithm for optimizing neural networks could enable people to train models that generate Deepfakes faster.\\n- The authors should consider possible harms that could arise when the technology is being used as intended and functioning correctly, harms that could arise when the technology is being used as intended but gives incorrect results, and harms following from (intentional or unintentional) misuse of the technology.\\n- If there are negative societal impacts, the authors could also discuss possible mitigation strategies (e.g., gated release of models, providing defenses in addition to attacks, mechanisms for monitoring misuse, mechanisms to monitor how a system learns from feedback over time, improving the efficiency and accessibility of ML).\\n\\n\\n# 11. Safeguards \\n\\nQuestion: Does the paper describe safeguards that have been put in place for responsible release of data or models that have a high risk for misuse (e.g., pretrained language models, image generators, or scraped datasets)?\\nAnswer: [NA]\\nJustification: This paper poses no such risks.\\nGuidelines:\\n\\n- The answer NA means that the paper poses no such risks.\\n- Released models that have a high risk for misuse or dual-use should be released with necessary safeguards to allow for controlled use of the model, for example by requiring that users adhere to usage guidelines or restrictions to access the model or implementing safety filters.\\n- Datasets that have been scraped from the Internet could pose safety risks. The authors should describe how they avoided releasing unsafe images.\\n- We recognize that providing effective safeguards is challenging, and many papers do not require this, but we encourage authors to take this into account and make a best faith effort.\\n\\n\\n## 12. Licenses for existing assets\\n\\nQuestion: Are the creators or original owners of assets (e.g., code, data, models), used in the paper, properly credited and are the license and terms of use explicitly mentioned and properly respected?\\nAnswer: [Yes]\\nJustification: We properly credit all external assets and explicitly acknowledge their licenses and terms of use.\\nGuidelines:\\n\\n- The answer NA means that the paper does not use existing assets.\\n- The authors should cite the original paper that produced the code package or dataset.\\n- The authors should state which version of the asset is used and, if possible, include a URL.\\n- The name of the license (e.g., CC-BY 4.0) should be included for each asset.\\n- For scraped data from a particular source (e.g., website), the copyright and terms of service of that source should be provided.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 21,\r\n      \"markdown\": \"- If assets are released, the license, copyright information, and terms of use in the package should be provided. For popular datasets, paperswithcode.com/datasets has curated licenses for some datasets. Their licensing guide can help determine the license of a dataset.\\n- For existing datasets that are re-packaged, both the original license and the license of the derived asset (if it has changed) should be provided.\\n- If this information is not available online, the authors are encouraged to reach out to the asset's creators.\\n\\n\\n# 13. New assets \\n\\nQuestion: Are new assets introduced in the paper well documented and is the documentation provided alongside the assets?\\nAnswer: [No]\\nJustification: This paper does not introduce new assets.\\nGuidelines:\\n\\n- The answer NA means that the paper does not release new assets.\\n- Researchers should communicate the details of the dataset/code/model as part of their submissions via structured templates. This includes details about training, license, limitations, etc.\\n- The paper should discuss whether and how consent was obtained from people whose asset is used.\\n- At submission time, remember to anonymize your assets (if applicable). You can either create an anonymized URL or include an anonymized zip file.\\n\\n\\n## 14. Crowdsourcing and research with human subjects\\n\\nQuestion: For crowdsourcing experiments and research with human subjects, does the paper include the full text of instructions given to participants and screenshots, if applicable, as well as details about compensation (if any)?\\nAnswer: [NA]\\nJustification: This paper does not involve such experiments.\\nGuidelines:\\n\\n- The answer NA means that the paper does not involve crowdsourcing nor research with human subjects.\\n- Including this information in the supplemental material is fine, but if the main contribution of the paper involves human subjects, then as much detail as possible should be included in the main paper.\\n- According to the NeurIPS Code of Ethics, workers involved in data collection, curation, or other labor should be paid at least the minimum wage in the country of the data collector.\\n\\n\\n## 15. Institutional review board (IRB) approvals or equivalent for research with human subjects\\n\\nQuestion: Does the paper describe potential risks incurred by study participants, whether such risks were disclosed to the subjects, and whether Institutional Review Board (IRB) approvals (or an equivalent approval/review based on the requirements of your country or institution) were obtained?\\nAnswer: [NA]\\nJustification: This paper does not involve any user studies.\\nGuidelines:\\n\\n- The answer NA means that the paper does not involve crowdsourcing nor research with human subjects.\\n- Depending on the country in which research is conducted, IRB approval (or equivalent) may be required for any human subjects research. If you obtained IRB approval, you should clearly state this in the paper.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 22,\r\n      \"markdown\": \"- We recognize that the procedures for this may vary significantly between institutions and locations, and we expect authors to adhere to the NeurIPS Code of Ethics and the guidelines for their institution.\\n- For initial submissions, do not include any information that would break anonymity (if applicable), such as the institution conducting the review.\\n\\n\\n# 16. Declaration of LLM usage \\n\\nQuestion: Does the paper describe the usage of LLMs if it is an important, original, or non-standard component of the core methods in this research? Note that if the LLM is used only for writing, editing, or formatting purposes and does not impact the core methodology, scientific rigorousness, or originality of the research, declaration is not required.\\nAnswer: [NA]\\nJustification: No usage of LLMs in the methodology or experiments.\\nGuidelines:\\n\\n- The answer NA means that the core method development in this research does not involve LLMs as any important, original, or non-standard components.\\n- Please refer to our LLM policy (https://neurips.cc/Conferences/2025/LLM) for what should or should not be described.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    }\r\n  ],\r\n  \"model\": \"mistral-ocr-2505-completion\",\r\n  \"document_annotation\": null,\r\n  \"usage_info\": {\r\n    \"pages_processed\": 23,\r\n    \"doc_size_bytes\": 2492847\r\n  }\r\n}"
  },
  {
    "path": "google_papers/Nested_Learning/Nested_Learning.md",
    "content": "PAGE 1\r\n# Nested Learning: The Illusion of Deep Learning Architectures \r\n\r\nAli Behrouz<br>Google Research<br>USA<br>alibehrouz@google.com\r\n\r\nMeisam Razaviyayn<br>Google Research<br>USA<br>rezavyayn@google.com\r\n\r\nPeiling Zhong<br>Google Research<br>USA<br>peilinz@google.com<br>Vahab Mirrokni<br>Google Research<br>USA<br>mirrokni@google.com\r\n\r\n## Abstract\r\n\r\nOver the last decades, developing more powerful neural architectures and simultaneously designing optimization algorithms to effectively train them have been the core of research efforts to enhance the capability of machine learning models. Despite the recent progresses, particularly in developing Language Models (LMs), there are fundamental challenges and unanswered questions about how such models can continually learn/memorize, self-improved, and find \"effective solutions,\". In this paper, we present a new learning paradigm, called Nested Learning (NL), that coherently represents a model with a set of nested, multi-level, and/or parallel optimization problems, each of which with its own \"context flow\". NL reveals that existing deep learning methods learns from data through compressing their own context flow, and explain how in-context learning emerges in large models. NL suggests a path (a new dimension to deep learning) to design more expressive learning algorithms with more \"levels\", resulting in higher-order in-context learning abilities. In addition to its neuroscientifically plausible and mathematically white-box nature, we advocate for its importance by presenting three core contributions: (1) Deep Optimizers: Based on NL, we show that well-known gradient-based optimizers (e.g., Adam, SGD with Momentum, etc.) are in fact associative memory modules that aim to compress the gradients with gradient descent. Building on this insight, we present a set of more expressive optimizers with deep memory and/or more powerful learning rules; (2) Self-Modifying Titans: Taking advantage of NL's insights on learning algorithms, we present a novel sequence model that learns how to modify itself by learning its own update algorithm; and (3) Continuum Memory System: We present a new formulation for memory system that generalizes the traditional viewpoint of \"long-term/short-term memory\". Combining our self-modifying sequence model with the continuum memory system, we present a learning module, called HOPE, showing promising results in language modeling, continual learning, and long-context reasoning tasks.\r\n\r\n## 1 Introduction\r\n\r\nThis version of the paper has been extensively summarized to fit the page limit of NeurIPS camera ready, and some materials, experiments, discussions, and methods are moved to appendix, which might make some parts hard to follow or cause inconsistencies. To avoid such cases, please read our arXiv version instead [1].\r\nPAGE 2\r\n![img-0.jpeg](img-0.jpeg)\r\n\r\nFigure 1: The uniform and reusable structure as well as multi time scale update in the brain are the key components to unlock the continual learning in humans. Nested Learning (NL) allows for multi time-scale update for each component of the brain, while showing that well-known architectures such as Transformers are in fact linear layers with different frequency updates.\r\n\r\nFor decades, AI research has focused on designing machine learning algorithms that learn from data [2–5] or experience [6–8]; often by optimizing an objective $\\mathcal{L}(\\boldsymbol{\\theta})$ over parameters $\\boldsymbol{\\theta} \\in \\Theta$ with gradient-based methods. While traditional machine learning techniques required careful engineering and domain expertise to design feature extractors, limiting their ability to directly process and learn from natural data [9], deep representation learning offered a fully automated alternative to discover the representations needed for the task. Thereafter, deep learning has been an inseparable part of the large-scale computational models with seminal success in chemistry and biology [10], games [11, 12], computer vision [13, 14], and multimodal and natural language understanding [15–17].\r\n\r\nStacking of multiple layers, as it is done in deep learning models, provides the models with larger capacity, better expressive power in representing complex features, and more internal computations (e.g., #FLOPS) [18–20], all of which are critical and desirable characteristics for static tasks that require in-distribution predictions over a previously fixed set. This deep design, however, is not a universal solution to all the challenges and cannot help the expressive power of the models in multiple aspects, for example: (i) The computational depth of deep models might not change with more layers [21, 22], leaving their ability to implement complex algorithms untouched compared to traditional shallow approaches [23]; (ii) The capacity of some class of parameters might show marginal improvement with increasing the depth/width of the model [24]; (iii) The training process might converge to a suboptimal solution, mainly due to the suboptimal choice of the optimizer or its hyperparameters; and (iv) The model's ability to fast adapt to a new task, continually learn, and/or generalize to out-of-distribution data might not changed with stacking more layers and requires more careful designs.\r\n\r\nThe core part of the efforts to overcome the above challenges and to enhance the capability of deep learning models concentrate on: (1) developing more expressive class of parameters (i.e., neural architectures) [13, 25–28]; (2) introducing objectives that can better model the tasks [29–32]; (3) designing more efficient/effective optimization algorithms to find better solutions or with more resilience to forgetting [33–36]; and (4) scaling the model size to enhance its expressivity, when the \"right\" choice of architecture, objective, and optimization algorithms are made [24, 37, 38]. Collectively, these advancements and new findings on scaling patterns of deep models have established the foundations upon which Large Language Models (LLMs) have been built.\r\n\r\nThe development of LLMs marks a pivotal milestone in deep learning research: a paradigm shift from task-specific models to more general-purpose systems with various emergent capabilities as a result of scaling the \"right\" architectures [38, 39]. Despite all their success and remarkable capabilities in diverse sets of tasks [15, 40, 41], LLMs are largely static after their initial deployment phase, meaning that they successfully perform tasks learned during pre- or post-training, but are unable to continually acquire new capabilities beyond their immediate context. The only adaptable component of LLMs is their *in-context learning* ability–a (known to be emergent) characteristic of LLMs that enables fast adaption to the context and so perform zero- or few-shot tasks [38]. Beyond in-context learning, recent efforts to overcome the static nature of LLMs either are computationally expensive, require external components, lack generalization, and/or might suffer from catastrophic forgetting [42–44], which has led researchers to question if there is a need to revisit how to design machine learning\r\nPAGE 3\r\nmodels and if a new learning paradigm beyond stacking of layers is required to unleash the capabilities of LLMs in continual setups.\r\n\r\nCurrent Models only Experience the Immediate Present. As an analogy and to better illustrate the static nature of LLMs, we use the example of anterograde amnesia-a neurological condition where a person cannot form new long-term memories after the onset of the disorder, while existing memories remain intact [45]. This condition limits the person's knowledge and experiences to a short window of present and long past-before the onset of the disorder-which results in continuously experiencing the immediate present as if it were always new. The memory processing system of current LLMs suffer from a similar pattern. Their knowledge is limited to either, the immediate context that fits into their context window, or the knowledge in MLP layers that stores long-past, before the onset of \"end of pre-training.\" This analogy, has motivated us to take inspiration from neurophysiology literature and how brain consolidate its short-term memories:\r\n\r\n# 1.1 Human Brain Perspective and Neurophysiological Motivation \r\n\r\nHuman brain is highly efficient and effective when it comes to continual learning (a.k.a. effective context management), which is often attributed to neuroplasticity-the brain's remarkable capacity to change itself in response to new experiences, memories, learning, and even damage [46, 47]. Recent studies support that the formation of Long-term memory involves at least two distinct but complementary consolidation processes [48-50]: (1) A rapid \"online\" consolidation (also known as synaptic consolidation) phase occurs immediately or soon after learning, even during wakefulness. This is when new and initially fragile memory traces are stabilized and begin transferring from short-term to long-term storage; (2) An \"offline\" consolidation (also known as systems consolidation) process repeats the replay of the recently encoded patterns-during sharp-wave ripples (SWRs) in the hippocampus, coordinated with cortical sleep spindles and slow oscillations-strengthens and reorganizes the memory and supports transfer to cortical sites [51-53].\r\nComing back to the analogy of anterograde amnesia, evidence indicates that the condition can impact both stages, but especially the online consolidation phase, mainly due to the fact that hippocampus is the gateway for encoding new declarative memories, and so its damage means new information never will be stored in long-term memory. As mentioned above, the design of LLMs, and more specifically Transformer-based backbones, suffers from a similar condition after the pre-training phase. That is, the information provided in the context, never impacts the long-term memory parameters (e.g., feedforward layers), and so the model is not capable of acquiring new knowledge or skill, unless the information is still stored in the short-term memory (e.g., attention). To this end, although the second stage is equally, or even more, crucial for the consolidation of memories, and its absence can damage the process and might cause loss of memory [54, 55], in this work, we focus on the first stage: memory consolidation as an online process. We provide additional discussion on human brain perspective and its connection to NL in Appendix A.\r\n\r\nNotations. We let $x \\in \\mathbb{R}^{N \\times d_{h}}$ be the input, $\\mathcal{M}_{t}$ represent the state of memory/model $\\mathcal{M}$ at time $t$, $\\mathbf{K}$ be the keys, $\\mathbf{V}$ be the values, and $\\mathbf{Q}$ be the query matrices. We use bold lowercase letters with subscript $t$ to refer to the vector corresponds to the input $t$ (i.e., $\\mathbf{k}_{t}, \\mathbf{v}_{t}$, and $\\mathbf{q}_{t}$ ). We further refer to the distribution of any entities $f$ as $p(f)$. Through the paper, we use simple MLPs with $\\mathcal{L}_{\\mathcal{M}} \\geq 1$ layers and residual connection as the architecture of the memory module $\\mathcal{M}(\\cdot)$. When it is needed, we parameterized the memory module with $\\boldsymbol{\\theta}_{\\mathcal{M}} \\supseteq\\left\\{W_{1}, W_{2}, \\ldots, W_{\\mathcal{L}_{\\mathcal{M}}}\\right\\}$, which at least includes the parameters of linear layers in the MLP. We use superscript with parenthesis to refer to parameters in different levels of nested learning (different update frequency): i.e., $W^{(\\ell)}$.\r\n\r\n## 2 Nested Learning\r\n\r\nThis section discusses the motivations, formal definitions, and general high-level implications of Nested Learning (NL). We start with a formulation of associative memory and then by using step-by-step examples, we build the intuition behind architecture decomposition and its connection to modeling a neural network as an integrated system of optimization problems. We aim to first show how existing methods and concepts in deep learning fall under the NL paradigm and then we present new formulations that go beyond traditional methods and/or provide insights on how to improve existing algorithms and designs.\r\nPAGE 4\r\n![img-1.jpeg](img-1.jpeg)\r\n\r\nFigure 2: Nested Learning Paradigm that represent a machine learning model and its training procedure as a set of nested optimization problems. (Left) An example of Hybrid architecture. While deep learning perspective, as the flattened image of NL, does not provide insight about the depth of computation in the blocks, NL transparently represent all the inner gradient flows. (Right) A Neural Learning Module: A computational model that learns how to compress its own context flow. For example, the first level corresponds to the model's the most outer-loop training, often refer to as \"pre-training\" step.\r\n\r\n# 2.1 Associative Memory \r\n\r\nAssociative memory-the ability to form and retrieve connections between events-is a fundamental mental process and is an inseparable component of human learning [56]. Often in the literature, the concept of memorization and learning are used interchangeably; in neuropsychology literature, however, these two are clearly distinguished. More specifically, following neuropsychology literature [57], we build our terminology based on the following definition of memory and learning:\r\n\r\n## Learning vs. Memorization:\r\n\r\nMemory is a neural update caused by an input, and learning is the process for acquiring effective and useful memory.\r\n\r\nIn this work, our goal is to first show that all the elements of a computational sequence model, including optimizers and neural networks, are associative memory systems that compress their own context flow. Broadly speaking, associative memory is an operator that maps a set of keys to a set of values. We follow the general definition of associative memory by Behrouz et al. [58]:\r\nDefinition 1 (Associative Memory). Given a set of keys $\\mathcal{K} \\subseteq \\mathbb{R}^{d_{k}}$ and values $\\mathcal{V} \\subseteq \\mathbb{R}^{d_{v}}$, associative memory is an operator $\\mathcal{M}: \\mathcal{K} \\rightarrow \\mathcal{V}$ that maps two sets of keys $\\mathcal{K}$ and values $\\mathcal{V}$. To learn such mapping from the data, an objective $\\hat{\\mathcal{L}}(\\cdot ; \\cdot)$ measures the quality of the mapping and $\\mathcal{M}$ can be defined as:\r\n\r\n$$\r\n\\mathcal{M}^{*}=\\arg \\min _{\\mathcal{M}} \\quad \\hat{\\mathcal{L}}(\\mathcal{M}(\\mathcal{K}) ; \\mathcal{V})\r\n$$\r\n\r\nWhile the operator itself is a memory and the mapping acts as a memorization process (i.e., memorizing the connections of events in the context), acquiring such effective operator based on the data, is a learning process. It is notable that, here, keys and values can be any arbitrary events that memory aims to map them and are not limited to tokens. Later in this section, we will discuss that given a context flow, keys and values might be tokens, gradients, sub-sequences, etc. Furthermore, while the term of associative memory is more common in neuroscience and neuropsychology literature, the above formulation is also closely related to data compression and low-dimensional representation. That is, one can interpret the optimization process in Equation 1 as the training process of a network $\\mathcal{M}(.)$ that aims to compress the mappings into its parameters and so represent them in a lower dimensional space.\r\n\r\nIn sequence modeling, where keys and values are input tokens (e.g., tokenized text), the choice of objective and the optimization process for solving Equation 1 can result in distinct sequence\r\nPAGE 5\r\nmodeling architectures (see [59] and [58]) such as global/local softmax attention [27], or other modern recurrent models [28, 60, 61]. This simple formulation of sequence models provides us with better understanding of their internal process and also a tool to simply compare their modeling power based on their objective and optimization process. In the following, using step-by-step examples, we discuss how this formulation can be applied to all components of a neural architecture (including its optimization process in pre-training) and in fact, how a model is an integrated system of multi-level, nested, and or parallel memories, each of which with its own context flow.\r\n\r\nA Simple Example of MLP Training. We start with a simple example, in which we aim to train a 1-layer MLP (parameterized with $W$ ) for task $\\mathcal{T}$ and on dataset $\\mathcal{D}_{\\text {train }}=\\left\\{x_{1}, \\ldots, x_{\\left|\\mathcal{D}_{\\text {train }}\\right|}\\right\\}$ by optimizing the objective $\\mathcal{L}(\\cdot ; \\cdot)$ with gradient descent. In this case, the training process is equivalent to the following optimization problem:\r\n\r\n$$\r\nW^{*}=\\arg \\min _{W} \\mathcal{L}\\left(W ; \\mathcal{D}_{\\text {train }}\\right)\r\n$$\r\n\r\nwhose optimization by gradient descent results in a weight update rule equivalent to:\r\n\r\n$$\r\n\\begin{aligned}\r\nW_{t+1} & =W_{t}-\\eta_{t+1} \\nabla_{W_{t}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right) \\\\\r\n& =W_{t}-\\eta_{t+1} \\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right) \\otimes x_{t+1}, \\quad \\text { where } x_{t+1} \\sim \\mathcal{D}_{\\text {train }}\r\n\\end{aligned}\r\n$$\r\n\r\nwhere $y_{t+1}=W x_{t+1}$ is the output of the model for input $x_{t+1}$. Given this formulation, one can let $u_{t+1}=\\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)$ and reformulate the backpropagation process as the solution to an optimization problem on finding an optimal associative memory that maps input data points $\\mathcal{D}_{\\text {train }}=\\left\\{x_{t}\\right\\}_{t=1}^{\\left|\\mathcal{D}_{\\text {train }}\\right|}$ to their corresponding $u_{t+1}=\\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)$. That is, we let $\\mathcal{M}(\\cdot)=W_{t}$. parametrizes the memory, and use dot-product similarity to measure the quality of $W_{t}$ 's mapping between $x_{t+1}$ and $\\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)$ :\r\n\r\n$$\r\n\\begin{aligned}\r\nW_{t+1} & =\\arg \\min _{W}\\left\\langle W x_{t+1}, u_{t+1}\\right\\rangle+\\frac{1}{2 \\eta_{t+1}}\\left\\|W-W_{t}\\right\\|_{2}^{2} \\\\\r\n& =\\arg \\min _{W}\\left\\langle W x_{t}, \\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)\\right\\rangle+\\frac{1}{2 \\eta_{t+1}}\\left\\|W-W_{t}\\right\\|_{2}^{2}\r\n\\end{aligned}\r\n$$\r\n\r\nIn the above formulation, $u_{t+1}=\\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)$ can be interpreted as a local surprise signal in representation space that quantifies the mismatch between the current output and the structure the objective $\\mathcal{L}(\\cdot ; \\cdot)$ enforces. Therefore, this formulation translates the training phase of the model as a process of acquiring effective memory that maps data samples to their Local Surprise Signal (LSS) in representation space-defined as the mismatch between the current output and the structure enforced by the objective $\\mathcal{L}(\\cdot ; \\cdot)$. Accordingly, in this example, our model has a single gradient flow over the data samples, which is only active over dataset $\\mathcal{D}_{\\text {train }}=\\left\\{x_{1}, \\ldots, x_{\\left|\\mathcal{D}_{\\text {train }}\\right|}\\right\\}$ and will be frozen for any other data samples afterwards (a.k.a inference or test time).\r\n\r\nNext, in the above example, we replace the gradient descent algorithm with its enhanced momentumbased variant, resulting in an update rule of:\r\n\r\n$$\r\n\\begin{aligned}\r\n& W_{t+1}=W_{t}-\\mathbf{m}_{t+1} \\\\\r\n& \\mathbf{m}_{t+1}=\\mathbf{m}_{t}-\\eta_{t+1} \\nabla_{W_{t}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)=\\mathbf{m}_{t}-\\eta_{t+1} \\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right) \\otimes x_{t+1}\r\n\\end{aligned}\r\n$$\r\n\r\nIn Equation 8, given the previous state of Equation 7 (at time $t$ ), the value of $\\nabla_{W_{t}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)$ or similarly $\\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)$ are independent of recurrence in Equation 8 and so can be pre-computed beforehand. To this end, we let $u_{t+1}=\\nabla_{W_{t}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)$, and so Equation 8 can be reformulated as:\r\n\r\n$$\r\n\\begin{aligned}\r\nW_{t+1} & =W_{t}-\\mathbf{m}_{t+1} \\\\\r\n\\mathbf{m}_{t+1} & =\\arg \\min _{\\mathbf{m}}-\\left\\langle\\mathbf{m}, \\nabla_{W_{t}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)\\right\\rangle+\\eta_{t+1}\\left\\|\\mathbf{m}-\\mathbf{m}_{t}\\right\\|_{2}^{2} \\\\\r\n& =\\arg \\min _{\\mathbf{m}}-\\left\\langle\\mathbf{m} x_{t+1}, \\nabla_{y_{t+1}} \\mathcal{L}\\left(W_{t} ; x_{t+1}\\right)\\right\\rangle+\\eta_{t+1}\\left\\|\\mathbf{m}-\\mathbf{m}_{t}\\right\\|_{2}^{2}\r\n\\end{aligned}\r\n$$\r\n\r\nwhere the optimization problem in Equation 10 is equivalent to on step of gradient descent with adaptive learning rate of $\\eta_{t+1}$. Given these formulation, one can interpret the momentum term as either: (1) a key-less associative memory that compress the gradients into its parameters, or (2) an associative memory that learns how to map data points to their corresponding LSS-value. Interestingly, this formulation reveals that gradient descent with momentum is indeed a two-level\r\nPAGE 6\r\noptimization process, where the memory is optimized by simple gradient descent algorithm. This process is closely related to Fast Weight Programs (FWPs) [62], where the weight update process (i.e., Equation 9) is the slow network that its momentum weight is generated by a fast network (i.e., Equation 10).\r\nConcluding the above examples, we observed that the training process of a 1-layer MLP with: (1) Gradient descent is a 1-level associative memory that learns how to map data points to their corresponding LSS-value; and (2) Gradient descent with momentum is a 2-level associative memory (or optimization process) that the inner-level learns to store gradient values into its parameters, and then the outer-level updates the slow weight (i.e., $W_{t}$ ) with the value of the inner-level memory. While these are the most simple examples with respect to both architecture and optimizer algorithms, one might ask if similar conclusion can be made in more complex setups.\r\n\r\nAn Example of Architectural Decomposition. In the next example, we replace the MLP module with a linear attention [60]. That is, we aim to train a 1-layer linear attention for task $\\mathcal{T}$ and on a sequence of $\\mathcal{D}_{\\text {train }}=\\left\\{x_{1}, \\ldots, x_{\\left|\\mathcal{D}_{\\text {train }}\\right|}\\right\\}$ by optimizing the objective $\\mathcal{L}$ with gradient descent. Recalling the unnormalized linear attention formulation:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\mathbf{k}_{t}=x_{t} W_{\\mathbf{k}}, \\quad \\mathbf{v}_{t}=x_{t} W_{\\mathbf{v}}, \\quad \\mathbf{q}_{t}=x_{t} W_{\\mathbf{q}} \\\\\r\n& \\mathcal{M}_{t}=\\mathcal{M}_{t-1}+\\mathbf{v}_{t} \\mathbf{k}_{t}^{\\top} \\\\\r\n& y_{t}=\\mathcal{M}_{t} \\mathbf{q}_{t}\r\n\\end{aligned}\r\n$$\r\n\r\nAs discussed in earlier studies [58, 59], the recurrence in Equation 13 can be reformulated as the optimization process of a matrix-valued associative memory $\\mathcal{M}_{t}(\\cdot)$, in which, it aims to compress the mappings of keys and values into its parameters. In more details, in Definition 1, if we let $\\tilde{\\mathcal{L}}\\left(\\mathcal{M}_{t-1} ; \\mathbf{k}_{t}, \\mathbf{v}_{t}\\right):=-\\left\\langle\\mathcal{M}_{t-1} \\mathbf{k}_{t}, \\mathbf{v}_{t}\\right\\rangle$ and aim to optimize the memory with gradient descent, the memory update rule is: (Note that $\\nabla \\tilde{\\mathcal{L}}\\left(\\mathcal{M}_{t-1} ; \\mathbf{k}_{t}, \\mathbf{v}_{t}\\right)=\\mathbf{v}_{t} \\mathbf{k}_{t}^{\\top}$ and we let learning rate $\\eta_{t}=1$ )\r\n\r\n$$\r\n\\begin{aligned}\r\n\\mathcal{M}_{t+1} & =\\arg \\min _{\\mathcal{M}}\\left\\langle\\mathcal{M} \\mathbf{k}_{t+1}, \\mathbf{v}_{t+1}\\right\\rangle+\\left\\|\\mathcal{M}-\\mathcal{M}_{t}\\right\\|_{2}^{2} \\quad \\text { with gradient descent } \\\\\r\n\\Rightarrow \\mathcal{M}_{t+1} & =\\mathcal{M}_{t}-\\nabla \\tilde{\\mathcal{L}}\\left(\\mathcal{M}_{t} ; \\mathbf{k}_{t+1}, \\mathbf{v}_{t+1}\\right)=\\mathcal{M}_{t}+\\mathbf{v}_{t+1} \\mathbf{k}_{t+1}^{\\top}\r\n\\end{aligned}\r\n$$\r\n\r\nwhich is equivalent to the update rule of an unnormalized linear attention in Equation 13. Also, note that as we observed in the first example, training a linear layer with gradient descent is a 1-layer optimization problem of an associative memory (see Equation 3) and so the general training/updating process of projection layers (i.e., $W_{\\mathbf{k}}, W_{\\mathbf{v}}$, and $W_{\\mathbf{q}}$ ) is itself an optimization process of associative memory. Accordingly, this setup, i.e., training a linear attention with gradient descent, can be seen as a two-level optimization process, where the outer-loop (also known as training process) optimizes the projection layers with gradient descent, while the inner-loop optimizes the inner memory of $\\mathcal{M}_{t}$ with gradient descent.\r\n\r\nNote that, as discussed above, here, we have two associative memories, and so each of which has their own optimization process and gradient flow. That is, in the optimization of outer-level parameters of $W_{\\mathbf{k}}, W_{\\mathbf{v}}$, and $W_{\\mathbf{q}}$ there is no gradient with respect to parameter $\\mathcal{M}(\\cdot)$ and so there is no backpropagation through it. Similarly, in the inner-level, there is no backpropagation through projection layers and they are considered frozen. Furthermore, it is notable that in this example, the above formulation is also closely connected to FWPs perspective of linear attentions [63], where projections are considered slow weights, and memory update in Equation 13 is the fast weight update rule.\r\n\r\nArchitectural Decomposition with More Levels. In both above examples, we discussed simple cases, where they can be translated into 2-level optimization processes, which also coincides with their FWPs interpretations. In practice, however, we need to use more powerful optimization algorithms to train the model, and/or use more powerful recurrent update rule for memory. As a simple example, assume we use gradient descent with momentum to train a linear attention model. In the above examples, we show that how the linear attention component can be decomposed into two nested optimization problem. Similarly, here the model can be represented as a 2-level optimization problem, where (1) the inner level optimizes the memory to compress the context using gradient descent (see Equation 15), and (2) the outer level optimizes the projection layers with gradient descent with momentum. Interestingly, from the first example, we know that \"gradient descent with momentum\" algorithm itself is indeed a 2-level optimization problem where the momentum term itself is an associative memory that compress the past gradients into its parameters.\r\nPAGE 7\r\n# 2.2 Nested Optimization Problems \r\n\r\nIn the previous section, we provided examples to demonstrate how one can decompose a machine learning model into a set of nested or multi-level optimization problems. Next, we first aim to present a formal formulation for nested learning problems and then define Neural Learning Module-an integrated computational system that learns from data.\r\n\r\nAs we observed in the previous section, while we decomposed the model into a set of optimization process, it is still unclear if we can define a hierarchy (or order) over these problems, and uniquely represent the model in this format. Inspired by the hierarchy of brain waves that indicates the information processing frequency rate of each part (discussed in Section 1), we use the update rate of each optimization problem to order the components in multiple levels. To this end, we let the one update step over one data point to be the unit of time, and define the update frequency rate of each component as:\r\nDefinition 2 (Update Frequency). For any component of $A$, which can be a parametric component (e.g., learnable weights or momentum term in gradient descent in momentum) or a non-parametric component (e.g., attention block), we define its frequency, denoted as $f_{A}$, as its number of updates per unit of time.\r\n\r\nGiven the above update frequency, we can order the components of a machine learning algorithm based on operator $(\\cdot \\succ \\cdot)$. We let $A$ to be faster than $B$ and denote $A \\succ B$ if: (1) $f_{A}>f_{B}$, or (2) $f_{A}=f_{B}$ but the computation of the $B$ 's state at time $t$ requires the computation of $A$ 's state at time $t$. In this definition, when $A \\nsucc B$ and $B \\nsucc A$, we let $A \\stackrel{t}{\\sim} B$, which indicates that $A$ and $B$ has the same frequency update, but their computation is independent of each other (Later, we provide an example of this cases in AdamW optimizer). Based on the above operator, we sort the components into an ordered set of \"levels\", where (1) components in the same level have the same frequency update, and (2) the higher the level is, the lower its frequency. Notably, based on the above definition, each component has its own optimization problem and so context. While we optimize the component's inner objective with gradient-based optimizers, the above statement is equivalent to having exclusive gradient flow for each component in the model. In general case, however, one can use non-parametric solution (as we later discuss about attention).\r\n\r\nNeural Learning Module. Given the above definition of nested learning problems, we define neural learning module as a new way of representation of machine learning models that shows the model as an interconnected system of components, each of which with its own gradient flow. Note that, orthogonal to deep learning, nested learning allows us to define neural learning models with more levels, resulting in more expressive architecture.\r\n\r\nNested learning allows computational models that are composed of multiple (multi-layer) levels to learn from and process data with different levels of abstraction and time-scales.\r\n\r\nNext, we study optimizers and well-known deep learning architectures from the nested learning perspective, and provide examples that how NL can help to enhance those components.\r\n\r\n### 2.3 Optimizers as Learning Modules\r\n\r\nIn this section, we start by understanding how well-known optimizers and their variants are special instances of nested learning. Recall the gradient descent method with momentum,\r\n\r\n$$\r\n\\begin{aligned}\r\n& W_{i+1}=W_{i}+\\mathbf{m}_{i+1} \\\\\r\n& \\mathbf{m}_{i+1}=\\alpha_{i+1} \\mathbf{m}_{i}-\\eta_{t} \\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)\r\n\\end{aligned}\r\n$$\r\n\r\nwhere matrix (or vector) $\\mathbf{m}_{i}$ is the momentum at state $i$ and $\\alpha_{i}$ and $\\eta_{i}$ are adaptive learning and momentum rates, respectively. Assuming $\\alpha_{i+1}=1$, the momentum term can be viewed as the result of optimizing the following objective with gradient descent:\r\n\r\n$$\r\n\\min _{\\mathbf{m}}\\left\\langle\\mathbf{m} \\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)^{\\top}, \\mathbf{I}\\right\\rangle\r\n$$\r\n\r\nThis interpretation shows that momentum can indeed be viewed as a meta memory module that learns how to memorize gradients of the objective into its parameters. Building on this intuition, in\r\nPAGE 8\r\nSection C. 4 we show that Adam with a small modification is the optimal associative memory for the models' gradients. Next, we show that how this perspective can result in designing more expressive optimizers:\r\n\r\nExtension: More Expressive Association. As discussed earlier, momentum is a value-less associative memory and so has limited expressive power. To address this issue, following the original definition of associative memory (i.e., mapping keys to values), we let value parameter $\\mathbf{v}_{i}=\\mathbf{P}_{i}$ and so the momentum aims to minimize:\r\n\r\n$$\r\n\\min _{\\mathbf{m}}\\left\\langle\\mathbf{m} \\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)^{\\top}, \\mathbf{P}_{i}\\right\\rangle\r\n$$\r\n\r\nusing gradient descent, resulting in the update rule:\r\n\r\n$$\r\n\\begin{aligned}\r\n& W_{i+1}=W_{i}+\\mathbf{m}_{i+1} \\\\\r\n& \\mathbf{m}_{i+1}=\\alpha_{i+1} \\mathbf{m}_{i}-\\eta_{t} \\mathbf{P}_{i} \\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)\r\n\\end{aligned}\r\n$$\r\n\r\nThis formulation is equivalent to using preconditioning the momentum GD. In fact, preconditioning means that the momentum term is an associative memory that learns how to compress the mappings between $\\mathbf{P}_{i}$ and the gradient term $\\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)$. While any reasonable choice (e.g., random features) of preconditioning can improve the expressivity of the initial version of GD with momentum per se is a value-less memory (i.e., mapping all gradients to a single value), the above perspective gives more intuition about what preconditioning are more useful. That is, the momentum acts as a memory that aims to map gradients to their corresponding values, and so a function of gradients (e.g., information about Hessian) can provide the memory with a more meaningful mappings.\r\n\r\nExtension: More Expressive Objectives. As discussed by Behrouz et al. [58], optimizing an inner objective of dot-product similarity results in Hebbian-like update rule, which can cause the memory to be less effective. A natural extension of this internal objective is to use $\\ell_{2}(\\cdot)$ regression loss (for measuring the corresponding key-value mapping fitness) and minimize the loss function $\\|\\mathbf{m} \\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)^{\\top}-\\mathbf{P}_{i}\\|_{2}^{2}$, resulting in the update rule of:\r\n\r\n$$\r\n\\begin{aligned}\r\n& W_{i+1}=W_{i}+\\mathbf{m}_{i+1} \\\\\r\n& \\mathbf{m}_{i+1}=\\left(\\alpha_{i+1} \\mathbf{I}-\\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)^{\\top} \\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)\\right) \\mathbf{m}_{i}-\\eta_{t} \\mathbf{P}_{i} \\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)\r\n\\end{aligned}\r\n$$\r\n\r\nThis update is based on delta-rule [64] and so it allows the memory (momentum) to better manage its limited capacity and better memorize the series of past gradients.\r\n\r\nExtension: More Expressive Memory. As discussed earlier, momentum can be viewed as a meta memory model that uses a linear layer (i.e., matrix-valued) to compress the past gradient values. Due to the linear nature of momentum, only linear functions of past gradients can be learned by its internal objective. To increase the learning capacity of this module, one alternative is to use alternative powerful persistent learning modules: i.e., replacing a linear matrix-valued memory for momentum with an MLP. Therefore, momentum as the a memory for the past gradients, has more capacity to capture the underlying dynamics of the gradients. To this end, we extend the formulation in Equation 17 as:\r\n\r\n$$\r\nW_{i+1}=W_{i}+\\mathbf{m}_{i+1}\\left(\\mathbf{u}_{i}\\right), \\quad \\text { and } \\quad \\mathbf{m}_{i+1}=\\alpha_{i+1} \\mathbf{m}_{i}-\\eta_{t} \\nabla \\mathcal{L}^{(2)}\\left(\\mathbf{m}_{i} ; \\mathbf{u}_{i}, \\mathbf{I}\\right)\r\n$$\r\n\r\nwhere $\\mathbf{u}_{i}=\\nabla \\mathcal{L}\\left(W_{i} ; x_{i}\\right)$ and $\\nabla \\mathcal{L}^{(2)}(\\cdot)$ is the internal objective of momentum (e.g., dot product similarity $\\left\\langle\\mathbf{m}\\left(\\mathbf{u}_{i}^{\\top}\\right), \\mathbf{1}\\right\\rangle$ ). We refer to this variant as Deep Momentum Gradient Descent (DMGD).\r\n\r\nExtension: None Linear Outputs. Building upon the above perspective, in which we see the momentum as a neural architecture, one common technique to enhance the representation power of momentum memory module is to use non-linearity on top of its output [28, 65]. That is, we re-formulate Equation 23 as:\r\n\r\n$$\r\nW_{i+1}=W_{i}+\\sigma\\left(\\mathbf{m}_{i+1}\\left(\\mathbf{u}_{i}\\right)\\right), \\quad \\text { and } \\quad \\mathbf{m}_{i+1}=\\alpha_{i+1} \\mathbf{m}_{i}-\\eta_{t} \\nabla \\mathcal{L}^{(2)}\\left(\\mathbf{m}_{i} ; \\mathbf{u}_{i}, \\mathbf{I}\\right)\r\n$$\r\n\r\nwhere $\\sigma(\\cdot)$ is an arbitrary non-linearity. As an example, we let $\\sigma(\\cdot)=$ Newton-Schulz $(\\cdot)$, where Newton-Schulz $(\\cdot)$ is the iterative Newton-Schulz method [66], and $\\mathbf{m}(\\cdot)$ be a linear layer; the resulted optimizer is equivalent to Muon optimizer [34].\r\nPAGE 9\r\nGoing Beyond Simple Backpropagation. As discussed earlier in Section 2.1, the pre-training process and backpropagation is a form of associative memory, where input data is mapped to the surprised caused by its predicted output $\\nabla_{y_{t}} \\mathcal{L}\\left(W_{t} ; x_{t}\\right)$ :\r\n\r\n$$\r\nW_{t+1}=W_{t}-\\eta_{t+1} \\nabla_{W_{t}} \\mathcal{L}\\left(W_{t} ; x_{t}\\right)=W_{t}-\\eta_{t+1} \\nabla_{y_{t}} \\mathcal{L}\\left(W_{t} ; x_{t}\\right) \\otimes x_{t}, \\quad \\text { where } x_{t} \\sim \\mathcal{D}_{\\text {train }}\r\n$$\r\n\r\nwhich from the associative memory perspective is equivalent to one step of gradient descent in optimization process of:\r\n\r\n$$\r\n\\min _{W}\\left\\langle W x_{t}, \\nabla_{y_{t}} \\mathcal{L}\\left(W_{t} ; x_{t}\\right)\\right\\rangle\r\n$$\r\n\r\nAs we discussed in Appendix C, the above formulation cause ignoring the dependencies of data samples like $x_{t}$. To extend it to a more powerful formulation where it also consider the dependencies of data points (which is extremely important when we use optimizer in the token space as they are not independent), we use $L_{2}$ regression objective with one step of gradient descent as follows:\r\n\r\n$$\r\n\\min _{W}\\left\\|W x_{t}-\\nabla_{y_{t}} \\mathcal{L}\\left(W_{t} ; x_{t}\\right)\\right\\|_{2}^{2}\r\n$$\r\n\r\nThis formulation results in a new variant of gradient descent, which can be simplified as follows:\r\n\r\n$$\r\n\\begin{aligned}\r\nW_{t+1} & =W_{t}\\left(\\mathbf{I}-x_{t} x_{t}^{\\top}\\right)-\\eta_{t+1} \\nabla_{W_{t}} \\mathcal{L}\\left(W_{t} ; x_{t}\\right) \\\\\r\n& =W_{t}\\left(\\mathbf{I}-x_{t} x_{t}^{\\top}\\right)-\\eta_{t+1} \\nabla_{y_{t}} \\mathcal{L}\\left(W_{t} ; x_{t}\\right) \\otimes x_{t}, \\quad \\text { where } x_{t} \\sim \\mathcal{D}_{\\text {train }}\r\n\\end{aligned}\r\n$$\r\n\r\nLater, we use this optimizer as the internal optimizer of our HOPE architecture.\r\n\r\n# 3 HOPE: A Self-Referential Learning Module with Continuum Memory \r\n\r\nExisting architectural backbones consist of (1) a working memory module (e.g., attention), which is responsible to actively fuse the information across sequence length, and (2) a feed-forward layer (e.g., MLP) that fuse information across features and acts as the persistent memory or knowledge storage of pre-training phase. From the NL perspective, pre-training is the phase that the most outer level of the learning module is updated over its limited context flow. Accordingly, in the continual setup, such pre-training phase is also rarely updated over time, and so its corresponding knowledge storage needs to rarely be updated over time. Given this intuition, we extend the traditional view-point of long-term/short-term memory system and suggest a knowledge storage feed-forward for each level (frequency domain).\r\n\r\nGiven the definition of frequency, Continuum Memory System (CMS) is formalized as a chain of MLP blocks $\\operatorname{MLP}^{\\left(f_{1}\\right)}(\\cdot), \\ldots, \\operatorname{MLP}^{\\left(f_{k}\\right)}(\\cdot)$, each of which associated with a chunk size of $C^{(\\ell)}:=\\frac{\\max _{f} C^{(\\ell)}}{f_{1}}$ such that given input $x=\\left\\{x_{1}, \\ldots, x_{T}\\right\\}$ the output of the chain is calculated as (we disregard normalizations for the sake of clarity):\r\n\r\n$$\r\ny_{t}=\\operatorname{MLP}^{\\left(f_{k}\\right)}\\left(\\operatorname{MLP}^{\\left(f_{k-1}\\right)}\\left(\\cdots \\operatorname{MLP}^{\\left(f_{1}\\right)}\\left(x_{t}\\right)\\right)\\right)\r\n$$\r\n\r\nwhere the parameters of $\\ell$-th MLP block, i.e., $\\boldsymbol{\\theta}^{\\left(f_{\\ell}\\right)}$, are updated every $C^{(\\ell)}$ steps:\r\n\r\n$$\r\n\\boldsymbol{\\theta}_{i+1}^{\\left(f_{\\ell}\\right)}=\\boldsymbol{\\theta}_{i}^{\\left(f_{\\ell}\\right)}-\\left\\{\\begin{array}{ll}\r\n\\sum_{t=i-C^{(\\ell)}}^{i} \\eta_{t}^{(\\ell)} f\\left(\\boldsymbol{\\theta}_{t}^{\\left(f_{t}\\right)} ; x_{t}\\right) & \\text { if } i \\equiv 0\\left(\\bmod C^{(\\ell)}\\right) \\\\\r\n0 & \\text { otherwise }\r\n\\end{array}\\right.\r\n$$\r\n\r\nIn Appendix B.1, we discuss different variants of this formulation, including fully nested MLP layers. Here $\\eta_{t}^{(\\ell)}$ are learning rates corresponds to $\\boldsymbol{\\theta}^{\\left(f_{t}\\right)}$, and $f(\\cdot)$ is the error component of an arbitrary optimizer (e.g., $\\nabla \\mathcal{L}\\left(\\boldsymbol{\\theta}_{t}^{\\left(f_{t}\\right)} ; x_{t}\\right)$ in gradient descent). The conventional Transformer block [27] is a special instance of this formulation, where $k=1$. It is notable that Equation 31 provides an important interpretation: parameters $\\boldsymbol{\\theta}_{t}^{\\left(f_{t}\\right)}$ are responsible for compressing their own context into the their parameters and so they are a representative of abstract knowledge of their context.\r\n\r\nHOPE. We further present a self-referential learning module based on Titans [28] and our variant of gradient descent in Section B.1. Combining this self-referential sequence model with continuum memory system results in HOPE architecture.\r\nPAGE 10\r\n![img-2.jpeg](img-2.jpeg)\r\n\r\nFigure 3: A comparison of Hope architectural backbone with Transformers (Normalization and potential data-dependent components are removed for the sake of clarity).\r\n\r\nTable 1: Performance of HOPE and baselines on language modeling and common-sense reasoning tasks. Hybrid models are marked with *.\r\n\r\n|  Model | Wiki. <br> ppt $\\downarrow$ | LMB. <br> ppt $\\downarrow$ | LMB. <br> acc $\\uparrow$ | PIQA <br> acc $\\uparrow$ | Hella. <br> acc_n $\\uparrow$ | Wino. <br> acc $\\uparrow$ | ARC-e <br> acc $\\uparrow$ | ARC-e <br> acc_n $\\uparrow$ | SIQA <br> acc $\\uparrow$ | BoolQ <br> acc $\\uparrow$ | Avg. <br> $\\uparrow$  |\r\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\r\n|  HOPE (ours) | 26.05 | 29.38 | 35.40 | 64.62 | 40.11 | 51.19 | 56.92 | 28.49 | 38.33 | 60.12 | 46.90  |\r\n|  760M params / 30B tokens |  |  |  |  |  |  |  |  |  |  |   |\r\n|  Transformer++ | 25.21 | 27.64 | 35.78 | 66.92 | 42.19 | 51.95 | 60.38 | 32.46 | 39.51 | 60.37 | 48.69  |\r\n|  RetNet | 26.08 | 24.45 | 34.51 | 67.19 | 41.63 | 52.09 | 63.17 | 32.78 | 38.36 | 57.92 | 48.46  |\r\n|  DeltaNet | 24.37 | 24.60 | 37.06 | 66.93 | 41.98 | 50.65 | 64.87 | 31.39 | 39.88 | 59.02 | 48.97  |\r\n|  TTT | 24.17 | 23.51 | 34.74 | 67.25 | 43.92 | 50.99 | 64.53 | 33.81 | 40.16 | 59.58 | 47.32  |\r\n|  Samba* | 20.63 | 22.71 | 39.72 | 69.19 | 47.35 | 52.01 | 66.92 | 33.20 | 38.98 | 61.24 | 51.08  |\r\n|  Titans (LMM) | 20.04 | 21.96 | 37.40 | 69.28 | 48.46 | 52.27 | 66.31 | 35.84 | 40.13 | 62.76 | 51.56  |\r\n|  HOPE (ours) | 20.53 | 20.47 | 39.02 | 70.13 | 49.21 | 52.70 | 66.89 | 36.05 | 40.71 | 63.29 | 52.26  |\r\n|  1.3B params / 100B tokens |  |  |  |  |  |  |  |  |  |  |   |\r\n|  Transformer++ | 18.53 | 18.32 | 42.60 | 70.02 | 50.23 | 53.51 | 68.83 | 35.10 | 40.66 | 57.09 | 52.25  |\r\n|  RetNet | 19.08 | 17.27 | 40.52 | 70.07 | 49.16 | 54.14 | 67.34 | 33.78 | 40.78 | 60.39 | 52.02  |\r\n|  DeltaNet | 17.71 | 16.88 | 42.46 | 70.72 | 50.93 | 53.35 | 68.47 | 35.66 | 40.22 | 55.29 | 52.14  |\r\n|  Samba* | 16.13 | 13.29 | 44.94 | 70.94 | 53.42 | 55.56 | 68.81 | 36.17 | 39.96 | 62.11 | 54.00  |\r\n|  Titans (LMM) | 15.60 | 11.41 | 49.14 | 73.09 | 56.31 | 59.81 | 72.43 | 40.82 | 42.05 | 60.97 | 56.82  |\r\n|  HOPE (ours) | 15.11 | 11.63 | 50.01 | 73.29 | 56.84 | 60.19 | 72.30 | 41.24 | 42.52 | 61.46 | 57.23  |\r\n\r\n# 4 Experiments\r\n\r\nFor the sake of space, in the main paper, we report the results of the HOPE's evaluation on language modeling, and common-sense reasoning, tasks. However, we report an extensive set of results, including on experiments on optimizers, emergence of in-context learning, continual learning abilities of HOPE, ablation studies, long-context tasks, etc. in the appendix. Details about the experimental setups and other used datasets are in Appendix G\r\n\r\nLanguage Modeling and Common-sense Reasoning. We follow recent sequence modeling studies $[28,67,68]$ and report the results of HOPE and baselines with size of $340 \\mathrm{M}, 760 \\mathrm{M}$, and 1.3 B on language modeling and also commonsense reasoning downstream tasks. These results are reported in Table 1. HOPE demonstrate a very good perfomance across all the scales and benchmark tasks, outperforming both Transformers and recent modern recurrent neural networks, including Gated DeltaNet and Titans. Comparing HOPE to Titans and Gated DeltaNet, we can see that dynamically changing the key, value, and query projections based on the context as well a deep memory module can result in a model with lower perplexity and higher accuracy in benchmark results.\r\nPAGE 11\r\n# References \r\n\r\n[1] Ali Behrouz, Meisam Razaviyayn, Peilin Zhong, and Vahab Mirrokni. Nested learning: The illusion of deep learning architectures. arXiv preprint arXiv.\r\n[2] Walter Pitts. The linear theory of neuron networks: The dynamic problem. The bulletin of mathematical biophysics, 5:23-31, 1943.\r\n[3] Warren S McCulloch. The brain computing machine. Electrical Engineering, 68(6):492-497, 1949 .\r\n[4] Warren S McCulloch and Walter Pitts. The statistical organization of nervous activity. Biometrics, 4(2):91-99, 1948.\r\n[5] Arthur L Samuel. Some studies in machine learning using the game of checkers. IBM Journal of research and development, 3(3):210-229, 1959.\r\n[6] David Silver and Richard S Sutton. Welcome to the era of experience. Google AI, 1, 2025.\r\n[7] Richard S Sutton, Andrew G Barto, et al. Reinforcement learning: An introduction, volume 1. 1998.\r\n[8] Jonathan H. Connell and Sridhar Mahadevan. Robot learning. Robotica, 17(2):229-235, 1999. doi: 10.1017/S0263574799271172.\r\n[9] Yann LeCun, Yoshua Bengio, and Geoffrey Hinton. Deep learning. nature, 521(7553):436-444, 2015.\r\n[10] John Jumper, Richard Evans, Alexander Pritzel, Tim Green, Michael Figurnov, Olaf Ronneberger, Kathryn Tunyasuvunakool, Russ Bates, Augustin Žídek, Anna Potapenko, et al. Highly accurate protein structure prediction with alphafold. nature, 596(7873):583-589, 2021.\r\n[11] David Silver, Aja Huang, Chris J Maddison, Arthur Guez, Laurent Sifre, George Van Den Driessche, Julian Schrittwieser, Ioannis Antonoglou, Veda Panneershelvam, Marc Lanctot, et al. Mastering the game of go with deep neural networks and tree search. nature, 529(7587):484-489, 2016.\r\n[12] David Silver, Thomas Hubert, Julian Schrittwieser, Ioannis Antonoglou, Matthew Lai, Arthur Guez, Marc Lanctot, Laurent Sifre, Dharshan Kumaran, Thore Graepel, et al. A general reinforcement learning algorithm that masters chess, shogi, and go through self-play. Science, 362(6419):1140-1144, 2018.\r\n[13] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classification with deep convolutional neural networks. Advances in neural information processing systems, 25, 2012.\r\n[14] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=YicbFdNTTy.\r\n[15] Gheorghe Comanici, Eric Bieber, Mike Schaekermann, Ice Pasupat, Noveen Sachdeva, Inderjit Dhillon, Marcel Blistein, Ori Ram, Dan Zhang, Evan Rosen, et al. Gemini 2.5: Pushing the frontier with advanced reasoning, multimodality, long context, and next generation agentic capabilities. arXiv preprint arXiv:2507.06261, 2025.\r\n[16] Aixin Liu, Bei Feng, Bing Xue, Bingxuan Wang, Bochao Wu, Chengda Lu, Chenggang Zhao, Chengqi Deng, Chenyu Zhang, Chong Ruan, et al. Deepseek-v3 technical report. arXiv preprint arXiv:2412.19437, 2024.\r\n[17] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.\r\nPAGE 12\r\n[18] Guido Montúfar, Razvan Pascanu, Kyunghyun Cho, and Yoshua Bengio. On the number of linear regions of deep neural networks. Advances in neural information processing systems, 27, 2014.\r\n[19] Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, and Surya Ganguli. Exponential expressivity in deep neural networks through transient chaos. Advances in neural information processing systems, 29, 2016.\r\n[20] Joel Hestness, Sharan Narang, Newsha Ardalani, Gregory Diamos, Heewoo Jun, Hassan Kianinejad, Md Mostofa Ali Patwary, Yang Yang, and Yanqi Zhou. Deep learning scaling is predictable, empirically. arXiv preprint arXiv:1712.00409, 2017.\r\n[21] William Merrill, Ashish Sabharwal, and Noah A Smith. Saturated transformers are constantdepth threshold circuits. Transactions of the Association for Computational Linguistics, 10: 843-856, 2022.\r\n[22] Clayton Sanford, Daniel Hsu, and Matus Telgarsky. Transformers, parallel computation, and logarithmic depth. In Forty-first International Conference on Machine Learning, 2024. URL https://openreview.net/forum?id=QCZabhKQhB.\r\n[23] William Merrill, Jackson Petty, and Ashish Sabharwal. The illusion of state in state-space models. In Forty-first International Conference on Machine Learning, 2024. URL https : //openreview.net/forum?id=QZgo9JZpLq.\r\n[24] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020.\r\n[25] Juergen Schmidhuber and Sepp Hochreiter. Long short-term memory. Neural Computation MIT-Press, 1997.\r\n[26] Kunihiko Fukushima. Neocognitron: A self-organizing neural network model for a mechanism of pattern recognition unaffected by shift in position. Biological cybernetics, 36(4):193-202, 1980.\r\n[27] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ł ukasz Kaiser, and Illia Polosukhin. Attention is all you need. In I. Guyon, U. Von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. URL https://proceedings.neurips.cc/paper_files/paper/2017/file/ 3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.\r\n[28] Ali Behrouz, Peilin Zhong, and Vahab Mirrokni. Titans: Learning to memorize at test time. arXiv preprint arXiv:2501.00663, 2024.\r\n[29] David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. Learning representations by back-propagating errors. nature, 323(6088):533-536, 1986.\r\n[30] Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial networks. Communications of the ACM, 63(11):139-144, 2020.\r\n[31] Shaden Alshammari, John Hershey, Axel Feldmann, William T Freeman, and Mark Hamilton. I-con: A unifying framework for representation learning. arXiv preprint arXiv:2504.16929, 2025.\r\n[32] R Devon Hjelm, Alex Fedorov, Samuel Lavoie-Marchildon, Karan Grewal, Phil Bachman, Adam Trischler, and Yoshua Bengio. Learning deep representations by mutual information estimation and maximization. In International Conference on Learning Representations, 2019. URL https://openreview.net/forum?id=Bklr3j0cKX.\r\n[33] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.\r\nPAGE 13\r\n[34] K Jordan, Y Jin, V Boza, Y Jiacheng, F Cecista, L Newhouse, and J Bernstein. Muon: An optimizer for hidden layers in neural networks, 2024b. URL https://kellerjordan. github. io/posts/muon, 2024.\r\n[35] Vineet Gupta, Tomer Koren, and Yoram Singer. Shampoo: Preconditioned stochastic tensor optimization. In International Conference on Machine Learning, pages 1842-1850. PMLR, 2018.\r\n[36] Nikhil Vyas, Depen Morwani, Rosie Zhao, Itai Shapira, David Brandfonbrener, Lucas Janson, and Sham M. Kakade. SOAP: Improving and stabilizing shampoo using adam for language modeling. In The Thirteenth International Conference on Learning Representations, 2025. URL https://openreview.net/forum?id=IDxZhXrpNf.\r\n[37] Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas, Lisa Anne Hendricks, Johannes Welbl, Aidan Clark, et al. Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.\r\n[38] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877-1901, 2020.\r\n[39] Rylan Schaeffer, Brando Miranda, and Sanmi Koyejo. Are emergent abilities of large language models a mirage? Advances in neural information processing systems, 36:55565-55581, 2023.\r\n[40] Erik Nijkamp, Bo Pang, Hiroaki Hayashi, Lifu Tu, Huan Wang, Yingbo Zhou, Silvio Savarese, and Caiming Xiong. Codegen: An open large language model for code with multi-turn program synthesis. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=iaYcJKpY2B.\r\n[41] Wenhai Wang, Zhe Chen, Xiaokang Chen, Jiannan Wu, Xizhou Zhu, Gang Zeng, Ping Luo, Tong Lu, Jie Zhou, Yu Qiao, et al. Visionllm: Large language model is also an open-ended decoder for vision-centric tasks. Advances in Neural Information Processing Systems, 36: $61501-61513,2023$.\r\n[42] Sabri Eyuboglu, Ryan Ehrlich, Simran Arora, Neel Guha, Dylan Zinsley, Emily Liu, Will Tennien, Atri Rudra, James Zou, Azalia Mirhoseini, et al. Cartridges: Lightweight and generalpurpose long context representations via self-study. arXiv preprint arXiv:2506.06266, 2025.\r\n[43] hongzhou yu, Tianhao Cheng, Yingwen Wang, Wen He, Qing Wang, Ying Cheng, Yuejie Zhang, Rui Feng, and Xiaobo Zhang. FinemedLM-o1: Enhancing medical knowledge reasoning ability of LLM from supervised fine-tuning to test-time training. In Second Conference on Language Modeling, 2025. URL https://openreview.net/forum?id=7ZwuGZCopw.\r\n[44] Ekin Akyürek, Mehul Damani, Adam Zweiger, Linlu Qiu, Han Guo, Jyothish Pari, Yoon Kim, and Jacob Andreas. The surprising effectiveness of test-time training for few-shot learning. In Forty-second International Conference on Machine Learning, 2024.\r\n[45] William Beecher Scoville and Brenda Milner. Loss of recent memory after bilateral hippocampal lesions. Journal of neurology, neurosurgery, and psychiatry, 20(1):11, 1957.\r\n[46] Alvaro Pascual-Leone, Amir Amedi, Felipe Fregni, and Lotfi B Merabet. The plastic human brain cortex. Annu. Rev. Neurosci., 28(1):377-401, 2005.\r\n[47] Michael V Johnston. Plasticity in the developing brain: implications for rehabilitation. Developmental disabilities research reviews, 15(2):94-101, 2009.\r\n[48] Akihiro Goto, Ayaka Bota, Ken Miya, Jingbo Wang, Suzune Tsukamoto, Xinzhi Jiang, Daichi Hirai, Masanori Murayama, Tomoki Matsuda, Thomas J. McHugh, Takeharu Nagai, and Yasunori Hayashi. Stepwise synaptic plasticity events drive the early phase of memory consolidation. Science, 374(6569):857-863, 2021. doi: 10.1126/science.abj9195. URL https://www.science.org/doi/abs/10.1126/science.abj9195.\r\n[49] Uwe Frey and Richard GM Morris. Synaptic tagging and long-term potentiation. Nature, 385 (6616):533-536, 1997.\r\nPAGE 14\r\n[50] Wannan Yang, Chen Sun, Roman Huszár, Thomas Hainmueller, Kirill Kiselev, and György Buzsáki. Selection of experience for memory by hippocampal sharp wave ripples. Science, 383 (6690):1478-1483, 2024.\r\n[51] Daoyun Ji and Matthew A Wilson. Coordinated memory replay in the visual cortex and hippocampus during sleep. Nature neuroscience, 10(1):100-107, 2007.\r\n[52] Adrien Peyrache, Mehdi Khamassi, Karim Benchenane, Sidney I Wiener, and Francesco P Battaglia. Replay of rule-learning related neural patterns in the prefrontal cortex during sleep. Nature neuroscience, 12(7):919-926, 2009.\r\n[53] David J Foster and Matthew A Wilson. Reverse replay of behavioural sequences in hippocampal place cells during the awake state. Nature, 440(7084):680-683, 2006.\r\n[54] Sean PA Drummond, Gregory G Brown, J Christian Gillin, John L Stricker, Eric C Wong, and Richard B Buxton. Altered brain response to verbal learning following sleep deprivation. Nature, 403(6770):655-657, 2000.\r\n[55] Seung-Schik Yoo, Peter T Hu, Ninad Gujar, Ferenc A Jolesz, and Matthew P Walker. A deficit in the ability to form new human memories without sleep. Nature neuroscience, 10(3):385-392, 2007.\r\n[56] W Scott Terry. Learning and memory: Basic principles, processes, and procedures. Routledge, 2017.\r\n[57] Hideyuki Okano, Tomoo Hirano, and Evan Balaban. Learning and memory. Proceedings of the National Academy of Sciences, 97(23):12403-12404, 2000.\r\n[58] Ali Behrouz, Meisam Razaviyayn, Peilin Zhong, and Vahab Mirrokni. It's all connected: A journey through test-time memorization, attentional bias, retention, and online optimization. arXiv preprint arXiv:2504.13173, 2025.\r\n[59] Bo Liu, Rui Wang, Lemeng Wu, Yihao Feng, Peter Stone, and Qiang Liu. Longhorn: State space models are amortized online learners. arXiv preprint arXiv:2407.14207, 2024.\r\n[60] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pages 5156-5165. PMLR, 2020.\r\n[61] Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. Retentive network: A successor to transformer for large language models. arXiv preprint arXiv:2307.08621, 2023.\r\n[62] Juergen Schmidhuber. Learning to control fast-weight memories: An alternative to recurrent nets. accepted for publication in. Neural Computation, 1992.\r\n[63] Imanol Schlag, Kazuki Irie, and Juergen Schmidhuber. Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning, pages 9355-9366. PMLR, 2021.\r\n[64] DL Prados and SC Kak. Neural network capacity using delta rule. Electronics Letters, 25(3): 197-199, 1989.\r\n[65] Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, et al. Learning to (learn at test time): Rnns with expressive hidden states. arXiv preprint arXiv:2407.04620, 2024.\r\n[66] Nicholas J Higham. Functions of matrices: theory and computation. SIAM, 2008.\r\n[67] Songlin Yang, Jan Kautz, and Ali Hatamizadeh. Gated delta networks: Improving mamba2 with delta rule. arXiv preprint arXiv:2412.06464, 2024.\r\n[68] Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim. Parallelizing linear transformers with the delta rule over sequence length. Advances in Neural Information Processing Systems, 37:115491-115522, 2024.\r\nPAGE 15\r\n[69] Matteo Tiezzi, Michele Casoni, Alessandro Betti, Tommaso Guidi, Marco Gori, and Stefano Melacci. On the resurgence of recurrent models for long sequences: Survey and research opportunities in the transformer era. arXiv preprint arXiv:2402.08132, 2024.\r\n[70] Bo Peng, Eric Alcaide, Quentin Gregory Anthony, Alon Albalak, Samuel Arcadinho, Stella Biderman, Huanqi Cao, Xin Cheng, Michael Nguyen Chung, Leon Derczynski, Xingjian Du, Matteo Grella, Kranthi Kiran GV, Xuzheng He, Haowen Hou, Przemyslaw Kazienko, Jan Kocon, Jiaming Kong, Bartłomiej Koptyra, Hayden Lau, Jiaju Lin, Krishna Sri Ipsit Mantri, Ferdinand Mom, Atsushi Saito, Guangyu Song, Xiangru Tang, Johan S. Wind, Stanisław Wozniak, Zhenyuan Zhang, Qinghua Zhou, Jian Zhu, and Rui-Jie Zhu. RWKV: Reinventing RNNs for the transformer era. In The 2023 Conference on Empirical Methods in Natural Language Processing, 2023. URL https://openreview.net/forum?id=7SaXczaBpG.\r\n[71] Jimmy T.H. Smith, Andrew Warrington, and Scott Linderman. Simplified state space layers for sequence modeling. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=Ai8Hw3AXqks.\r\n[72] Ramin Hasani, Mathias Lechner, Tsun-Hsuan Wang, Makram Chahine, Alexander Amini, and Daniela Rus. Liquid structural state-space models. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id= g40TKRKfS7R.\r\n[73] Ali Behrouz, Michele Santacatterina, and Ramin Zabih. Mambamixer: Efficient selective state space models with dual token and channel selection. arXiv preprint arXiv:2403.19888, 2024.\r\n[74] Bo Peng, Daniel Goldstein, Quentin Anthony, Alon Albalak, Eric Alcaide, Stella Biderman, Eugene Cheah, Xingjian Du, Teddy Ferdinan, Haowen Hou, et al. Eagle and finch: Rwkv with matrix-valued states and dynamic recurrence. arXiv preprint arXiv:2404.05892, 2024.\r\n[75] Bo Peng, Ruichong Zhang, Daniel Goldstein, Eric Alcaide, Haowen Hou, Janna Lu, William Merrill, Guangyu Song, Kaifeng Tan, Saiteja Utpala, et al. Rwkv-7\" goose\" with expressive dynamic state evolution. arXiv preprint arXiv:2503.14456, 2025.\r\n[76] Julien Siems, Timur Carstensen, Arber Zela, Frank Hutter, Massimiliano Pontil, and Riccardo Grazzi. Deltaproduct: Increasing the expressivity of deltanet through products of householders. arXiv preprint arXiv:2502.10297, 2025.\r\n[77] John J Hopfield. Neural networks and physical systems with emergent collective computational abilities. Proceedings of the national academy of sciences, 79(8):2554-2558, 1982.\r\n[78] Juergen Schmidhuber. Reducing the ratio between learning complexity and number of time varying variables in fully recurrent nets. In ICANN'93: Proceedings of the International Conference on Artificial Neural Networks Amsterdam, The Netherlands 13-16 September 1993 3, pages 460-463. Springer, 1993.\r\n[79] Donald Olding Hebb. The organization of behavior: A neuropsychological theory. Psychology press, 2005 .\r\n[80] Tsendsuren Munkhdalai and Hong Yu. Neural semantic encoders. In Proceedings of the conference. Association for Computational Linguistics. Meeting, volume 1, page 397. NIH Public Access, 2017.\r\n[81] Tsendsuren Munkhdalai, Alessandro Sordoni, Tong Wang, and Adam Trischler. Metalearned neural memory. Advances in Neural Information Processing Systems, 32, 2019.\r\n[82] Kazuki Irie, Imanol Schlag, Robert Csordas, and Juergen Schmidhuber. Going beyond linear transformers with recurrent fast weight programmers. Advances in neural information processing systems, 34:7703-7717, 2021.\r\n[83] Ke Alexander Wang, Jiaxin Shi, and Emily B Fox. Test-time regression: a unifying framework for designing sequence models with associative memory. arXiv preprint arXiv:2501.12352, 2025 .\r\nPAGE 16\r\n[84] Kazuki Irie, Robert Csordas, and Juergen Schmidhuber. The dual form of neural networks revisited: Connecting test time predictions to training patterns via spotlights of attention. In International Conference on Machine Learning, pages 9639-9659. PMLR, 2022.\r\n[85] Kazuki Irie, Imanol Schlag, Róbert Csordás, and Juergen Schmidhuber. A modern selfreferential weight matrix that learns to modify itself. In International Conference on Machine Learning, pages 9660-9677. PMLR, 2022.\r\n[86] Jongho Park, Jaeseung Park, Zheyang Xiong, Nayoung Lee, Jaewoong Cho, Samet Oymak, Kangwook Lee, and Dimitris Papailiopoulos. Can mamba learn how to learn? a comparative study on in-context learning tasks. In Forty-first International Conference on Machine Learning, 2024. URL https://openreview.net/forum?id=GbFluKMmtE.\r\n[87] Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. In International Conference on Learning Representations, 2017. URL https:// openreview.net/forum?id=Byj72udxe.\r\n[88] Denis Paperno, German Kruszewski, Angeliki Lazaridou, Ngoc Quan Pham, Raffaella Bernardi, Sandro Pezzelle, Marco Baroni, Gemma Boleda, and Raquel Fernandez. The LAMBADA dataset: Word prediction requiring a broad discourse context. In Katrin Erk and Noah A. Smith, editors, Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1525-1534, Berlin, Germany, August 2016. Association for Computational Linguistics. doi: 10.18653/v1/P16-1144. URL https://aclanthology.org/P16-1144/.\r\n[89] Yonatan Bisk, Rowan Zellers, Jianfeng Gao, Yejin Choi, et al. Piqa: Reasoning about physical commonsense in natural language. In Proceedings of the AAAI conference on artificial intelligence, volume 34, pages 7432-7439, 2020.\r\n[90] Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. HellaSwag: Can a machine really finish your sentence? In Anna Korhonen, David Traum, and Lluis Marquez, editors, Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 4791-4800, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1472. URL https://aclanthology.org/P19-1472/.\r\n[91] Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. Winogrande: An adversarial winograd schema challenge at scale. Communications of the ACM, 64(9):99-106, 2021.\r\n[92] Peter Clark, Isaac Cowhey, Oren Etzioni, Tushar Khot, Ashish Sabharwal, Carissa Schoenick, and Oyvind Tafjord. Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv preprint arXiv:1803.05457, 2018.\r\n[93] Maarten Sap, Hannah Rashkin, Derek Chen, Ronan Le Bras, and Yejin Choi. Social IQa: Commonsense reasoning about social interactions. In Kentaro Inui, Jing Jiang, Vincent Ng, and Xiaojun Wan, editors, Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 4463-4473, Hong Kong, China, November 2019. Association for Computational Linguistics. doi: 10.18653/v1/D19-1454. URL https://aclanthology.org/D19-1454/.\r\n[94] Christopher Clark, Kenton Lee, Ming-Wei Chang, Tom Kwiatkowski, Michael Collins, and Kristina Toutanova. BoolQ: Exploring the surprising difficulty of natural yes/no questions. In Jill Burstein, Christy Doran, and Thamar Solorio, editors, Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 2924-2936, Minneapolis, Minnesota, June 2019. Association for Computational Linguistics. doi: 10.18653/v1/N19-1300. URL https://aclanthology.org/N19-1300/.\r\n[95] Michael Poli, Armin W Thomas, Eric Nguyen, Pragaash Ponnusamy, Björn Deiseroth, Kristian Kersting, Taiji Suzuki, Brian Hie, Stefano Ermon, Christopher Ré, et al. Mechanistic design and scaling of hybrid architectures. arXiv preprint arXiv:2403.17844, 2024.\r\nPAGE 17\r\n# NeurIPS Paper Checklist \r\n\r\n## 1. Claims\r\n\r\nQuestion: Do the main claims made in the abstract and introduction accurately reflect the paper's contributions and scope?\r\nAnswer: [Yes]\r\nJustification: The abstract and introduction both accurately state the main theoretical contributions of the paper and accurately discuss our experimental results.\r\nGuidelines:\r\n\r\n- The answer NA means that the abstract and introduction do not include the claims made in the paper.\r\n- The abstract and/or introduction should clearly state the claims made, including the contributions made in the paper and important assumptions and limitations. A No or NA answer to this question will not be perceived well by the reviewers.\r\n- The claims made should match theoretical and experimental results, and reflect how much the results can be expected to generalize to other settings.\r\n- It is fine to include aspirational goals as motivation as long as it is clear that these goals are not attained by the paper.\r\n\r\n\r\n## 2. Limitations\r\n\r\nQuestion: Does the paper discuss the limitations of the work performed by the authors?\r\nAnswer: [Yes]\r\nJustification: We discuss the limitations of our method in the main text, including its scope, assumptions, and potential areas for future improvement.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper has no limitation while the answer No means that the paper has limitations, but those are not discussed in the paper.\r\n- The authors are encouraged to create a separate \"Limitations\" section in their paper.\r\n- The paper should point out any strong assumptions and how robust the results are to violations of these assumptions (e.g., independence assumptions, noiseless settings, model well-specification, asymptotic approximations only holding locally). The authors should reflect on how these assumptions might be violated in practice and what the implications would be.\r\n- The authors should reflect on the scope of the claims made, e.g., if the approach was only tested on a few datasets or with a few runs. In general, empirical results often depend on implicit assumptions, which should be articulated.\r\n- The authors should reflect on the factors that influence the performance of the approach. For example, a facial recognition algorithm may perform poorly when image resolution is low or images are taken in low lighting. Or a speech-to-text system might not be used reliably to provide closed captions for online lectures because it fails to handle technical jargon.\r\n- The authors should discuss the computational efficiency of the proposed algorithms and how they scale with dataset size.\r\n- If applicable, the authors should discuss possible limitations of their approach to address problems of privacy and fairness.\r\n- While the authors might fear that complete honesty about limitations might be used by reviewers as grounds for rejection, a worse outcome might be that reviewers discover limitations that aren't acknowledged in the paper. The authors should use their best judgment and recognize that individual actions in favor of transparency play an important role in developing norms that preserve the integrity of the community. Reviewers will be specifically instructed to not penalize honesty concerning limitations.\r\n\r\n\r\n## 3. Theory assumptions and proofs\r\n\r\nQuestion: For each theoretical result, does the paper provide the full set of assumptions and a complete (and correct) proof?\r\nPAGE 18\r\nAnswer: [Yes]\r\nJustification: All assumptions are stated in the body of the theorems directly. As much of the proofs as possible is provided in the main paper text. We provided all detailed proofs of lemmas and theorems introducing this work.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not include theoretical results.\r\n- All the theorems, formulas, and proofs in the paper should be numbered and crossreferenced.\r\n- All assumptions should be clearly stated or referenced in the statement of any theorems.\r\n- The proofs can either appear in the main paper or the supplemental material, but if they appear in the supplemental material, the authors are encouraged to provide a short proof sketch to provide intuition.\r\n- Inversely, any informal proof provided in the core of the paper should be complemented by formal proofs provided in appendix or supplemental material.\r\n- Theorems and Lemmas that the proof relies upon should be properly referenced.\r\n\r\n\r\n# 4. Experimental result reproducibility \r\n\r\nQuestion: Does the paper fully disclose all the information needed to reproduce the main experimental results of the paper to the extent that it affects the main claims and/or conclusions of the paper (regardless of whether the code and data are provided or not)?\r\nAnswer: [Yes]\r\nJustification: This paper fully discloses all the information needed to reproduce the main experimental results.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not include experiments.\r\n- If the paper includes experiments, a No answer to this question will not be perceived well by the reviewers: Making the paper reproducible is important, regardless of whether the code and data are provided or not.\r\n- If the contribution is a dataset and/or model, the authors should describe the steps taken to make their results reproducible or verifiable.\r\n- Depending on the contribution, reproducibility can be accomplished in various ways. For example, if the contribution is a novel architecture, describing the architecture fully might suffice, or if the contribution is a specific model and empirical evaluation, it may be necessary to either make it possible for others to replicate the model with the same dataset, or provide access to the model. In general. releasing code and data is often one good way to accomplish this, but reproducibility can also be provided via detailed instructions for how to replicate the results, access to a hosted model (e.g., in the case of a large language model), releasing of a model checkpoint, or other means that are appropriate to the research performed.\r\n- While NeurIPS does not require releasing code, the conference does require all submissions to provide some reasonable avenue for reproducibility, which may depend on the nature of the contribution. For example\r\n(a) If the contribution is primarily a new algorithm, the paper should make it clear how to reproduce that algorithm.\r\n(b) If the contribution is primarily a new model architecture, the paper should describe the architecture clearly and fully.\r\n(c) If the contribution is a new model (e.g., a large language model), then there should either be a way to access this model for reproducing the results or a way to reproduce the model (e.g., with an open-source dataset or instructions for how to construct the dataset).\r\n(d) We recognize that reproducibility may be tricky in some cases, in which case authors are welcome to describe the particular way they provide for reproducibility. In the case of closed-source models, it may be that access to the model is limited in some way (e.g., to registered users), but it should be possible for other researchers to have some path to reproducing or verifying the results.\r\nPAGE 19\r\n# 5. Open access to data and code \r\n\r\nQuestion: Does the paper provide open access to the data and code, with sufficient instructions to faithfully reproduce the main experimental results, as described in supplemental material?\r\nAnswer: [Yes]\r\nJustification: We plan to provide data and code via Github after the paper is made publicly available online.\r\nGuidelines:\r\n\r\n- The answer NA means that paper does not include experiments requiring code.\r\n- Please see the NeurIPS code and data submission guidelines (https://nips.cc/ public/guides/CodeSubmissionPolicy) for more details.\r\n- While we encourage the release of code and data, we understand that this might not be possible, so \"No\" is an acceptable answer. Papers cannot be rejected simply for not including code, unless this is central to the contribution (e.g., for a new open-source benchmark).\r\n- The instructions should contain the exact command and environment needed to run to reproduce the results. See the NeurIPS code and data submission guidelines (https: //nips.cc/public/guides/CodeSubmissionPolicy) for more details.\r\n- The authors should provide instructions on data access and preparation, including how to access the raw data, preprocessed data, intermediate data, and generated data, etc.\r\n- The authors should provide scripts to reproduce all experimental results for the new proposed method and baselines. If only a subset of experiments are reproducible, they should state which ones are omitted from the script and why.\r\n- At submission time, to preserve anonymity, the authors should release anonymized versions (if applicable).\r\n- Providing as much information as possible in supplemental material (appended to the paper) is recommended, but including URLs to data and code is permitted.\r\n\r\n\r\n## 6. Experimental setting/details\r\n\r\nQuestion: Does the paper specify all the training and test details (e.g., data splits, hyperparameters, how they were chosen, type of optimizer, etc.) necessary to understand the results?\r\nAnswer:[Yes]\r\nJustification: Detailed descriptions of the experiment and the associated code are provided in the supplementary materials.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not include experiments.\r\n- The experimental setting should be presented in the core of the paper to a level of detail that is necessary to appreciate the results and make sense of them.\r\n- The full details can be provided either with the code, in appendix, or as supplemental material.\r\n\r\n\r\n## 7. Experiment statistical significance\r\n\r\nQuestion: Does the paper report error bars suitably and correctly defined or other appropriate information about the statistical significance of the experiments?\r\nAnswer: [Yes]\r\nJustification: To account for variability and support statistical validity, we report standard deviations computed over multiple runs.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not include experiments.\r\n- The authors should answer \"Yes\" if the results are accompanied by error bars, confidence intervals, or statistical significance tests, at least for the experiments that support the main claims of the paper.\r\nPAGE 20\r\n- The factors of variability that the error bars are capturing should be clearly stated (for example, train/test split, initialization, random drawing of some parameter, or overall run with given experimental conditions).\r\n- The method for calculating the error bars should be explained (closed form formula, call to a library function, bootstrap, etc.)\r\n- The assumptions made should be given (e.g., Normally distributed errors).\r\n- It should be clear whether the error bar is the standard deviation or the standard error of the mean.\r\n- It is OK to report 1-sigma error bars, but one should state it. The authors should preferably report a 2-sigma error bar than state that they have a $96 \\%$ CI, if the hypothesis of Normality of errors is not verified.\r\n- For asymmetric distributions, the authors should be careful not to show in tables or figures symmetric error bars that would yield results that are out of range (e.g. negative error rates).\r\n- If error bars are reported in tables or plots, The authors should explain in the text how they were calculated and reference the corresponding figures or tables in the text.\r\n\r\n\r\n# 8. Experiments compute resources \r\n\r\nQuestion: For each experiment, does the paper provide sufficient information on the computer resources (type of compute workers, memory, time of execution) needed to reproduce the experiments?\r\nAnswer: [Yes]\r\nJustification: To facilitate reproducibility, we report accelerator type, memory capacity, and runtime for key experiments.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not include experiments.\r\n- The paper should indicate the type of compute workers CPU or GPU, internal cluster, or cloud provider, including relevant memory and storage.\r\n- The paper should provide the amount of compute required for each of the individual experimental runs as well as estimate the total compute.\r\n- The paper should disclose whether the full research project required more compute than the experiments reported in the paper (e.g., preliminary or failed experiments that didn't make it into the paper).\r\n\r\n\r\n## 9. Code of ethics\r\n\r\nQuestion: Does the research conducted in the paper conform, in every respect, with the NeurIPS Code of Ethics https://neurips.cc/public/EthicsGuidelines?\r\nAnswer: [Yes]\r\nJustification: There are no violations of the NeurIPS Code of Ethics in this paper.\r\nGuidelines:\r\n\r\n- The answer NA means that the authors have not reviewed the NeurIPS Code of Ethics.\r\n- If the authors answer No, they should explain the special circumstances that require a deviation from the Code of Ethics.\r\n- The authors should make sure to preserve anonymity (e.g., if there is a special consideration due to laws or regulations in their jurisdiction).\r\n\r\n\r\n## 10. Broader impacts\r\n\r\nQuestion: Does the paper discuss both potential positive societal impacts and negative societal impacts of the work performed?\r\nAnswer: [Yes]\r\nJustification: We provided a discussion in the Appendix about the impact of the work.\r\nGuidelines:\r\n\r\n- The answer NA means that there is no societal impact of the work performed.\r\nPAGE 21\r\n- If the authors answer NA or No, they should explain why their work has no societal impact or why the paper does not address societal impact.\r\n- Examples of negative societal impacts include potential malicious or unintended uses (e.g., disinformation, generating fake profiles, surveillance), fairness considerations (e.g., deployment of technologies that could make decisions that unfairly impact specific groups), privacy considerations, and security considerations.\r\n- The conference expects that many papers will be foundational research and not tied to particular applications, let alone deployments. However, if there is a direct path to any negative applications, the authors should point it out. For example, it is legitimate to point out that an improvement in the quality of generative models could be used to generate deepfakes for disinformation. On the other hand, it is not needed to point out that a generic algorithm for optimizing neural networks could enable people to train models that generate Deepfakes faster.\r\n- The authors should consider possible harms that could arise when the technology is being used as intended and functioning correctly, harms that could arise when the technology is being used as intended but gives incorrect results, and harms following from (intentional or unintentional) misuse of the technology.\r\n- If there are negative societal impacts, the authors could also discuss possible mitigation strategies (e.g., gated release of models, providing defenses in addition to attacks, mechanisms for monitoring misuse, mechanisms to monitor how a system learns from feedback over time, improving the efficiency and accessibility of ML).\r\n\r\n\r\n# 11. Safeguards \r\n\r\nQuestion: Does the paper describe safeguards that have been put in place for responsible release of data or models that have a high risk for misuse (e.g., pretrained language models, image generators, or scraped datasets)?\r\nAnswer: [NA]\r\nJustification: This paper poses no such risks.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper poses no such risks.\r\n- Released models that have a high risk for misuse or dual-use should be released with necessary safeguards to allow for controlled use of the model, for example by requiring that users adhere to usage guidelines or restrictions to access the model or implementing safety filters.\r\n- Datasets that have been scraped from the Internet could pose safety risks. The authors should describe how they avoided releasing unsafe images.\r\n- We recognize that providing effective safeguards is challenging, and many papers do not require this, but we encourage authors to take this into account and make a best faith effort.\r\n\r\n\r\n## 12. Licenses for existing assets\r\n\r\nQuestion: Are the creators or original owners of assets (e.g., code, data, models), used in the paper, properly credited and are the license and terms of use explicitly mentioned and properly respected?\r\nAnswer: [Yes]\r\nJustification: We properly credit all external assets and explicitly acknowledge their licenses and terms of use.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not use existing assets.\r\n- The authors should cite the original paper that produced the code package or dataset.\r\n- The authors should state which version of the asset is used and, if possible, include a URL.\r\n- The name of the license (e.g., CC-BY 4.0) should be included for each asset.\r\n- For scraped data from a particular source (e.g., website), the copyright and terms of service of that source should be provided.\r\nPAGE 22\r\n- If assets are released, the license, copyright information, and terms of use in the package should be provided. For popular datasets, paperswithcode.com/datasets has curated licenses for some datasets. Their licensing guide can help determine the license of a dataset.\r\n- For existing datasets that are re-packaged, both the original license and the license of the derived asset (if it has changed) should be provided.\r\n- If this information is not available online, the authors are encouraged to reach out to the asset's creators.\r\n\r\n\r\n# 13. New assets \r\n\r\nQuestion: Are new assets introduced in the paper well documented and is the documentation provided alongside the assets?\r\nAnswer: [No]\r\nJustification: This paper does not introduce new assets.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not release new assets.\r\n- Researchers should communicate the details of the dataset/code/model as part of their submissions via structured templates. This includes details about training, license, limitations, etc.\r\n- The paper should discuss whether and how consent was obtained from people whose asset is used.\r\n- At submission time, remember to anonymize your assets (if applicable). You can either create an anonymized URL or include an anonymized zip file.\r\n\r\n\r\n## 14. Crowdsourcing and research with human subjects\r\n\r\nQuestion: For crowdsourcing experiments and research with human subjects, does the paper include the full text of instructions given to participants and screenshots, if applicable, as well as details about compensation (if any)?\r\nAnswer: [NA]\r\nJustification: This paper does not involve such experiments.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not involve crowdsourcing nor research with human subjects.\r\n- Including this information in the supplemental material is fine, but if the main contribution of the paper involves human subjects, then as much detail as possible should be included in the main paper.\r\n- According to the NeurIPS Code of Ethics, workers involved in data collection, curation, or other labor should be paid at least the minimum wage in the country of the data collector.\r\n\r\n\r\n## 15. Institutional review board (IRB) approvals or equivalent for research with human subjects\r\n\r\nQuestion: Does the paper describe potential risks incurred by study participants, whether such risks were disclosed to the subjects, and whether Institutional Review Board (IRB) approvals (or an equivalent approval/review based on the requirements of your country or institution) were obtained?\r\nAnswer: [NA]\r\nJustification: This paper does not involve any user studies.\r\nGuidelines:\r\n\r\n- The answer NA means that the paper does not involve crowdsourcing nor research with human subjects.\r\n- Depending on the country in which research is conducted, IRB approval (or equivalent) may be required for any human subjects research. If you obtained IRB approval, you should clearly state this in the paper.\r\nPAGE 23\r\n- We recognize that the procedures for this may vary significantly between institutions and locations, and we expect authors to adhere to the NeurIPS Code of Ethics and the guidelines for their institution.\r\n- For initial submissions, do not include any information that would break anonymity (if applicable), such as the institution conducting the review.\r\n\r\n\r\n# 16. Declaration of LLM usage \r\n\r\nQuestion: Does the paper describe the usage of LLMs if it is an important, original, or non-standard component of the core methods in this research? Note that if the LLM is used only for writing, editing, or formatting purposes and does not impact the core methodology, scientific rigorousness, or originality of the research, declaration is not required.\r\nAnswer: [NA]\r\nJustification: No usage of LLMs in the methodology or experiments.\r\nGuidelines:\r\n\r\n- The answer NA means that the core method development in this research does not involve LLMs as any important, original, or non-standard components.\r\n- Please refer to our LLM policy (https://neurips.cc/Conferences/2025/LLM) for what should or should not be described."
  },
  {
    "path": "google_papers/TITANs/TITANs.json",
    "content": "{\r\n  \"pages\": [\r\n    {\r\n      \"index\": 0,\r\n      \"markdown\": \"# Titans: Learning to Memorize at Test Time \\n\\nAli Behrouz ${ }^{\\\\dagger}$, Peilin Zhong ${ }^{\\\\dagger}$, and Vahab Mirrokni ${ }^{\\\\dagger}$<br>$\\\\dagger$ Google Research<br>\\\\{alibehrouz, peilinz, mirrokni\\\\}@google.com\\n\\n\\n#### Abstract\\n\\nOver more than a decade there has been an extensive research effort of how effectively utilize recurrent models and attentions. While recurrent models aim to compress the data into a fixed-size memory (called hidden state), attention allows attending to the entire context window, capturing the direct dependencies of all tokens. This more accurate modeling of dependencies, however, comes with a quadratic cost, limiting the model to a fixed-length context. We present a new neural long-term memory module that learns to memorize historical context and helps an attention to attend to the current context while utilizing long past information. We show that this neural memory has the advantage of a fast parallelizable training while maintaining a fast inference. From a memory perspective, we argue that attention due to its limited context but accurate dependency modeling performs as a short-term memory, while neural memory due to its ability to memorize the data, acts as a long-term, more persistent, memory. Based on these two modules, we introduce a new family of architectures, called Titans, and present three variants to address how one can effectively incorporate memory into this architecture. Our experimental results on language modeling, common-sense reasoning, genomics, and time series tasks show that Titans are more effective than Transformers and recent modern linear recurrent models. They further can effectively scale to larger than 2 M context window size with higher accuracy in needle-in-haystack tasks compared to baselines.\\n\\n\\n## 1 Introduction\\n\\n\\\"The true art of memory is the art of attention!\\\"\\n\\n- Samuel Johnson, 1787\\n\\nTransformers, pure attention-based architectures (Vaswani et al. 2017), have been firmly established as state-of-the-art models in sequence modeling, mainly due to their in-context learning and ability to learn at scale (Kaplan et al. 2020). The primary building blocks of Transformers-attention modules-function as associative memory blocks (Bietti et al. 2024), where they learn to store key-value associations and retrieve them by computing pairwise similarity between queries (i.e., search signals) and keys (i.e., contexts). Accordingly, by design, the output of a Transformer is exclusively conditioned on the direct dependencies of tokens in the current context window. This accurate modeling of dependencies, however, comes with quadratic time and memory complexity in terms of the context length. In complex real-world tasks (e.g., language modeling (N. F. Liu et al. 2024), video understanding (C.-Y. Wu et al. 2019), long-term time series forecasting (H. Zhou et al. 2021)), the context window can become extremely large, making the applicability of Transformers challenging in these downstream tasks.\\n\\nTo overcome the scalability issue of Transformers, recent studies aim to design different variants of linear Transformers (Kacham, Mirrokni, and P. Zhong 2024; Katharopoulos et al. 2020; S. Yang, B. Wang, Shen, et al. 2024), where softmax is replaced by a kernel function in the attention (see $\\\\S 2.1$ for details), resulting in a significant drop in memory consumption. Despite efficiency and the ability to scale to longer context, linear Transformers do not show competitive performance compared to Transformers as the kernel trick makes the model a linear recurrent network, in which the data is compressed into a matrix-valued states (Katharopoulos et al. 2020). This, however, brings a contradictory fact about linear recurrent (or linear Transformers) models: On one hand, we use these linear models to enhance scalability and efficiency (linear vs. quadratic complexity), whose advantages is appeared for very long context; On the other hand, a very long context cannot be properly compressed in a small vector-valued or matrix-valued states (S. Wang 2024).\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 1,\r\n      \"markdown\": \"Furthermore, beyond efficiency, most existing architectures-ranging from Hopfield Networks (Hopfield 1982) to LSTMs (Jürgen Schmidhuber and Hochreiter 1997) and Transformers (Vaswani et al. 2017)-face challenges when dealing with generalization, length extrapolation, and/or reasoning (Anil et al. 2022; Qin, Y. Zhong, and Deng 2024), all of which are inseparable parts of many hard real-world tasks. Although these architectures draw inspiration from the human brain, each of which are missing: (1) a crucial component for learning process-such as short-term memory, long-term memory, meta-memory, attending to current context, etc. (Cowan 2008); (2) how these components are interconnected systems that can operate independently; and/or (3) the ability to actively learn from data and memorize the abstraction of past history. We argue that in an effective learning paradigm, similar to human brain, there are distinct yet interconnected modules, each of which is responsible for a component crucial to the learning process.\\n\\n# Memory Perspective \\n\\nMemory is a fundamental mental process and is an inseparable component of human learning (Terry 2017). Without a properly functioning memory system, humans and animals would be restricted to basic reflexes and stereotyped behaviors. Accordingly, memory has been the inspiration for many seminal research in machine learning literature; e.g., Hopfield Networks (Hopfield 1982), LSTMs (Jürgen Schmidhuber and Hochreiter 1997), and Transformers (Vaswani et al. 2017).\\n\\nTaking inspiration from the common definitions of memory and learning in neuropsychology literature (Okano, Hirano, and Balaban 2000), most existing architectures consider memory as a neural update caused by an input, and define learning as a process for acquiring effective and useful memory, given an objective. In this perspective, Recurrent Neural Networks (RNNs) (Williams and Zipser 1989) can be defined as models with a vector-valued memory module $\\\\mathcal{M}$ (also called hidden state) with two main steps: Given a new input $x_{t}$ at time $t$, the model (1) updates the memory using a function $f\\\\left(\\\\mathcal{M}_{t-1}, x_{t}\\\\right)$ (with compression); and (2) retrieves the corresponding memory of input using a function $g\\\\left(\\\\mathcal{M}_{t}, x_{t}\\\\right)$ (see $\\\\S 2.1$ for details). Similarly, Transformers can be seen as architectures with a growing memory and two similar steps. That is, the pair of key and value matrices acts as the model's memory, and the model: (1) updates the memory by appending the key and value to the memory (without compression), and (2) retrieves query vectors' corresponding memory by finding the similarity of query and key vectors, which is then used to weight the value vectors for the output.\\n\\nThis perspective, can help us better understand existing paradigms, their critical differences, and design more effective architectures. For example, the main difference between Transformers (Vaswani et al. 2017) and linear Transformers (Katharopoulos et al. 2020) is the memory structure as well as the memory updating step, in which linear Transformers compress the historical data into a fixed-size matrix-valued memory while Transformers keep all historical data (within the context length) without any compression. While both linear Transformers and linear RNNs (including state space models) compress the information in memory update step, the critical difference lies in the structure of the memory, where linear RNNs (vs. linear Transformers) use a vector-valued memory (vs. matrix-valued memory). Therefore, this perspective motivates us to ask: (Q1) What constitute a good structure for the memory? (Q2) What is a proper memory update mechanism? and (Q3) What is a good memory retrieval process?\\n\\nRevisiting our understanding of human memory, it is neither a unitary process nor it serves a single function (Cowan 2008). In fact, memory is a confederation of systems-e.g., short-term, working, and long-term memory-each serving a different function with different neural structures, and each capable of operating independently (Willingham 1997). This fact motivates us to ask: (Q4) How to design an efficient architecture that incorporates different interconnected memory modules. Finally, storing a memory is a neural process that requires to encode and store the abstraction of the past. It can be over-simplification to assume a single vector or a matrix, whose parameters are encoding the data in a linear manner, are enough for storing long-term history. (Q5) Is a deep memory module needed to effectively store/remember long past?\\n\\n## Contributions and Roadmap\\n\\nIn this paper, we aim to answer the above five questions by designing a long-term neural memory module, that can efficiently and effectively learn to memorize at test time. Building upon its design, we discuss how it can be incorporated into an architecture.\\n\\nNeural Memory (§3). We present a (deep) neural long-term memory that (as a meta in-context model) learns how to memorize/store the data into its parameters at test time. Inspired by human long-term memory system (Mandler 2014),\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 2,\r\n      \"markdown\": \"we design this memory module so an event that violates the expectations (being surprising) is more memorable. To this end, we measure the surprise of an input with the gradient of the neural network with respect to the input in associative memory loss (see $\\\\S 3.1$ for details). To better handle the limited memory, we present a decaying mechanism that consider the proportion of memory size and the amount of data surprise, resulting in better memory management. We show that this decay mechanism is in fact the generalization of forgetting mechanism in modern recurrent models (Dao and Gu 2024; Gu and Dao 2024; S. Yang, Kautz, and Hatamizadeh 2024). Interestingly, we find that this mechanism is equivalent to optimizing a meta neural network with mini-batch gradient descent, momentum, and weight decay. Building upon tensorizing mini-batch gradient descent to use more matmul operations (Yu Sun et al. 2024), we present a fast and parallelizable algorithm to train our deep neural long-term memory.\\n\\nTitans Architectures (§4). After designing the long-term neural memory, an important remaining question is how to effectively and efficiently incorporate memory into a deep learning architecture. We present Titans, a family of deep models that consists of three hyper-heads: (1) Core: this module consists of the short-term memory, and is responsible for the main flow of processing the data (we use attention with limited window size); (2) Long-term Memory: this branch is our neural long-term memory module that is responsible to store/remember long past; (3) Persistent Memory: this is a set of learnable but date-independent parameters that encodes the knowledge about a task. Finally, as a proof of concept, we present three variants of Titans, in which we incorporate memory as: (i) a context, (ii) a layer, and (iii) a gated branch.\\n\\nExperimental Results (§5). We perform experimental evaluations on language modeling, commonsense reasoning, recallintensive, needle in haystack, time series forecasting, and DNA modeling tasks. We observe that our Titan architecture outperforms all modern recurrent models as well as their hybrid variants (combining with sliding-window attention) across a comprehensive set of benchmarks. Furthermore, Titans outperforms Transformers with the same context window, and show competitive performance with Transformers that use the entire context. This results are achieved while, contrary to Transformers, Titans scale to larger than 2 M context window size.\\n\\n# 2 Preliminaries \\n\\n$\\\\square$In this section, we discuss the notation and some background concepts that we use though the paper. We let $\\\\mathbf{y}_{i} \\\\in \\\\mathbb{R}^{N \\\\times d_{\\\\text {in }}}$ be the input, $\\\\mathcal{M}$ be a neural network (neural memory module), $\\\\mathbf{Q}, \\\\mathbf{K}, \\\\mathbf{V}$ be the query, key and value of the attention mechanism, and $\\\\mathbf{M}$ be the attention mask. When segmenting the sequence, we use $S^{(t)}$ to refer to the $i$-th segment. Through the paper, we abuse the notation and use subscripts to refer to a specific element of a matrix, vector, or segments. For example, we let $S_{j}^{(t)}$ be the $j$-th token in the $i$-th segment. The only exception is subscripts with $t$, which we reserved to index recurrence over time, or the state of a neural network at time $t$. Given a neural network $\\\\mathcal{N}$ and a data sample $x$, we use $\\\\mathcal{N}(x)$ (resp. $\\\\mathcal{N}^{*}(x)$ ) to refer to the forward pass with (resp. without) weight adjustment. Also, we abuse the notation and use $\\\\mathcal{N}^{(k)}$ to refer to the $k$-th layer of the neural network. In the following, we first, discuss the backgrounds for attention and its efficient variants followed by a review of modern linear RNNs. Finally, we discuss a memory perspective of these architectures that motivates us to design Titans.\\n\\n### 2.1 Backgrounds\\n\\nAttention. Transformers (Vaswani et al. 2017) as the de facto backbone for many deep learning models are based on attention mechanism. Given input $x \\\\in \\\\mathbb{R}^{N \\\\times d_{\\\\text {in }}}$, causal attention computes output $\\\\mathbf{y} \\\\in \\\\mathbb{R}^{N \\\\times d_{\\\\text {in }}}$ based on softmax over input dependent key, value, and query matrices:\\n\\n$$\\n\\\\begin{gathered}\\n\\\\mathbf{Q}=x \\\\mathbf{W}_{\\\\mathbf{Q}}, \\\\quad \\\\mathbf{K}=x \\\\mathbf{W}_{\\\\mathbf{K}}, \\\\quad \\\\mathbf{V}=x \\\\mathbf{W}_{\\\\mathbf{V}} \\\\\\\\\\n\\\\mathbf{y}_{i}=\\\\sum_{j=1}^{i} \\\\frac{\\\\exp \\\\left(\\\\mathbf{Q}_{i}^{\\\\top} \\\\mathbf{K}_{j} / \\\\sqrt{d_{\\\\text {in }}}\\\\right) \\\\mathbf{V}_{j}}{\\\\sum_{f=1}^{i} \\\\exp \\\\left(\\\\mathbf{Q}_{i}^{\\\\top} \\\\mathbf{K}_{\\\\mathrm{f}} / \\\\sqrt{d_{\\\\text {in }}}\\\\right)}\\n\\\\end{gathered}\\n$$\\n\\nwhere $\\\\mathbf{W}_{\\\\mathbf{Q}}, \\\\mathbf{W}_{\\\\mathbf{K}}$, and $\\\\mathbf{W}_{\\\\mathbf{V}} \\\\in \\\\mathbb{R}^{d_{\\\\text {in }} \\\\times d_{\\\\text {in }}}$ are learnable parameters. Despite the power and effectiveness in recall, transformers need at least $N \\\\times d$ operators to calculate the output, resulting in larger memory consumption and lower-throughput for longer sequences.\\n\\nEfficient Attentions. To improve the memory consumption and throughput of softmax attention for longer sequences, various studies focused on I/O aware implementations of attention (Dao 2024; Dao, D. Fu, et al. 2022), designing more\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 3,\r\n      \"markdown\": \"efficient attention mechanisms by sparsifying the attention matrix (B. Chen et al. 2021; Choromanski et al. 2021; Dai et al. 2019), approximating the softmax (Arora et al. 2024), or developing kernel-based (linear) attentions (Aksenov et al. 2024; Kacham, Mirrokni, and P. Zhong 2024; Schlag, Irie, and Jürgen Schmidhuber 2021; S. Yang, B. Wang, Shen, et al. 2024). In this part, we focus on the later, i.e., linear attentions, where the softmax in standard attention is replaced with an alternative kernel function $\\\\phi(\\\\cdot, \\\\cdot)$, such that $\\\\phi(x, y)=\\\\phi(x) \\\\phi(y)$. Accordingly, the attention can be written as:\\n\\n$$\\n\\\\mathbf{y}_{i}=\\\\sum_{j=1}^{i} \\\\frac{\\\\phi\\\\left(Q_{i}^{\\\\top} K_{j}\\\\right)}{\\\\sum_{\\\\ell=1}^{i} \\\\phi\\\\left(Q_{i}^{\\\\top} K_{\\\\ell}\\\\right)} V_{j}=\\\\sum_{j=1}^{i} \\\\frac{\\\\phi\\\\left(Q_{i}\\\\right)^{\\\\top} \\\\phi\\\\left(K_{j}\\\\right)}{\\\\sum_{\\\\ell=1}^{i} \\\\phi\\\\left(Q_{i}\\\\right)^{\\\\top} \\\\phi\\\\left(K_{\\\\ell}\\\\right)} V_{j}=\\\\frac{\\\\phi\\\\left(Q_{i}\\\\right)^{\\\\top} \\\\sum_{j=1}^{i} \\\\phi\\\\left(K_{j}\\\\right) V_{j}}{\\\\phi\\\\left(Q_{i}\\\\right)^{\\\\top} \\\\sum_{\\\\ell=1}^{i} \\\\phi\\\\left(K_{\\\\ell}\\\\right)}\\n$$\\n\\nresulting in a higher-throughput as terms $\\\\sum_{j=1}^{i} \\\\phi\\\\left(K_{j}\\\\right)$ and $\\\\sum_{\\\\ell=1}^{i} \\\\phi\\\\left(K_{\\\\ell}\\\\right)$ are re-using in each step. When choosing the kernel as identity matrix (Yutao Sun et al. 2023), the above formulation can also be written in a recurrent format:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\mathcal{M}_{t}=\\\\mathcal{M}_{t-1}+K_{t}^{\\\\top} V_{t} \\\\\\\\\\n& \\\\mathbf{y}_{t}=Q_{t} \\\\mathcal{M}_{t}\\n\\\\end{aligned}\\n$$\\n\\nwhich allows efficient inference for linear attentions.\\nModern Linear Models and Their Memory Perspective. As discussed earlier, one can define learning as a process for acquiring effective and useful memory. Building upon this, one can see the hidden state of Recurrent Neural Networks (RNNs) as a memory unit, which the model aims to compress the information into. Accordingly, in a general form of recurrent neural network, the hidden state can be treated as a memory unit and the recurrence process can be split into the read and write operations in the memory unit. That is, we let $x \\\\in \\\\mathbb{R}^{N \\\\times d_{\\\\text {in }}}$ be the input, $\\\\mathcal{M} \\\\in \\\\mathbb{R}^{d}$ is the memory unit, and $\\\\mathbf{y} \\\\in \\\\mathbb{R}^{d_{\\\\text {in }}}$ is the output, then the general form of the recurrent neural network is defined as:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\mathcal{M}_{t}=f\\\\left(\\\\mathcal{M}_{t-1}, x_{t}\\\\right), \\\\\\\\\\n& \\\\text { Write Operation } \\\\\\\\\\n& \\\\mathbf{y}_{t}=g\\\\left(\\\\mathcal{M}_{t}, x_{t}\\\\right),\\n\\\\end{aligned}\\n$$\\n\\nwhere $f(\\\\cdot, \\\\cdot)$ is the read and $g(\\\\cdot, \\\\cdot)$ is the write corresponding functions. Note that here the subscript of $\\\\mathcal{M}_{t}$ shows the state of the memory at time $t$.\\n\\nIn this perspective, the recurrence formula of linear Transformers (see Equation 4) is equivalent to additively compress and write keys and values, $\\\\left(K_{t}, V_{t}\\\\right)$, into a matrix-valued memory unit $\\\\mathcal{M}_{t}$. Therefore, when dealing with long context data, this additive nature of the process results in memory overflow, significantly damaging the performance of the model. To address this, studies have focused on two promising directions: (1) Adding forget mechanism: several studies have presented adaptive (data-dependent) forgetting gate mechanisms for linear models, where it can erase the memory when it is needed. As examples of such models, we refer to GLA (S. Yang, B. Wang, Shen, et al. 2024), LRU (Orvieto et al. 2023), Griffin (De et al. 2024), xLSTM (Beck et al. 2024), and Mamba2 (Dao and Gu 2024), which the later is also connected to the discretized version of traditional state space models (Gu and Dao 2024).(2) Improving the write operation: To overcome the additive nature of memory write operation in traditional recurrent models, Widrow and Hoff (1988) presented Delta Rule, in which before adding a memory (i.e., a pair of key and value), the model first removes its past value. To enhance the parallelizable training and scaling, S. Yang, B. Wang, Yu Zhang, et al. (2024) present a fast paralellizable algorithm. Finally, very recently, S. Yang, Kautz, and Hatamizadeh (2024) improved the DeltaNets by adding a forget gate.\\n\\nMemory Modules. Memory has always been one of the core parts of the neural network designs (Graves, Wayne, and Danihelka 2014; JH Schmidhuber 1992; Jürgen Schmidhuber and Hochreiter 1997; J. Zhang et al. 2024). The idea of seeing linear layers as the key-value (associative) memory system backs to fast weight programs, in which dynamic fast programs are incorporated into recurrent neural networks to serve as writable memory (JH Schmidhuber 1992). The two learning rules of Hebbian (Hebb 2005) and delta (Prados and Kak 1989) are the most popular learning rules for fast weight programs, which have been extensively explored in various studies (Irie, Schlag, et al. 2021; Munkhdalai, Sordoni, et al. 2019; Munkhdalai and H. Yu 2017; Schlag, Irie, and Jürgen Schmidhuber 2021; JH Schmidhuber 1992; S. Yang, Kautz, and Hatamizadeh 2024; S. Yang, B. Wang, Yu Zhang, et al. 2024). All these models, however, are based on momentary surprise, missing the token flow in the sequences (see Section 3.1), and most of them lacks a forgetting gate, resulting in a poor memory management.\\n\\nWe further discuss the connection of our architectures with recent models in Appendix C. Additional related work are discussed in Appendix A.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 4,\r\n      \"markdown\": \"# 3 Learning to Memorize at Test Time \\n\\nO overcome the lack of long-term memory and to enable the model to learn, forget, and retrieve information, in this section, we present a neural long-term memory module, which is a meta models that learns to memorize at test time. In Section 3.1, we first discuss the motivation and the design of the neural memory. In Section 3.2, we discuss how our architecture design can benefit from a fast and parallelizable training. Finally, in Section 3.3, we augment our architecture using persistent memory module, in which we use learnable but data-independent parameters to learn meta information about the task.\\n\\n### 3.1 Long-term Memory\\n\\nTo design a neural long-term memory module, we need a model that can encode the abstraction of the past history into its parameters. An example of this can be LLMs that are shown to be memorizing their training data (Leybzon and Kervadec 2024; Schwarzschild et al. 2024; Staab et al. 2024). Therefore, a simple idea is to train a neural network and expect it to memorize its training data. Memorization, however, has almost always been known as an undesirable phenomena in neural networks as it limits the model generalization (Bayat et al. 2024), causes privacy concerns (Staab et al. 2024), and so results in poor performance at test time. Moreover, the memorization of the training data might not be helpful at test time, in which the data might be out-of-distribution. We argue that, we need an online meta-model that learns how to memorize/forget the data at test time. In this setup, the model is learning a function that is capable of memorization, but it is not overfitting to the training data, resulting in a better generalization at test time.\\n\\nLearning Process and Surprise Metric. The key idea to train a long-term memory is to treat its training as an online learning problem, in which we aim to compress the past information $x_{1}, \\\\ldots, x_{t-1}$ into the parameters of our long-term neural memory module $\\\\mathcal{M}_{t}$. As discussed earlier, an event that violates the expectations (i.e., is surprising) is more memorable for humans (Mandler 2014). Inspired by this, a simple definition of surprise for a model can be its gradient with respect to the input. The larger the gradient is, the more different the input data is from the past data. Accordingly, using this surprise score, we can update the memory as:\\n\\n$$\\n\\\\mathcal{M}_{t}=\\\\mathcal{M}_{t-1}-\\\\theta_{t} \\\\underbrace{\\\\nabla \\\\ell\\\\left(\\\\mathcal{M}_{t-1} ; x_{t}\\\\right)}_{\\\\text {Surprise }}\\n$$\\n\\nThis surprise metric, however, can result in missing important information that comes after a big surprising moment. That is, the gradient can become extremely small after several surprising steps, leading to stocking in a flat area (i.e., local minima), and missing information about some parts of the sequence. From the human memory perspective, an event might not consistently surprise us through a long-period of time although it is memorable. The reason is that the initial moment is surprising enough to get our attention through a long time frame, leading to memorizing the entire time frame. To improve the above surprise metric (Equation 8), we break the surprise metric into (1) past surprise, which measures the surprise amount of a very recent past; and (2) momentary surprise, which measures the surprise of incoming data:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\mathcal{M}_{t}=\\\\mathcal{M}_{t-1}+S_{t} \\\\\\\\\\n& S_{t}=\\\\eta_{t} \\\\underbrace{S_{t-1}}_{\\\\text {Past Surprise }}-\\\\theta_{t} \\\\underbrace{\\\\nabla \\\\ell\\\\left(M_{t-1} ; x_{t}\\\\right)}_{\\\\text {Momentary Surprise }}\\n\\\\end{aligned}\\n$$\\n\\nInterestingly, this formulation is similar to gradient descent with momentum, where $S_{t}$ is the momentum element. Therefore, the momentum here act as a memory of surprise across time (sequence length). In this formulation, the term $\\\\eta_{t}$ is a data-dependent surprise decay (a function of $x_{t}$ ), controlling how surprise decays over time, and the term $\\\\theta_{t}$ is controlling how much of momentary surprise should be incorporated into the final surprise metric in a data-dependent manner. This data-dependency is particularly important in this design: While surprise of previous tokens might be needed to affect the surprise of the next token, it is mostly valid if all tokens are relevant and are in the same context. Accordingly, a data-dependent $\\\\eta$ can control if memory needs to: (1) ignore the last surprise by setting $\\\\eta_{t} \\\\rightarrow 0$ (possibly due to the change of context), or (2) fully incorporate the last surprise by setting $\\\\eta_{t} \\\\rightarrow 1$ (possibly as the token is highly relevant to its recent past tokens).\\n\\nObjective. Our above surprise metric is based on a loss function $\\\\ell(. ;$.$) , which is the objective that our memory is learning$ to act as it at test time. That is, our memory module is a meta model that learns a function based on the loss function $\\\\ell(. ;$.$) .$\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 5,\r\n      \"markdown\": \"In this work, we focus on associative memory, in which we aim to store the past data as the pairs of keys and values. Given $x_{t}$, similar to Transformers (Vaswani et al. 2017), we use two linear layers to project $x_{t}$ into a key and value:\\n\\n$$\\n\\\\mathbf{k}_{t}=x_{t} W_{K}, \\\\quad \\\\mathbf{v}_{t}=x_{t} W_{V}\\n$$\\n\\nwhere $W_{K}$ and $W_{V} \\\\in \\\\mathbb{R}^{d_{\\\\mathrm{m}} \\\\times d_{\\\\mathrm{m}}}$. Next, we expect our memory module to learn the associations between keys and values. To this end, we define the loss as follows:\\n\\n$$\\n\\\\ell\\\\left(\\\\mathcal{M}_{t-1} ; x_{t}\\\\right)=\\\\left\\\\|\\\\mathcal{M}_{t-1}\\\\left(\\\\mathbf{k}_{t}\\\\right)-\\\\mathbf{v}_{t}\\\\right\\\\|_{2}^{2}\\n$$\\n\\nBy optimizing the above loss function in the inner-loop of our meta model (memory), the model learns how to memorize the mapping between keys and values at test time. Note that, similar to meta-learning models (Nichol 2018; Zintgraf et al. 2019), training of the memory is in the inner-loop, and so parameters $W_{K}$ and $W_{V}$ are hyperparameters in the above loss function. Accordingly, in the inner loop, we optimize $\\\\mathcal{M}^{\\\\prime}$ s weights, while in the outer-loop, we optimize other parameters of the entire architecture.\\n\\nForgetting Mechanism. When dealing with very large sequences (e.g., millions of tokens), it is crucial to manage which past information should be forgotten-even with a deep or a very large matrix-valued memory. To this end, we use an adaptive forgetting mechanism that allows the memory to forget the information that is not needed anymore, resulting in better managing the memory's limited capacity. That is, given the next token $x_{t}$, we modify the update rule as:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\mathcal{M}_{t}=\\\\left(1-\\\\alpha_{t}\\\\right) \\\\mathcal{M}_{t-1}+S_{t} \\\\\\\\\\n& S_{t}=\\\\eta_{t} S_{t-1}-\\\\theta_{t} \\\\nabla \\\\ell\\\\left(M_{t-1} ; x_{t}\\\\right)\\n\\\\end{aligned}\\n$$\\n\\nwhere $\\\\alpha_{t} \\\\in[0,1]$ is the gating mechanism that flexibly controls the memory; i.e., decides how much information should be forgotten. For example, it can update the memory without affecting the past abstraction by letting $\\\\alpha_{t} \\\\rightarrow 0$, and can clear the entire memory by letting $\\\\alpha_{t} \\\\rightarrow 1$. Later in this section, we show that this weight decay mechanism is closely related to the gating mechanism in modern RNNs (Dao and Gu 2024; Orvieto et al. 2023).\\n\\nMemory Architecture. In this paper, we focus on simple MLPs with $L_{M} \\\\geq 1$ layers as the architecture of our long-term memory. The main reason behind this choice is that we want to focus on better motivating the design of the long-term memory and ways that it can be incorporated into an architecture. However, our formulation and architectural design opens a new research direction to design neural architectures that are more effective and efficient in memorization of data. Recently, there has been a promising line of work to design such architectures (Berges et al. 2024; Cetin et al. 2024; J. Zhang et al. 2024), which incorporating them into our framework (i.e., replacing simple MLPs with such architectures) can be an interesting future work.\\n\\nWhen using vector-valued or matrix-valued memory (De et al. 2024; Orvieto et al. 2023; S. Yang, B. Wang, Shen, et al. 2024), the memory module is compressing the past data and fit it into a line. That is, from the meta learning or online learning perspective (Yu Sun et al. 2024), using a matrix-valued memory $\\\\mathcal{M}=W \\\\in \\\\mathbb{R}^{d_{\\\\mathrm{m}} \\\\times d_{\\\\mathrm{m}}}$ is equivalent to optimize $\\\\ell\\\\left(W_{t-1} ; x_{t}\\\\right)=\\\\left\\\\|W_{t-1} \\\\mathbf{k}_{t}-\\\\mathbf{v}_{t}\\\\right\\\\|_{2}^{2}$, which is an online linear regression objective and so the optimal solution assumes the underlying dependency of historical data is linear. On the other hand, we argue that deep memory modules (i.e., $L_{M} \\\\geq 2$ ). Aligning with the theoretical results that MLPs with at least two layers are strictly more expressive than linear models (Hornik, Stinchcombe, and White 1989), in Section 5.5, we show that deep memory modules are more effective in practice.\\n\\nRetrieving a Memory. In the above, we discuss how one can design and train a long-term memory module that learns to memorize at test time. A key remaining question is: How one can retrieve information from the memory? We simply use the forward pass without weight update (i.e., inference) to retrieve a memory correspond to a query. Formally, given an input $x_{t}$, we use a linear layer $W_{Q}$ to project the input, i.e., $\\\\mathbf{q}_{t}=x_{t} W_{Q}$ and retrieve the corresponding (or useful) information from the memory $y_{t}$ by:\\n\\n$$\\ny_{t}=\\\\mathcal{M}^{*}\\\\left(\\\\mathbf{q}_{t}\\\\right)\\n$$\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 6,\r\n      \"markdown\": \"![img-0.jpeg](img-0.jpeg)\\n\\nFigure 1: The illustration of how the training of neural memory can be done in parallel and using matmuls.\\n\\n# 3.2 How to Parallelize the Long-term Memory Training \\n\\nAs discussed above, the design of our long-term memory module is equivalent to training a meta model by optimizing associative memory loss function $\\\\ell\\\\left(\\\\mathcal{M}_{t-1} ; x_{t}\\\\right)=\\\\left\\\\|\\\\mathcal{M}_{t-1}\\\\left(\\\\mathbf{k}_{t}\\\\right)-\\\\mathbf{v}_{t}\\\\right\\\\|_{2}^{2}$ using gradient descent with momentum and weight decay. Therefore, in theory, the training of long-term memory module requires $O(N)$ FLOPs, where $N$ is the sequence length. However, in practice, we need to parallelize the training process and to fully take advantage of hardware accelerators (e.g., TPUs, GPUs), we need to tensorize the process and use more matmuls.\\n\\nNext, we show that calculating the weights in the inner loop with mini-batch gradient descent, data-dependent learning rate, and weight decay can be reformulated so that it uses only matmuls and sum. We build upon the work of Yu Sun et al. (2024) that shows forward pass of a model optimizing with the mini-batch gradient descent (with constant learning rate) can be calculated using matmuls. We can split the sequence into chunks of size $b \\\\geq 1$, and write the mini-batch gradient descent as:\\n\\n$$\\n\\\\mathcal{M}_{t}=\\\\left(1-\\\\alpha_{t}\\\\right) \\\\mathcal{M}_{t-1}-\\\\theta_{t} \\\\nabla \\\\ell\\\\left(\\\\mathcal{M}_{t-1} ; x_{t}\\\\right)=\\\\beta_{t} \\\\mathcal{M}_{0}-\\\\sum_{i=1}^{t} \\\\theta_{i} \\\\frac{\\\\beta_{t}}{\\\\beta_{i}} \\\\nabla \\\\ell\\\\left(\\\\mathcal{M}_{t^{\\\\prime}} ; x_{i}\\\\right)\\n$$\\n\\nwhere $t^{\\\\prime}=t-\\\\bmod (t, b)$, and $\\\\beta_{i}=\\\\prod_{j=1}^{i}\\\\left(1-\\\\alpha_{j}\\\\right)$. For the sake of simplicity, we focus on the first chunk, i.e., $t=b$ and so $t^{\\\\prime}=0$. Also, we explain the process for the case that $\\\\mathcal{M}_{t}=W_{t}$ is linear. The process for MLPs with $N_{p} \\\\geq 2$ is similar. Using our loss function, we have:\\n\\n$$\\n\\\\nabla \\\\ell\\\\left(W_{0} ; x_{t}\\\\right)=\\\\left(W_{0} x_{t}-x_{t}\\\\right) x_{t}^{\\\\top} \\\\Rightarrow \\\\sum_{i=1}^{b} \\\\theta_{i} \\\\frac{\\\\beta_{b}}{\\\\beta_{i}} \\\\nabla \\\\ell\\\\left(W_{0} ; x_{i}\\\\right)=\\\\Theta_{b} \\\\mathbf{B}_{b}\\\\left(W_{0} X-X\\\\right) X^{\\\\top}\\n$$\\n\\nwhere $\\\\Theta_{b}=\\\\operatorname{diag}\\\\left(\\\\left[\\\\begin{array}{llll}\\\\theta_{1} & \\\\theta_{2} & \\\\ldots & \\\\theta_{b}\\\\end{array}\\\\right]\\\\right)$ and $\\\\mathbf{B}_{b}$ is defined analogously on $\\\\frac{\\\\beta_{b}}{\\\\beta_{i}} \\\\mathrm{~s}$. Note that, we do not need to store all $\\\\Theta_{k b}$ and $\\\\mathbf{B}_{k b}$ for $k=1, \\\\ldots, N / b$, instead, we store these matrices for each chunk, resulting in using less memory. Next, we extend this representation so we can also incorporate the momentum term. In a chunk wise gradient descent with momentum, if we look at the momentum term, we have:\\n\\n$$\\nS_{t}=\\\\eta_{t} S_{t-1}-\\\\theta_{t} u_{t}\\n$$\\n\\nwhere $u_{t}=\\\\nabla \\\\ell\\\\left(M_{t^{\\\\prime}} ; x_{t}\\\\right)$. Note that, we can compute all $u_{t}$ at the same time, and so Equation 18 is a linear recurrence with $u_{t}$ as an input, $S_{t}$ as the hidden state, and $\\\\eta_{t}$ as input-dependent transition value. Accordingly, we can use parallel associative scan (J. T. Smith, Warrington, and Linderman 2023) to calculate $S_{t} \\\\mathrm{~s}$ in this chunk.\\n\\nParameters as the Function of Chunks. Instead of making parameters like $\\\\alpha_{t}, \\\\theta_{t}$, and $\\\\eta_{t}$ input-dependent (i.e., a function of token $x_{t}$ ), we can make them functions of their chunk. Despite losing expressive power, this formulation can help to make the training even faster. In this case, we are using the same value for each of $\\\\alpha, \\\\theta$, and $\\\\eta$ in each chunk. Accordingly, in Equation 17, we can store $\\\\Theta$ using a single scaler. Similarly we can make Equation 18 faster. That is, when $\\\\eta$ and $\\\\theta$ are learnable but time-invariant inside each chunk, this equation becomes a linear time-invariant system (LTI), which can be computed by a global convolution (Gu, Goel, and Re 2022). In our experiments, we make these parameters as the functions of tokens. However, such simplifications (i.e., as the function of chunks) can be the interest of future work to training larger models in more efficient manner.\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-0.jpeg\",\r\n          \"top_left_x\": 328,\r\n          \"top_left_y\": 202,\r\n          \"bottom_right_x\": 1411,\r\n          \"bottom_right_y\": 453,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 7,\r\n      \"markdown\": \"![img-1.jpeg](img-1.jpeg)\\n\\nFigure 2: Memory as a Context (MAC) Architecture. This architecture includes three branches of (1) core, (2) contextual (long-term) memory, and (3) persistent memory. The core branch concatenates the corresponding long-term and persistent memories with the input sequence. Next, attention performs on the sequence and decides what part of the information should store in the long-term memory. At the test time, parameters corresponds to contextual memory are still learning, parameters corresponds to the core branch are responsible for in-context learning, and parameters of persistent memory are responsible to store the knowledge about tasks and so are fixed.\\n\\n# 3.3 Persistent Memory \\n\\nOur long-term memory can also be seen as a contextual memory, meaning that the output is fully depend on the context. Therefore, in addition to our long-term memory, we also use a set of learnable but input-independent parameters to act as task-related memory. This type of memory has been referred to as persistent or meta-memory in the literature (X. Dong et al. 2024; Sukhbaatar, Grave, et al. 2019). Given $N_{p} \\\\geq 1$, we use learnable parameters $P=\\\\left[\\\\begin{array}{llll}p_{1} & p_{2} & \\\\ldots & p_{N_{p}}\\\\end{array}\\\\right]$ and append it to the start of our sequence: i.e., given a context window size of $N$, we modify the input as:\\n\\n$$\\nx_{\\\\text {new }}=\\\\left[\\\\begin{array}{llll}\\np_{1} & p_{2} & \\\\ldots & p_{N_{p}}\\n\\\\end{array}\\\\right] \\\\quad \\\\mid \\\\quad x\\n$$\\n\\nwhere $\\\\|$ is concatenation. Next, we discuss the motivation of persistent memory from three perspective:\\nMemory Perspective. As discussed earlier, our neural long-term memory is a contextual memory, in which all parameters are input-dependent. An effective memory system, however, also needs input-independent parameters to store the abstraction of the task knowledge. That is, mastering a task requires the memorization of the knowledge that how the task can be done, and these parameters are responsible for storing such knowledge.\\n\\nFeedforward Network Perspective. In the Transformer architectures, there are fully connected layers after the attention module, which are shown to be similar to attention weights but with data-independent parameters. That is, Sukhbaatar, Grave, et al. (2019) showed that replacing the ReLU in fully connected layers with Softmax can results in an attention-like weights, in which weights are data-independent:\\n\\n$$\\nF F N(x)=W_{V} \\\\operatorname{Softmax}\\\\left(W_{K} x\\\\right)\\n$$\\n\\nIn fact, $W_{K}$ and $W_{V}$ are acting similar to $K$ and $V$ matrices in attention module when they are input-independent. The persistent memory weights are expected to have the same functionality, meaning that using them in the first part of the sequence leads to having input-independent attention weights (Sukhbaatar, Grave, et al. 2019).\\n\\nTechnical Perspective. Attention with causal mask has implicit bias toward initial tokens in the sequence, and so attention weights are almost always highly active for initial tokens, resulting in performance damage. From the technical perspective, these learnable parameters at the start of the sequence can mitigate such effect by redistributing the attention weights more effectively (Han et al. 2024; Xiao et al. 2024).\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-1.jpeg\",\r\n          \"top_left_x\": 270,\r\n          \"top_left_y\": 195,\r\n          \"bottom_right_x\": 1496,\r\n          \"bottom_right_y\": 591,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 8,\r\n      \"markdown\": \"![img-2.jpeg](img-2.jpeg)\\n(a) Memory as a Context (MAC). We segment the sequence and use full causal attention in each window. Again, the first $N_{p}$ tokens are persistent memory and the next $N_{l}$ are long-term memory tokens\\n![img-3.jpeg](img-3.jpeg)\\n(b) Memory as Gating (MAG). We use sliding window attention (SWA) as a short-term memory and our neural memory module as a long-term memory, combining by a gating.\\n\\nFigure 3: Attention masks for different variants of Titans.\\n\\n# 4 How to Incorporate Memory? \\n\\nAn important question that remained unanswered is: How one can effectively and efficiently incorporate the designed neural memory into a deep learning architecture? As discussed earlier, from a memory perspective, the pair of K and V matrices in transformers can be interpreted as an associative memory block. Due to their accurate modeling of dependencies and so their limited context window, we interpret them as short-term memory modules, attending to the current context window size. On the other hand, our neural memory with the ability to continuously learn from data and store it in its weights can play the role of a a long-term memory. In this section, we aim to answer the above question by proposing three different variants of Titans. Later in our experiments, we show that each of these variants has its own advantages/disadvantages and also can show a trade-off between the efficiency and effectiveness in very long-contexts.\\n\\n### 4.1 Memory as a Context\\n\\nIn the first architecture design (see Figure 2), we treat the memory as a context to the current information. That is, given a long sequence $x \\\\in \\\\mathbb{R}^{N \\\\times d_{m}}$, we first chunk the sequence into fixed-size segments $S^{(i)}$ for $i=1, \\\\ldots, N / C$. Given the incoming segment $S^{(t)}$, we consider it as the current context and its past segment as the historical information. Therefore, let $\\\\mathcal{M}_{t-1}$ be the state of long-term memory before segment $S^{(t)}$, we use the input context as the query to the memory $\\\\mathcal{M}_{t-1}$ to retrieve the corresponding information from the long-term memory. That is, we retrieve the past information that corresponds to $S^{(t)}$ as:\\n\\n$$\\nh_{t}=\\\\mathcal{M}_{t-1}^{*}\\\\left(\\\\mathbf{q}_{t}\\\\right)\\n$$\\n\\nwhere $\\\\mathbf{q}_{t}=S^{(t)} W_{Q}$. Next, we use this historical information along with our persistent memory parameters as the input sequence to the attention module:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\tilde{S}^{(t)}=\\\\left[\\\\begin{array}{llll}\\np_{1} & p_{2} & \\\\ldots & p_{N_{p}}\\n\\\\end{array}\\\\right]\\\\left\\\\|h_{t}\\\\right\\\\| S^{(t)} \\\\\\\\\\n& y_{t}=\\\\operatorname{Attn}\\\\left(\\\\tilde{S}^{(t)}\\\\right)\\n\\\\end{aligned}\\n$$\\n\\nThe structure of the attention map over the entire sequence is shown in Figure 3a. We then use $y_{t}$ to update the long-term memory module for the next segment and the final output:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\mathcal{M}_{t}=\\\\mathcal{M}_{t-1}\\\\left(y_{t}\\\\right) \\\\\\\\\\n& o_{t}=y_{t} \\\\otimes \\\\mathcal{M}_{t}^{*}\\\\left(y_{t}\\\\right)\\n\\\\end{aligned}\\n$$\\n\\nNote that, in the above, we are updating the weight of $\\\\mathcal{M}_{t-1}$ through forward pass.\\nThis architecture has two key advantages: (1) Attention by having both historical and current context, has the ability to decides whether given the current data, the long-term memory information is needed. (2) The attention module helps\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-2.jpeg\",\r\n          \"top_left_x\": 190,\r\n          \"top_left_y\": 191,\r\n          \"bottom_right_x\": 868,\r\n          \"bottom_right_y\": 484,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        },\r\n        {\r\n          \"id\": \"img-3.jpeg\",\r\n          \"top_left_x\": 901,\r\n          \"top_left_y\": 222,\r\n          \"bottom_right_x\": 1564,\r\n          \"bottom_right_y\": 510,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 9,\r\n      \"markdown\": \"![img-4.jpeg](img-4.jpeg)\\n\\nFigure 4: Memory as a Gate (MAG) Architecture. This architecture, similarly, has the three branches of (1) core, (2) contextual memory, and (3) persistent memory. It, however, incorporates only persistent memory into the context and combine memory with the core branch using a gating mechanism. At test time, the behavior is the same as Figure 2.\\nthe long-term memory to store only useful information from the current context. That is, not all tokens in each segment are useful and memorizing all of them can result in memory overflow. Therefore, attention is helping the memory to understand which information is useful, better managing the memory capacity. (3) At test time: (i) persistent memory parameters are fixed as they encodes the knowledge about the task, which should not be changed; (ii) the attention module weights are in-context learner; and (iii) the long-term memory module is still learning (memorizing) the information at test time. That is, we update the weights of the neural memory even at test time as weights are encoding the abstraction of long past.\\n\\n# 4.2 Gated Memory \\n\\nIn the next variant (see Figure 4), in one branch, we directly use the input data to update the long-term memory, and in the second branch, we use a sliding window attention (SWA):\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\tilde{x}=\\\\left[\\\\begin{array}{llll}\\np_{1} & p_{2} & \\\\ldots & p_{N_{p}}\\n\\\\end{array}\\\\right] \\\\quad \\\\mid \\\\mid x \\\\\\\\\\n& y=\\\\text { SW-Attn } \\\\tilde{x}) \\\\\\\\\\n& o=y \\\\otimes \\\\mathcal{M}(\\\\tilde{x})\\n\\\\end{aligned}\\n$$\\n\\nwhere SW-Attn* is sliding window attention with prefix (see Figure 3b). Note that, contrary to the previous design, we are not segmenting the input data. Also, we abuse the notation and use $\\\\mathcal{M}(x)$ to refer to the final output of the memory after all recursion over the tokens of the sequence. In the above equation, $\\\\otimes$ can be any non-linear gating. In our experiments, we normalize the outputs $y$ and $\\\\mathcal{M}(\\\\tilde{x})$ using learnable vector-valued weights, followed by a non-linearity $\\\\sigma($.$) .$\\n\\nThe overall attention mask of this design is shown in Figure 3b. In this design, sliding window attention is act as a precise short-term memory, while the neural memory module is acting as a fading memory for the model. This architecture design can also be seen as a multi-head architecture where the structure of heads are different (X. Dong et al. 2024).\\n\\n### 4.3 Memory as a Layer\\n\\nThe last variant uses the neural Memory As a Layer (MAL) of a deep neural network (see Figure 5). This architecture design is more common in the literature, where the hybrid models stack recurrent models with full or sliding window attentions. Given input $x$, we have:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\tilde{x}=\\\\left[\\\\begin{array}{llll}\\np_{1} & p_{2} & \\\\ldots & p_{N_{p}}\\n\\\\end{array}\\\\right] \\\\quad \\\\mid \\\\mid x \\\\\\\\\\n& y=\\\\mathcal{M}(\\\\tilde{x}) \\\\\\\\\\n& o=\\\\text { SW-Attn }(y)\\n\\\\end{aligned}\\n$$\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-4.jpeg\",\r\n          \"top_left_x\": 270,\r\n          \"top_left_y\": 195,\r\n          \"bottom_right_x\": 1496,\r\n          \"bottom_right_y\": 591,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 10,\r\n      \"markdown\": \"![img-5.jpeg](img-5.jpeg)\\n\\nFigure 5: Memory as a Layer (MAL) Architecture. In this architecture, the memory layer is responsible to compress the past and current context before the attention module.\\nwhere SW-Attn is sliding window attention. The main drawback of this design is that the power of the model is limited by each of the layers and so it cannot take advantage of the complementary data processing of attention and neural memory module. In our experiments, for evaluating memory in this design, we use a similar architecture as H3 (D. Y. Fu et al. 2023), where we replace the the sequence model with our neural memory module (LMM).\\n\\nMemory Without Attention. Although in the above, we discussed MAL as the combination of LMMs and attention in a sequential manner, one simple variant of MAL is to treat LMM as a sequence model without any attention. From the memory perspective, as discussed in Section 1, we expect each part of the memory system to work independently, even if other components are disturbed. Therefore, a long-term memory module should still be a powerful model even without short-term memory (i.e., attention). We refer to this variant as LMM or Titans (LMM) in our experiments. We provide additional discussions on the connection of Titans and other modern recurrent models in Appendix C.\\n\\n# 4.4 Architectural Details \\n\\nFor the sake of simplicity and presentation, we avoid discussing the implementation details like using residual connection, gating with linear layer, and normalization. In all blocks, we use residual connections. In our implementation, we use SiLU(.) activation (Elfwing, Uchibe, and Doya 2018) as the non-linear activation for computing query, key, and values and normalize queries and keys using $\\\\ell_{2}$-norm.\\n\\nConvolution. Following the recent modern linear recurrent models (Gu and Dao 2024; S. Yang, Kautz, and Hatamizadeh 2024), we incorporate a 1D depthwise-separable convolution layer after each of the query, key, and value projections. While not significantly affect the performance, these 1D convolutions have shown performance improvement and are also computationally efficient.\\n\\nGating. We also follow the recent architectures that use normalization and gating with a linear layer before the final output projection (Mehta et al. 2023).\\nTheorem 4.1. Contrary to Transformers, diagonal linear recurrent models, and DeltaNet, all of which are limited to $\\\\mathrm{TC}^{0}$ (Merrill, Petty, and Sabharwal 2024), Titans are capable of solving problems beyond $\\\\mathrm{TC}^{0}$, meaning that Titans are theoretically more expressive than Transformers and most modern linear recurrent models in state tracking tasks.\\n\\n## 5 Experiments\\n\\n$\\\\mathbf{R}$ext, we evaluate the performance of Titans and its variants in language modeling, commonsense reasoning, needle in haystack, DNA modeling, and time series forecasting tasks ${ }^{1}$. In more details, in this section, we answer the following empirical questions: (1) How do Titans perform compared to baselines in downstream tasks? (see §5.2,\\n\\n[^0]\\n[^0]:    ${ }^{1}$ In the first version of the work, we aim to provide insights/evidences about why the learning paradigms of Titans are effective. We are working on finalizing the results of larger models and will report them in the next version.\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-5.jpeg\",\r\n          \"top_left_x\": 270,\r\n          \"top_left_y\": 202,\r\n          \"bottom_right_x\": 1492,\r\n          \"bottom_right_y\": 596,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 11,\r\n      \"markdown\": \"$\\\\S 5.6$, and $\\\\S 5.7$ ); (2) What is the actual context length of Titans? (see $\\\\S 5.3$ and $\\\\S 5.4$ ); (3) How do Titans scale with respect to context length? (see §5.8); (4) How the depth of memory can affect both performance and efficiency? (see §5.5); and (5) What is the contribution of each Titans' component in its performance? (see §5.9).\\n\\n# 5.1 Experimental Setup \\n\\nModels. In our experiments, we focus on the three variants of Titans, which we refer to as: Titans with (1) Memory as a Context (MAC), (2) Memory as a Gate (MAG), and (3) Memory as a Layer (MAL) as well as (4) neural memory module alone. The reason behind using our long-term memory as a separate module is based on our definition of learning. As discussed in Section 1, we define learning a process for acquiring effective and useful memory. Accordingly, we expect our long-term memory to effectively learn from data, even without attention. For each of these models, we consider four scales with: (i) 170 M , (ii) 340 M , (iii) 400 M , and (iv) 760 M parameters. While the first three are trained on 15B tokens sampled from FineWeb-Edu dataset (Penedo et al. 2024), the last one is trained on 30B tokens from the same dataset.\\n\\nBaselines. We compare our models with the state-of-the-art linear recurrent models, Transformers, and hybrid models (recurrent + attention). More specifically in language tasks, we compare with Transformer++ (Touvron et al. 2023), RetNet (Yutao Sun et al. 2023), Gated Linear Attention (GLA) (S. Yang, B. Wang, Shen, et al. 2024), Mamba (Gu and Dao 2024), Mamba2 (Dao and Gu 2024), DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024), TTT (Yu Sun et al. 2024), and Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024). In needle in haystack tasks, we also compare with GPT4 (Achiam et al. 2023), Llama3 with RAG (Touvron et al. 2023), RecurrentGemma2-9B (Botev et al. 2024), and Mistral (Jiang et al. 2023) models, all of which are provided in the benchmark (Yuri Kuratov et al. 2024). In time series tasks, we compare with Mamba-based (Behrouz, Santacatterina, and Zabih 2024), Transformer-based (Y. Liu et al. 2023; Nie et al. 2022; Yunhao Zhang and Yan 2023), and linear models (Das et al. 2023; Z. Li et al. 2023; H. Wu et al. 2023; Zeng et al. 2023).\\n\\nTraining. In the training, we follow the training procedure of S. Yang, Kautz, and Hatamizadeh (2024), and use LLama 2 tokenizer with a vocabulary size of 32 K and use training length of 4 K tokens. We employ AdamW optimizer with learning rate of $4 e-4$ with cosine annealing schedule with batch size of 0.5 M tokens, and weight decay of 0.1 .\\n\\n### 5.2 Language Modeling\\n\\nWe first focus on the perplexity in language modeling and also commonsense reasoning tasks. The results for Titans' variants and also baselines with three different sizes of $340 \\\\mathrm{M}, 400 \\\\mathrm{M}$, and 760 M are reported in Table 1. Among non-hybrid models, including Transformer++, our neural memory module achieves the best performance in both perplexity and accuracy measures. Comparing our neural memory module and TTT, which is also a gradient-based recurrent model can show us the importance of our weight decay as well as the momentum. As discussed earlier, the weight decay can be interpreted as a gating mechanism to forget the past data, when it is needed. Also, momentum can help us better manage the memory by providing additional memory for the surprise metric. While some baselines also take advantage of gating mechanism, e.g., Mamba, Mamba2, and Gated DeltaNet, the superior performance of our neural memory module shows the importance of both our surprise mechanism and having deep and non-linear memory. We further discuss the later in Section 5.5.\\n\\nComparing the hybrid models, we found that all three variants of Titans (MAC, MAG, and MAL) outperform both Samba (Mamba + attention) and Gated DeltaNet-H2 (Gated DeltaNet + atttention). We attribute the superior performance of Titans (MAL) to the power of neural memory module as the architecture design and used attention are all the same. Comparing Titans (MAG) and (MAC), we find that while their performance are close, MAC performs better when dealing with longer dependencies in the data. Interestingly, both MAG and MAC outperform MAL variant, which due to using the same modules, we attribute this to the architecture design of these models. This finding is particularly important as the current hybrid models (except Hymba (X. Dong et al. 2024)) in the literature are using MAL-style combination of recurrent models and attention.\\n\\n### 5.3 Needle in a Haystack\\n\\nScaling a model to longer context window is not always equivalent to being effective for very long sequences (Hsieh et al. 2024). The needle-in-a-haystack (NIAH) task is designed to measure the actual effective context length of models. In this task, we evaluate the model on retrieving a piece of information (i.e., the \\\"needle\\\") from long distractor texts (i.e.,\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 12,\r\n      \"markdown\": \"Table 1: Performance of Titans and recurrent- and Transformer-based baselines on language modeling and common-sense reasoning tasks. Hybrid models are marked with *. The best results among simple and hybrid models are highlighted.\\n\\n| Model | Wiki. <br> ppl $\\\\downarrow$ | $\\\\begin{aligned} & \\\\text { LMB. } \\\\\\\\ & \\\\text { ppl } \\\\downarrow \\\\end{aligned}$ | $\\\\begin{aligned} & \\\\text { LMB. } \\\\\\\\ & \\\\text { acc } \\\\uparrow \\\\end{aligned}$ | PIQA acc $\\\\uparrow$ | $\\\\begin{aligned} & \\\\text { Hella. } \\\\\\\\ & \\\\text { acc_n } \\\\uparrow \\\\end{aligned}$ | Wino. acc $\\\\uparrow$ | ARC-e acc $\\\\uparrow$ | ARC-c acc_n $\\\\uparrow$ | $\\\\begin{aligned} & \\\\text { SIQA } \\\\\\\\ & \\\\text { acc } \\\\uparrow \\\\end{aligned}$ | $\\\\begin{aligned} & \\\\text { BoolQ } \\\\\\\\ & \\\\text { acc } \\\\uparrow \\\\end{aligned}$ | Avg. $\\\\uparrow$ |\\n| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\\n| 340M params / 15B tokens |  |  |  |  |  |  |  |  |  |  |  |\\n| Transformer++ | 31.52 | 41.08 | 30.76 | 62.98 | 34.76 | 50.53 | 45.21 | 24.05 | 36.81 | 58.24 | 42.92 |\\n| RetNet | 32.50 | 49.73 | 28.24 | 62.61 | 34.15 | 50.91 | 44.27 | 23.62 | 36.79 | 59.72 | 42.54 |\\n| GLA | 28.51 | 43.02 | 28.73 | 64.05 | 35.96 | 50.00 | 54.19 | 24.29 | 37.13 | 58.39 | 44.09 |\\n| Mamba | 30.83 | 40.21 | 29.94 | 63.79 | 35.88 | 49.82 | 49.24 | 24.56 | 35.41 | 60.07 | 43.59 |\\n| DeltaNet | 28.65 | 47.30 | 28.43 | 63.52 | 35.95 | 49.63 | 52.68 | 25.37 | 37.96 | 58.79 | 44.04 |\\n| TTT | 27.44 | 34.19 | 30.06 | 63.97 | 35.71 | 50.08 | 53.01 | 26.11 | 37.32 | 59.83 | 44.51 |\\n| Gated DeltaNet | 27.01 | 30.94 | 34.11 | 63.08 | 38.12 | 51.60 | 55.28 | 26.77 | 34.89 | 59.54 | 45.42 |\\n| Titans (LMM) | 26.18 | 29.97 | 34.98 | 64.73 | 39.61 | 51.85 | 55.60 | 28.14 | 34.52 | 59.99 | 46.17 |\\n| Titans (MAC)* | 25.43 | 28.13 | 36.00 | 65.32 | 40.35 | 51.21 | 58.17 | 29.00 | 38.63 | 60.18 | 47.36 |\\n| Titans (MAG)* | 25.07 | 28.72 | 36.71 | 64.88 | 40.56 | 52.49 | 57.72 | 28.16 | 39.75 | 60.01 | 47.54 |\\n| Titans (MAL)* | 24.69 | 28.80 | 35.74 | 64.97 | 39.44 | 51.97 | 56.58 | 28.21 | 38.14 | 57.32 | 46.55 |\\n| 400M params / 15B tokens |  |  |  |  |  |  |  |  |  |  |  |\\n| Transformer++ | 30.63 | 37.37 | 29.64 | 64.27 | 37.72 | 51.53 | 54.95 | 27.36 | 38.07 | 61.59 | 45.64 |\\n| RetNet | 29.92 | 46.83 | 29.16 | 65.23 | 36.97 | 51.85 | 56.01 | 27.55 | 37.30 | 59.66 | 45.47 |\\n| HGRN2 | 32.33 | 47.14 | 26.12 | 64.52 | 35.45 | 52.24 | 55.97 | 25.51 | 37.35 | 59.02 | 44.52 |\\n| GLA | 27.96 | 36.66 | 27.86 | 65.94 | 37.41 | 49.56 | 56.01 | 26.36 | 38.94 | 59.84 | 45.24 |\\n| Mamba | 29.22 | 39.88 | 29.82 | 65.72 | 37.93 | 50.11 | 58.37 | 26.70 | 37.76 | 61.13 | 45.94 |\\n| Mamba2 | 26.34 | 33.19 | 32.03 | 65.77 | 39.73 | 52.48 | 59.00 | 27.64 | 37.92 | 60.72 | 46.91 |\\n| DeltaNet | 27.69 | 44.04 | 29.96 | 64.52 | 37.03 | 50.82 | 56.77 | 27.13 | 38.22 | 60.09 | 45.57 |\\n| TTT | 26.11 | 31.52 | 33.25 | 65.70 | 39.11 | 51.68 | 58.04 | 28.99 | 38.26 | 59.87 | 46.86 |\\n| Gated DeltaNet | 25.47 | 29.24 | 34.40 | 65.94 | 40.46 | 51.46 | 59.80 | 28.58 | 37.43 | 60.03 | 47.26 |\\n| Samba* | 25.32 | 29.47 | 36.86 | 66.09 | 39.24 | 51.45 | 60.12 | 27.20 | 38.68 | 58.22 | 47.23 |\\n| Gated DeltaNet-H2* | 24.19 | 28.09 | 36.77 | 66.43 | 40.79 | 52.17 | 59.55 | 29.09 | 39.04 | 58.56 | 47.69 |\\n| Titans (LMM) | 25.03 | 28.99 | 35.21 | 65.85 | 40.91 | 52.19 | 59.97 | 29.20 | 38.74 | 60.85 | 47.83 |\\n| Titans (MAC)* | 25.61 | 27.73 | 36.92 | 66.39 | 41.18 | 52.80 | 60.24 | 29.69 | 40.07 | 61.93 | 48.65 |\\n| Titans (MAG)* | 23.59 | 27.81 | 37.24 | 66.80 | 40.92 | 53.21 | 60.01 | 29.45 | 39.91 | 61.28 | 48.60 |\\n| Titans (MAL)* | 23.93 | 27.89 | 36.84 | 66.29 | 40.74 | 52.26 | 59.85 | 29.71 | 38.92 | 58.40 | 47.87 |\\n| 760M params / 30B tokens |  |  |  |  |  |  |  |  |  |  |  |\\n| Transformer++ | 25.21 | 27.64 | 35.78 | 66.92 | 42.19 | 51.95 | 60.38 | 32.46 | 39.51 | 60.37 | 48.69 |\\n| RetNet | 26.08 | 24.45 | 34.51 | 67.19 | 41.63 | 52.09 | 63.17 | 32.78 | 38.36 | 57.92 | 48.46 |\\n| Mamba | 28.12 | 23.96 | 32.80 | 66.04 | 39.15 | 52.38 | 61.49 | 30.34 | 37.96 | 57.62 | 47.22 |\\n| Mamba2 | 22.94 | 28.37 | 33.54 | 67.90 | 42.71 | 49.77 | 63.48 | 31.09 | 40.06 | 58.15 | 48.34 |\\n| DeltaNet | 24.37 | 24.60 | 37.06 | 66.93 | 41.98 | 50.65 | 64.87 | 31.39 | 39.88 | 59.02 | 48.97 |\\n| TTT | 24.17 | 23.51 | 34.74 | 67.25 | 43.92 | 50.99 | 64.53 | 33.81 | 40.16 | 59.58 | 47.32 |\\n| Gated DeltaNet | 21.18 | 22.09 | 35.54 | 68.01 | 44.95 | 50.73 | 66.87 | 33.09 | 39.21 | 59.14 | 49.69 |\\n| Samba* | 20.63 | 22.71 | 39.72 | 69.19 | 47.35 | 52.01 | 66.92 | 33.20 | 38.98 | 61.24 | 51.08 |\\n| Gated DeltaNet-H2* | 19.88 | 20.83 | 39.18 | 68.95 | 48.22 | 52.57 | 67.01 | 35.49 | 39.39 | 61.11 | 51.49 |\\n| Titans (LMM) | 20.04 | 21.96 | 37.40 | 69.28 | 48.46 | 52.27 | 66.31 | 35.84 | 40.13 | 62.76 | 51.56 |\\n| Titans (MAC) | 19.93 | 20.12 | 39.62 | 70.46 | 49.01 | 53.18 | 67.86 | 36.01 | 41.87 | 62.05 | 52.51 |\\n| Titans (MAG) | 18.61 | 19.86 | 40.98 | 70.25 | 48.94 | 52.89 | 68.23 | 36.19 | 40.38 | 62.11 | 52.50 |\\n| Titans (MAL) | 19.07 | 20.33 | 40.05 | 69.99 | 48.82 | 53.02 | 67.54 | 35.65 | 30.98 | 61.72 | 50.97 |\\n\\nthe \\\"haystack\\\"). In this part, we use Single NIAH (S-NIAH) task from RULER benchmark (Hsieh et al. 2024) and evaluate Titans and baselines on sequences with length $2 \\\\mathrm{~K}, 4 \\\\mathrm{~K}, 8 \\\\mathrm{~K}$, and 16 K . The results are reported in Table 2. Neural Memory module achieves the best results compare to baselines in all three tasks. We attribute this superior performance to three key differences of Titans with existing sequence models: (1) Compared to TTT, our Neural Memory can better handle the memory capacity by using momentum and also the forgetting mechanism (i.e., weight decay). Therefore, with increasing the sequence length, the performance of Neural Memory does not drop and show a consistent trend; (2) Compared to Mamba2, which has the gating (forgetting) mechanism, Titans have deep non-linear memory, resulting in better memory management. Also, contrary to our neural memory and DeltaNet, Mamba2 is not capable of removing a memory and so\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 13,\r\n      \"markdown\": \"Table 2: Performance of Titans and baselines on S-NIAH task from RULER benchmark. The best results among simple and hybrid models are highlighted.\\n\\n| Model | S-NIAH-PK |  |  |  | S-NIAH-N |  |  |  | S-NIAH-W |  |  |  |\\n| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\\n|  | 2 K | 4 K | 8 K | 16 K | 2 K | 4 K | 8 K | 16 K | 2 K | 4 K | 8 K | 16 K |\\n| TTT | 98.4 | 98.8 | 98.0 | 88.4 | 60.2 | 36.6 | 10.2 | 4.4 | 78.8 | 28.0 | 4.4 | 0.0 |\\n| Mamba2 | 98.6 | 61.4 | 31.0 | 5.4 | 98.4 | 55.8 | 14.2 | 0.0 | 42.2 | 4.2 | 0.0 | 0.0 |\\n| DeltaNet | 96.8 | 98.8 | 98.6 | 71.4 | 47.2 | 15.4 | 12.8 | 5.4 | 46.2 | 20.0 | 1.6 | 0.0 |\\n| Titans (LMM) | 99.8 | 98.4 | 98.2 | 96.2 | 100.0 | 99.8 | 93.4 | 80.2 | 90.4 | 89.4 | 85.8 | 80.6 |\\n| Titans (MAC) | 99.2 | 98.8 | 99.0 | 98.4 | 99.6 | 98.2 | 97.6 | 97.4 | 98.2 | 98.2 | 95.6 | 95.2 |\\n| Titans (MAG) | 99.4 | 98.0 | 97.4 | 97.4 | 99.2 | 98.8 | 97.2 | 98.6 | 98.0 | 98.0 | 90.2 | 88.2 |\\n| Titans (MAL) | 98.8 | 98.6 | 98.8 | 97.8 | 99.8 | 98.1 | 96.8 | 96.4 | 98.0 | 97.4 | 92.0 | 90.4 |\\n\\n![img-6.jpeg](img-6.jpeg)\\n\\nFigure 6: Performance of Titans and baselines on BABILong benchmark. Titans (MAC) outperforms all baselines, including extremely large models, e.g., GPT4.\\nwe can see a significant drop in performance when increasing the sequence length; (3) Compared to DeltaNet, although it is capable of removing memory using delta rule, it cannot erase the memory, lacking forgetting mechanism. Finally, As expected we can see on par or better results when using Titans variants, where the best results correspond to MAC.\\n\\n# 5.4 BABILong Benchmark \\n\\nIn the previous section we discussed the results on a simple NIAH tasks where a single needle needs to be retrieved. Although Titans showed better performance compared to baselines, their true advantage over very long sequences is still hidden. To this end, in this section, we use a harder task from BABILong benchmark (Yuri Kuratov et al. 2024), in which the model needs to reason across facts distributed in extremely long documents. We follow the original experimental setup and training process in the benchmark. There are two settings: (1) Few-shot setting, in which we use large pre-trained models, and (2) fine-tuning setting, where we fine-tune the MAC variant of Titans to compare it with other fine-tuned baselines. The results for few-shot setting are reported in Figure 6a. In this setup, we can see Titans outperform all baselines-i.e., Mamba2.8B (Gu and Dao 2024), RWKV-6-7B (Peng, Goldstein, et al. 2024), RecurrentGemma-9B (Botev et al. 2024), Gemma-9B (Team et al. 2024), Llama3.1-8B (Touvron et al. 2023), GPT-4, and GPT4o-mini (Achiam et al. 2023). These results are achieved while Titans (MAC) is having much less number of parameters than baselines.\\n\\nIn the fine-tuning setup, we compare the small fine-tuned version of Titans (MAC) with: (i) the fine-tuned version of small models (almost the same number of parameters as Titans) such as Mamba (Gu and Dao 2024), RMT (Bulatov, Yury Kuratov, and Burtsev 2022), (ii) large models with Retrieval-Augmented Generation (RAG) (P. Lewis et al. 2020) such as Llama3.18B (Touvron et al. 2023), and (iii) extremely large models such as GPT-4 (Achiam et al. 2023), GPT4o-mini, Qwen2.5-72B (A. Yang et al. 2024), and Llama3.1-70B (Touvron et al. 2023). Baseline results are reported by (Yuri Kuratov et al. 2024). The results of Titans and baselines are reported in Figure 6b. Titans outperform all models even extremely large models like GPT4. Also, compared to Transformer-based with memory models like RMT, Titans show better performance mainly due to their powerful memory. That is, RMT compress the historical data into 16 size vector-valued memory, while Titans with in-context online memory learner are capable of encoding the past into the parameters of the model. Interestingly, even\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-6.jpeg\",\r\n          \"top_left_x\": 374,\r\n          \"top_left_y\": 649,\r\n          \"bottom_right_x\": 1371,\r\n          \"bottom_right_y\": 1047,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 14,\r\n      \"markdown\": \"![img-7.jpeg](img-7.jpeg)\\n\\nFigure 7: The effect of memory depth on the perplexity. Deeper long-term memory results in better scaling in longer sequences.\\n\\nTable 3: Performance on long-term forecasting. The best results are highlighted.\\n\\n|  | Neural Memory |  | Simba |  | iTransformer |  | RLinear |  | PatchTST |  | Crossformer |  | TiDE |  | TimesNet |  | DLinear |  |\\n| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\\n|  | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE |\\n| ETTm1 | 0.358 | 0.387 | 0.383 | 0.396 | 0.407 | 0.410 | 0.414 | 0.407 | 0.387 | 0.400 | 0.513 | 0.496 | 0.419 | 0.419 | 0.400 | 0.406 | 0.403 | 0.407 |\\n| ETTm2 | 0.261 | 0.309 | 0.271 | 0.327 | 0.288 | 0.332 | 0.286 | 0.327 | 0.281 | 0.326 | 0.757 | 0.610 | 0.358 | 0.404 | 0.291 | 0.333 | 0.350 | 0.401 |\\n| ETTh1 | 0.420 | 0.421 | 0.441 | 0.432 | 0.454 | 0.447 | 0.446 | 0.434 | 0.469 | 0.454 | 0.529 | 0.522 | 0.541 | 0.507 | 0.458 | 0.450 | 0.456 | 0.452 |\\n| ETTh2 | 0.356 | 0.382 | 0.361 | 0.391 | 0.383 | 0.407 | 0.374 | 0.398 | 0.387 | 0.407 | 0.942 | 0.684 | 0.611 | 0.530 | 0.414 | 0.427 | 0.559 | 0.515 |\\n| ECL | 0.162 | 0.261 | 0.169 | 0.274 | 0.178 | 0.270 | 0.219 | 0.298 | 0.205 | 0.290 | 0.244 | 0.334 | 0.251 | 0.344 | 0.192 | 0.295 | 0.212 | 0.300 |\\n| Traffic | 0.415 | 0.289 | 0.493 | 0.291 | 0.428 | 0.282 | 0.626 | 0.378 | 0.481 | 0.304 | 0.550 | 0.304 | 0.760 | 0.473 | 0.620 | 0.336 | 0.625 | 0.383 |\\n| Weather | 0.231 | 0.265 | 0.255 | 0.280 | 0.258 | 0.278 | 0.272 | 0.291 | 0.259 | 0.281 | 0.259 | 0.315 | 0.271 | 0.320 | 0.259 | 0.287 | 0.265 | 0.317 |\\n\\naugmenting Llama3.1-8B model with RAG performs worse than Titans with about $\\\\times 70$ less parameters.\\n\\n# 5.5 The Effect of Deep Memory \\n\\nIn this section, we evaluate the effect of deep memory in both wall-clock training time and model performance ${ }^{2}$. To this end, we focus on different variants of our neural memory module, where $L_{M}=1,2,3,4$. We also use Mamba as a baseline for the model performance. For a fair comparison, we use the same training process for all models and train them on a subset of the Pile dataset (L. Gao et al. 2020).\\n\\nWe report the perplexity of our models and baselines as the function of the sequence length in Figure 7. Interestingly, with the increase of memory depth, $L_{M}$, the model can achieve better perplexity over all sequence length. Also, deeper memory modules are more robust to the sequence length when the model has less number of parameters. With the increase of the number of parameters, all models show better performance on longer sequences.\\nWe also evaluate the effect of memory depth $\\\\left(L_{M}=1,2,3,4\\\\right)$ on the training throughput. We report the training throughput (the number of tokens per second) as the function of sequence length in Figure 8. All models scale linearly with respect to the context length (i.e., constant trend in the number of tokens per second with respect to sequence length). Also, by increasing the memory depth, as expected, we can see a linear trend that a deeper memory results in a slower training. Therefore, it is not always efficient to use deeper memory modules, showing a trade-off between effectiveness and efficiency.\\n\\n### 5.6 Time Series Forecasting\\n\\n![img-8.jpeg](img-8.jpeg)\\n\\nFigure 8: The effect of memory depth on training throughput\\n\\nTo show the effectiveness of our memory module in a broader tasks, we also evaluate its performance in time series forecasting tasks. To this end, we use Simba framework (Patro and Agneeswaran 2024) for time series forecasting, and\\n\\n[^0]\\n[^0]:    ${ }^{2}$ Note that, in this experiment, we only focus on the neural memory module to evaluate the effect of memory depth in the memorization process. Combining neural memory with attention as we do in Titans variants, can additionally enhance the performance of the model over long sequences.\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-7.jpeg\",\r\n          \"top_left_x\": 190,\r\n          \"top_left_y\": 195,\r\n          \"bottom_right_x\": 1579,\r\n          \"bottom_right_y\": 528,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        },\r\n        {\r\n          \"id\": \"img-8.jpeg\",\r\n          \"top_left_x\": 1105,\r\n          \"top_left_y\": 1458,\r\n          \"bottom_right_x\": 1562,\r\n          \"bottom_right_y\": 1738,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 15,\r\n      \"markdown\": \"Table 4: Downstream evaluation of pre-trained DNA models on GenomicsBenchmarks (Grešová et al. 2023). We report top-1 classification accuracy ( $\\\\%$ ).\\n\\n| Model | Enhancer Cohn | Enhancer Ens | Human Reg. | Non-TATA Promoters | Human OCR Ens. |\\n| :-- | :--: | :--: | :--: | :--: | :--: |\\n| CNN | 69.5 | 68.9 | 93.3 | 84.6 | 68.0 |\\n| DNABERT | 74.0 | 85.7 | 88.1 | 85.6 | 75.1 |\\n| GPT | 70.5 | 83.5 | 91.5 | 87.7 | 73.0 |\\n| HyenaDNA | 74.2 | 89.2 | 93.8 | 96.6 | 80.9 |\\n| Transformer++ | 73.4 | 89.5 | 89.9 | 94.4 | 79.5 |\\n| Mamba | 73.0 | - | - | 96.6 | - |\\n| Based | 74.6 | 89.5 | 89.5 | 96.8 | 79.0 |\\n| Neural Memory Module | 75.2 | 89.6 | 89.3 | 96.6 | 79.9 |\\n\\nreplace its Mamba module with our neural memory. We report the results on common time series forecasting benchmark datasets-ETT, ECL, Traffic, and Weather (H. Zhou et al. 2021). The results are reported in Table 3. Our neural memory module is outperforming all baselines, including Mamba-based, linear-based, and Transformer-based architectures.\\n\\n# 5.7 DNA Modeling \\n\\nIn order to understand the capability of Titans beyond natural language, we further evaluate the performance of our neural memory module on DNA modeling tasks. To this end, we evaluate pre-trained models on the downstream tasks in GenomicsBenchmarks (Grešová et al. 2023). We follow the same experimental setups from Nguyen et al. (2024), and re-use the reported results of baselines by Arora et al. (2024). The performance of Titans (LMM) and baselines are reported in Table 4. We find that LMM is competitive with state-of-the-art architectures across different downstream genomics tasks.\\n\\n### 5.8 Efficiency\\n\\nIn this part, we compare the efficiency of our neural memory as well as Titans with state-of-the-art sequence models. The training throughput of models for different sequence length $\\\\times$ batch size are reported in Figure 9. Comparing recurrent models, including our neural memory module, we can see our memory module is slightly slower than Mamba2 and Gated DeltaNet, mainly due to: (1) having deep memory and more expressive transition process (memory update), and (2) highly optimized kernel in the implementation of Mamba2. Interestingly, Titans (MAL) are faster than baselines as well as the memory module. The main reason for this better throughput is the highly optimized kernel of FlashAttention (Dao 2024), which is used for implementing SWA and full attention module in Titans.\\n![img-9.jpeg](img-9.jpeg)\\n\\nFigure 9: Training throughput comparison of Titans and baselines.\\n\\n### 5.9 Ablation Study\\n\\nFinally, we perform ablation studies on the different architectural choices in Titans. We consider our neural memory module as a base model and then changing one component at a time: (1) replacing deep memory with linear memory, removing (2) convolution, (3) momentum in the surprise measure, (4) weight decay (or forgot mechanism), and (5) persistent memory. The results are reported in Table 5. All components of neural memory design are positively contributing to its performance, where the greatest contribution comes from weight decay, momentum, convolution, and persistent memory, respectively.\\n\\nThe Effect of Architectural Design. To evaluate the effect of architecture design, we compare the performance of three represented variants of Titans in three aspects of (i) language modeling, (ii) commen-sense reasoning, and (iii) long context NIAH (BABILong) tasks. The results are reported in Table 5. We find that MAC and MAG have close performance in language modeling and common-sense reasoning tasks, while MAC achieve significantly better performance in long-context NIAH. Both of these models achieve better performance than MAL. These results along with Figure 9, show a trade-off between fast training and more expressive design.\",\r\n      \"images\": [\r\n        {\r\n          \"id\": \"img-9.jpeg\",\r\n          \"top_left_x\": 1108,\r\n          \"top_left_y\": 1113,\r\n          \"bottom_right_x\": 1562,\r\n          \"bottom_right_y\": 1423,\r\n          \"image_base64\": null,\r\n          \"image_annotation\": null\r\n        }\r\n      ],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 16,\r\n      \"markdown\": \"Table 5: Ablation Study on Titans. All components of Titans are positively contributing to its performance.\\n\\n| Model | Language Modeling <br> $\\\\mathrm{ppl} \\\\downarrow$ | Reasoning <br> $\\\\mathrm{acc} \\\\uparrow$ | Long Context <br> $\\\\mathrm{acc} \\\\uparrow$ |\\n| :-- | :--: | :--: | :--: |\\n| LMM | 27.01 | 47.83 | 92.68 |\\n| +Attn (MAC) | 26.67 | 48.65 | 97.95 |\\n| +Attn (MAG) | 25.70 | 48.60 | 96.70 |\\n| +Attn (MAL) | 25.91 | 47.87 | 96.91 |\\n| Linear Memory | 28.49 | 46.97 | 85.34 |\\n| w/o Convolution | 28.73 | 45.82 | 90.28 |\\n| w/o Momentum | 28.98 | 45.49 | 87.12 |\\n| w/o Weight Decay | 29.04 | 45.11 | 85.60 |\\n| w/o Persistent Memory | 27.63 | 46.35 | 92.49 |\\n\\n# 6 Conclusion \\n\\nIn this paper, we present a neural long-term memory that, as a meta in-context learner, learns to memorize at test time. The neural memory module is a recurrent model in nature, and is adaptively memorizing tokens that are more surprising or are close to surprising tokens. Comparing to modern recurrent models, it has more expressive memory update and storing mechanism. Using this memory, we present Titans architectures, and its three variants, in which we suggest to incorporate the memory module as (1) a context, (2) gating, and (3) a layer. Our experimental evaluation on diverse tasks tasks validate that Titans are more effective than Transformers and recent modern linear recurrent models, specifically for long context. That is, Titans can scale to larger than 2 M context window size with better accuracy than baselines.\\nTitans are implemented in Pytorch and JAX and we intend to make the code we used to train and evaluate our models available soon.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 17,\r\n      \"markdown\": \"# References \\n\\n[1] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. \\\"Gpt-4 technical report\\\". In: arXiv preprint arXiv:2303.08774 (2023).\\n[2] Yaroslav Aksenov, Nikita Balagansky, Sofia Maria Lo Cicero Vaina, Boris Shaposhnikov, Alexey Gorbatovski, and Daniil Gavrilov. \\\"Linear Transformers with Learnable Kernel Functions are Better In-Context Models\\\". In: arXiv preprint arXiv:2402.10644 (2024).\\n[3] Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W Hoffman, David Pfau, Tom Schaul, Brendan Shillingford, and Nando De Freitas. \\\"Learning to learn by gradient descent by gradient descent\\\". In: Advances in neural information processing systems 29 (2016).\\n[4] Cem Anil, Yuhuai Wu, Anders Andreassen, Aitor Lewkowycz, Vedant Misra, Vinay Ramasesh, Ambrose Slone, Guy Gur-Ari, Ethan Dyer, and Behnam Neyshabur. \\\"Exploring length generalization in large language models\\\". In: Advances in Neural Information Processing Systems 35 (2022), pp. 38546-38556.\\n[5] Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, James Zou, Atri Rudra, and Christopher Re. \\\"Simple linear attention language models balance the recall-throughput tradeoff\\\". In: Forty-first International Conference on Machine Learning. 2024. URL: https://openreview.net/forum?id=e93ffDcpH3.\\n[6] Dzmitry Bahdanau. \\\"Neural machine translation by jointly learning to align and translate\\\". In: arXiv preprint arXiv:1409.0473 (2014).\\n[7] Reza Bayat, Mohammad Pezeshki, Elvis Dohmatob, David Lopez-Paz, and Pascal Vincent. \\\"The Pitfalls of Memorization: When Memorization Hurts Generalization\\\". In: arXiv preprint arXiv:2412.07684 (2024).\\n[8] Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter. \\\"xLSTM: Extended Long Short-Term Memory\\\". In: arXiv preprint arXiv:2405.04517 (2024).\\n[9] Ali Behrouz, Michele Santacatterina, and Ramin Zabih. \\\"Mambamixer: Efficient selective state space models with dual token and channel selection\\\". In: arXiv preprint arXiv:2403.19888 (2024).\\n[10] Vincent-Pierre Berges, Barlas Oğuz, Daniel Haziza, Wen-tau Yih, Luke Zettlemoyer, and Gargi Gosh. \\\"Memory Layers at Scale\\\". In: arXiv preprint arXiv:2412.09764 (2024).\\n[11] Alberto Bietti, Vivien Cabannes, Diane Bouchacourt, Herve Jegou, and Leon Bottou. \\\"Birth of a transformer: A memory viewpoint\\\". In: Advances in Neural Information Processing Systems 36 (2024).\\n[12] Yonatan Bisk, Rowan Zellers, Jianfeng Gao, Yejin Choi, et al. \\\"Piqa: Reasoning about physical commonsense in natural language\\\". In: Proceedings of the AAAI conference on artificial intelligence. Vol. 34. 05. 2020, pp. 7432-7439.\\n[13] Aleksandar Botev, Soham De, Samuel L Smith, Anushan Fernando, George-Cristian Muraru, Ruba Haroun, Leonard Berrada, Razvan Pascanu, Pier Giuseppe Sessa, Robert Dadashi, et al. \\\"RecurrentGemma: Moving Past Transformers for Efficient Open Language Models\\\". In: arXiv preprint arXiv:2404.07839 (2024).\\n[14] Léon Bottou and Vladimir Vapnik. \\\"Local learning algorithms\\\". In: Neural computation 4.6 (1992), pp. 888-900.\\n[15] Aydar Bulatov, Yuri Kuratov, Yermek Kapushev, and Mikhail S Burtsev. \\\"Scaling transformer to 1m tokens and beyond with rmt\\\". In: arXiv preprint arXiv:2304.11062 (2023).\\n[16] Aydar Bulatov, Yury Kuratov, and Mikhail Burtsev. \\\"Recurrent memory transformer\\\". In: Advances in Neural Information Processing Systems 35 (2022), pp. 11079-11091.\\n[17] Edoardo Cetin, Qi Sun, Tianyu Zhao, and Yujin Tang. \\\"An Evolved Universal Transformer Memory\\\". In: arXiv preprint arXiv:2410.13166 (2024).\\n[18] Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, and Christopher Ré. \\\"Scatterbrain: Unifying sparse and low-rank attention\\\". In: Advances in Neural Information Processing Systems 34 (2021), pp. 17413-17426.\\n[19] Krzysztof Marcin Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, David Benjamin Belanger, Lucy J Colwell, and Adrian Weller. \\\"Rethinking Attention with Performers\\\". In: International Conference on Learning Representations. 2021. URL: https://openreview.net/forum?id=Ua6zuk0WRH.\\n[20] Christopher Clark, Kenton Lee, Ming-Wei Chang, Tom Kwiatkowski, Michael Collins, and Kristina Toutanova. \\\"BoolQ: Exploring the Surprising Difficulty of Natural Yes/No Questions\\\". In: Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers). Ed. by Jill Burstein, Christy Doran, and Thamar Solorio. Minneapolis, Minnesota: Association for Computational Linguistics, June 2019, pp. 2924-2936. DOI: 10.18653/v1/N19-1300. URL: https: //aclanthology.org/N19-1300/.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 18,\r\n      \"markdown\": \"[21] Peter Clark, Isaac Cowhey, Oren Etzioni, Tushar Khot, Ashish Sabharwal, Carissa Schoenick, and Oyvind Tafjord. \\\"Think you have solved question answering? try arc, the ai2 reasoning challenge\\\". In: arXiv preprint arXiv:1803.05457 (2018).\\n[22] Nelson Cowan. \\\"What are the differences between long-term, short-term, and working memory?\\\" In: Progress in brain research 169 (2008), pp. 323-338.\\n[23] Zihang Dai, Zhilin Yang, Yiming Yang, Jaime G. Carbonell, Quoc Viet Le, and Ruslan Salakhutdinov. \\\"TransformerXL: Attentive Language Models beyond a Fixed-Length Context\\\". In: ACL (1). Ed. by Anna Korhonen, David R. Traum, and Lluís Márquez. Association for Computational Linguistics, 2019, pp. 2978-2988. ISBN: 978-1-950737-48-2.\\n[24] Tri Dao. \\\"FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning\\\". In: The Twelfth International Conference on Learning Representations. 2024. URL: https://openreview.net/forum?id=mZn2Xyh9Ec.\\n[25] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. \\\"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness\\\". In: Advances in Neural Information Processing Systems. Ed. by S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh. Vol. 35. Curran Associates, Inc., 2022, pp. 16344-16359. URL: https://proceedings.neurips.cc/paper_files/paper/2022/file/67d57c32e20fd0a7a302cb81d36e40d5-Paper-Conference.pdf.\\n[26] Tri Dao and Albert Gu. \\\"Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality\\\". In: arXiv preprint arXiv:2405.21060 (2024).\\n[27] Abhimanyu Das, Weihao Kong, Andrew Leach, Shaan K Mathur, Rajat Sen, and Rose Yu. \\\"Long-term Forecasting with TiDE: Time-series Dense Encoder\\\". In: Transactions on Machine Learning Research (2023). ISSN: 2835-8856. URL: https://openreview.net/forum?id=pCbC3aQB5W.\\n[28] Soham De, Samuel L Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, et al. \\\"Griffin: Mixing gated linear recurrences with local attention for efficient language models\\\". In: arXiv preprint arXiv:2402.19427 (2024).\\n[29] Juechu Dong, Boyuan Feng, Driss Guessous, Yanbo Liang, and Horace He. \\\"Flex Attention: A Programming Model for Generating Optimized Attention Kernels\\\". In: arXiv preprint arXiv:2412.05496 (2024).\\n[30] Xin Dong, Yonggan Fu, Shizhe Diao, Wonmin Byeon, Zijia Chen, Ameya Sunil Mahabaleshwarkar, Shih-Yang Liu, Matthijs Van Keirsbilck, Min-Hung Chen, Yoshi Suhara, et al. \\\"Hymba: A Hybrid-head Architecture for Small Language Models\\\". In: arXiv preprint arXiv:2411.13676 (2024).\\n[31] Stefan Elfwing, Eiji Uchibe, and Kenji Doya. \\\"Sigmoid-weighted linear units for neural network function approximation in reinforcement learning\\\". In: Neural networks 107 (2018), pp. 3-11.\\n[32] Yukun Feng, Feng Li, Ziang Song, Boyuan Zheng, and Philipp Koehn. \\\"Learn to remember: Transformer with recurrent memory for document-level machine translation\\\". In: arXiv preprint arXiv:2205.01546 (2022).\\n[33] Daniel Y Fu, Tri Dao, Khaled Kamal Saab, Armin W Thomas, Atri Rudra, and Christopher Re. \\\"Hungry Hungry Hippos: Towards Language Modeling with State Space Models\\\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview.net/forum?id=COZDy0WYGg.\\n[34] Yossi Gandelsman, Yu Sun, Xinlei Chen, and Alexei Efros. \\\"Test-time training with masked autoencoders\\\". In: Advances in Neural Information Processing Systems 35 (2022), pp. 29374-29385.\\n[35] Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, et al. \\\"The pile: An 800gb dataset of diverse text for language modeling\\\". In: arXiv preprint arXiv:2101.00027 (2020).\\n[36] Felix A Gers, Jürgen Schmidhuber, and Fred Cummins. \\\"Learning to forget: Continual prediction with LSTM\\\". In: Neural computation 12.10 (2000), pp. 2451-2471.\\n[37] Alex Graves, Greg Wayne, and Ivo Danihelka. Neural Turing Machines. 2014. arXiv: 1410.5401 [cs.NE]. URL: https://arxiv.org/abs/1410.5401.\\n[38] Klaus Greff, Rupesh K Srivastava, Jan Koutník, Bas R Steunebrink, and Jürgen Schmidhuber. \\\"LSTM: A search space odyssey\\\". In: IEEE transactions on neural networks and learning systems 28.10 (2016), pp. 2222-2232.\\n[39] Katarína Grešová, Vlastimil Martinek, David Čechák, Petr Šimeček, and Panagiotis Alexiou. \\\"Genomic benchmarks: a collection of datasets for genomic sequence classification\\\". In: BMC Genomic Data 24.1 (2023), p. 25.\\n[40] Albert Gu and Tri Dao. \\\"Mamba: Linear-Time Sequence Modeling with Selective State Spaces\\\". In: First Conference on Language Modeling. 2024. URL: https://openreview.net/forum?id=tEYskw1VY2.\\n[41] Albert Gu, Karan Goel, and Christopher Re. \\\"Efficiently Modeling Long Sequences with Structured State Spaces\\\". In: International Conference on Learning Representations. 2022. URL: https : //openreview . net / forum?id= uYLFoz1v1AC.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 19,\r\n      \"markdown\": \"[42] Chi Han, Qifan Wang, Hao Peng, Wenhan Xiong, Yu Chen, Heng Ji, and Sinong Wang. \\\"LM-Infinite: Zero-Shot Extreme Length Generalization for Large Language Models\\\". In: Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers). Ed. by Kevin Duh, Helena Gomez, and Steven Bethard. Mexico City, Mexico: Association for Computational Linguistics, June 2024, pp. 3991-4008. DOI: 10.18653/v1/2024.naacl-long.222. URL: https://aclanthology. org/2024.naacl-long. 222.\\n[43] Ramin Hasani, Mathias Lechner, Tsun-Hsuan Wang, Makram Chahine, Alexander Amini, and Daniela Rus. \\\"Liquid Structural State-Space Models\\\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview.net/forum?id=g40TKRKfS7R.\\n[44] Zexue He, Leonid Karlinsky, Donghyun Kim, Julian McAuley, Dmitry Krotov, and Rogerio Feris. \\\"CAMELoT: Towards Large Language Models with Training-Free Consolidated Associative Memory\\\". In: arXiv preprint arXiv:2402.13449 (2024).\\n[45] Donald Olding Hebb. The organization of behavior: A neuropsychological theory. Psychology press, 2005.\\n[46] John J Hopfield. \\\"Neural networks and physical systems with emergent collective computational abilities.\\\" In: Proceedings of the national academy of sciences 79.8 (1982), pp. 2554-2558.\\n[47] Kurt Hornik, Maxwell Stinchcombe, and Halbert White. \\\"Multilayer feedforward networks are universal approximators\\\". In: Neural networks 2.5 (1989), pp. 359-366.\\n[48] Cheng-Ping Hsieh, Simeng Sun, Samuel Kriman, Shantanu Acharya, Dima Rekesh, Fei Jia, and Boris Ginsburg. \\\"RULER: What's the Real Context Size of Your Long-Context Language Models?\\\" In: First Conference on Language Modeling. 2024. URL: https://openreview.net/forum?id=kIoBbc76Sy.\\n[49] DeLesley Hutchins, Imanol Schlag, Yuhuai Wu, Ethan Dyer, and Behnam Neyshabur. \\\"Block-recurrent transformers\\\". In: Advances in neural information processing systems 35 (2022), pp. 33248-33261.\\n[50] Kazuki Irie, Róbert Csordás, and Jürgen Schmidhuber. \\\"The dual form of neural networks revisited: Connecting test time predictions to training patterns via spotlights of attention\\\". In: International Conference on Machine Learning. PMLR. 2022, pp. 9639-9659.\\n[51] Kazuki Irie, Imanol Schlag, Róbert Csordás, and Jürgen Schmidhuber. \\\"Going beyond linear transformers with recurrent fast weight programmers\\\". In: Advances in neural information processing systems 34 (2021), pp. 7703-7717.\\n[52] Vidit Jain and Erik Learned-Miller. \\\"Online domain adaptation of a pre-trained cascade of classifiers\\\". In: CVPR 2011. IEEE. 2011, pp. 577-584.\\n[53] Albert Q Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lucile Saulnier, et al. \\\"Mistral 7B\\\". In: arXiv preprint arXiv:2310.06825 (2023).\\n[54] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. \\\"PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels\\\". In: Forty-first International Conference on Machine Learning. 2024. URL: https://openreview.net/ forum?id=ghYrfdJfjK.\\n[55] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. \\\"Scaling laws for neural language models\\\". In: arXiv preprint arXiv:2001.08361 (2020).\\n[56] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. \\\"Transformers are rnns: Fast autoregressive transformers with linear attention\\\". In: International conference on machine learning. PMLR. 2020, pp. 5156-5165.\\n[57] Urvashi Khandelwal, Omer Levy, Dan Jurafsky, Luke Zettlemoyer, and Mike Lewis. \\\"Generalization through Memorization: Nearest Neighbor Language Models\\\". In: International Conference on Learning Representations. 2020. URL: https://openreview.net/forum?id=HkIBjCEKvH.\\n[58] Yuri Kuratov, Aydar Bulatov, Petr Anokhin, Ivan Rodkin, Dmitry Igorevich Sorokin, Artyom Sorokin, and Mikhail Burtsev. \\\"BABILong: Testing the Limits of LLMs with Long Context Reasoning-in-a-Haystack\\\". In: The Thirtyeight Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2024. URL: https : //openreview.net/forum?id=u7m2CG84BQ.\\n[59] Hung Le, Truyen Tran, and Svetha Venkatesh. \\\"Self-attentive associative memory\\\". In: International conference on machine learning. PMLR. 2020, pp. 5682-5691.\\n[60] Patrick Lewis, Ethan Perez, Aleksandra Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, et al. \\\"Retrieval-augmented generation for knowledge-intensive nlp tasks\\\". In: Advances in Neural Information Processing Systems 33 (2020), pp. 9459-9474.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 20,\r\n      \"markdown\": \"[61] Danny Leybzon and Corentin Kervadec. \\\"Learning, Forgetting, Remembering: Insights From Tracking LLM Memorization During Training\\\". In: Proceedings of the 7th BlackboxNLP Workshop: Analyzing and Interpreting Neural Networks for NLP. 2024, pp. 43-57.\\n[62] Zhe Li, Shiyi Qi, Yiduo Li, and Zenglin Xu. \\\"Revisiting long-term time series forecasting: An investigation on linear mapping\\\". In: arXiv preprint arXiv:2305.10721 (2023).\\n[63] Bo Liu, Rui Wang, Lemeng Wu, Yihao Feng, Peter Stone, and Qiang Liu. \\\"Longhorn: State space models are amortized online learners\\\". In: arXiv preprint arXiv:2407.14207 (2024).\\n[64] Nelson F Liu, Kevin Lin, John Hewitt, Ashwin Paranjape, Michele Bevilacqua, Fabio Petroni, and Percy Liang. \\\"Lost in the middle: How language models use long contexts\\\". In: Transactions of the Association for Computational Linguistics 12 (2024), pp. 157-173.\\n[65] Yong Liu, Tengge Hu, Haoran Zhang, Haixu Wu, Shiyu Wang, Lintao Ma, and Mingsheng Long. \\\"itransformer: Inverted transformers are effective for time series forecasting\\\". In: arXiv preprint arXiv:2310.06625 (2023).\\n[66] George Mandler. \\\"The structure of value: Accounting for taste\\\". In: Affect and cognition. Psychology Press, 2014, pp. 3-36.\\n[67] Harsh Mehta, Ankit Gupta, Ashok Cutkosky, and Behnam Neyshabur. \\\"Long Range Language Modeling via Gated State Spaces\\\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https : //openreview.net/forum?id=5MkYIYCbva.\\n[68] Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. \\\"Pointer Sentinel Mixture Models\\\". In: International Conference on Learning Representations. 2017. URL: https://openreview.net/forum?id=Byj72udxe.\\n[69] William Merrill, Jackson Petty, and Ashish Sabharwal. \\\"The Illusion of State in State-Space Models\\\". In: Forty-first International Conference on Machine Learning. 2024. URL: https://openreview.net/forum?id=QZgo9JZpLq.\\n[70] Ravi Teja Mullapudi, Steven Chen, Keyi Zhang, Deva Ramanan, and Kayvon Fatahalian. \\\"Online model distillation for efficient video inference\\\". In: Proceedings of the IEEE/CVF International conference on computer vision. 2019, pp. 3573-3582.\\n[71] Tsendsuren Munkhdalai, Manaal Faruqui, and Siddharth Gopal. \\\"Leave no context behind: Efficient infinite context transformers with infini-attention\\\". In: arXiv preprint arXiv:2404.07143 (2024).\\n[72] Tsendsuren Munkhdalai, Alessandro Sordoni, Tong Wang, and Adam Trischler. \\\"Metalearned neural memory\\\". In: Advances in Neural Information Processing Systems 32 (2019).\\n[73] Tsendsuren Munkhdalai and Hong Yu. \\\"Neural semantic encoders\\\". In: Proceedings of the conference. Association for Computational Linguistics. Meeting. Vol. 1. NIH Public Access. 2017, p. 397.\\n[74] Eric Nguyen, Michael Poli, Marjan Faizi, Armin Thomas, Michael Wornow, Callum Birch-Sykes, Stefano Massaroli, Aman Patel, Clayton Rabideau, Yoshua Bengio, et al. \\\"Hyenadna: Long-range genomic sequence modeling at single nucleotide resolution\\\". In: Advances in neural information processing systems 36 (2024).\\n[75] A Nichol. \\\"On first-order meta-learning algorithms\\\". In: arXiv preprint arXiv:1803.02999 (2018).\\n[76] Yuqi Nie, Nam H Nguyen, Phanwadee Sinthong, and Jayant Kalagnanam. \\\"A time series is worth 64 words: Long-term forecasting with transformers\\\". In: arXiv preprint arXiv:2211.14730 (2022).\\n[77] Hideyuki Okano, Tomoo Hirano, and Evan Balaban. \\\"Learning and memory\\\". In: Proceedings of the National Academy of Sciences 97.23 (2000), pp. 12403-12404.\\n[78] Antonio Orvieto, Samuel L Smith, Albert Gu, Anushan Fernando, Caglar Gulcehre, Razvan Pascanu, and Soham De. \\\"Resurrecting recurrent neural networks for long sequences\\\". In: International Conference on Machine Learning. PMLR. 2023, pp. 26670-26698.\\n[79] Denis Paperno, Germán Kruszewski, Angeliki Lazaridou, Ngoc Quan Pham, Raffaella Bernardi, Sandro Pezzelle, Marco Baroni, Gemma Boleda, and Raquel Fernández. \\\"The LAMBADA dataset: Word prediction requiring a broad discourse context\\\". In: Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Ed. by Katrin Erk and Noah A. Smith. Berlin, Germany: Association for Computational Linguistics, Aug. 2016, pp. 1525-1534. DOI: 10.18653/v1/P16-1144. URL: https://aclanthology.org/P16-1144/.\\n[80] Badri N. Patro and Vijay S. Agneeswaran. SiMBA: Simplified Mamba-Based Architecture for Vision and Multivariate Time series. 2024. arXiv: 2403.15360 [cs.CV].\\n[81] Guilherme Penedo, Hynek Kydliček, Loubna Ben allal, Anton Lozhkov, Margaret Mitchell, Colin Raffel, Leandro Von Werra, and Thomas Wolf. \\\"The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale\\\". In: The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2024. URL: https://openreview.net/forum?id=n6SCkn2QaG.\\n[82] Bo Peng. RWKV-LM. Version 1.0.0. Aug. 2021. DOI: 10.5281 / zenodo. 5196577. URL: https://github.com/ BlinkDL/RWKV-LM.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 21,\r\n      \"markdown\": \"[83] Bo Peng, Eric Alcaide, Quentin Gregory Anthony, Alon Albalak, Samuel Arcadinho, Stella Biderman, Huanqi Cao, Xin Cheng, Michael Nguyen Chung, Leon Derczynski, Xingjian Du, Matteo Grella, Kranthi Kiran GV, Xuzheng He, Haowen Hou, Przemyslaw Kazienko, Jan Kocon, Jiaming Kong, Bartłomiej Koptyra, Hayden Lau, Jiaju Lin, Krishna Sri Ipsit Mantri, Ferdinand Mom, Atsushi Saito, Guangyu Song, Xiangru Tang, Johan S. Wind, Stanisław Woźniak, Zhenyuan Zhang, Qinghua Zhou, Jian Zhu, and Rui-Jie Zhu. \\\"RWKV: Reinventing RNNs for the Transformer Era\\\". In: The 2023 Conference on Empirical Methods in Natural Language Processing. 2023. URL: https://openreview. net/forum?id=7SaXcza8pG.\\n[84] Bo Peng, Daniel Goldstein, Quentin Anthony, Alon Albalak, Eric Alcaide, Stella Biderman, Eugene Cheah, Xingjian Du, Teddy Ferdinan, Haowen Hou, et al. \\\"Eagle and finch: Rwkv with matrix-valued states and dynamic recurrence\\\". In: arXiv preprint arXiv:2404.05892 (2024).\\n[85] DL Prados and SC Kak. \\\"Neural network capacity using delta rule\\\". In: Electronics Letters 25.3 (1989), pp. 197-199.\\n[86] Zhen Qin, Yiran Zhong, and Hui Deng. \\\"Exploring Transformer Extrapolation\\\". In: Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. 17. 2024, pp. 18897-18905.\\n[87] Liliang Ren, Yang Liu, Yadong Lu, Yelong Shen, Chen Liang, and Weizhu Chen. \\\"Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling\\\". In: arXiv preprint arXiv:2406.07522 (2024).\\n[88] Ivan Rodkin, Yuri Kuratov, Aydar Bulatov, and Mikhail Burtsev. \\\"Associative recurrent memory transformer\\\". In: arXiv preprint arXiv:2407.04841 (2024).\\n[89] Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David Grangier. \\\"Efficient content-based sparse attention with routing transformers\\\". In: Transactions of the Association for Computational Linguistics 9 (2021), pp. 53-68.\\n[90] Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. \\\"Winogrande: An adversarial winograd schema challenge at scale\\\". In: Communications of the ACM 64.9 (2021), pp. 99-106.\\n[91] Maarten Sap, Hannah Rashkin, Derek Chen, Ronan Le Bras, and Yejin Choi. \\\"Social IQa: Commonsense Reasoning about Social Interactions\\\". In: Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP). Ed. by Kentaro Inui, Jing Jiang, Vincent Ng, and Xiaojun Wan. Hong Kong, China: Association for Computational Linguistics, Nov. 2019, pp. 4463-4473. DOI: 10.18653/v1/D19-1454. URL: https://aclanthology.org/D19-1454/.\\n[92] Imanol Schlag, Kazuki Irie, and Jürgen Schmidhuber. \\\"Linear transformers are secretly fast weight programmers\\\". In: International Conference on Machine Learning. PMLR. 2021, pp. 9355-9366.\\n[93] JH Schmidhuber. \\\"Learning to control fast-weight memories: An alternative to recurrent nets. Accepted for publication in\\\". In: Neural Computation (1992).\\n[94] Jürgen Schmidhuber. \\\"Reducing the ratio between learning complexity and number of time varying variables in fully recurrent nets\\\". In: ICANN'93: Proceedings of the International Conference on Artificial Neural Networks Amsterdam, The Netherlands 13-16 September 1993 3. Springer. 1993, pp. 460-463.\\n[95] Jürgen Schmidhuber and Sepp Hochreiter. \\\"Long Short-term Memory\\\". In: Neural Computation MIT-Press (1997).\\n[96] Avi Schwarzschild, Zhili Feng, Pratyush Maini, Zachary C Lipton, and J Zico Kolter. \\\"Rethinking llm memorization through the lens of adversarial compression\\\". In: arXiv preprint arXiv:2404.15146 (2024).\\n[97] Jimmy T.H. Smith, Andrew Warrington, and Scott Linderman. \\\"Simplified State Space Layers for Sequence Modeling\\\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview. net/forum? id=AiBHw3AXqks.\\n[98] Robin Staab, Mark Vero, Mislav Balunovic, and Martin Vechev. \\\"Beyond Memorization: Violating Privacy via Inference with Large Language Models\\\". In: The Twelfth International Conference on Learning Representations. 2024. URL: https://openreview. net/forum?id=kmn0BhQk7p.\\n[99] Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. \\\"Augmenting selfattention with persistent memory\\\". In: arXiv preprint arXiv:1907.01470 (2019).\\n[100] Sainbayar Sukhbaatar, Jason Weston, Rob Fergus, et al. \\\"End-to-end memory networks\\\". In: Advances in neural information processing systems 28 (2015).\\n[101] Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, et al. \\\"Learning to (learn at test time): Rnns with expressive hidden states\\\". In: arXiv preprint arXiv:2407.04620 (2024).\\n[102] Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. \\\"Retentive network: A successor to transformer for large language models\\\". In: arXiv preprint arXiv:2307.08621 (2023).\\n[103] Gemma Team, Thomas Mesnard, Cassidy Hardin, Robert Dadashi, Surya Bhupatiraju, Shreya Pathak, Laurent Sifre, Morgane Rivière, Mihir Sanjay Kale, Juliette Love, et al. \\\"Gemma: Open models based on gemini research and technology\\\". In: arXiv preprint arXiv:2403.08295 (2024).\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 22,\r\n      \"markdown\": \"[104] W Scott Terry. Learning and memory: Basic principles, processes, and procedures. Routledge, 2017.\\n[105] Matteo Tiezzi, Michele Casoni, Alessandro Betti, Tommaso Guidi, Marco Gori, and Stefano Melacci. \\\"On the resurgence of recurrent models for long sequences: Survey and research opportunities in the transformer era\\\". In: arXiv preprint arXiv:2402.08132 (2024).\\n[106] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, et al. \\\"Llama: Open and efficient foundation language models\\\". In: arXiv preprint arXiv:2302.13971 (2023).\\n[107] Jos Van Der Westhuizen and Joan Lasenby. \\\"The unreasonable effectiveness of the forget gate\\\". In: arXiv preprint arXiv:1804.04849 (2018).\\n[108] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. \\\"Attention is All you Need\\\". In: Advances in Neural Information Processing Systems. Ed. by I. Guyon, U. Von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett. Vol. 30. Curran Associates, Inc., 2017. URL: https : / / proceedings . neurips . cc / paper_files / paper / 2017 / file / 3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.\\n[109] Shida Wang. \\\"LongSSM: On the Length Extension of State-space Models in Language Modelling\\\". In: arXiv preprint arXiv:2406.02080 (2024).\\n[110] Yu Wang, Yifan Gao, Xiusi Chen, Haoming Jiang, Shiyang Li, Jingfeng Yang, Qingyu Yin, Zheng Li, Xian Li, Bing Yin, Jingbo Shang, and Julian McAuley. \\\"MEMORYLLM: Towards Self-Updatable Large Language Models\\\". In: Forty-first International Conference on Machine Learning. 2024. URL: https: //openreview. net/forum?id=p01KWzdikQ.\\n[111] Yu Wang, Chi Han, Tongtong Wu, Xiaoxin He, Wangchunshu Zhou, Nafis Sadeq, Xiusi Chen, Zexue He, Wei Wang, Gholamreza Haffari, et al. \\\"Towards LifeSpan Cognitive Systems\\\". In: arXiv preprint arXiv:2409.13265 (2024).\\n[112] Zhiwei Wang, Yao Ma, Zitao Liu, and Jiliang Tang. \\\"R-transformer: Recurrent neural network enhanced transformer\\\". In: arXiv preprint arXiv:1907.05572 (2019).\\n[113] Jason Weston, Sumit Chopra, and Antoine Bordes. \\\"Memory networks\\\". In: arXiv preprint arXiv:1410.3916 (2014).\\n[114] Bernard Widrow and Marcian E Hoff. \\\"Adaptive switching circuits\\\". In: Neurocomputing: foundations of research. 1988, pp. 123-134.\\n[115] Ronald J Williams and David Zipser. \\\"A learning algorithm for continually running fully recurrent neural networks\\\". In: Neural computation 1.2 (1989), pp. 270-280.\\n[116] Daniel B Willingham. \\\"Systems of memory in the human brain\\\". In: Neuron 18.1 (1997), pp. 5-8.\\n[117] Chao-Yuan Wu, Christoph Feichtenhofer, Haoqi Fan, Kaiming He, Philipp Krahenbuhl, and Ross Girshick. \\\"Longterm feature banks for detailed video understanding\\\". In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019, pp. 284-293.\\n[118] Haixu Wu, Tengge Hu, Yong Liu, Hang Zhou, Jianmin Wang, and Mingsheng Long. \\\"TimesNet: Temporal 2DVariation Modeling for General Time Series Analysis\\\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview. net/forum?id=ju_Uqw3840q.\\n[119] Qingyang Wu, Zhenzhong Lan, Kun Qian, Jing Gu, Alborz Geramifard, and Zhou Yu. \\\"Memformer: A memoryaugmented transformer for sequence modeling\\\". In: arXiv preprint arXiv:2010.06891 (2020).\\n[120] Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. \\\"Efficient Streaming Language Models with Attention Sinks\\\". In: The Twelfth International Conference on Learning Representations. 2024. URL: https: //openreview. net/forum?id=NG7sS51zVF.\\n[121] An Yang, Baosong Yang, Beichen Zhang, Binyuan Hui, Bo Zheng, Bowen Yu, Chengyuan Li, Dayiheng Liu, Fei Huang, Haoran Wei, et al. \\\"Qwen2. 5 Technical Report\\\". In: arXiv preprint arXiv:2412.15115 (2024).\\n[122] Songlin Yang, Jan Kautz, and Ali Hatamizadeh. \\\"Gated Delta Networks: Improving Mamba2 with Delta Rule\\\". In: arXiv preprint arXiv:2412.06464 (2024).\\n[123] Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, and Yoon Kim. \\\"Gated Linear Attention Transformers with Hardware-Efficient Training\\\". In: Forty-first International Conference on Machine Learning. 2024. URL: https: //openreview. net/forum?id=ia5XvxFUJ7.\\n[124] Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim. \\\"Parallelizing Linear Transformers with the Delta Rule over Sequence Length\\\". In: The Thirty-eighth Annual Conference on Neural Information Processing Systems. 2024. URL: https://openreview. net/forum?id=y8Rm4VNRPH.\\n[125] Luca Zancato, Arjun Seshadri, Yonatan Dukler, Aditya Golatkar, Yantao Shen, Benjamin Bowman, Matthew Trager, Alessandro Achille, and Stefano Soatto. \\\"B'MOJO: Hybrid State Space Realizations of Foundation Models with Eidetic and Fading Memory\\\". In: The Thirty-eighth Annual Conference on Neural Information Processing Systems. 2024. URL: https://openreview. net/forum?id=RnQdRY1h5v.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 23,\r\n      \"markdown\": \"[126] Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. \\\"HellaSwag: Can a Machine Really Finish Your Sentence?\\\" In: Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics. Ed. by Anna Korhonen, David Traum, and Lluís Márquez. Florence, Italy: Association for Computational Linguistics, July 2019, pp. 4791-4800. DOI: 10.18653/v1/P19-1472. URL: https://aclanthology.org/P19-1472/.\\n[127] Ailing Zeng, Muxi Chen, Lei Zhang, and Qiang Xu. \\\"Are transformers effective for time series forecasting?\\\" In: Proceedings of the AAAI conference on artificial intelligence. Vol. 37. 2023, pp. 11121-11128.\\n[128] Hao Zhang, Alexander C Berg, Michael Maire, and Jitendra Malik. \\\"SVM-KNN: Discriminative nearest neighbor classification for visual category recognition\\\". In: 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06). Vol. 2. IEEE. 2006, pp. 2126-2136.\\n[129] Jianyu Zhang, Niklas Nolte, Ranajoy Sadhukhan, Beidi Chen, and Léon Bottou. \\\"Memory Mosaics\\\". In: arXiv preprint arXiv:2405.06394 (2024).\\n[130] Yunhao Zhang and Junchi Yan. \\\"Crossformer: Transformer utilizing cross-dimension dependency for multivariate time series forecasting\\\". In: The eleventh international conference on learning representations. 2023.\\n[131] Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang. \\\"Informer: Beyond efficient transformer for long sequence time-series forecasting\\\". In: Proceedings of the AAAI conference on artificial intelligence. Vol. 35. 12. 2021, pp. 11106-11115.\\n[132] Luisa Zintgraf, Kyriacos Shiarli, Vitaly Kurin, Katja Hofmann, and Shimon Whiteson. \\\"Fast context adaptation via meta-learning\\\". In: International Conference on Machine Learning. PMLR. 2019, pp. 7693-7702.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 24,\r\n      \"markdown\": \"# A Related Work \\n\\nThere are diverse perspectives that can independently lead to the design of Titans or its components. Accordingly, to further situate our work in a broader context, we review three categories of studies:\\n\\n## A. 1 Linear Recurrent Models\\n\\nRecently, to address the computational cost of Transformers in both training and inference, linear recurrent models have attracted much attention (Tiezzi et al. 2024), mainly due to their fast inference and training. The first generation of models-such as RetNet (Yutao Sun et al. 2023), LRU (Orvieto et al. 2023), RWKV (Peng, Alcaide, et al. 2023), S5 (J. T. Smith, Warrington, and Linderman 2023), and S4 (Gu, Goel, and Re 2022)-uses data-independent transition matrix/decay mechanism. The second generation of such models started to incorporate gating mechanism, a widely used techniques in traditional RNNs (Gers, Jürgen Schmidhuber, and Cummins 2000; Greff et al. 2016; Van Der Westhuizen and Lasenby 2018), into such linear architectures-e.g., Griffin (De et al. 2024), SSMs (Behrouz, Santacatterina, and Zabih 2024; Dao and Gu 2024; Gu and Dao 2024; Hasani et al. 2023), RWKV6 (Peng, Goldstein, et al. 2024). The third generation of linear recurrent models are based on more complex memory updating rule based on meta-learning, online learning, and/or delta-rule, resulting in more expressive and effective models such as: Longhorn (B. Liu et al. 2024), Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024), TTT (Yu Sun et al. 2024), and DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024). Our LMM model can be seen as the next generation of such models, in which we incorporate the token flow into the memory updating mechanism, having more powerful memory updating process. See Appendix C for a detailed discussion of different recurrent models and Titans.\\n\\n## A. 2 Transformer-based Architectures\\n\\nTransformers. Transformers (Vaswani et al. 2017) as the de facto backbone for many deep learning models are based on attention mechanism (Bahdanau 2014). They, however, suffer from quadratic computational cost, limiting their ability to scale to long context window. To improve the memory consumption and throughput of softmax attention for longer sequences, various studies focused on I/O aware implementations of attention (Dao 2024; Dao, D. Fu, et al. 2022), designing more efficient attention mechanisms by sparsifying the attention matrix (B. Chen et al. 2021; Choromanski et al. 2021; Dai et al. 2019; J. Dong et al. 2024; Roy et al. 2021), approximating the softmax (Arora et al. 2024), or developing kernel-based (linear) attentions (Aksenov et al. 2024; Kacham, Mirrokni, and P. Zhong 2024; Schlag, Irie, and Jürgen Schmidhuber 2021; S. Yang, B. Wang, Shen, et al. 2024).\\n\\nSegment-based Transformers. Another line of research to improve the efficiency of Transformers is segment-based or Chunk Transformers (Dai et al. 2019). The main drawback of chunk Transformers is that segments are fully separated and so the context window is limited to the length of the chunks. To address this issue, various studies discuss the importance of a memory so it can help the model to transfer information across chunks (Bulatov, Yuri Kuratov, et al. 2023; Bulatov, Yury Kuratov, and Burtsev 2022; Feng et al. 2022; Hutchins et al. 2022; Rodkin et al. 2024; Z. Wang et al. 2019; Q. Wu et al. 2020; Zancato et al. 2024). The key differences of Titans with these models are: (1) The memory in such models are simple small size vectors, lacking expressive power to compress complex information; (2) The memory module lacks forget mechanism, leading to a fast memory overflow; (3) only focus on momentary surprise, missing the information flow. More specifically, recalling Recurrent Memory Transformers (RMT) (Bulatov, Yuri Kuratov, et al. 2023; Bulatov, Yury Kuratov, and Burtsev 2022; Rodkin et al. 2024), one can treat Titans (MAC) as the generalization of RMT, where we use a neural memory module instead of a vector-valued small size memory.\\n\\nMemory for Large Language Models. Another interesting research direction has been to incorporate external memory modules to LLMs after training (Z. He et al. 2024; Khandelwal et al. 2020; Y. Wang, Y. Gao, et al. 2024). Such models are different from our approach as we incorporate the memory as a part of initial architecture and so we train it in an end-to-end manner. Also, most of these explicit memory modules suffer from the same limitations as chunk-based Transformers (mentioned above). For a detailed discussion of such models, we refer to the recent study of Y. Wang, Han, et al. (2024).\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 25,\r\n      \"markdown\": \"# A. 3 Test Time Training and Fast Weight Programs \\n\\nMemory Design and Augmentation with Memory. In the literature, a substantial research effort have been toward designing memory modules that are capable of either memorizing the knowledge abstraction (e.g., persistent memory) (Sukhbaatar, Grave, et al. 2019), or memorizing the data-dependent information (also known as contextual memory), through recurrence (Bulatov, Yury Kuratov, and Burtsev 2022; Rodkin et al. 2024; Zancato et al. 2024), Transformers (Berges et al. 2024; Cetin et al. 2024; Feng et al. 2022; Le, Tran, and Venkatesh 2020; Munkhdalai, Faruqui, and Gopal 2024; J. Zhang et al. 2024), gradient (Irie, Csordás, and Jürgen Schmidhuber 2022; Munkhdalai, Sordoni, et al. 2019), or other learning paradigms (Sukhbaatar, Weston, Fergus, et al. 2015; Weston, Chopra, and Bordes 2014). These memory models, however, either (1) are based on momentary surprise, missing the data flow and events, (2) lack forget mechanisms to remove the memory, leading to a fast memory overflow (3) are fixed-size shallow (matrix valued) memory, resulting in poor performance in long context, and (4) are based on fixed parameters at test time, lacking test time adaption.\\n\\nFast Weight Programs. The idea of seeing linear layers as the key-value (associative) memory system backs to fast weight programs, in which dynamic fast programs are incorporated into recurrent neural networks to serve as writable memory (Schlag, Irie, and Jürgen Schmidhuber 2021; JH Schmidhuber 1992; Jürgen Schmidhuber 1993). The two learning rules of Hebbian (Hebb 2005) and delta (Prados and Kak 1989) are the most popular learning rules for fast weight programs, which have been extensively explored in various studies (Irie, Schlag, et al. 2021; Munkhdalai, Sordoni, et al. 2019; Munkhdalai and H. Yu 2017; Schlag, Irie, and Jürgen Schmidhuber 2021; JH Schmidhuber 1992; S. Yang, Kautz, and Hatamizadeh 2024; S. Yang, B. Wang, Yu Zhang, et al. 2024). All these models, however, are based on momentary surprise, missing the token flow in the sequences (see Section 3.1), and most of them lacks a forgetting gate, resulting in a poor memory management.\\n\\nTest Time Training. The key ideas of learning at test time or learning to learn (i.e., (Andrychowicz et al. 2016)) backs to very early studies on local learning Bottou and Vapnik 1992, in which each test data sample is trained on its neighbors before making a prediction (Gandelsman et al. 2022; H. Zhang et al. 2006). This approach further has shown promising performance in vision tasks (Jain and Learned-Miller 2011; Mullapudi et al. 2019), mostly due to their ability to mitigate out-of-distribution samples. The most similar studies to ours in this direction are MNM (Munkhdalai, Sordoni, et al. 2019) and TTT-layer (Yu Sun et al. 2024), which we discussed the key differences in Appendix C.\\n\\n## B Language Modeling and Common-sense Reasoning Datasets\\n\\nFollowing recent studies on linear recurrent models (Dao and Gu 2024; S. Yang, Kautz, and Hatamizadeh 2024; S. Yang, B. Wang, Yu Zhang, et al. 2024), we use Wikitext (Merity et al. 2017), LMB (Paperno et al. 2016), PIQA (Bisk et al. 2020), HellaSwag (Zellers et al. 2019), WinoGrande (Sakaguchi et al. 2021), ARC-easy (ARC-e) and ARC-challenge (ARC-c) (P. Clark et al. 2018), SIQA (Sap et al. 2019), and BoolQ (C. Clark et al. 2019). Also, the baselines results for 400M models are from the reported results by S. Yang, Kautz, and Hatamizadeh (2024).\\n\\n## C Long-term Memory Module (LMM) as a Sequence Model\\n\\nIn this section, we discuss how LMM as a sequence model is connected to modern linear recurrent models. For the sake of simplicity, we start with a linear memory, where $\\\\mathcal{M}_{t}=W_{t} \\\\in \\\\mathbb{R}^{d_{m} \\\\times d_{m}}$. In this case, our objective function becomes $\\\\ell\\\\left(\\\\mathcal{M} ; x_{t}\\\\right)=\\\\frac{1}{2}\\\\left\\\\|\\\\mathcal{M}_{t} \\\\mathbf{k}_{t}-\\\\mathbf{v}_{t}\\\\right\\\\|_{2}^{2}$, in which we use gradient descent with momentum and weight decay for the optimization. Accordingly, revisiting the recurrent formula in Equation 13:\\n\\n$$\\n\\\\begin{aligned}\\n& \\\\mathcal{M}_{t}=\\\\operatorname{diag}\\\\left(1-\\\\alpha_{t}\\\\right) \\\\mathcal{M}_{t}+S_{t} \\\\\\\\\\n& S_{t}=\\\\operatorname{diag}\\\\left(\\\\eta_{t}\\\\right) S_{t-1}-\\\\operatorname{diag}\\\\left(\\\\theta_{t}\\\\right)\\\\left(\\\\mathcal{M}_{t-1} \\\\mathbf{k}_{t}^{\\\\top} \\\\mathbf{k}_{t}-\\\\mathbf{v}_{t}^{\\\\top} \\\\mathbf{k}_{t}\\\\right)\\n\\\\end{aligned}\\n$$\\n\\nLMM is Generalized Gated DeltaNet. As discussed by S. Yang, Kautz, and Hatamizadeh (2024), DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024) can alternatively be interpreted as an online learning problem that optimizes the $\\\\mathcal{L}=\\\\frac{1}{2}\\\\left\\\\|\\\\mathbf{S}_{t} \\\\mathbf{k}_{t}-\\\\mathbf{v}_{t}\\\\right\\\\|_{2}^{2}$, resulting in:\\n\\n$$\\n\\\\mathbf{S}_{t+1}=\\\\mathbf{S}_{t}-\\\\theta_{t} \\\\nabla \\\\mathcal{L}=\\\\mathbf{S}_{t}\\\\left(\\\\mathbf{I}-\\\\theta_{t} \\\\mathbf{k}_{t} \\\\mathbf{k}_{t}^{\\\\top}\\\\right)+\\\\theta_{t} \\\\mathbf{v}_{t} \\\\mathbf{k}_{t}^{\\\\top}\\n$$\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    },\r\n    {\r\n      \"index\": 26,\r\n      \"markdown\": \"In this formulation, Gated DeltaNet is the same as above but with an additional weight decay term (S. Yang, Kautz, and Hatamizadeh 2024). Comparing Equation 32 and Equation 34, we can see that setting $\\\\eta_{t}=0$ results in both formulations to be equivalent. Accordingly, we can say LMM is generalizing the very recent study of Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024) from three aspects:\\n\\n- Momentum-based Rule: The Delta Rule is based on momentary surprise, meaning that the flow of tokens cannot affect the memory update rule. LMM, however, is based on a momentum rule, which consider both past and momentary surprise.\\n- Deep Memory: While Gated DeltaNet is limited to a linear (matrix-valued) memory as it requires finding the closed recurrence form, LMM allows using deep memory module by using a gradient-based formulation, resulting in higher expressive power.\\n- Non-Linear Recurrence: While DeltaNet and Gated DeltaNet are based on linear recurrence, our LMM is using inter-chunk non-linear recurrence and intra-chunk linear recurrence. This design allows LMM having a higher expressive power.\\n\\nHere, we discussed Gated DeltaNet as a sample of recent generation of recurrent models. Similar approaches such as RWKV-7 (Peng 2021) are also using the same formulation and loss function, and so LMM is generalizing all such models.\\n\\nLMM is Generalized Longhorn. Similar to DeltaNet, Longhorn (B. Liu et al. 2024) uses the same loss function but it derives the closed form using implicit online learning:\\n\\n$$\\n\\\\mathbf{S}_{t+1}=\\\\mathbf{S}_{t}\\\\left(\\\\mathbf{I}-\\\\delta_{t} \\\\mathbf{k}_{t} \\\\mathbf{k}_{t}^{\\\\top}\\\\right)+\\\\delta_{t} \\\\mathbf{v}_{t} \\\\mathbf{k}_{t}^{\\\\top}\\n$$\\n\\nwhere $\\\\delta_{t}=\\\\frac{\\\\theta_{t}}{1+\\\\theta_{t} \\\\mathbf{k}_{t} \\\\mathbf{k}_{t}}$. It, however, lacks a forgetting gate, resulting in a faster memory overflow. Therefore, in addition two the abovementioned aspects of (1) Momentum-based Rule, (2) Deep Memory, and (3) Non-Linear Recurrence, LMM has the advantage of using an additional (4) Forget Gate, leading to a better memory management.\\n\\nLMM is Generalized TTT Layer. To the best of our knowledge, TTT (Yu Sun et al. 2024), is the only modern linear recurrent models with a gradient-based updating rule. In addition to different architectural designs and also objective functions, our LMM has three key differences with presented TTT layers (Yu Sun et al. 2024):\\n\\n1. Forgetting Mechanism: TTT layers are updating memory at each time, without having the chance to forget the past data. Accordingly, when fixing the memory size, the model cannot manage the memory for long sequences. A forget mechanism, such as LMM's, allows clearing the memory when very past information is not needed anymore. We show that in a general case, this forget mechanism is equivalent to weight decay and provide a fast method to incorporate it into the parallel training.\\n2. Momentum-based Update Rule: TTT layers are based on momentary surprise, meaning that the flow of tokens cannot affect the memory update rule. LMM, however, is based on a momentum rule, which consider both past and momentary surprise. See Section 3.1 for the motivation of this design.\\n3. Deep Memory: While TTT-layers allows for deeper memory, the advantages/disadvantages of such deeper memory modules have not been experimentally evaluated.\\n\\nTo the best of our knowledge, our neural long-term memory module is the first linear recurrent model with momentumbased update rule.\\n\\nFinally, as a key difference with all the above and other recent linear recurrent studies, note that the hybrid variants of modern linear models-such as Griffin (De et al. 2024), DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024), Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024), H3 (D. Y. Fu et al. 2023), Mamba2 (Dao and Gu 2024), Samba (Ren et al. 2024), etc.-all are based on sequential layer-wise design. We present Titans to show how effectively one can incorporate such memory modules into an architecture.\",\r\n      \"images\": [],\r\n      \"dimensions\": {\r\n        \"dpi\": 200,\r\n        \"height\": 2200,\r\n        \"width\": 1700\r\n      }\r\n    }\r\n  ],\r\n  \"model\": \"mistral-ocr-2505-completion\",\r\n  \"document_annotation\": null,\r\n  \"usage_info\": {\r\n    \"pages_processed\": 27,\r\n    \"doc_size_bytes\": 3657065\r\n  }\r\n}"
  },
  {
    "path": "google_papers/TITANs/TITANs.md",
    "content": "PAGE 1\r\n# Titans: Learning to Memorize at Test Time \r\n\r\nAli Behrouz ${ }^{\\dagger}$, Peilin Zhong ${ }^{\\dagger}$, and Vahab Mirrokni ${ }^{\\dagger}$<br>$\\dagger$ Google Research<br>\\{alibehrouz, peilinz, mirrokni\\}@google.com\r\n\r\n\r\n#### Abstract\r\n\r\nOver more than a decade there has been an extensive research effort of how effectively utilize recurrent models and attentions. While recurrent models aim to compress the data into a fixed-size memory (called hidden state), attention allows attending to the entire context window, capturing the direct dependencies of all tokens. This more accurate modeling of dependencies, however, comes with a quadratic cost, limiting the model to a fixed-length context. We present a new neural long-term memory module that learns to memorize historical context and helps an attention to attend to the current context while utilizing long past information. We show that this neural memory has the advantage of a fast parallelizable training while maintaining a fast inference. From a memory perspective, we argue that attention due to its limited context but accurate dependency modeling performs as a short-term memory, while neural memory due to its ability to memorize the data, acts as a long-term, more persistent, memory. Based on these two modules, we introduce a new family of architectures, called Titans, and present three variants to address how one can effectively incorporate memory into this architecture. Our experimental results on language modeling, common-sense reasoning, genomics, and time series tasks show that Titans are more effective than Transformers and recent modern linear recurrent models. They further can effectively scale to larger than 2 M context window size with higher accuracy in needle-in-haystack tasks compared to baselines.\r\n\r\n\r\n## 1 Introduction\r\n\r\n\"The true art of memory is the art of attention!\"\r\n\r\n- Samuel Johnson, 1787\r\n\r\nTransformers, pure attention-based architectures (Vaswani et al. 2017), have been firmly established as state-of-the-art models in sequence modeling, mainly due to their in-context learning and ability to learn at scale (Kaplan et al. 2020). The primary building blocks of Transformers-attention modules-function as associative memory blocks (Bietti et al. 2024), where they learn to store key-value associations and retrieve them by computing pairwise similarity between queries (i.e., search signals) and keys (i.e., contexts). Accordingly, by design, the output of a Transformer is exclusively conditioned on the direct dependencies of tokens in the current context window. This accurate modeling of dependencies, however, comes with quadratic time and memory complexity in terms of the context length. In complex real-world tasks (e.g., language modeling (N. F. Liu et al. 2024), video understanding (C.-Y. Wu et al. 2019), long-term time series forecasting (H. Zhou et al. 2021)), the context window can become extremely large, making the applicability of Transformers challenging in these downstream tasks.\r\n\r\nTo overcome the scalability issue of Transformers, recent studies aim to design different variants of linear Transformers (Kacham, Mirrokni, and P. Zhong 2024; Katharopoulos et al. 2020; S. Yang, B. Wang, Shen, et al. 2024), where softmax is replaced by a kernel function in the attention (see $\\S 2.1$ for details), resulting in a significant drop in memory consumption. Despite efficiency and the ability to scale to longer context, linear Transformers do not show competitive performance compared to Transformers as the kernel trick makes the model a linear recurrent network, in which the data is compressed into a matrix-valued states (Katharopoulos et al. 2020). This, however, brings a contradictory fact about linear recurrent (or linear Transformers) models: On one hand, we use these linear models to enhance scalability and efficiency (linear vs. quadratic complexity), whose advantages is appeared for very long context; On the other hand, a very long context cannot be properly compressed in a small vector-valued or matrix-valued states (S. Wang 2024).\r\nPAGE 2\r\nFurthermore, beyond efficiency, most existing architectures-ranging from Hopfield Networks (Hopfield 1982) to LSTMs (Jürgen Schmidhuber and Hochreiter 1997) and Transformers (Vaswani et al. 2017)-face challenges when dealing with generalization, length extrapolation, and/or reasoning (Anil et al. 2022; Qin, Y. Zhong, and Deng 2024), all of which are inseparable parts of many hard real-world tasks. Although these architectures draw inspiration from the human brain, each of which are missing: (1) a crucial component for learning process-such as short-term memory, long-term memory, meta-memory, attending to current context, etc. (Cowan 2008); (2) how these components are interconnected systems that can operate independently; and/or (3) the ability to actively learn from data and memorize the abstraction of past history. We argue that in an effective learning paradigm, similar to human brain, there are distinct yet interconnected modules, each of which is responsible for a component crucial to the learning process.\r\n\r\n# Memory Perspective \r\n\r\nMemory is a fundamental mental process and is an inseparable component of human learning (Terry 2017). Without a properly functioning memory system, humans and animals would be restricted to basic reflexes and stereotyped behaviors. Accordingly, memory has been the inspiration for many seminal research in machine learning literature; e.g., Hopfield Networks (Hopfield 1982), LSTMs (Jürgen Schmidhuber and Hochreiter 1997), and Transformers (Vaswani et al. 2017).\r\n\r\nTaking inspiration from the common definitions of memory and learning in neuropsychology literature (Okano, Hirano, and Balaban 2000), most existing architectures consider memory as a neural update caused by an input, and define learning as a process for acquiring effective and useful memory, given an objective. In this perspective, Recurrent Neural Networks (RNNs) (Williams and Zipser 1989) can be defined as models with a vector-valued memory module $\\mathcal{M}$ (also called hidden state) with two main steps: Given a new input $x_{t}$ at time $t$, the model (1) updates the memory using a function $f\\left(\\mathcal{M}_{t-1}, x_{t}\\right)$ (with compression); and (2) retrieves the corresponding memory of input using a function $g\\left(\\mathcal{M}_{t}, x_{t}\\right)$ (see $\\S 2.1$ for details). Similarly, Transformers can be seen as architectures with a growing memory and two similar steps. That is, the pair of key and value matrices acts as the model's memory, and the model: (1) updates the memory by appending the key and value to the memory (without compression), and (2) retrieves query vectors' corresponding memory by finding the similarity of query and key vectors, which is then used to weight the value vectors for the output.\r\n\r\nThis perspective, can help us better understand existing paradigms, their critical differences, and design more effective architectures. For example, the main difference between Transformers (Vaswani et al. 2017) and linear Transformers (Katharopoulos et al. 2020) is the memory structure as well as the memory updating step, in which linear Transformers compress the historical data into a fixed-size matrix-valued memory while Transformers keep all historical data (within the context length) without any compression. While both linear Transformers and linear RNNs (including state space models) compress the information in memory update step, the critical difference lies in the structure of the memory, where linear RNNs (vs. linear Transformers) use a vector-valued memory (vs. matrix-valued memory). Therefore, this perspective motivates us to ask: (Q1) What constitute a good structure for the memory? (Q2) What is a proper memory update mechanism? and (Q3) What is a good memory retrieval process?\r\n\r\nRevisiting our understanding of human memory, it is neither a unitary process nor it serves a single function (Cowan 2008). In fact, memory is a confederation of systems-e.g., short-term, working, and long-term memory-each serving a different function with different neural structures, and each capable of operating independently (Willingham 1997). This fact motivates us to ask: (Q4) How to design an efficient architecture that incorporates different interconnected memory modules. Finally, storing a memory is a neural process that requires to encode and store the abstraction of the past. It can be over-simplification to assume a single vector or a matrix, whose parameters are encoding the data in a linear manner, are enough for storing long-term history. (Q5) Is a deep memory module needed to effectively store/remember long past?\r\n\r\n## Contributions and Roadmap\r\n\r\nIn this paper, we aim to answer the above five questions by designing a long-term neural memory module, that can efficiently and effectively learn to memorize at test time. Building upon its design, we discuss how it can be incorporated into an architecture.\r\n\r\nNeural Memory (§3). We present a (deep) neural long-term memory that (as a meta in-context model) learns how to memorize/store the data into its parameters at test time. Inspired by human long-term memory system (Mandler 2014),\r\nPAGE 3\r\nwe design this memory module so an event that violates the expectations (being surprising) is more memorable. To this end, we measure the surprise of an input with the gradient of the neural network with respect to the input in associative memory loss (see $\\S 3.1$ for details). To better handle the limited memory, we present a decaying mechanism that consider the proportion of memory size and the amount of data surprise, resulting in better memory management. We show that this decay mechanism is in fact the generalization of forgetting mechanism in modern recurrent models (Dao and Gu 2024; Gu and Dao 2024; S. Yang, Kautz, and Hatamizadeh 2024). Interestingly, we find that this mechanism is equivalent to optimizing a meta neural network with mini-batch gradient descent, momentum, and weight decay. Building upon tensorizing mini-batch gradient descent to use more matmul operations (Yu Sun et al. 2024), we present a fast and parallelizable algorithm to train our deep neural long-term memory.\r\n\r\nTitans Architectures (§4). After designing the long-term neural memory, an important remaining question is how to effectively and efficiently incorporate memory into a deep learning architecture. We present Titans, a family of deep models that consists of three hyper-heads: (1) Core: this module consists of the short-term memory, and is responsible for the main flow of processing the data (we use attention with limited window size); (2) Long-term Memory: this branch is our neural long-term memory module that is responsible to store/remember long past; (3) Persistent Memory: this is a set of learnable but date-independent parameters that encodes the knowledge about a task. Finally, as a proof of concept, we present three variants of Titans, in which we incorporate memory as: (i) a context, (ii) a layer, and (iii) a gated branch.\r\n\r\nExperimental Results (§5). We perform experimental evaluations on language modeling, commonsense reasoning, recallintensive, needle in haystack, time series forecasting, and DNA modeling tasks. We observe that our Titan architecture outperforms all modern recurrent models as well as their hybrid variants (combining with sliding-window attention) across a comprehensive set of benchmarks. Furthermore, Titans outperforms Transformers with the same context window, and show competitive performance with Transformers that use the entire context. This results are achieved while, contrary to Transformers, Titans scale to larger than 2 M context window size.\r\n\r\n# 2 Preliminaries \r\n\r\n$\\square$In this section, we discuss the notation and some background concepts that we use though the paper. We let $\\mathbf{y}_{i} \\in \\mathbb{R}^{N \\times d_{\\text {in }}}$ be the input, $\\mathcal{M}$ be a neural network (neural memory module), $\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}$ be the query, key and value of the attention mechanism, and $\\mathbf{M}$ be the attention mask. When segmenting the sequence, we use $S^{(t)}$ to refer to the $i$-th segment. Through the paper, we abuse the notation and use subscripts to refer to a specific element of a matrix, vector, or segments. For example, we let $S_{j}^{(t)}$ be the $j$-th token in the $i$-th segment. The only exception is subscripts with $t$, which we reserved to index recurrence over time, or the state of a neural network at time $t$. Given a neural network $\\mathcal{N}$ and a data sample $x$, we use $\\mathcal{N}(x)$ (resp. $\\mathcal{N}^{*}(x)$ ) to refer to the forward pass with (resp. without) weight adjustment. Also, we abuse the notation and use $\\mathcal{N}^{(k)}$ to refer to the $k$-th layer of the neural network. In the following, we first, discuss the backgrounds for attention and its efficient variants followed by a review of modern linear RNNs. Finally, we discuss a memory perspective of these architectures that motivates us to design Titans.\r\n\r\n### 2.1 Backgrounds\r\n\r\nAttention. Transformers (Vaswani et al. 2017) as the de facto backbone for many deep learning models are based on attention mechanism. Given input $x \\in \\mathbb{R}^{N \\times d_{\\text {in }}}$, causal attention computes output $\\mathbf{y} \\in \\mathbb{R}^{N \\times d_{\\text {in }}}$ based on softmax over input dependent key, value, and query matrices:\r\n\r\n$$\r\n\\begin{gathered}\r\n\\mathbf{Q}=x \\mathbf{W}_{\\mathbf{Q}}, \\quad \\mathbf{K}=x \\mathbf{W}_{\\mathbf{K}}, \\quad \\mathbf{V}=x \\mathbf{W}_{\\mathbf{V}} \\\\\r\n\\mathbf{y}_{i}=\\sum_{j=1}^{i} \\frac{\\exp \\left(\\mathbf{Q}_{i}^{\\top} \\mathbf{K}_{j} / \\sqrt{d_{\\text {in }}}\\right) \\mathbf{V}_{j}}{\\sum_{f=1}^{i} \\exp \\left(\\mathbf{Q}_{i}^{\\top} \\mathbf{K}_{\\mathrm{f}} / \\sqrt{d_{\\text {in }}}\\right)}\r\n\\end{gathered}\r\n$$\r\n\r\nwhere $\\mathbf{W}_{\\mathbf{Q}}, \\mathbf{W}_{\\mathbf{K}}$, and $\\mathbf{W}_{\\mathbf{V}} \\in \\mathbb{R}^{d_{\\text {in }} \\times d_{\\text {in }}}$ are learnable parameters. Despite the power and effectiveness in recall, transformers need at least $N \\times d$ operators to calculate the output, resulting in larger memory consumption and lower-throughput for longer sequences.\r\n\r\nEfficient Attentions. To improve the memory consumption and throughput of softmax attention for longer sequences, various studies focused on I/O aware implementations of attention (Dao 2024; Dao, D. Fu, et al. 2022), designing more\r\nPAGE 4\r\nefficient attention mechanisms by sparsifying the attention matrix (B. Chen et al. 2021; Choromanski et al. 2021; Dai et al. 2019), approximating the softmax (Arora et al. 2024), or developing kernel-based (linear) attentions (Aksenov et al. 2024; Kacham, Mirrokni, and P. Zhong 2024; Schlag, Irie, and Jürgen Schmidhuber 2021; S. Yang, B. Wang, Shen, et al. 2024). In this part, we focus on the later, i.e., linear attentions, where the softmax in standard attention is replaced with an alternative kernel function $\\phi(\\cdot, \\cdot)$, such that $\\phi(x, y)=\\phi(x) \\phi(y)$. Accordingly, the attention can be written as:\r\n\r\n$$\r\n\\mathbf{y}_{i}=\\sum_{j=1}^{i} \\frac{\\phi\\left(Q_{i}^{\\top} K_{j}\\right)}{\\sum_{\\ell=1}^{i} \\phi\\left(Q_{i}^{\\top} K_{\\ell}\\right)} V_{j}=\\sum_{j=1}^{i} \\frac{\\phi\\left(Q_{i}\\right)^{\\top} \\phi\\left(K_{j}\\right)}{\\sum_{\\ell=1}^{i} \\phi\\left(Q_{i}\\right)^{\\top} \\phi\\left(K_{\\ell}\\right)} V_{j}=\\frac{\\phi\\left(Q_{i}\\right)^{\\top} \\sum_{j=1}^{i} \\phi\\left(K_{j}\\right) V_{j}}{\\phi\\left(Q_{i}\\right)^{\\top} \\sum_{\\ell=1}^{i} \\phi\\left(K_{\\ell}\\right)}\r\n$$\r\n\r\nresulting in a higher-throughput as terms $\\sum_{j=1}^{i} \\phi\\left(K_{j}\\right)$ and $\\sum_{\\ell=1}^{i} \\phi\\left(K_{\\ell}\\right)$ are re-using in each step. When choosing the kernel as identity matrix (Yutao Sun et al. 2023), the above formulation can also be written in a recurrent format:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\mathcal{M}_{t}=\\mathcal{M}_{t-1}+K_{t}^{\\top} V_{t} \\\\\r\n& \\mathbf{y}_{t}=Q_{t} \\mathcal{M}_{t}\r\n\\end{aligned}\r\n$$\r\n\r\nwhich allows efficient inference for linear attentions.\r\nModern Linear Models and Their Memory Perspective. As discussed earlier, one can define learning as a process for acquiring effective and useful memory. Building upon this, one can see the hidden state of Recurrent Neural Networks (RNNs) as a memory unit, which the model aims to compress the information into. Accordingly, in a general form of recurrent neural network, the hidden state can be treated as a memory unit and the recurrence process can be split into the read and write operations in the memory unit. That is, we let $x \\in \\mathbb{R}^{N \\times d_{\\text {in }}}$ be the input, $\\mathcal{M} \\in \\mathbb{R}^{d}$ is the memory unit, and $\\mathbf{y} \\in \\mathbb{R}^{d_{\\text {in }}}$ is the output, then the general form of the recurrent neural network is defined as:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\mathcal{M}_{t}=f\\left(\\mathcal{M}_{t-1}, x_{t}\\right), \\\\\r\n& \\text { Write Operation } \\\\\r\n& \\mathbf{y}_{t}=g\\left(\\mathcal{M}_{t}, x_{t}\\right),\r\n\\end{aligned}\r\n$$\r\n\r\nwhere $f(\\cdot, \\cdot)$ is the read and $g(\\cdot, \\cdot)$ is the write corresponding functions. Note that here the subscript of $\\mathcal{M}_{t}$ shows the state of the memory at time $t$.\r\n\r\nIn this perspective, the recurrence formula of linear Transformers (see Equation 4) is equivalent to additively compress and write keys and values, $\\left(K_{t}, V_{t}\\right)$, into a matrix-valued memory unit $\\mathcal{M}_{t}$. Therefore, when dealing with long context data, this additive nature of the process results in memory overflow, significantly damaging the performance of the model. To address this, studies have focused on two promising directions: (1) Adding forget mechanism: several studies have presented adaptive (data-dependent) forgetting gate mechanisms for linear models, where it can erase the memory when it is needed. As examples of such models, we refer to GLA (S. Yang, B. Wang, Shen, et al. 2024), LRU (Orvieto et al. 2023), Griffin (De et al. 2024), xLSTM (Beck et al. 2024), and Mamba2 (Dao and Gu 2024), which the later is also connected to the discretized version of traditional state space models (Gu and Dao 2024).(2) Improving the write operation: To overcome the additive nature of memory write operation in traditional recurrent models, Widrow and Hoff (1988) presented Delta Rule, in which before adding a memory (i.e., a pair of key and value), the model first removes its past value. To enhance the parallelizable training and scaling, S. Yang, B. Wang, Yu Zhang, et al. (2024) present a fast paralellizable algorithm. Finally, very recently, S. Yang, Kautz, and Hatamizadeh (2024) improved the DeltaNets by adding a forget gate.\r\n\r\nMemory Modules. Memory has always been one of the core parts of the neural network designs (Graves, Wayne, and Danihelka 2014; JH Schmidhuber 1992; Jürgen Schmidhuber and Hochreiter 1997; J. Zhang et al. 2024). The idea of seeing linear layers as the key-value (associative) memory system backs to fast weight programs, in which dynamic fast programs are incorporated into recurrent neural networks to serve as writable memory (JH Schmidhuber 1992). The two learning rules of Hebbian (Hebb 2005) and delta (Prados and Kak 1989) are the most popular learning rules for fast weight programs, which have been extensively explored in various studies (Irie, Schlag, et al. 2021; Munkhdalai, Sordoni, et al. 2019; Munkhdalai and H. Yu 2017; Schlag, Irie, and Jürgen Schmidhuber 2021; JH Schmidhuber 1992; S. Yang, Kautz, and Hatamizadeh 2024; S. Yang, B. Wang, Yu Zhang, et al. 2024). All these models, however, are based on momentary surprise, missing the token flow in the sequences (see Section 3.1), and most of them lacks a forgetting gate, resulting in a poor memory management.\r\n\r\nWe further discuss the connection of our architectures with recent models in Appendix C. Additional related work are discussed in Appendix A.\r\nPAGE 5\r\n# 3 Learning to Memorize at Test Time \r\n\r\nO overcome the lack of long-term memory and to enable the model to learn, forget, and retrieve information, in this section, we present a neural long-term memory module, which is a meta models that learns to memorize at test time. In Section 3.1, we first discuss the motivation and the design of the neural memory. In Section 3.2, we discuss how our architecture design can benefit from a fast and parallelizable training. Finally, in Section 3.3, we augment our architecture using persistent memory module, in which we use learnable but data-independent parameters to learn meta information about the task.\r\n\r\n### 3.1 Long-term Memory\r\n\r\nTo design a neural long-term memory module, we need a model that can encode the abstraction of the past history into its parameters. An example of this can be LLMs that are shown to be memorizing their training data (Leybzon and Kervadec 2024; Schwarzschild et al. 2024; Staab et al. 2024). Therefore, a simple idea is to train a neural network and expect it to memorize its training data. Memorization, however, has almost always been known as an undesirable phenomena in neural networks as it limits the model generalization (Bayat et al. 2024), causes privacy concerns (Staab et al. 2024), and so results in poor performance at test time. Moreover, the memorization of the training data might not be helpful at test time, in which the data might be out-of-distribution. We argue that, we need an online meta-model that learns how to memorize/forget the data at test time. In this setup, the model is learning a function that is capable of memorization, but it is not overfitting to the training data, resulting in a better generalization at test time.\r\n\r\nLearning Process and Surprise Metric. The key idea to train a long-term memory is to treat its training as an online learning problem, in which we aim to compress the past information $x_{1}, \\ldots, x_{t-1}$ into the parameters of our long-term neural memory module $\\mathcal{M}_{t}$. As discussed earlier, an event that violates the expectations (i.e., is surprising) is more memorable for humans (Mandler 2014). Inspired by this, a simple definition of surprise for a model can be its gradient with respect to the input. The larger the gradient is, the more different the input data is from the past data. Accordingly, using this surprise score, we can update the memory as:\r\n\r\n$$\r\n\\mathcal{M}_{t}=\\mathcal{M}_{t-1}-\\theta_{t} \\underbrace{\\nabla \\ell\\left(\\mathcal{M}_{t-1} ; x_{t}\\right)}_{\\text {Surprise }}\r\n$$\r\n\r\nThis surprise metric, however, can result in missing important information that comes after a big surprising moment. That is, the gradient can become extremely small after several surprising steps, leading to stocking in a flat area (i.e., local minima), and missing information about some parts of the sequence. From the human memory perspective, an event might not consistently surprise us through a long-period of time although it is memorable. The reason is that the initial moment is surprising enough to get our attention through a long time frame, leading to memorizing the entire time frame. To improve the above surprise metric (Equation 8), we break the surprise metric into (1) past surprise, which measures the surprise amount of a very recent past; and (2) momentary surprise, which measures the surprise of incoming data:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\mathcal{M}_{t}=\\mathcal{M}_{t-1}+S_{t} \\\\\r\n& S_{t}=\\eta_{t} \\underbrace{S_{t-1}}_{\\text {Past Surprise }}-\\theta_{t} \\underbrace{\\nabla \\ell\\left(M_{t-1} ; x_{t}\\right)}_{\\text {Momentary Surprise }}\r\n\\end{aligned}\r\n$$\r\n\r\nInterestingly, this formulation is similar to gradient descent with momentum, where $S_{t}$ is the momentum element. Therefore, the momentum here act as a memory of surprise across time (sequence length). In this formulation, the term $\\eta_{t}$ is a data-dependent surprise decay (a function of $x_{t}$ ), controlling how surprise decays over time, and the term $\\theta_{t}$ is controlling how much of momentary surprise should be incorporated into the final surprise metric in a data-dependent manner. This data-dependency is particularly important in this design: While surprise of previous tokens might be needed to affect the surprise of the next token, it is mostly valid if all tokens are relevant and are in the same context. Accordingly, a data-dependent $\\eta$ can control if memory needs to: (1) ignore the last surprise by setting $\\eta_{t} \\rightarrow 0$ (possibly due to the change of context), or (2) fully incorporate the last surprise by setting $\\eta_{t} \\rightarrow 1$ (possibly as the token is highly relevant to its recent past tokens).\r\n\r\nObjective. Our above surprise metric is based on a loss function $\\ell(. ;$.$) , which is the objective that our memory is learning$ to act as it at test time. That is, our memory module is a meta model that learns a function based on the loss function $\\ell(. ;$.$) .$\r\nPAGE 6\r\nIn this work, we focus on associative memory, in which we aim to store the past data as the pairs of keys and values. Given $x_{t}$, similar to Transformers (Vaswani et al. 2017), we use two linear layers to project $x_{t}$ into a key and value:\r\n\r\n$$\r\n\\mathbf{k}_{t}=x_{t} W_{K}, \\quad \\mathbf{v}_{t}=x_{t} W_{V}\r\n$$\r\n\r\nwhere $W_{K}$ and $W_{V} \\in \\mathbb{R}^{d_{\\mathrm{m}} \\times d_{\\mathrm{m}}}$. Next, we expect our memory module to learn the associations between keys and values. To this end, we define the loss as follows:\r\n\r\n$$\r\n\\ell\\left(\\mathcal{M}_{t-1} ; x_{t}\\right)=\\left\\|\\mathcal{M}_{t-1}\\left(\\mathbf{k}_{t}\\right)-\\mathbf{v}_{t}\\right\\|_{2}^{2}\r\n$$\r\n\r\nBy optimizing the above loss function in the inner-loop of our meta model (memory), the model learns how to memorize the mapping between keys and values at test time. Note that, similar to meta-learning models (Nichol 2018; Zintgraf et al. 2019), training of the memory is in the inner-loop, and so parameters $W_{K}$ and $W_{V}$ are hyperparameters in the above loss function. Accordingly, in the inner loop, we optimize $\\mathcal{M}^{\\prime}$ s weights, while in the outer-loop, we optimize other parameters of the entire architecture.\r\n\r\nForgetting Mechanism. When dealing with very large sequences (e.g., millions of tokens), it is crucial to manage which past information should be forgotten-even with a deep or a very large matrix-valued memory. To this end, we use an adaptive forgetting mechanism that allows the memory to forget the information that is not needed anymore, resulting in better managing the memory's limited capacity. That is, given the next token $x_{t}$, we modify the update rule as:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\mathcal{M}_{t}=\\left(1-\\alpha_{t}\\right) \\mathcal{M}_{t-1}+S_{t} \\\\\r\n& S_{t}=\\eta_{t} S_{t-1}-\\theta_{t} \\nabla \\ell\\left(M_{t-1} ; x_{t}\\right)\r\n\\end{aligned}\r\n$$\r\n\r\nwhere $\\alpha_{t} \\in[0,1]$ is the gating mechanism that flexibly controls the memory; i.e., decides how much information should be forgotten. For example, it can update the memory without affecting the past abstraction by letting $\\alpha_{t} \\rightarrow 0$, and can clear the entire memory by letting $\\alpha_{t} \\rightarrow 1$. Later in this section, we show that this weight decay mechanism is closely related to the gating mechanism in modern RNNs (Dao and Gu 2024; Orvieto et al. 2023).\r\n\r\nMemory Architecture. In this paper, we focus on simple MLPs with $L_{M} \\geq 1$ layers as the architecture of our long-term memory. The main reason behind this choice is that we want to focus on better motivating the design of the long-term memory and ways that it can be incorporated into an architecture. However, our formulation and architectural design opens a new research direction to design neural architectures that are more effective and efficient in memorization of data. Recently, there has been a promising line of work to design such architectures (Berges et al. 2024; Cetin et al. 2024; J. Zhang et al. 2024), which incorporating them into our framework (i.e., replacing simple MLPs with such architectures) can be an interesting future work.\r\n\r\nWhen using vector-valued or matrix-valued memory (De et al. 2024; Orvieto et al. 2023; S. Yang, B. Wang, Shen, et al. 2024), the memory module is compressing the past data and fit it into a line. That is, from the meta learning or online learning perspective (Yu Sun et al. 2024), using a matrix-valued memory $\\mathcal{M}=W \\in \\mathbb{R}^{d_{\\mathrm{m}} \\times d_{\\mathrm{m}}}$ is equivalent to optimize $\\ell\\left(W_{t-1} ; x_{t}\\right)=\\left\\|W_{t-1} \\mathbf{k}_{t}-\\mathbf{v}_{t}\\right\\|_{2}^{2}$, which is an online linear regression objective and so the optimal solution assumes the underlying dependency of historical data is linear. On the other hand, we argue that deep memory modules (i.e., $L_{M} \\geq 2$ ). Aligning with the theoretical results that MLPs with at least two layers are strictly more expressive than linear models (Hornik, Stinchcombe, and White 1989), in Section 5.5, we show that deep memory modules are more effective in practice.\r\n\r\nRetrieving a Memory. In the above, we discuss how one can design and train a long-term memory module that learns to memorize at test time. A key remaining question is: How one can retrieve information from the memory? We simply use the forward pass without weight update (i.e., inference) to retrieve a memory correspond to a query. Formally, given an input $x_{t}$, we use a linear layer $W_{Q}$ to project the input, i.e., $\\mathbf{q}_{t}=x_{t} W_{Q}$ and retrieve the corresponding (or useful) information from the memory $y_{t}$ by:\r\n\r\n$$\r\ny_{t}=\\mathcal{M}^{*}\\left(\\mathbf{q}_{t}\\right)\r\n$$\r\nPAGE 7\r\n![img-0.jpeg](img-0.jpeg)\r\n\r\nFigure 1: The illustration of how the training of neural memory can be done in parallel and using matmuls.\r\n\r\n# 3.2 How to Parallelize the Long-term Memory Training \r\n\r\nAs discussed above, the design of our long-term memory module is equivalent to training a meta model by optimizing associative memory loss function $\\ell\\left(\\mathcal{M}_{t-1} ; x_{t}\\right)=\\left\\|\\mathcal{M}_{t-1}\\left(\\mathbf{k}_{t}\\right)-\\mathbf{v}_{t}\\right\\|_{2}^{2}$ using gradient descent with momentum and weight decay. Therefore, in theory, the training of long-term memory module requires $O(N)$ FLOPs, where $N$ is the sequence length. However, in practice, we need to parallelize the training process and to fully take advantage of hardware accelerators (e.g., TPUs, GPUs), we need to tensorize the process and use more matmuls.\r\n\r\nNext, we show that calculating the weights in the inner loop with mini-batch gradient descent, data-dependent learning rate, and weight decay can be reformulated so that it uses only matmuls and sum. We build upon the work of Yu Sun et al. (2024) that shows forward pass of a model optimizing with the mini-batch gradient descent (with constant learning rate) can be calculated using matmuls. We can split the sequence into chunks of size $b \\geq 1$, and write the mini-batch gradient descent as:\r\n\r\n$$\r\n\\mathcal{M}_{t}=\\left(1-\\alpha_{t}\\right) \\mathcal{M}_{t-1}-\\theta_{t} \\nabla \\ell\\left(\\mathcal{M}_{t-1} ; x_{t}\\right)=\\beta_{t} \\mathcal{M}_{0}-\\sum_{i=1}^{t} \\theta_{i} \\frac{\\beta_{t}}{\\beta_{i}} \\nabla \\ell\\left(\\mathcal{M}_{t^{\\prime}} ; x_{i}\\right)\r\n$$\r\n\r\nwhere $t^{\\prime}=t-\\bmod (t, b)$, and $\\beta_{i}=\\prod_{j=1}^{i}\\left(1-\\alpha_{j}\\right)$. For the sake of simplicity, we focus on the first chunk, i.e., $t=b$ and so $t^{\\prime}=0$. Also, we explain the process for the case that $\\mathcal{M}_{t}=W_{t}$ is linear. The process for MLPs with $N_{p} \\geq 2$ is similar. Using our loss function, we have:\r\n\r\n$$\r\n\\nabla \\ell\\left(W_{0} ; x_{t}\\right)=\\left(W_{0} x_{t}-x_{t}\\right) x_{t}^{\\top} \\Rightarrow \\sum_{i=1}^{b} \\theta_{i} \\frac{\\beta_{b}}{\\beta_{i}} \\nabla \\ell\\left(W_{0} ; x_{i}\\right)=\\Theta_{b} \\mathbf{B}_{b}\\left(W_{0} X-X\\right) X^{\\top}\r\n$$\r\n\r\nwhere $\\Theta_{b}=\\operatorname{diag}\\left(\\left[\\begin{array}{llll}\\theta_{1} & \\theta_{2} & \\ldots & \\theta_{b}\\end{array}\\right]\\right)$ and $\\mathbf{B}_{b}$ is defined analogously on $\\frac{\\beta_{b}}{\\beta_{i}} \\mathrm{~s}$. Note that, we do not need to store all $\\Theta_{k b}$ and $\\mathbf{B}_{k b}$ for $k=1, \\ldots, N / b$, instead, we store these matrices for each chunk, resulting in using less memory. Next, we extend this representation so we can also incorporate the momentum term. In a chunk wise gradient descent with momentum, if we look at the momentum term, we have:\r\n\r\n$$\r\nS_{t}=\\eta_{t} S_{t-1}-\\theta_{t} u_{t}\r\n$$\r\n\r\nwhere $u_{t}=\\nabla \\ell\\left(M_{t^{\\prime}} ; x_{t}\\right)$. Note that, we can compute all $u_{t}$ at the same time, and so Equation 18 is a linear recurrence with $u_{t}$ as an input, $S_{t}$ as the hidden state, and $\\eta_{t}$ as input-dependent transition value. Accordingly, we can use parallel associative scan (J. T. Smith, Warrington, and Linderman 2023) to calculate $S_{t} \\mathrm{~s}$ in this chunk.\r\n\r\nParameters as the Function of Chunks. Instead of making parameters like $\\alpha_{t}, \\theta_{t}$, and $\\eta_{t}$ input-dependent (i.e., a function of token $x_{t}$ ), we can make them functions of their chunk. Despite losing expressive power, this formulation can help to make the training even faster. In this case, we are using the same value for each of $\\alpha, \\theta$, and $\\eta$ in each chunk. Accordingly, in Equation 17, we can store $\\Theta$ using a single scaler. Similarly we can make Equation 18 faster. That is, when $\\eta$ and $\\theta$ are learnable but time-invariant inside each chunk, this equation becomes a linear time-invariant system (LTI), which can be computed by a global convolution (Gu, Goel, and Re 2022). In our experiments, we make these parameters as the functions of tokens. However, such simplifications (i.e., as the function of chunks) can be the interest of future work to training larger models in more efficient manner.\r\nPAGE 8\r\n![img-1.jpeg](img-1.jpeg)\r\n\r\nFigure 2: Memory as a Context (MAC) Architecture. This architecture includes three branches of (1) core, (2) contextual (long-term) memory, and (3) persistent memory. The core branch concatenates the corresponding long-term and persistent memories with the input sequence. Next, attention performs on the sequence and decides what part of the information should store in the long-term memory. At the test time, parameters corresponds to contextual memory are still learning, parameters corresponds to the core branch are responsible for in-context learning, and parameters of persistent memory are responsible to store the knowledge about tasks and so are fixed.\r\n\r\n# 3.3 Persistent Memory \r\n\r\nOur long-term memory can also be seen as a contextual memory, meaning that the output is fully depend on the context. Therefore, in addition to our long-term memory, we also use a set of learnable but input-independent parameters to act as task-related memory. This type of memory has been referred to as persistent or meta-memory in the literature (X. Dong et al. 2024; Sukhbaatar, Grave, et al. 2019). Given $N_{p} \\geq 1$, we use learnable parameters $P=\\left[\\begin{array}{llll}p_{1} & p_{2} & \\ldots & p_{N_{p}}\\end{array}\\right]$ and append it to the start of our sequence: i.e., given a context window size of $N$, we modify the input as:\r\n\r\n$$\r\nx_{\\text {new }}=\\left[\\begin{array}{llll}\r\np_{1} & p_{2} & \\ldots & p_{N_{p}}\r\n\\end{array}\\right] \\quad \\mid \\quad x\r\n$$\r\n\r\nwhere $\\|$ is concatenation. Next, we discuss the motivation of persistent memory from three perspective:\r\nMemory Perspective. As discussed earlier, our neural long-term memory is a contextual memory, in which all parameters are input-dependent. An effective memory system, however, also needs input-independent parameters to store the abstraction of the task knowledge. That is, mastering a task requires the memorization of the knowledge that how the task can be done, and these parameters are responsible for storing such knowledge.\r\n\r\nFeedforward Network Perspective. In the Transformer architectures, there are fully connected layers after the attention module, which are shown to be similar to attention weights but with data-independent parameters. That is, Sukhbaatar, Grave, et al. (2019) showed that replacing the ReLU in fully connected layers with Softmax can results in an attention-like weights, in which weights are data-independent:\r\n\r\n$$\r\nF F N(x)=W_{V} \\operatorname{Softmax}\\left(W_{K} x\\right)\r\n$$\r\n\r\nIn fact, $W_{K}$ and $W_{V}$ are acting similar to $K$ and $V$ matrices in attention module when they are input-independent. The persistent memory weights are expected to have the same functionality, meaning that using them in the first part of the sequence leads to having input-independent attention weights (Sukhbaatar, Grave, et al. 2019).\r\n\r\nTechnical Perspective. Attention with causal mask has implicit bias toward initial tokens in the sequence, and so attention weights are almost always highly active for initial tokens, resulting in performance damage. From the technical perspective, these learnable parameters at the start of the sequence can mitigate such effect by redistributing the attention weights more effectively (Han et al. 2024; Xiao et al. 2024).\r\nPAGE 9\r\n![img-2.jpeg](img-2.jpeg)\r\n(a) Memory as a Context (MAC). We segment the sequence and use full causal attention in each window. Again, the first $N_{p}$ tokens are persistent memory and the next $N_{l}$ are long-term memory tokens\r\n![img-3.jpeg](img-3.jpeg)\r\n(b) Memory as Gating (MAG). We use sliding window attention (SWA) as a short-term memory and our neural memory module as a long-term memory, combining by a gating.\r\n\r\nFigure 3: Attention masks for different variants of Titans.\r\n\r\n# 4 How to Incorporate Memory? \r\n\r\nAn important question that remained unanswered is: How one can effectively and efficiently incorporate the designed neural memory into a deep learning architecture? As discussed earlier, from a memory perspective, the pair of K and V matrices in transformers can be interpreted as an associative memory block. Due to their accurate modeling of dependencies and so their limited context window, we interpret them as short-term memory modules, attending to the current context window size. On the other hand, our neural memory with the ability to continuously learn from data and store it in its weights can play the role of a a long-term memory. In this section, we aim to answer the above question by proposing three different variants of Titans. Later in our experiments, we show that each of these variants has its own advantages/disadvantages and also can show a trade-off between the efficiency and effectiveness in very long-contexts.\r\n\r\n### 4.1 Memory as a Context\r\n\r\nIn the first architecture design (see Figure 2), we treat the memory as a context to the current information. That is, given a long sequence $x \\in \\mathbb{R}^{N \\times d_{m}}$, we first chunk the sequence into fixed-size segments $S^{(i)}$ for $i=1, \\ldots, N / C$. Given the incoming segment $S^{(t)}$, we consider it as the current context and its past segment as the historical information. Therefore, let $\\mathcal{M}_{t-1}$ be the state of long-term memory before segment $S^{(t)}$, we use the input context as the query to the memory $\\mathcal{M}_{t-1}$ to retrieve the corresponding information from the long-term memory. That is, we retrieve the past information that corresponds to $S^{(t)}$ as:\r\n\r\n$$\r\nh_{t}=\\mathcal{M}_{t-1}^{*}\\left(\\mathbf{q}_{t}\\right)\r\n$$\r\n\r\nwhere $\\mathbf{q}_{t}=S^{(t)} W_{Q}$. Next, we use this historical information along with our persistent memory parameters as the input sequence to the attention module:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\tilde{S}^{(t)}=\\left[\\begin{array}{llll}\r\np_{1} & p_{2} & \\ldots & p_{N_{p}}\r\n\\end{array}\\right]\\left\\|h_{t}\\right\\| S^{(t)} \\\\\r\n& y_{t}=\\operatorname{Attn}\\left(\\tilde{S}^{(t)}\\right)\r\n\\end{aligned}\r\n$$\r\n\r\nThe structure of the attention map over the entire sequence is shown in Figure 3a. We then use $y_{t}$ to update the long-term memory module for the next segment and the final output:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\mathcal{M}_{t}=\\mathcal{M}_{t-1}\\left(y_{t}\\right) \\\\\r\n& o_{t}=y_{t} \\otimes \\mathcal{M}_{t}^{*}\\left(y_{t}\\right)\r\n\\end{aligned}\r\n$$\r\n\r\nNote that, in the above, we are updating the weight of $\\mathcal{M}_{t-1}$ through forward pass.\r\nThis architecture has two key advantages: (1) Attention by having both historical and current context, has the ability to decides whether given the current data, the long-term memory information is needed. (2) The attention module helps\r\nPAGE 10\r\n![img-4.jpeg](img-4.jpeg)\r\n\r\nFigure 4: Memory as a Gate (MAG) Architecture. This architecture, similarly, has the three branches of (1) core, (2) contextual memory, and (3) persistent memory. It, however, incorporates only persistent memory into the context and combine memory with the core branch using a gating mechanism. At test time, the behavior is the same as Figure 2.\r\nthe long-term memory to store only useful information from the current context. That is, not all tokens in each segment are useful and memorizing all of them can result in memory overflow. Therefore, attention is helping the memory to understand which information is useful, better managing the memory capacity. (3) At test time: (i) persistent memory parameters are fixed as they encodes the knowledge about the task, which should not be changed; (ii) the attention module weights are in-context learner; and (iii) the long-term memory module is still learning (memorizing) the information at test time. That is, we update the weights of the neural memory even at test time as weights are encoding the abstraction of long past.\r\n\r\n# 4.2 Gated Memory \r\n\r\nIn the next variant (see Figure 4), in one branch, we directly use the input data to update the long-term memory, and in the second branch, we use a sliding window attention (SWA):\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\tilde{x}=\\left[\\begin{array}{llll}\r\np_{1} & p_{2} & \\ldots & p_{N_{p}}\r\n\\end{array}\\right] \\quad \\mid \\mid x \\\\\r\n& y=\\text { SW-Attn } \\tilde{x}) \\\\\r\n& o=y \\otimes \\mathcal{M}(\\tilde{x})\r\n\\end{aligned}\r\n$$\r\n\r\nwhere SW-Attn* is sliding window attention with prefix (see Figure 3b). Note that, contrary to the previous design, we are not segmenting the input data. Also, we abuse the notation and use $\\mathcal{M}(x)$ to refer to the final output of the memory after all recursion over the tokens of the sequence. In the above equation, $\\otimes$ can be any non-linear gating. In our experiments, we normalize the outputs $y$ and $\\mathcal{M}(\\tilde{x})$ using learnable vector-valued weights, followed by a non-linearity $\\sigma($.$) .$\r\n\r\nThe overall attention mask of this design is shown in Figure 3b. In this design, sliding window attention is act as a precise short-term memory, while the neural memory module is acting as a fading memory for the model. This architecture design can also be seen as a multi-head architecture where the structure of heads are different (X. Dong et al. 2024).\r\n\r\n### 4.3 Memory as a Layer\r\n\r\nThe last variant uses the neural Memory As a Layer (MAL) of a deep neural network (see Figure 5). This architecture design is more common in the literature, where the hybrid models stack recurrent models with full or sliding window attentions. Given input $x$, we have:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\tilde{x}=\\left[\\begin{array}{llll}\r\np_{1} & p_{2} & \\ldots & p_{N_{p}}\r\n\\end{array}\\right] \\quad \\mid \\mid x \\\\\r\n& y=\\mathcal{M}(\\tilde{x}) \\\\\r\n& o=\\text { SW-Attn }(y)\r\n\\end{aligned}\r\n$$\r\nPAGE 11\r\n![img-5.jpeg](img-5.jpeg)\r\n\r\nFigure 5: Memory as a Layer (MAL) Architecture. In this architecture, the memory layer is responsible to compress the past and current context before the attention module.\r\nwhere SW-Attn is sliding window attention. The main drawback of this design is that the power of the model is limited by each of the layers and so it cannot take advantage of the complementary data processing of attention and neural memory module. In our experiments, for evaluating memory in this design, we use a similar architecture as H3 (D. Y. Fu et al. 2023), where we replace the the sequence model with our neural memory module (LMM).\r\n\r\nMemory Without Attention. Although in the above, we discussed MAL as the combination of LMMs and attention in a sequential manner, one simple variant of MAL is to treat LMM as a sequence model without any attention. From the memory perspective, as discussed in Section 1, we expect each part of the memory system to work independently, even if other components are disturbed. Therefore, a long-term memory module should still be a powerful model even without short-term memory (i.e., attention). We refer to this variant as LMM or Titans (LMM) in our experiments. We provide additional discussions on the connection of Titans and other modern recurrent models in Appendix C.\r\n\r\n# 4.4 Architectural Details \r\n\r\nFor the sake of simplicity and presentation, we avoid discussing the implementation details like using residual connection, gating with linear layer, and normalization. In all blocks, we use residual connections. In our implementation, we use SiLU(.) activation (Elfwing, Uchibe, and Doya 2018) as the non-linear activation for computing query, key, and values and normalize queries and keys using $\\ell_{2}$-norm.\r\n\r\nConvolution. Following the recent modern linear recurrent models (Gu and Dao 2024; S. Yang, Kautz, and Hatamizadeh 2024), we incorporate a 1D depthwise-separable convolution layer after each of the query, key, and value projections. While not significantly affect the performance, these 1D convolutions have shown performance improvement and are also computationally efficient.\r\n\r\nGating. We also follow the recent architectures that use normalization and gating with a linear layer before the final output projection (Mehta et al. 2023).\r\nTheorem 4.1. Contrary to Transformers, diagonal linear recurrent models, and DeltaNet, all of which are limited to $\\mathrm{TC}^{0}$ (Merrill, Petty, and Sabharwal 2024), Titans are capable of solving problems beyond $\\mathrm{TC}^{0}$, meaning that Titans are theoretically more expressive than Transformers and most modern linear recurrent models in state tracking tasks.\r\n\r\n## 5 Experiments\r\n\r\n$\\mathbf{R}$ext, we evaluate the performance of Titans and its variants in language modeling, commonsense reasoning, needle in haystack, DNA modeling, and time series forecasting tasks ${ }^{1}$. In more details, in this section, we answer the following empirical questions: (1) How do Titans perform compared to baselines in downstream tasks? (see §5.2,\r\n\r\n[^0]\r\n[^0]:    ${ }^{1}$ In the first version of the work, we aim to provide insights/evidences about why the learning paradigms of Titans are effective. We are working on finalizing the results of larger models and will report them in the next version.\r\nPAGE 12\r\n$\\S 5.6$, and $\\S 5.7$ ); (2) What is the actual context length of Titans? (see $\\S 5.3$ and $\\S 5.4$ ); (3) How do Titans scale with respect to context length? (see §5.8); (4) How the depth of memory can affect both performance and efficiency? (see §5.5); and (5) What is the contribution of each Titans' component in its performance? (see §5.9).\r\n\r\n# 5.1 Experimental Setup \r\n\r\nModels. In our experiments, we focus on the three variants of Titans, which we refer to as: Titans with (1) Memory as a Context (MAC), (2) Memory as a Gate (MAG), and (3) Memory as a Layer (MAL) as well as (4) neural memory module alone. The reason behind using our long-term memory as a separate module is based on our definition of learning. As discussed in Section 1, we define learning a process for acquiring effective and useful memory. Accordingly, we expect our long-term memory to effectively learn from data, even without attention. For each of these models, we consider four scales with: (i) 170 M , (ii) 340 M , (iii) 400 M , and (iv) 760 M parameters. While the first three are trained on 15B tokens sampled from FineWeb-Edu dataset (Penedo et al. 2024), the last one is trained on 30B tokens from the same dataset.\r\n\r\nBaselines. We compare our models with the state-of-the-art linear recurrent models, Transformers, and hybrid models (recurrent + attention). More specifically in language tasks, we compare with Transformer++ (Touvron et al. 2023), RetNet (Yutao Sun et al. 2023), Gated Linear Attention (GLA) (S. Yang, B. Wang, Shen, et al. 2024), Mamba (Gu and Dao 2024), Mamba2 (Dao and Gu 2024), DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024), TTT (Yu Sun et al. 2024), and Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024). In needle in haystack tasks, we also compare with GPT4 (Achiam et al. 2023), Llama3 with RAG (Touvron et al. 2023), RecurrentGemma2-9B (Botev et al. 2024), and Mistral (Jiang et al. 2023) models, all of which are provided in the benchmark (Yuri Kuratov et al. 2024). In time series tasks, we compare with Mamba-based (Behrouz, Santacatterina, and Zabih 2024), Transformer-based (Y. Liu et al. 2023; Nie et al. 2022; Yunhao Zhang and Yan 2023), and linear models (Das et al. 2023; Z. Li et al. 2023; H. Wu et al. 2023; Zeng et al. 2023).\r\n\r\nTraining. In the training, we follow the training procedure of S. Yang, Kautz, and Hatamizadeh (2024), and use LLama 2 tokenizer with a vocabulary size of 32 K and use training length of 4 K tokens. We employ AdamW optimizer with learning rate of $4 e-4$ with cosine annealing schedule with batch size of 0.5 M tokens, and weight decay of 0.1 .\r\n\r\n### 5.2 Language Modeling\r\n\r\nWe first focus on the perplexity in language modeling and also commonsense reasoning tasks. The results for Titans' variants and also baselines with three different sizes of $340 \\mathrm{M}, 400 \\mathrm{M}$, and 760 M are reported in Table 1. Among non-hybrid models, including Transformer++, our neural memory module achieves the best performance in both perplexity and accuracy measures. Comparing our neural memory module and TTT, which is also a gradient-based recurrent model can show us the importance of our weight decay as well as the momentum. As discussed earlier, the weight decay can be interpreted as a gating mechanism to forget the past data, when it is needed. Also, momentum can help us better manage the memory by providing additional memory for the surprise metric. While some baselines also take advantage of gating mechanism, e.g., Mamba, Mamba2, and Gated DeltaNet, the superior performance of our neural memory module shows the importance of both our surprise mechanism and having deep and non-linear memory. We further discuss the later in Section 5.5.\r\n\r\nComparing the hybrid models, we found that all three variants of Titans (MAC, MAG, and MAL) outperform both Samba (Mamba + attention) and Gated DeltaNet-H2 (Gated DeltaNet + atttention). We attribute the superior performance of Titans (MAL) to the power of neural memory module as the architecture design and used attention are all the same. Comparing Titans (MAG) and (MAC), we find that while their performance are close, MAC performs better when dealing with longer dependencies in the data. Interestingly, both MAG and MAC outperform MAL variant, which due to using the same modules, we attribute this to the architecture design of these models. This finding is particularly important as the current hybrid models (except Hymba (X. Dong et al. 2024)) in the literature are using MAL-style combination of recurrent models and attention.\r\n\r\n### 5.3 Needle in a Haystack\r\n\r\nScaling a model to longer context window is not always equivalent to being effective for very long sequences (Hsieh et al. 2024). The needle-in-a-haystack (NIAH) task is designed to measure the actual effective context length of models. In this task, we evaluate the model on retrieving a piece of information (i.e., the \"needle\") from long distractor texts (i.e.,\r\nPAGE 13\r\nTable 1: Performance of Titans and recurrent- and Transformer-based baselines on language modeling and common-sense reasoning tasks. Hybrid models are marked with *. The best results among simple and hybrid models are highlighted.\r\n\r\n| Model | Wiki. <br> ppl $\\downarrow$ | $\\begin{aligned} & \\text { LMB. } \\\\ & \\text { ppl } \\downarrow \\end{aligned}$ | $\\begin{aligned} & \\text { LMB. } \\\\ & \\text { acc } \\uparrow \\end{aligned}$ | PIQA acc $\\uparrow$ | $\\begin{aligned} & \\text { Hella. } \\\\ & \\text { acc_n } \\uparrow \\end{aligned}$ | Wino. acc $\\uparrow$ | ARC-e acc $\\uparrow$ | ARC-c acc_n $\\uparrow$ | $\\begin{aligned} & \\text { SIQA } \\\\ & \\text { acc } \\uparrow \\end{aligned}$ | $\\begin{aligned} & \\text { BoolQ } \\\\ & \\text { acc } \\uparrow \\end{aligned}$ | Avg. $\\uparrow$ |\r\n| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\r\n| 340M params / 15B tokens |  |  |  |  |  |  |  |  |  |  |  |\r\n| Transformer++ | 31.52 | 41.08 | 30.76 | 62.98 | 34.76 | 50.53 | 45.21 | 24.05 | 36.81 | 58.24 | 42.92 |\r\n| RetNet | 32.50 | 49.73 | 28.24 | 62.61 | 34.15 | 50.91 | 44.27 | 23.62 | 36.79 | 59.72 | 42.54 |\r\n| GLA | 28.51 | 43.02 | 28.73 | 64.05 | 35.96 | 50.00 | 54.19 | 24.29 | 37.13 | 58.39 | 44.09 |\r\n| Mamba | 30.83 | 40.21 | 29.94 | 63.79 | 35.88 | 49.82 | 49.24 | 24.56 | 35.41 | 60.07 | 43.59 |\r\n| DeltaNet | 28.65 | 47.30 | 28.43 | 63.52 | 35.95 | 49.63 | 52.68 | 25.37 | 37.96 | 58.79 | 44.04 |\r\n| TTT | 27.44 | 34.19 | 30.06 | 63.97 | 35.71 | 50.08 | 53.01 | 26.11 | 37.32 | 59.83 | 44.51 |\r\n| Gated DeltaNet | 27.01 | 30.94 | 34.11 | 63.08 | 38.12 | 51.60 | 55.28 | 26.77 | 34.89 | 59.54 | 45.42 |\r\n| Titans (LMM) | 26.18 | 29.97 | 34.98 | 64.73 | 39.61 | 51.85 | 55.60 | 28.14 | 34.52 | 59.99 | 46.17 |\r\n| Titans (MAC)* | 25.43 | 28.13 | 36.00 | 65.32 | 40.35 | 51.21 | 58.17 | 29.00 | 38.63 | 60.18 | 47.36 |\r\n| Titans (MAG)* | 25.07 | 28.72 | 36.71 | 64.88 | 40.56 | 52.49 | 57.72 | 28.16 | 39.75 | 60.01 | 47.54 |\r\n| Titans (MAL)* | 24.69 | 28.80 | 35.74 | 64.97 | 39.44 | 51.97 | 56.58 | 28.21 | 38.14 | 57.32 | 46.55 |\r\n| 400M params / 15B tokens |  |  |  |  |  |  |  |  |  |  |  |\r\n| Transformer++ | 30.63 | 37.37 | 29.64 | 64.27 | 37.72 | 51.53 | 54.95 | 27.36 | 38.07 | 61.59 | 45.64 |\r\n| RetNet | 29.92 | 46.83 | 29.16 | 65.23 | 36.97 | 51.85 | 56.01 | 27.55 | 37.30 | 59.66 | 45.47 |\r\n| HGRN2 | 32.33 | 47.14 | 26.12 | 64.52 | 35.45 | 52.24 | 55.97 | 25.51 | 37.35 | 59.02 | 44.52 |\r\n| GLA | 27.96 | 36.66 | 27.86 | 65.94 | 37.41 | 49.56 | 56.01 | 26.36 | 38.94 | 59.84 | 45.24 |\r\n| Mamba | 29.22 | 39.88 | 29.82 | 65.72 | 37.93 | 50.11 | 58.37 | 26.70 | 37.76 | 61.13 | 45.94 |\r\n| Mamba2 | 26.34 | 33.19 | 32.03 | 65.77 | 39.73 | 52.48 | 59.00 | 27.64 | 37.92 | 60.72 | 46.91 |\r\n| DeltaNet | 27.69 | 44.04 | 29.96 | 64.52 | 37.03 | 50.82 | 56.77 | 27.13 | 38.22 | 60.09 | 45.57 |\r\n| TTT | 26.11 | 31.52 | 33.25 | 65.70 | 39.11 | 51.68 | 58.04 | 28.99 | 38.26 | 59.87 | 46.86 |\r\n| Gated DeltaNet | 25.47 | 29.24 | 34.40 | 65.94 | 40.46 | 51.46 | 59.80 | 28.58 | 37.43 | 60.03 | 47.26 |\r\n| Samba* | 25.32 | 29.47 | 36.86 | 66.09 | 39.24 | 51.45 | 60.12 | 27.20 | 38.68 | 58.22 | 47.23 |\r\n| Gated DeltaNet-H2* | 24.19 | 28.09 | 36.77 | 66.43 | 40.79 | 52.17 | 59.55 | 29.09 | 39.04 | 58.56 | 47.69 |\r\n| Titans (LMM) | 25.03 | 28.99 | 35.21 | 65.85 | 40.91 | 52.19 | 59.97 | 29.20 | 38.74 | 60.85 | 47.83 |\r\n| Titans (MAC)* | 25.61 | 27.73 | 36.92 | 66.39 | 41.18 | 52.80 | 60.24 | 29.69 | 40.07 | 61.93 | 48.65 |\r\n| Titans (MAG)* | 23.59 | 27.81 | 37.24 | 66.80 | 40.92 | 53.21 | 60.01 | 29.45 | 39.91 | 61.28 | 48.60 |\r\n| Titans (MAL)* | 23.93 | 27.89 | 36.84 | 66.29 | 40.74 | 52.26 | 59.85 | 29.71 | 38.92 | 58.40 | 47.87 |\r\n| 760M params / 30B tokens |  |  |  |  |  |  |  |  |  |  |  |\r\n| Transformer++ | 25.21 | 27.64 | 35.78 | 66.92 | 42.19 | 51.95 | 60.38 | 32.46 | 39.51 | 60.37 | 48.69 |\r\n| RetNet | 26.08 | 24.45 | 34.51 | 67.19 | 41.63 | 52.09 | 63.17 | 32.78 | 38.36 | 57.92 | 48.46 |\r\n| Mamba | 28.12 | 23.96 | 32.80 | 66.04 | 39.15 | 52.38 | 61.49 | 30.34 | 37.96 | 57.62 | 47.22 |\r\n| Mamba2 | 22.94 | 28.37 | 33.54 | 67.90 | 42.71 | 49.77 | 63.48 | 31.09 | 40.06 | 58.15 | 48.34 |\r\n| DeltaNet | 24.37 | 24.60 | 37.06 | 66.93 | 41.98 | 50.65 | 64.87 | 31.39 | 39.88 | 59.02 | 48.97 |\r\n| TTT | 24.17 | 23.51 | 34.74 | 67.25 | 43.92 | 50.99 | 64.53 | 33.81 | 40.16 | 59.58 | 47.32 |\r\n| Gated DeltaNet | 21.18 | 22.09 | 35.54 | 68.01 | 44.95 | 50.73 | 66.87 | 33.09 | 39.21 | 59.14 | 49.69 |\r\n| Samba* | 20.63 | 22.71 | 39.72 | 69.19 | 47.35 | 52.01 | 66.92 | 33.20 | 38.98 | 61.24 | 51.08 |\r\n| Gated DeltaNet-H2* | 19.88 | 20.83 | 39.18 | 68.95 | 48.22 | 52.57 | 67.01 | 35.49 | 39.39 | 61.11 | 51.49 |\r\n| Titans (LMM) | 20.04 | 21.96 | 37.40 | 69.28 | 48.46 | 52.27 | 66.31 | 35.84 | 40.13 | 62.76 | 51.56 |\r\n| Titans (MAC) | 19.93 | 20.12 | 39.62 | 70.46 | 49.01 | 53.18 | 67.86 | 36.01 | 41.87 | 62.05 | 52.51 |\r\n| Titans (MAG) | 18.61 | 19.86 | 40.98 | 70.25 | 48.94 | 52.89 | 68.23 | 36.19 | 40.38 | 62.11 | 52.50 |\r\n| Titans (MAL) | 19.07 | 20.33 | 40.05 | 69.99 | 48.82 | 53.02 | 67.54 | 35.65 | 30.98 | 61.72 | 50.97 |\r\n\r\nthe \"haystack\"). In this part, we use Single NIAH (S-NIAH) task from RULER benchmark (Hsieh et al. 2024) and evaluate Titans and baselines on sequences with length $2 \\mathrm{~K}, 4 \\mathrm{~K}, 8 \\mathrm{~K}$, and 16 K . The results are reported in Table 2. Neural Memory module achieves the best results compare to baselines in all three tasks. We attribute this superior performance to three key differences of Titans with existing sequence models: (1) Compared to TTT, our Neural Memory can better handle the memory capacity by using momentum and also the forgetting mechanism (i.e., weight decay). Therefore, with increasing the sequence length, the performance of Neural Memory does not drop and show a consistent trend; (2) Compared to Mamba2, which has the gating (forgetting) mechanism, Titans have deep non-linear memory, resulting in better memory management. Also, contrary to our neural memory and DeltaNet, Mamba2 is not capable of removing a memory and so\r\nPAGE 14\r\nTable 2: Performance of Titans and baselines on S-NIAH task from RULER benchmark. The best results among simple and hybrid models are highlighted.\r\n\r\n| Model | S-NIAH-PK |  |  |  | S-NIAH-N |  |  |  | S-NIAH-W |  |  |  |\r\n| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\r\n|  | 2 K | 4 K | 8 K | 16 K | 2 K | 4 K | 8 K | 16 K | 2 K | 4 K | 8 K | 16 K |\r\n| TTT | 98.4 | 98.8 | 98.0 | 88.4 | 60.2 | 36.6 | 10.2 | 4.4 | 78.8 | 28.0 | 4.4 | 0.0 |\r\n| Mamba2 | 98.6 | 61.4 | 31.0 | 5.4 | 98.4 | 55.8 | 14.2 | 0.0 | 42.2 | 4.2 | 0.0 | 0.0 |\r\n| DeltaNet | 96.8 | 98.8 | 98.6 | 71.4 | 47.2 | 15.4 | 12.8 | 5.4 | 46.2 | 20.0 | 1.6 | 0.0 |\r\n| Titans (LMM) | 99.8 | 98.4 | 98.2 | 96.2 | 100.0 | 99.8 | 93.4 | 80.2 | 90.4 | 89.4 | 85.8 | 80.6 |\r\n| Titans (MAC) | 99.2 | 98.8 | 99.0 | 98.4 | 99.6 | 98.2 | 97.6 | 97.4 | 98.2 | 98.2 | 95.6 | 95.2 |\r\n| Titans (MAG) | 99.4 | 98.0 | 97.4 | 97.4 | 99.2 | 98.8 | 97.2 | 98.6 | 98.0 | 98.0 | 90.2 | 88.2 |\r\n| Titans (MAL) | 98.8 | 98.6 | 98.8 | 97.8 | 99.8 | 98.1 | 96.8 | 96.4 | 98.0 | 97.4 | 92.0 | 90.4 |\r\n\r\n![img-6.jpeg](img-6.jpeg)\r\n\r\nFigure 6: Performance of Titans and baselines on BABILong benchmark. Titans (MAC) outperforms all baselines, including extremely large models, e.g., GPT4.\r\nwe can see a significant drop in performance when increasing the sequence length; (3) Compared to DeltaNet, although it is capable of removing memory using delta rule, it cannot erase the memory, lacking forgetting mechanism. Finally, As expected we can see on par or better results when using Titans variants, where the best results correspond to MAC.\r\n\r\n# 5.4 BABILong Benchmark \r\n\r\nIn the previous section we discussed the results on a simple NIAH tasks where a single needle needs to be retrieved. Although Titans showed better performance compared to baselines, their true advantage over very long sequences is still hidden. To this end, in this section, we use a harder task from BABILong benchmark (Yuri Kuratov et al. 2024), in which the model needs to reason across facts distributed in extremely long documents. We follow the original experimental setup and training process in the benchmark. There are two settings: (1) Few-shot setting, in which we use large pre-trained models, and (2) fine-tuning setting, where we fine-tune the MAC variant of Titans to compare it with other fine-tuned baselines. The results for few-shot setting are reported in Figure 6a. In this setup, we can see Titans outperform all baselines-i.e., Mamba2.8B (Gu and Dao 2024), RWKV-6-7B (Peng, Goldstein, et al. 2024), RecurrentGemma-9B (Botev et al. 2024), Gemma-9B (Team et al. 2024), Llama3.1-8B (Touvron et al. 2023), GPT-4, and GPT4o-mini (Achiam et al. 2023). These results are achieved while Titans (MAC) is having much less number of parameters than baselines.\r\n\r\nIn the fine-tuning setup, we compare the small fine-tuned version of Titans (MAC) with: (i) the fine-tuned version of small models (almost the same number of parameters as Titans) such as Mamba (Gu and Dao 2024), RMT (Bulatov, Yury Kuratov, and Burtsev 2022), (ii) large models with Retrieval-Augmented Generation (RAG) (P. Lewis et al. 2020) such as Llama3.18B (Touvron et al. 2023), and (iii) extremely large models such as GPT-4 (Achiam et al. 2023), GPT4o-mini, Qwen2.5-72B (A. Yang et al. 2024), and Llama3.1-70B (Touvron et al. 2023). Baseline results are reported by (Yuri Kuratov et al. 2024). The results of Titans and baselines are reported in Figure 6b. Titans outperform all models even extremely large models like GPT4. Also, compared to Transformer-based with memory models like RMT, Titans show better performance mainly due to their powerful memory. That is, RMT compress the historical data into 16 size vector-valued memory, while Titans with in-context online memory learner are capable of encoding the past into the parameters of the model. Interestingly, even\r\nPAGE 15\r\n![img-7.jpeg](img-7.jpeg)\r\n\r\nFigure 7: The effect of memory depth on the perplexity. Deeper long-term memory results in better scaling in longer sequences.\r\n\r\nTable 3: Performance on long-term forecasting. The best results are highlighted.\r\n\r\n|  | Neural Memory |  | Simba |  | iTransformer |  | RLinear |  | PatchTST |  | Crossformer |  | TiDE |  | TimesNet |  | DLinear |  |\r\n| :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: | :--: |\r\n|  | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE | MSE | MAE |\r\n| ETTm1 | 0.358 | 0.387 | 0.383 | 0.396 | 0.407 | 0.410 | 0.414 | 0.407 | 0.387 | 0.400 | 0.513 | 0.496 | 0.419 | 0.419 | 0.400 | 0.406 | 0.403 | 0.407 |\r\n| ETTm2 | 0.261 | 0.309 | 0.271 | 0.327 | 0.288 | 0.332 | 0.286 | 0.327 | 0.281 | 0.326 | 0.757 | 0.610 | 0.358 | 0.404 | 0.291 | 0.333 | 0.350 | 0.401 |\r\n| ETTh1 | 0.420 | 0.421 | 0.441 | 0.432 | 0.454 | 0.447 | 0.446 | 0.434 | 0.469 | 0.454 | 0.529 | 0.522 | 0.541 | 0.507 | 0.458 | 0.450 | 0.456 | 0.452 |\r\n| ETTh2 | 0.356 | 0.382 | 0.361 | 0.391 | 0.383 | 0.407 | 0.374 | 0.398 | 0.387 | 0.407 | 0.942 | 0.684 | 0.611 | 0.530 | 0.414 | 0.427 | 0.559 | 0.515 |\r\n| ECL | 0.162 | 0.261 | 0.169 | 0.274 | 0.178 | 0.270 | 0.219 | 0.298 | 0.205 | 0.290 | 0.244 | 0.334 | 0.251 | 0.344 | 0.192 | 0.295 | 0.212 | 0.300 |\r\n| Traffic | 0.415 | 0.289 | 0.493 | 0.291 | 0.428 | 0.282 | 0.626 | 0.378 | 0.481 | 0.304 | 0.550 | 0.304 | 0.760 | 0.473 | 0.620 | 0.336 | 0.625 | 0.383 |\r\n| Weather | 0.231 | 0.265 | 0.255 | 0.280 | 0.258 | 0.278 | 0.272 | 0.291 | 0.259 | 0.281 | 0.259 | 0.315 | 0.271 | 0.320 | 0.259 | 0.287 | 0.265 | 0.317 |\r\n\r\naugmenting Llama3.1-8B model with RAG performs worse than Titans with about $\\times 70$ less parameters.\r\n\r\n# 5.5 The Effect of Deep Memory \r\n\r\nIn this section, we evaluate the effect of deep memory in both wall-clock training time and model performance ${ }^{2}$. To this end, we focus on different variants of our neural memory module, where $L_{M}=1,2,3,4$. We also use Mamba as a baseline for the model performance. For a fair comparison, we use the same training process for all models and train them on a subset of the Pile dataset (L. Gao et al. 2020).\r\n\r\nWe report the perplexity of our models and baselines as the function of the sequence length in Figure 7. Interestingly, with the increase of memory depth, $L_{M}$, the model can achieve better perplexity over all sequence length. Also, deeper memory modules are more robust to the sequence length when the model has less number of parameters. With the increase of the number of parameters, all models show better performance on longer sequences.\r\nWe also evaluate the effect of memory depth $\\left(L_{M}=1,2,3,4\\right)$ on the training throughput. We report the training throughput (the number of tokens per second) as the function of sequence length in Figure 8. All models scale linearly with respect to the context length (i.e., constant trend in the number of tokens per second with respect to sequence length). Also, by increasing the memory depth, as expected, we can see a linear trend that a deeper memory results in a slower training. Therefore, it is not always efficient to use deeper memory modules, showing a trade-off between effectiveness and efficiency.\r\n\r\n### 5.6 Time Series Forecasting\r\n\r\n![img-8.jpeg](img-8.jpeg)\r\n\r\nFigure 8: The effect of memory depth on training throughput\r\n\r\nTo show the effectiveness of our memory module in a broader tasks, we also evaluate its performance in time series forecasting tasks. To this end, we use Simba framework (Patro and Agneeswaran 2024) for time series forecasting, and\r\n\r\n[^0]\r\n[^0]:    ${ }^{2}$ Note that, in this experiment, we only focus on the neural memory module to evaluate the effect of memory depth in the memorization process. Combining neural memory with attention as we do in Titans variants, can additionally enhance the performance of the model over long sequences.\r\nPAGE 16\r\nTable 4: Downstream evaluation of pre-trained DNA models on GenomicsBenchmarks (Grešová et al. 2023). We report top-1 classification accuracy ( $\\%$ ).\r\n\r\n| Model | Enhancer Cohn | Enhancer Ens | Human Reg. | Non-TATA Promoters | Human OCR Ens. |\r\n| :-- | :--: | :--: | :--: | :--: | :--: |\r\n| CNN | 69.5 | 68.9 | 93.3 | 84.6 | 68.0 |\r\n| DNABERT | 74.0 | 85.7 | 88.1 | 85.6 | 75.1 |\r\n| GPT | 70.5 | 83.5 | 91.5 | 87.7 | 73.0 |\r\n| HyenaDNA | 74.2 | 89.2 | 93.8 | 96.6 | 80.9 |\r\n| Transformer++ | 73.4 | 89.5 | 89.9 | 94.4 | 79.5 |\r\n| Mamba | 73.0 | - | - | 96.6 | - |\r\n| Based | 74.6 | 89.5 | 89.5 | 96.8 | 79.0 |\r\n| Neural Memory Module | 75.2 | 89.6 | 89.3 | 96.6 | 79.9 |\r\n\r\nreplace its Mamba module with our neural memory. We report the results on common time series forecasting benchmark datasets-ETT, ECL, Traffic, and Weather (H. Zhou et al. 2021). The results are reported in Table 3. Our neural memory module is outperforming all baselines, including Mamba-based, linear-based, and Transformer-based architectures.\r\n\r\n# 5.7 DNA Modeling \r\n\r\nIn order to understand the capability of Titans beyond natural language, we further evaluate the performance of our neural memory module on DNA modeling tasks. To this end, we evaluate pre-trained models on the downstream tasks in GenomicsBenchmarks (Grešová et al. 2023). We follow the same experimental setups from Nguyen et al. (2024), and re-use the reported results of baselines by Arora et al. (2024). The performance of Titans (LMM) and baselines are reported in Table 4. We find that LMM is competitive with state-of-the-art architectures across different downstream genomics tasks.\r\n\r\n### 5.8 Efficiency\r\n\r\nIn this part, we compare the efficiency of our neural memory as well as Titans with state-of-the-art sequence models. The training throughput of models for different sequence length $\\times$ batch size are reported in Figure 9. Comparing recurrent models, including our neural memory module, we can see our memory module is slightly slower than Mamba2 and Gated DeltaNet, mainly due to: (1) having deep memory and more expressive transition process (memory update), and (2) highly optimized kernel in the implementation of Mamba2. Interestingly, Titans (MAL) are faster than baselines as well as the memory module. The main reason for this better throughput is the highly optimized kernel of FlashAttention (Dao 2024), which is used for implementing SWA and full attention module in Titans.\r\n![img-9.jpeg](img-9.jpeg)\r\n\r\nFigure 9: Training throughput comparison of Titans and baselines.\r\n\r\n### 5.9 Ablation Study\r\n\r\nFinally, we perform ablation studies on the different architectural choices in Titans. We consider our neural memory module as a base model and then changing one component at a time: (1) replacing deep memory with linear memory, removing (2) convolution, (3) momentum in the surprise measure, (4) weight decay (or forgot mechanism), and (5) persistent memory. The results are reported in Table 5. All components of neural memory design are positively contributing to its performance, where the greatest contribution comes from weight decay, momentum, convolution, and persistent memory, respectively.\r\n\r\nThe Effect of Architectural Design. To evaluate the effect of architecture design, we compare the performance of three represented variants of Titans in three aspects of (i) language modeling, (ii) commen-sense reasoning, and (iii) long context NIAH (BABILong) tasks. The results are reported in Table 5. We find that MAC and MAG have close performance in language modeling and common-sense reasoning tasks, while MAC achieve significantly better performance in long-context NIAH. Both of these models achieve better performance than MAL. These results along with Figure 9, show a trade-off between fast training and more expressive design.\r\nPAGE 17\r\nTable 5: Ablation Study on Titans. All components of Titans are positively contributing to its performance.\r\n\r\n| Model | Language Modeling <br> $\\mathrm{ppl} \\downarrow$ | Reasoning <br> $\\mathrm{acc} \\uparrow$ | Long Context <br> $\\mathrm{acc} \\uparrow$ |\r\n| :-- | :--: | :--: | :--: |\r\n| LMM | 27.01 | 47.83 | 92.68 |\r\n| +Attn (MAC) | 26.67 | 48.65 | 97.95 |\r\n| +Attn (MAG) | 25.70 | 48.60 | 96.70 |\r\n| +Attn (MAL) | 25.91 | 47.87 | 96.91 |\r\n| Linear Memory | 28.49 | 46.97 | 85.34 |\r\n| w/o Convolution | 28.73 | 45.82 | 90.28 |\r\n| w/o Momentum | 28.98 | 45.49 | 87.12 |\r\n| w/o Weight Decay | 29.04 | 45.11 | 85.60 |\r\n| w/o Persistent Memory | 27.63 | 46.35 | 92.49 |\r\n\r\n# 6 Conclusion \r\n\r\nIn this paper, we present a neural long-term memory that, as a meta in-context learner, learns to memorize at test time. The neural memory module is a recurrent model in nature, and is adaptively memorizing tokens that are more surprising or are close to surprising tokens. Comparing to modern recurrent models, it has more expressive memory update and storing mechanism. Using this memory, we present Titans architectures, and its three variants, in which we suggest to incorporate the memory module as (1) a context, (2) gating, and (3) a layer. Our experimental evaluation on diverse tasks tasks validate that Titans are more effective than Transformers and recent modern linear recurrent models, specifically for long context. That is, Titans can scale to larger than 2 M context window size with better accuracy than baselines.\r\nTitans are implemented in Pytorch and JAX and we intend to make the code we used to train and evaluate our models available soon.\r\nPAGE 18\r\n# References \r\n\r\n[1] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. \"Gpt-4 technical report\". In: arXiv preprint arXiv:2303.08774 (2023).\r\n[2] Yaroslav Aksenov, Nikita Balagansky, Sofia Maria Lo Cicero Vaina, Boris Shaposhnikov, Alexey Gorbatovski, and Daniil Gavrilov. \"Linear Transformers with Learnable Kernel Functions are Better In-Context Models\". In: arXiv preprint arXiv:2402.10644 (2024).\r\n[3] Marcin Andrychowicz, Misha Denil, Sergio Gomez, Matthew W Hoffman, David Pfau, Tom Schaul, Brendan Shillingford, and Nando De Freitas. \"Learning to learn by gradient descent by gradient descent\". In: Advances in neural information processing systems 29 (2016).\r\n[4] Cem Anil, Yuhuai Wu, Anders Andreassen, Aitor Lewkowycz, Vedant Misra, Vinay Ramasesh, Ambrose Slone, Guy Gur-Ari, Ethan Dyer, and Behnam Neyshabur. \"Exploring length generalization in large language models\". In: Advances in Neural Information Processing Systems 35 (2022), pp. 38546-38556.\r\n[5] Simran Arora, Sabri Eyuboglu, Michael Zhang, Aman Timalsina, Silas Alberti, James Zou, Atri Rudra, and Christopher Re. \"Simple linear attention language models balance the recall-throughput tradeoff\". In: Forty-first International Conference on Machine Learning. 2024. URL: https://openreview.net/forum?id=e93ffDcpH3.\r\n[6] Dzmitry Bahdanau. \"Neural machine translation by jointly learning to align and translate\". In: arXiv preprint arXiv:1409.0473 (2014).\r\n[7] Reza Bayat, Mohammad Pezeshki, Elvis Dohmatob, David Lopez-Paz, and Pascal Vincent. \"The Pitfalls of Memorization: When Memorization Hurts Generalization\". In: arXiv preprint arXiv:2412.07684 (2024).\r\n[8] Maximilian Beck, Korbinian Pöppel, Markus Spanring, Andreas Auer, Oleksandra Prudnikova, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter. \"xLSTM: Extended Long Short-Term Memory\". In: arXiv preprint arXiv:2405.04517 (2024).\r\n[9] Ali Behrouz, Michele Santacatterina, and Ramin Zabih. \"Mambamixer: Efficient selective state space models with dual token and channel selection\". In: arXiv preprint arXiv:2403.19888 (2024).\r\n[10] Vincent-Pierre Berges, Barlas Oğuz, Daniel Haziza, Wen-tau Yih, Luke Zettlemoyer, and Gargi Gosh. \"Memory Layers at Scale\". In: arXiv preprint arXiv:2412.09764 (2024).\r\n[11] Alberto Bietti, Vivien Cabannes, Diane Bouchacourt, Herve Jegou, and Leon Bottou. \"Birth of a transformer: A memory viewpoint\". In: Advances in Neural Information Processing Systems 36 (2024).\r\n[12] Yonatan Bisk, Rowan Zellers, Jianfeng Gao, Yejin Choi, et al. \"Piqa: Reasoning about physical commonsense in natural language\". In: Proceedings of the AAAI conference on artificial intelligence. Vol. 34. 05. 2020, pp. 7432-7439.\r\n[13] Aleksandar Botev, Soham De, Samuel L Smith, Anushan Fernando, George-Cristian Muraru, Ruba Haroun, Leonard Berrada, Razvan Pascanu, Pier Giuseppe Sessa, Robert Dadashi, et al. \"RecurrentGemma: Moving Past Transformers for Efficient Open Language Models\". In: arXiv preprint arXiv:2404.07839 (2024).\r\n[14] Léon Bottou and Vladimir Vapnik. \"Local learning algorithms\". In: Neural computation 4.6 (1992), pp. 888-900.\r\n[15] Aydar Bulatov, Yuri Kuratov, Yermek Kapushev, and Mikhail S Burtsev. \"Scaling transformer to 1m tokens and beyond with rmt\". In: arXiv preprint arXiv:2304.11062 (2023).\r\n[16] Aydar Bulatov, Yury Kuratov, and Mikhail Burtsev. \"Recurrent memory transformer\". In: Advances in Neural Information Processing Systems 35 (2022), pp. 11079-11091.\r\n[17] Edoardo Cetin, Qi Sun, Tianyu Zhao, and Yujin Tang. \"An Evolved Universal Transformer Memory\". In: arXiv preprint arXiv:2410.13166 (2024).\r\n[18] Beidi Chen, Tri Dao, Eric Winsor, Zhao Song, Atri Rudra, and Christopher Ré. \"Scatterbrain: Unifying sparse and low-rank attention\". In: Advances in Neural Information Processing Systems 34 (2021), pp. 17413-17426.\r\n[19] Krzysztof Marcin Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Quincy Davis, Afroz Mohiuddin, Lukasz Kaiser, David Benjamin Belanger, Lucy J Colwell, and Adrian Weller. \"Rethinking Attention with Performers\". In: International Conference on Learning Representations. 2021. URL: https://openreview.net/forum?id=Ua6zuk0WRH.\r\n[20] Christopher Clark, Kenton Lee, Ming-Wei Chang, Tom Kwiatkowski, Michael Collins, and Kristina Toutanova. \"BoolQ: Exploring the Surprising Difficulty of Natural Yes/No Questions\". In: Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers). Ed. by Jill Burstein, Christy Doran, and Thamar Solorio. Minneapolis, Minnesota: Association for Computational Linguistics, June 2019, pp. 2924-2936. DOI: 10.18653/v1/N19-1300. URL: https: //aclanthology.org/N19-1300/.\r\nPAGE 19\r\n[21] Peter Clark, Isaac Cowhey, Oren Etzioni, Tushar Khot, Ashish Sabharwal, Carissa Schoenick, and Oyvind Tafjord. \"Think you have solved question answering? try arc, the ai2 reasoning challenge\". In: arXiv preprint arXiv:1803.05457 (2018).\r\n[22] Nelson Cowan. \"What are the differences between long-term, short-term, and working memory?\" In: Progress in brain research 169 (2008), pp. 323-338.\r\n[23] Zihang Dai, Zhilin Yang, Yiming Yang, Jaime G. Carbonell, Quoc Viet Le, and Ruslan Salakhutdinov. \"TransformerXL: Attentive Language Models beyond a Fixed-Length Context\". In: ACL (1). Ed. by Anna Korhonen, David R. Traum, and Lluís Márquez. Association for Computational Linguistics, 2019, pp. 2978-2988. ISBN: 978-1-950737-48-2.\r\n[24] Tri Dao. \"FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning\". In: The Twelfth International Conference on Learning Representations. 2024. URL: https://openreview.net/forum?id=mZn2Xyh9Ec.\r\n[25] Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. \"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness\". In: Advances in Neural Information Processing Systems. Ed. by S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh. Vol. 35. Curran Associates, Inc., 2022, pp. 16344-16359. URL: https://proceedings.neurips.cc/paper_files/paper/2022/file/67d57c32e20fd0a7a302cb81d36e40d5-Paper-Conference.pdf.\r\n[26] Tri Dao and Albert Gu. \"Transformers are SSMs: Generalized models and efficient algorithms through structured state space duality\". In: arXiv preprint arXiv:2405.21060 (2024).\r\n[27] Abhimanyu Das, Weihao Kong, Andrew Leach, Shaan K Mathur, Rajat Sen, and Rose Yu. \"Long-term Forecasting with TiDE: Time-series Dense Encoder\". In: Transactions on Machine Learning Research (2023). ISSN: 2835-8856. URL: https://openreview.net/forum?id=pCbC3aQB5W.\r\n[28] Soham De, Samuel L Smith, Anushan Fernando, Aleksandar Botev, George Cristian-Muraru, Albert Gu, Ruba Haroun, Leonard Berrada, Yutian Chen, Srivatsan Srinivasan, et al. \"Griffin: Mixing gated linear recurrences with local attention for efficient language models\". In: arXiv preprint arXiv:2402.19427 (2024).\r\n[29] Juechu Dong, Boyuan Feng, Driss Guessous, Yanbo Liang, and Horace He. \"Flex Attention: A Programming Model for Generating Optimized Attention Kernels\". In: arXiv preprint arXiv:2412.05496 (2024).\r\n[30] Xin Dong, Yonggan Fu, Shizhe Diao, Wonmin Byeon, Zijia Chen, Ameya Sunil Mahabaleshwarkar, Shih-Yang Liu, Matthijs Van Keirsbilck, Min-Hung Chen, Yoshi Suhara, et al. \"Hymba: A Hybrid-head Architecture for Small Language Models\". In: arXiv preprint arXiv:2411.13676 (2024).\r\n[31] Stefan Elfwing, Eiji Uchibe, and Kenji Doya. \"Sigmoid-weighted linear units for neural network function approximation in reinforcement learning\". In: Neural networks 107 (2018), pp. 3-11.\r\n[32] Yukun Feng, Feng Li, Ziang Song, Boyuan Zheng, and Philipp Koehn. \"Learn to remember: Transformer with recurrent memory for document-level machine translation\". In: arXiv preprint arXiv:2205.01546 (2022).\r\n[33] Daniel Y Fu, Tri Dao, Khaled Kamal Saab, Armin W Thomas, Atri Rudra, and Christopher Re. \"Hungry Hungry Hippos: Towards Language Modeling with State Space Models\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview.net/forum?id=COZDy0WYGg.\r\n[34] Yossi Gandelsman, Yu Sun, Xinlei Chen, and Alexei Efros. \"Test-time training with masked autoencoders\". In: Advances in Neural Information Processing Systems 35 (2022), pp. 29374-29385.\r\n[35] Leo Gao, Stella Biderman, Sid Black, Laurence Golding, Travis Hoppe, Charles Foster, Jason Phang, Horace He, Anish Thite, Noa Nabeshima, et al. \"The pile: An 800gb dataset of diverse text for language modeling\". In: arXiv preprint arXiv:2101.00027 (2020).\r\n[36] Felix A Gers, Jürgen Schmidhuber, and Fred Cummins. \"Learning to forget: Continual prediction with LSTM\". In: Neural computation 12.10 (2000), pp. 2451-2471.\r\n[37] Alex Graves, Greg Wayne, and Ivo Danihelka. Neural Turing Machines. 2014. arXiv: 1410.5401 [cs.NE]. URL: https://arxiv.org/abs/1410.5401.\r\n[38] Klaus Greff, Rupesh K Srivastava, Jan Koutník, Bas R Steunebrink, and Jürgen Schmidhuber. \"LSTM: A search space odyssey\". In: IEEE transactions on neural networks and learning systems 28.10 (2016), pp. 2222-2232.\r\n[39] Katarína Grešová, Vlastimil Martinek, David Čechák, Petr Šimeček, and Panagiotis Alexiou. \"Genomic benchmarks: a collection of datasets for genomic sequence classification\". In: BMC Genomic Data 24.1 (2023), p. 25.\r\n[40] Albert Gu and Tri Dao. \"Mamba: Linear-Time Sequence Modeling with Selective State Spaces\". In: First Conference on Language Modeling. 2024. URL: https://openreview.net/forum?id=tEYskw1VY2.\r\n[41] Albert Gu, Karan Goel, and Christopher Re. \"Efficiently Modeling Long Sequences with Structured State Spaces\". In: International Conference on Learning Representations. 2022. URL: https : //openreview . net / forum?id= uYLFoz1v1AC.\r\nPAGE 20\r\n[42] Chi Han, Qifan Wang, Hao Peng, Wenhan Xiong, Yu Chen, Heng Ji, and Sinong Wang. \"LM-Infinite: Zero-Shot Extreme Length Generalization for Large Language Models\". In: Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers). Ed. by Kevin Duh, Helena Gomez, and Steven Bethard. Mexico City, Mexico: Association for Computational Linguistics, June 2024, pp. 3991-4008. DOI: 10.18653/v1/2024.naacl-long.222. URL: https://aclanthology. org/2024.naacl-long. 222.\r\n[43] Ramin Hasani, Mathias Lechner, Tsun-Hsuan Wang, Makram Chahine, Alexander Amini, and Daniela Rus. \"Liquid Structural State-Space Models\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview.net/forum?id=g40TKRKfS7R.\r\n[44] Zexue He, Leonid Karlinsky, Donghyun Kim, Julian McAuley, Dmitry Krotov, and Rogerio Feris. \"CAMELoT: Towards Large Language Models with Training-Free Consolidated Associative Memory\". In: arXiv preprint arXiv:2402.13449 (2024).\r\n[45] Donald Olding Hebb. The organization of behavior: A neuropsychological theory. Psychology press, 2005.\r\n[46] John J Hopfield. \"Neural networks and physical systems with emergent collective computational abilities.\" In: Proceedings of the national academy of sciences 79.8 (1982), pp. 2554-2558.\r\n[47] Kurt Hornik, Maxwell Stinchcombe, and Halbert White. \"Multilayer feedforward networks are universal approximators\". In: Neural networks 2.5 (1989), pp. 359-366.\r\n[48] Cheng-Ping Hsieh, Simeng Sun, Samuel Kriman, Shantanu Acharya, Dima Rekesh, Fei Jia, and Boris Ginsburg. \"RULER: What's the Real Context Size of Your Long-Context Language Models?\" In: First Conference on Language Modeling. 2024. URL: https://openreview.net/forum?id=kIoBbc76Sy.\r\n[49] DeLesley Hutchins, Imanol Schlag, Yuhuai Wu, Ethan Dyer, and Behnam Neyshabur. \"Block-recurrent transformers\". In: Advances in neural information processing systems 35 (2022), pp. 33248-33261.\r\n[50] Kazuki Irie, Róbert Csordás, and Jürgen Schmidhuber. \"The dual form of neural networks revisited: Connecting test time predictions to training patterns via spotlights of attention\". In: International Conference on Machine Learning. PMLR. 2022, pp. 9639-9659.\r\n[51] Kazuki Irie, Imanol Schlag, Róbert Csordás, and Jürgen Schmidhuber. \"Going beyond linear transformers with recurrent fast weight programmers\". In: Advances in neural information processing systems 34 (2021), pp. 7703-7717.\r\n[52] Vidit Jain and Erik Learned-Miller. \"Online domain adaptation of a pre-trained cascade of classifiers\". In: CVPR 2011. IEEE. 2011, pp. 577-584.\r\n[53] Albert Q Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lucile Saulnier, et al. \"Mistral 7B\". In: arXiv preprint arXiv:2310.06825 (2023).\r\n[54] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. \"PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels\". In: Forty-first International Conference on Machine Learning. 2024. URL: https://openreview.net/ forum?id=ghYrfdJfjK.\r\n[55] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. \"Scaling laws for neural language models\". In: arXiv preprint arXiv:2001.08361 (2020).\r\n[56] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. \"Transformers are rnns: Fast autoregressive transformers with linear attention\". In: International conference on machine learning. PMLR. 2020, pp. 5156-5165.\r\n[57] Urvashi Khandelwal, Omer Levy, Dan Jurafsky, Luke Zettlemoyer, and Mike Lewis. \"Generalization through Memorization: Nearest Neighbor Language Models\". In: International Conference on Learning Representations. 2020. URL: https://openreview.net/forum?id=HkIBjCEKvH.\r\n[58] Yuri Kuratov, Aydar Bulatov, Petr Anokhin, Ivan Rodkin, Dmitry Igorevich Sorokin, Artyom Sorokin, and Mikhail Burtsev. \"BABILong: Testing the Limits of LLMs with Long Context Reasoning-in-a-Haystack\". In: The Thirtyeight Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2024. URL: https : //openreview.net/forum?id=u7m2CG84BQ.\r\n[59] Hung Le, Truyen Tran, and Svetha Venkatesh. \"Self-attentive associative memory\". In: International conference on machine learning. PMLR. 2020, pp. 5682-5691.\r\n[60] Patrick Lewis, Ethan Perez, Aleksandra Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, et al. \"Retrieval-augmented generation for knowledge-intensive nlp tasks\". In: Advances in Neural Information Processing Systems 33 (2020), pp. 9459-9474.\r\nPAGE 21\r\n[61] Danny Leybzon and Corentin Kervadec. \"Learning, Forgetting, Remembering: Insights From Tracking LLM Memorization During Training\". In: Proceedings of the 7th BlackboxNLP Workshop: Analyzing and Interpreting Neural Networks for NLP. 2024, pp. 43-57.\r\n[62] Zhe Li, Shiyi Qi, Yiduo Li, and Zenglin Xu. \"Revisiting long-term time series forecasting: An investigation on linear mapping\". In: arXiv preprint arXiv:2305.10721 (2023).\r\n[63] Bo Liu, Rui Wang, Lemeng Wu, Yihao Feng, Peter Stone, and Qiang Liu. \"Longhorn: State space models are amortized online learners\". In: arXiv preprint arXiv:2407.14207 (2024).\r\n[64] Nelson F Liu, Kevin Lin, John Hewitt, Ashwin Paranjape, Michele Bevilacqua, Fabio Petroni, and Percy Liang. \"Lost in the middle: How language models use long contexts\". In: Transactions of the Association for Computational Linguistics 12 (2024), pp. 157-173.\r\n[65] Yong Liu, Tengge Hu, Haoran Zhang, Haixu Wu, Shiyu Wang, Lintao Ma, and Mingsheng Long. \"itransformer: Inverted transformers are effective for time series forecasting\". In: arXiv preprint arXiv:2310.06625 (2023).\r\n[66] George Mandler. \"The structure of value: Accounting for taste\". In: Affect and cognition. Psychology Press, 2014, pp. 3-36.\r\n[67] Harsh Mehta, Ankit Gupta, Ashok Cutkosky, and Behnam Neyshabur. \"Long Range Language Modeling via Gated State Spaces\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https : //openreview.net/forum?id=5MkYIYCbva.\r\n[68] Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. \"Pointer Sentinel Mixture Models\". In: International Conference on Learning Representations. 2017. URL: https://openreview.net/forum?id=Byj72udxe.\r\n[69] William Merrill, Jackson Petty, and Ashish Sabharwal. \"The Illusion of State in State-Space Models\". In: Forty-first International Conference on Machine Learning. 2024. URL: https://openreview.net/forum?id=QZgo9JZpLq.\r\n[70] Ravi Teja Mullapudi, Steven Chen, Keyi Zhang, Deva Ramanan, and Kayvon Fatahalian. \"Online model distillation for efficient video inference\". In: Proceedings of the IEEE/CVF International conference on computer vision. 2019, pp. 3573-3582.\r\n[71] Tsendsuren Munkhdalai, Manaal Faruqui, and Siddharth Gopal. \"Leave no context behind: Efficient infinite context transformers with infini-attention\". In: arXiv preprint arXiv:2404.07143 (2024).\r\n[72] Tsendsuren Munkhdalai, Alessandro Sordoni, Tong Wang, and Adam Trischler. \"Metalearned neural memory\". In: Advances in Neural Information Processing Systems 32 (2019).\r\n[73] Tsendsuren Munkhdalai and Hong Yu. \"Neural semantic encoders\". In: Proceedings of the conference. Association for Computational Linguistics. Meeting. Vol. 1. NIH Public Access. 2017, p. 397.\r\n[74] Eric Nguyen, Michael Poli, Marjan Faizi, Armin Thomas, Michael Wornow, Callum Birch-Sykes, Stefano Massaroli, Aman Patel, Clayton Rabideau, Yoshua Bengio, et al. \"Hyenadna: Long-range genomic sequence modeling at single nucleotide resolution\". In: Advances in neural information processing systems 36 (2024).\r\n[75] A Nichol. \"On first-order meta-learning algorithms\". In: arXiv preprint arXiv:1803.02999 (2018).\r\n[76] Yuqi Nie, Nam H Nguyen, Phanwadee Sinthong, and Jayant Kalagnanam. \"A time series is worth 64 words: Long-term forecasting with transformers\". In: arXiv preprint arXiv:2211.14730 (2022).\r\n[77] Hideyuki Okano, Tomoo Hirano, and Evan Balaban. \"Learning and memory\". In: Proceedings of the National Academy of Sciences 97.23 (2000), pp. 12403-12404.\r\n[78] Antonio Orvieto, Samuel L Smith, Albert Gu, Anushan Fernando, Caglar Gulcehre, Razvan Pascanu, and Soham De. \"Resurrecting recurrent neural networks for long sequences\". In: International Conference on Machine Learning. PMLR. 2023, pp. 26670-26698.\r\n[79] Denis Paperno, Germán Kruszewski, Angeliki Lazaridou, Ngoc Quan Pham, Raffaella Bernardi, Sandro Pezzelle, Marco Baroni, Gemma Boleda, and Raquel Fernández. \"The LAMBADA dataset: Word prediction requiring a broad discourse context\". In: Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Ed. by Katrin Erk and Noah A. Smith. Berlin, Germany: Association for Computational Linguistics, Aug. 2016, pp. 1525-1534. DOI: 10.18653/v1/P16-1144. URL: https://aclanthology.org/P16-1144/.\r\n[80] Badri N. Patro and Vijay S. Agneeswaran. SiMBA: Simplified Mamba-Based Architecture for Vision and Multivariate Time series. 2024. arXiv: 2403.15360 [cs.CV].\r\n[81] Guilherme Penedo, Hynek Kydliček, Loubna Ben allal, Anton Lozhkov, Margaret Mitchell, Colin Raffel, Leandro Von Werra, and Thomas Wolf. \"The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale\". In: The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2024. URL: https://openreview.net/forum?id=n6SCkn2QaG.\r\n[82] Bo Peng. RWKV-LM. Version 1.0.0. Aug. 2021. DOI: 10.5281 / zenodo. 5196577. URL: https://github.com/ BlinkDL/RWKV-LM.\r\nPAGE 22\r\n[83] Bo Peng, Eric Alcaide, Quentin Gregory Anthony, Alon Albalak, Samuel Arcadinho, Stella Biderman, Huanqi Cao, Xin Cheng, Michael Nguyen Chung, Leon Derczynski, Xingjian Du, Matteo Grella, Kranthi Kiran GV, Xuzheng He, Haowen Hou, Przemyslaw Kazienko, Jan Kocon, Jiaming Kong, Bartłomiej Koptyra, Hayden Lau, Jiaju Lin, Krishna Sri Ipsit Mantri, Ferdinand Mom, Atsushi Saito, Guangyu Song, Xiangru Tang, Johan S. Wind, Stanisław Woźniak, Zhenyuan Zhang, Qinghua Zhou, Jian Zhu, and Rui-Jie Zhu. \"RWKV: Reinventing RNNs for the Transformer Era\". In: The 2023 Conference on Empirical Methods in Natural Language Processing. 2023. URL: https://openreview. net/forum?id=7SaXcza8pG.\r\n[84] Bo Peng, Daniel Goldstein, Quentin Anthony, Alon Albalak, Eric Alcaide, Stella Biderman, Eugene Cheah, Xingjian Du, Teddy Ferdinan, Haowen Hou, et al. \"Eagle and finch: Rwkv with matrix-valued states and dynamic recurrence\". In: arXiv preprint arXiv:2404.05892 (2024).\r\n[85] DL Prados and SC Kak. \"Neural network capacity using delta rule\". In: Electronics Letters 25.3 (1989), pp. 197-199.\r\n[86] Zhen Qin, Yiran Zhong, and Hui Deng. \"Exploring Transformer Extrapolation\". In: Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. 17. 2024, pp. 18897-18905.\r\n[87] Liliang Ren, Yang Liu, Yadong Lu, Yelong Shen, Chen Liang, and Weizhu Chen. \"Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling\". In: arXiv preprint arXiv:2406.07522 (2024).\r\n[88] Ivan Rodkin, Yuri Kuratov, Aydar Bulatov, and Mikhail Burtsev. \"Associative recurrent memory transformer\". In: arXiv preprint arXiv:2407.04841 (2024).\r\n[89] Aurko Roy, Mohammad Saffar, Ashish Vaswani, and David Grangier. \"Efficient content-based sparse attention with routing transformers\". In: Transactions of the Association for Computational Linguistics 9 (2021), pp. 53-68.\r\n[90] Keisuke Sakaguchi, Ronan Le Bras, Chandra Bhagavatula, and Yejin Choi. \"Winogrande: An adversarial winograd schema challenge at scale\". In: Communications of the ACM 64.9 (2021), pp. 99-106.\r\n[91] Maarten Sap, Hannah Rashkin, Derek Chen, Ronan Le Bras, and Yejin Choi. \"Social IQa: Commonsense Reasoning about Social Interactions\". In: Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP). Ed. by Kentaro Inui, Jing Jiang, Vincent Ng, and Xiaojun Wan. Hong Kong, China: Association for Computational Linguistics, Nov. 2019, pp. 4463-4473. DOI: 10.18653/v1/D19-1454. URL: https://aclanthology.org/D19-1454/.\r\n[92] Imanol Schlag, Kazuki Irie, and Jürgen Schmidhuber. \"Linear transformers are secretly fast weight programmers\". In: International Conference on Machine Learning. PMLR. 2021, pp. 9355-9366.\r\n[93] JH Schmidhuber. \"Learning to control fast-weight memories: An alternative to recurrent nets. Accepted for publication in\". In: Neural Computation (1992).\r\n[94] Jürgen Schmidhuber. \"Reducing the ratio between learning complexity and number of time varying variables in fully recurrent nets\". In: ICANN'93: Proceedings of the International Conference on Artificial Neural Networks Amsterdam, The Netherlands 13-16 September 1993 3. Springer. 1993, pp. 460-463.\r\n[95] Jürgen Schmidhuber and Sepp Hochreiter. \"Long Short-term Memory\". In: Neural Computation MIT-Press (1997).\r\n[96] Avi Schwarzschild, Zhili Feng, Pratyush Maini, Zachary C Lipton, and J Zico Kolter. \"Rethinking llm memorization through the lens of adversarial compression\". In: arXiv preprint arXiv:2404.15146 (2024).\r\n[97] Jimmy T.H. Smith, Andrew Warrington, and Scott Linderman. \"Simplified State Space Layers for Sequence Modeling\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview. net/forum? id=AiBHw3AXqks.\r\n[98] Robin Staab, Mark Vero, Mislav Balunovic, and Martin Vechev. \"Beyond Memorization: Violating Privacy via Inference with Large Language Models\". In: The Twelfth International Conference on Learning Representations. 2024. URL: https://openreview. net/forum?id=kmn0BhQk7p.\r\n[99] Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. \"Augmenting selfattention with persistent memory\". In: arXiv preprint arXiv:1907.01470 (2019).\r\n[100] Sainbayar Sukhbaatar, Jason Weston, Rob Fergus, et al. \"End-to-end memory networks\". In: Advances in neural information processing systems 28 (2015).\r\n[101] Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, et al. \"Learning to (learn at test time): Rnns with expressive hidden states\". In: arXiv preprint arXiv:2407.04620 (2024).\r\n[102] Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. \"Retentive network: A successor to transformer for large language models\". In: arXiv preprint arXiv:2307.08621 (2023).\r\n[103] Gemma Team, Thomas Mesnard, Cassidy Hardin, Robert Dadashi, Surya Bhupatiraju, Shreya Pathak, Laurent Sifre, Morgane Rivière, Mihir Sanjay Kale, Juliette Love, et al. \"Gemma: Open models based on gemini research and technology\". In: arXiv preprint arXiv:2403.08295 (2024).\r\nPAGE 23\r\n[104] W Scott Terry. Learning and memory: Basic principles, processes, and procedures. Routledge, 2017.\r\n[105] Matteo Tiezzi, Michele Casoni, Alessandro Betti, Tommaso Guidi, Marco Gori, and Stefano Melacci. \"On the resurgence of recurrent models for long sequences: Survey and research opportunities in the transformer era\". In: arXiv preprint arXiv:2402.08132 (2024).\r\n[106] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, et al. \"Llama: Open and efficient foundation language models\". In: arXiv preprint arXiv:2302.13971 (2023).\r\n[107] Jos Van Der Westhuizen and Joan Lasenby. \"The unreasonable effectiveness of the forget gate\". In: arXiv preprint arXiv:1804.04849 (2018).\r\n[108] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. \"Attention is All you Need\". In: Advances in Neural Information Processing Systems. Ed. by I. Guyon, U. Von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett. Vol. 30. Curran Associates, Inc., 2017. URL: https : / / proceedings . neurips . cc / paper_files / paper / 2017 / file / 3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.\r\n[109] Shida Wang. \"LongSSM: On the Length Extension of State-space Models in Language Modelling\". In: arXiv preprint arXiv:2406.02080 (2024).\r\n[110] Yu Wang, Yifan Gao, Xiusi Chen, Haoming Jiang, Shiyang Li, Jingfeng Yang, Qingyu Yin, Zheng Li, Xian Li, Bing Yin, Jingbo Shang, and Julian McAuley. \"MEMORYLLM: Towards Self-Updatable Large Language Models\". In: Forty-first International Conference on Machine Learning. 2024. URL: https: //openreview. net/forum?id=p01KWzdikQ.\r\n[111] Yu Wang, Chi Han, Tongtong Wu, Xiaoxin He, Wangchunshu Zhou, Nafis Sadeq, Xiusi Chen, Zexue He, Wei Wang, Gholamreza Haffari, et al. \"Towards LifeSpan Cognitive Systems\". In: arXiv preprint arXiv:2409.13265 (2024).\r\n[112] Zhiwei Wang, Yao Ma, Zitao Liu, and Jiliang Tang. \"R-transformer: Recurrent neural network enhanced transformer\". In: arXiv preprint arXiv:1907.05572 (2019).\r\n[113] Jason Weston, Sumit Chopra, and Antoine Bordes. \"Memory networks\". In: arXiv preprint arXiv:1410.3916 (2014).\r\n[114] Bernard Widrow and Marcian E Hoff. \"Adaptive switching circuits\". In: Neurocomputing: foundations of research. 1988, pp. 123-134.\r\n[115] Ronald J Williams and David Zipser. \"A learning algorithm for continually running fully recurrent neural networks\". In: Neural computation 1.2 (1989), pp. 270-280.\r\n[116] Daniel B Willingham. \"Systems of memory in the human brain\". In: Neuron 18.1 (1997), pp. 5-8.\r\n[117] Chao-Yuan Wu, Christoph Feichtenhofer, Haoqi Fan, Kaiming He, Philipp Krahenbuhl, and Ross Girshick. \"Longterm feature banks for detailed video understanding\". In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019, pp. 284-293.\r\n[118] Haixu Wu, Tengge Hu, Yong Liu, Hang Zhou, Jianmin Wang, and Mingsheng Long. \"TimesNet: Temporal 2DVariation Modeling for General Time Series Analysis\". In: The Eleventh International Conference on Learning Representations. 2023. URL: https://openreview. net/forum?id=ju_Uqw3840q.\r\n[119] Qingyang Wu, Zhenzhong Lan, Kun Qian, Jing Gu, Alborz Geramifard, and Zhou Yu. \"Memformer: A memoryaugmented transformer for sequence modeling\". In: arXiv preprint arXiv:2010.06891 (2020).\r\n[120] Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. \"Efficient Streaming Language Models with Attention Sinks\". In: The Twelfth International Conference on Learning Representations. 2024. URL: https: //openreview. net/forum?id=NG7sS51zVF.\r\n[121] An Yang, Baosong Yang, Beichen Zhang, Binyuan Hui, Bo Zheng, Bowen Yu, Chengyuan Li, Dayiheng Liu, Fei Huang, Haoran Wei, et al. \"Qwen2. 5 Technical Report\". In: arXiv preprint arXiv:2412.15115 (2024).\r\n[122] Songlin Yang, Jan Kautz, and Ali Hatamizadeh. \"Gated Delta Networks: Improving Mamba2 with Delta Rule\". In: arXiv preprint arXiv:2412.06464 (2024).\r\n[123] Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, and Yoon Kim. \"Gated Linear Attention Transformers with Hardware-Efficient Training\". In: Forty-first International Conference on Machine Learning. 2024. URL: https: //openreview. net/forum?id=ia5XvxFUJ7.\r\n[124] Songlin Yang, Bailin Wang, Yu Zhang, Yikang Shen, and Yoon Kim. \"Parallelizing Linear Transformers with the Delta Rule over Sequence Length\". In: The Thirty-eighth Annual Conference on Neural Information Processing Systems. 2024. URL: https://openreview. net/forum?id=y8Rm4VNRPH.\r\n[125] Luca Zancato, Arjun Seshadri, Yonatan Dukler, Aditya Golatkar, Yantao Shen, Benjamin Bowman, Matthew Trager, Alessandro Achille, and Stefano Soatto. \"B'MOJO: Hybrid State Space Realizations of Foundation Models with Eidetic and Fading Memory\". In: The Thirty-eighth Annual Conference on Neural Information Processing Systems. 2024. URL: https://openreview. net/forum?id=RnQdRY1h5v.\r\nPAGE 24\r\n[126] Rowan Zellers, Ari Holtzman, Yonatan Bisk, Ali Farhadi, and Yejin Choi. \"HellaSwag: Can a Machine Really Finish Your Sentence?\" In: Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics. Ed. by Anna Korhonen, David Traum, and Lluís Márquez. Florence, Italy: Association for Computational Linguistics, July 2019, pp. 4791-4800. DOI: 10.18653/v1/P19-1472. URL: https://aclanthology.org/P19-1472/.\r\n[127] Ailing Zeng, Muxi Chen, Lei Zhang, and Qiang Xu. \"Are transformers effective for time series forecasting?\" In: Proceedings of the AAAI conference on artificial intelligence. Vol. 37. 2023, pp. 11121-11128.\r\n[128] Hao Zhang, Alexander C Berg, Michael Maire, and Jitendra Malik. \"SVM-KNN: Discriminative nearest neighbor classification for visual category recognition\". In: 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06). Vol. 2. IEEE. 2006, pp. 2126-2136.\r\n[129] Jianyu Zhang, Niklas Nolte, Ranajoy Sadhukhan, Beidi Chen, and Léon Bottou. \"Memory Mosaics\". In: arXiv preprint arXiv:2405.06394 (2024).\r\n[130] Yunhao Zhang and Junchi Yan. \"Crossformer: Transformer utilizing cross-dimension dependency for multivariate time series forecasting\". In: The eleventh international conference on learning representations. 2023.\r\n[131] Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang. \"Informer: Beyond efficient transformer for long sequence time-series forecasting\". In: Proceedings of the AAAI conference on artificial intelligence. Vol. 35. 12. 2021, pp. 11106-11115.\r\n[132] Luisa Zintgraf, Kyriacos Shiarli, Vitaly Kurin, Katja Hofmann, and Shimon Whiteson. \"Fast context adaptation via meta-learning\". In: International Conference on Machine Learning. PMLR. 2019, pp. 7693-7702.\r\nPAGE 25\r\n# A Related Work \r\n\r\nThere are diverse perspectives that can independently lead to the design of Titans or its components. Accordingly, to further situate our work in a broader context, we review three categories of studies:\r\n\r\n## A. 1 Linear Recurrent Models\r\n\r\nRecently, to address the computational cost of Transformers in both training and inference, linear recurrent models have attracted much attention (Tiezzi et al. 2024), mainly due to their fast inference and training. The first generation of models-such as RetNet (Yutao Sun et al. 2023), LRU (Orvieto et al. 2023), RWKV (Peng, Alcaide, et al. 2023), S5 (J. T. Smith, Warrington, and Linderman 2023), and S4 (Gu, Goel, and Re 2022)-uses data-independent transition matrix/decay mechanism. The second generation of such models started to incorporate gating mechanism, a widely used techniques in traditional RNNs (Gers, Jürgen Schmidhuber, and Cummins 2000; Greff et al. 2016; Van Der Westhuizen and Lasenby 2018), into such linear architectures-e.g., Griffin (De et al. 2024), SSMs (Behrouz, Santacatterina, and Zabih 2024; Dao and Gu 2024; Gu and Dao 2024; Hasani et al. 2023), RWKV6 (Peng, Goldstein, et al. 2024). The third generation of linear recurrent models are based on more complex memory updating rule based on meta-learning, online learning, and/or delta-rule, resulting in more expressive and effective models such as: Longhorn (B. Liu et al. 2024), Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024), TTT (Yu Sun et al. 2024), and DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024). Our LMM model can be seen as the next generation of such models, in which we incorporate the token flow into the memory updating mechanism, having more powerful memory updating process. See Appendix C for a detailed discussion of different recurrent models and Titans.\r\n\r\n## A. 2 Transformer-based Architectures\r\n\r\nTransformers. Transformers (Vaswani et al. 2017) as the de facto backbone for many deep learning models are based on attention mechanism (Bahdanau 2014). They, however, suffer from quadratic computational cost, limiting their ability to scale to long context window. To improve the memory consumption and throughput of softmax attention for longer sequences, various studies focused on I/O aware implementations of attention (Dao 2024; Dao, D. Fu, et al. 2022), designing more efficient attention mechanisms by sparsifying the attention matrix (B. Chen et al. 2021; Choromanski et al. 2021; Dai et al. 2019; J. Dong et al. 2024; Roy et al. 2021), approximating the softmax (Arora et al. 2024), or developing kernel-based (linear) attentions (Aksenov et al. 2024; Kacham, Mirrokni, and P. Zhong 2024; Schlag, Irie, and Jürgen Schmidhuber 2021; S. Yang, B. Wang, Shen, et al. 2024).\r\n\r\nSegment-based Transformers. Another line of research to improve the efficiency of Transformers is segment-based or Chunk Transformers (Dai et al. 2019). The main drawback of chunk Transformers is that segments are fully separated and so the context window is limited to the length of the chunks. To address this issue, various studies discuss the importance of a memory so it can help the model to transfer information across chunks (Bulatov, Yuri Kuratov, et al. 2023; Bulatov, Yury Kuratov, and Burtsev 2022; Feng et al. 2022; Hutchins et al. 2022; Rodkin et al. 2024; Z. Wang et al. 2019; Q. Wu et al. 2020; Zancato et al. 2024). The key differences of Titans with these models are: (1) The memory in such models are simple small size vectors, lacking expressive power to compress complex information; (2) The memory module lacks forget mechanism, leading to a fast memory overflow; (3) only focus on momentary surprise, missing the information flow. More specifically, recalling Recurrent Memory Transformers (RMT) (Bulatov, Yuri Kuratov, et al. 2023; Bulatov, Yury Kuratov, and Burtsev 2022; Rodkin et al. 2024), one can treat Titans (MAC) as the generalization of RMT, where we use a neural memory module instead of a vector-valued small size memory.\r\n\r\nMemory for Large Language Models. Another interesting research direction has been to incorporate external memory modules to LLMs after training (Z. He et al. 2024; Khandelwal et al. 2020; Y. Wang, Y. Gao, et al. 2024). Such models are different from our approach as we incorporate the memory as a part of initial architecture and so we train it in an end-to-end manner. Also, most of these explicit memory modules suffer from the same limitations as chunk-based Transformers (mentioned above). For a detailed discussion of such models, we refer to the recent study of Y. Wang, Han, et al. (2024).\r\nPAGE 26\r\n# A. 3 Test Time Training and Fast Weight Programs \r\n\r\nMemory Design and Augmentation with Memory. In the literature, a substantial research effort have been toward designing memory modules that are capable of either memorizing the knowledge abstraction (e.g., persistent memory) (Sukhbaatar, Grave, et al. 2019), or memorizing the data-dependent information (also known as contextual memory), through recurrence (Bulatov, Yury Kuratov, and Burtsev 2022; Rodkin et al. 2024; Zancato et al. 2024), Transformers (Berges et al. 2024; Cetin et al. 2024; Feng et al. 2022; Le, Tran, and Venkatesh 2020; Munkhdalai, Faruqui, and Gopal 2024; J. Zhang et al. 2024), gradient (Irie, Csordás, and Jürgen Schmidhuber 2022; Munkhdalai, Sordoni, et al. 2019), or other learning paradigms (Sukhbaatar, Weston, Fergus, et al. 2015; Weston, Chopra, and Bordes 2014). These memory models, however, either (1) are based on momentary surprise, missing the data flow and events, (2) lack forget mechanisms to remove the memory, leading to a fast memory overflow (3) are fixed-size shallow (matrix valued) memory, resulting in poor performance in long context, and (4) are based on fixed parameters at test time, lacking test time adaption.\r\n\r\nFast Weight Programs. The idea of seeing linear layers as the key-value (associative) memory system backs to fast weight programs, in which dynamic fast programs are incorporated into recurrent neural networks to serve as writable memory (Schlag, Irie, and Jürgen Schmidhuber 2021; JH Schmidhuber 1992; Jürgen Schmidhuber 1993). The two learning rules of Hebbian (Hebb 2005) and delta (Prados and Kak 1989) are the most popular learning rules for fast weight programs, which have been extensively explored in various studies (Irie, Schlag, et al. 2021; Munkhdalai, Sordoni, et al. 2019; Munkhdalai and H. Yu 2017; Schlag, Irie, and Jürgen Schmidhuber 2021; JH Schmidhuber 1992; S. Yang, Kautz, and Hatamizadeh 2024; S. Yang, B. Wang, Yu Zhang, et al. 2024). All these models, however, are based on momentary surprise, missing the token flow in the sequences (see Section 3.1), and most of them lacks a forgetting gate, resulting in a poor memory management.\r\n\r\nTest Time Training. The key ideas of learning at test time or learning to learn (i.e., (Andrychowicz et al. 2016)) backs to very early studies on local learning Bottou and Vapnik 1992, in which each test data sample is trained on its neighbors before making a prediction (Gandelsman et al. 2022; H. Zhang et al. 2006). This approach further has shown promising performance in vision tasks (Jain and Learned-Miller 2011; Mullapudi et al. 2019), mostly due to their ability to mitigate out-of-distribution samples. The most similar studies to ours in this direction are MNM (Munkhdalai, Sordoni, et al. 2019) and TTT-layer (Yu Sun et al. 2024), which we discussed the key differences in Appendix C.\r\n\r\n## B Language Modeling and Common-sense Reasoning Datasets\r\n\r\nFollowing recent studies on linear recurrent models (Dao and Gu 2024; S. Yang, Kautz, and Hatamizadeh 2024; S. Yang, B. Wang, Yu Zhang, et al. 2024), we use Wikitext (Merity et al. 2017), LMB (Paperno et al. 2016), PIQA (Bisk et al. 2020), HellaSwag (Zellers et al. 2019), WinoGrande (Sakaguchi et al. 2021), ARC-easy (ARC-e) and ARC-challenge (ARC-c) (P. Clark et al. 2018), SIQA (Sap et al. 2019), and BoolQ (C. Clark et al. 2019). Also, the baselines results for 400M models are from the reported results by S. Yang, Kautz, and Hatamizadeh (2024).\r\n\r\n## C Long-term Memory Module (LMM) as a Sequence Model\r\n\r\nIn this section, we discuss how LMM as a sequence model is connected to modern linear recurrent models. For the sake of simplicity, we start with a linear memory, where $\\mathcal{M}_{t}=W_{t} \\in \\mathbb{R}^{d_{m} \\times d_{m}}$. In this case, our objective function becomes $\\ell\\left(\\mathcal{M} ; x_{t}\\right)=\\frac{1}{2}\\left\\|\\mathcal{M}_{t} \\mathbf{k}_{t}-\\mathbf{v}_{t}\\right\\|_{2}^{2}$, in which we use gradient descent with momentum and weight decay for the optimization. Accordingly, revisiting the recurrent formula in Equation 13:\r\n\r\n$$\r\n\\begin{aligned}\r\n& \\mathcal{M}_{t}=\\operatorname{diag}\\left(1-\\alpha_{t}\\right) \\mathcal{M}_{t}+S_{t} \\\\\r\n& S_{t}=\\operatorname{diag}\\left(\\eta_{t}\\right) S_{t-1}-\\operatorname{diag}\\left(\\theta_{t}\\right)\\left(\\mathcal{M}_{t-1} \\mathbf{k}_{t}^{\\top} \\mathbf{k}_{t}-\\mathbf{v}_{t}^{\\top} \\mathbf{k}_{t}\\right)\r\n\\end{aligned}\r\n$$\r\n\r\nLMM is Generalized Gated DeltaNet. As discussed by S. Yang, Kautz, and Hatamizadeh (2024), DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024) can alternatively be interpreted as an online learning problem that optimizes the $\\mathcal{L}=\\frac{1}{2}\\left\\|\\mathbf{S}_{t} \\mathbf{k}_{t}-\\mathbf{v}_{t}\\right\\|_{2}^{2}$, resulting in:\r\n\r\n$$\r\n\\mathbf{S}_{t+1}=\\mathbf{S}_{t}-\\theta_{t} \\nabla \\mathcal{L}=\\mathbf{S}_{t}\\left(\\mathbf{I}-\\theta_{t} \\mathbf{k}_{t} \\mathbf{k}_{t}^{\\top}\\right)+\\theta_{t} \\mathbf{v}_{t} \\mathbf{k}_{t}^{\\top}\r\n$$\r\nPAGE 27\r\nIn this formulation, Gated DeltaNet is the same as above but with an additional weight decay term (S. Yang, Kautz, and Hatamizadeh 2024). Comparing Equation 32 and Equation 34, we can see that setting $\\eta_{t}=0$ results in both formulations to be equivalent. Accordingly, we can say LMM is generalizing the very recent study of Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024) from three aspects:\r\n\r\n- Momentum-based Rule: The Delta Rule is based on momentary surprise, meaning that the flow of tokens cannot affect the memory update rule. LMM, however, is based on a momentum rule, which consider both past and momentary surprise.\r\n- Deep Memory: While Gated DeltaNet is limited to a linear (matrix-valued) memory as it requires finding the closed recurrence form, LMM allows using deep memory module by using a gradient-based formulation, resulting in higher expressive power.\r\n- Non-Linear Recurrence: While DeltaNet and Gated DeltaNet are based on linear recurrence, our LMM is using inter-chunk non-linear recurrence and intra-chunk linear recurrence. This design allows LMM having a higher expressive power.\r\n\r\nHere, we discussed Gated DeltaNet as a sample of recent generation of recurrent models. Similar approaches such as RWKV-7 (Peng 2021) are also using the same formulation and loss function, and so LMM is generalizing all such models.\r\n\r\nLMM is Generalized Longhorn. Similar to DeltaNet, Longhorn (B. Liu et al. 2024) uses the same loss function but it derives the closed form using implicit online learning:\r\n\r\n$$\r\n\\mathbf{S}_{t+1}=\\mathbf{S}_{t}\\left(\\mathbf{I}-\\delta_{t} \\mathbf{k}_{t} \\mathbf{k}_{t}^{\\top}\\right)+\\delta_{t} \\mathbf{v}_{t} \\mathbf{k}_{t}^{\\top}\r\n$$\r\n\r\nwhere $\\delta_{t}=\\frac{\\theta_{t}}{1+\\theta_{t} \\mathbf{k}_{t} \\mathbf{k}_{t}}$. It, however, lacks a forgetting gate, resulting in a faster memory overflow. Therefore, in addition two the abovementioned aspects of (1) Momentum-based Rule, (2) Deep Memory, and (3) Non-Linear Recurrence, LMM has the advantage of using an additional (4) Forget Gate, leading to a better memory management.\r\n\r\nLMM is Generalized TTT Layer. To the best of our knowledge, TTT (Yu Sun et al. 2024), is the only modern linear recurrent models with a gradient-based updating rule. In addition to different architectural designs and also objective functions, our LMM has three key differences with presented TTT layers (Yu Sun et al. 2024):\r\n\r\n1. Forgetting Mechanism: TTT layers are updating memory at each time, without having the chance to forget the past data. Accordingly, when fixing the memory size, the model cannot manage the memory for long sequences. A forget mechanism, such as LMM's, allows clearing the memory when very past information is not needed anymore. We show that in a general case, this forget mechanism is equivalent to weight decay and provide a fast method to incorporate it into the parallel training.\r\n2. Momentum-based Update Rule: TTT layers are based on momentary surprise, meaning that the flow of tokens cannot affect the memory update rule. LMM, however, is based on a momentum rule, which consider both past and momentary surprise. See Section 3.1 for the motivation of this design.\r\n3. Deep Memory: While TTT-layers allows for deeper memory, the advantages/disadvantages of such deeper memory modules have not been experimentally evaluated.\r\n\r\nTo the best of our knowledge, our neural long-term memory module is the first linear recurrent model with momentumbased update rule.\r\n\r\nFinally, as a key difference with all the above and other recent linear recurrent studies, note that the hybrid variants of modern linear models-such as Griffin (De et al. 2024), DeltaNet (S. Yang, B. Wang, Yu Zhang, et al. 2024), Gated DeltaNet (S. Yang, Kautz, and Hatamizadeh 2024), H3 (D. Y. Fu et al. 2023), Mamba2 (Dao and Gu 2024), Samba (Ren et al. 2024), etc.-all are based on sequential layer-wise design. We present Titans to show how effectively one can incorporate such memory modules into an architecture."
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"nested-learning\"\nversion = \"0.2.0\"\ndescription = \"Reproduction of Google's Nested Learning (HOPE) architecture\"\nlicense = {text = \"Apache-2.0\"}\nauthors = [\n  {name = \"Nested Learning Team\", email = \"nested-learning@example.com\"}\n]\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\ndependencies = [\n  \"torch>=2.9,<3\",\n  \"einops>=0.7.0\",\n  \"numpy>=1.26\",\n  \"hydra-core>=1.3.2\",\n  \"omegaconf>=2.3.0\",\n  \"pyyaml>=6.0\",\n  \"tqdm>=4.66\",\n  \"typing-extensions>=4.9\",\n  \"datasets>=2.19,<3.0\",\n  \"sentencepiece>=0.2.0\",\n  \"huggingface-hub>=0.23,<1.0\",\n  \"zstandard>=0.22.0\",\n  \"langdetect>=1.0.9\",\n  \"typer>=0.12\",\n]\n\n[project.scripts]\nnl = \"nested_learning.cli:app\"\n\n[project.optional-dependencies]\ngpu = [\n  \"torchvision>=0.24,<1\",\n  \"torchaudio>=2.9,<3\",\n]\nlogging = [\n  \"wandb>=0.18.0\",\n]\nviz = [\n  \"matplotlib>=3.8\",\n]\ndev = [\n  \"pytest>=7.4\",\n  \"pytest-cov>=4.1\",\n  \"ruff>=0.6.8\",\n  \"mypy>=1.11\",\n  \"types-PyYAML\",\n]\n\n[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n\n[tool.hatch.build.targets.wheel]\npackages = [\"src/nested_learning\"]\n\n[tool.hatch.build.targets.wheel.force-include]\n\"configs\" = \"nested_learning/configs\"\n\n[tool.pytest.ini_options]\nminversion = \"7.0\"\naddopts = \"-ra -q\"\ntestpaths = [\"tests\"]\n\n[tool.ruff]\nline-length = 100\ntarget-version = \"py310\"\n\n[tool.ruff.lint]\nselect = [\"E\", \"F\", \"I\"]\nignore = []\n\n[tool.mypy]\npython_version = \"3.10\"\nwarn_unused_configs = true\n# PyTorch-style codebases inevitably interact with untyped third-party deps and dynamic module calls.\n# Keep mypy enabled, but avoid noisy `Any` return warnings for nn.Module forward calls.\nignore_missing_imports = true\nwarn_return_any = false\nstrict_optional = true\nshow_error_codes = true\npretty = true\npackages = [\"nested_learning\"]\n"
  },
  {
    "path": "reports/ablations.md",
    "content": "# Planned Ablations – Pilot Run\n\nThis document tracks the ablation studies we intend to run once the 3 B-token pilot checkpoint is available. The goal is to isolate the contributions of teach-signal scaling, CMS chunk accumulation, self-modifiers, and optimizer choices (AdamW vs Muon) before moving to larger configs.\n\n## 1. Teach-signal schedule\n| Variant | Description | Status | Notes |\n|---------|-------------|--------|-------|\n| Baseline | Warmup 2 k → decay 120 k→140 k (current pilot config) | ✅ (step 230 k) | Metrics in `eval/zeroshot_pilot_step230000.json` + friends. |\n| Low scale (0.05) | Reduce teach_scale to 0.05, 2 k-step pilot ablation | ✅ | `artifacts/checkpoints/pilot_teach05/step_002000.pt`, JSON log `logs/pilot-teach05-20251114010549.json`, evals under `eval/*_pilot_teach05_step2000.json`. |\n| High scale (0.15) | Increase teach_scale to 0.15, runs at 2 k and 25 k steps | ✅ | Short run: `artifacts/checkpoints/pilot_teach15/step_002000.pt`; long run: `artifacts/checkpoints/pilot_teach15_long/step_025000.pt`; logs/evals `logs/pilot-teach15-20251114012109.json`, `logs/pilot-teach15-long-20251114185448.json`, `eval/*_pilot_teach15_long_step25000.json`. |\n| No decay | Warmup only, no decay | ⏳ | Expect higher plasticity, risk of instability |\n| Per-level scale | Different teach_scale per CMS level | ⏳ | Requires config changes |\n\n## 2. CMS chunk accumulation\n| Variant | Description | Status | Notes |\n|---------|-------------|--------|-------|\n| Full CMS | Chunk accumulation + telemetry (default) | ✅ smoke-tested | Verified via `tests/test_cms.py`. Baseline checkpoint: `artifacts/checkpoints/pilot/step_230000.pt`. |\n| No chunking | Update each token (Transformer-like) | ✅ | Run `pilot-cms-nochunk` (5 k steps) with overrides `model.cms_levels.*.update_period=1`. Outputs: `logs/pilot-cms-nochunk-20251114124501.json`, eval JSONs `eval/*_pilot_cms_nochunk_step5000.json`. |\n| Sparse chunks | Update every 512 tokens only | ✅ | Config `configs/ablations/cms_sparse.yaml` (dim 384, layers 8, seq 1024, batch 2, chunk periods 8/32/128/512). Run `pilot-cms-sparse` (5 k steps) w/ resolved config `configs/resolved/cms_sparse_eval.yaml`. Metrics: PIQA 0.516, BoolQ 0.367, continual CE ≈25 across segments (see `eval/*_pilot_cms_sparse_step5000.json`). |\n\nTo keep chunk buffers tractable we reduced the CMS-hidden multiplier to 2, halved the batch size, and exported a resolved Hydra config at `configs/resolved/cms_sparse_eval.yaml` so that evaluation scripts can load the composed settings without Hydra. The highest-frequency buffer now tops out at ~3 GB (inputs + targets) instead of the 12 GB spikes we observed during the initial 2048-token attempt.\n\n## 3. Self-modifier toggles\n| Variant | Description | Status | Notes |\n|---------|-------------|--------|-------|\n| Enabled | SelfModifier active (default) | ✅ | Baseline pilot + long-run checkpoints. |\n| Disabled | Freeze self-modifier params | ✅ | Run `pilot-selfmod-off` (5 k steps). Continual CE jumped to ~45; see `eval/*_pilot_selfmod_off_step5000.json`. |\n| Teach-only | Teach signal applied but self-mod not updated | ⏳ | Planned follow-up once optimizer ablation finishes. |\n\n## 4. Optimizer swaps\n| Variant | Description | Status | Notes |\n|---------|-------------|--------|-------|\n| AdamW fused (control) | `optim.type=adamw` with fused kernels (override the Muon default) | ✅ | Run `pilot-opt-adamw-20251115173858` (5 k steps, batch 6, seq 2048) → checkpoint `artifacts/checkpoints/pilot-opt-adamw-20251115173858/step_005000.pt`. Eval highlights: PIQA 0.559, HellaSwag 0.273, Winogrande 0.500, BoolQ 0.367 (`eval/zeroshot_pilot_opt_adamw_step5000.json`); NIAH accuracies {0.75, 1.0, 0.5, 0.75, 0.5, 0.25}; continual CE ≈50/43/39/39 across segments. |\n| Muon hybrid | `optim.type=muon` for ≥2D params, AdamW for embeddings/bias | ✅ | Run `pilot-opt-muon-20251115180139` (identical setup) → checkpoint `artifacts/checkpoints/pilot-opt-muon-20251115180139/step_005000.pt`. Eval highlights: PIQA 0.531, HellaSwag 0.313, Winogrande 0.484, BoolQ 0.570 (`eval/zeroshot_pilot_opt_muon_step5000.json`); NIAH {0.5, 0.5, 0.25, 0.75, 0.75, 0.75}; continual CE ≈11 across all segments (`eval/continual_pilot_opt_muon_step5000.json`). |\n| Full Muon | Force Muon everywhere | ⏳ | Pending stability run; expect to require per-layer LR tuning. |\n\nAt 5 k pilot steps the hybrid Muon optimizer trades a small PIQA drop (0.559→0.531) for markedly better BoolQ (0.37→0.57) and 4× lower continual losses. Muon also cuts training loss faster (final CE ≈6.8 vs 8.5). Based on this we plan to adopt Muon for the resumed long HOPE run while keeping AdamW checkpoints for baseline comparisons.\n\n## 5. Automation hooks\n| Tool | Purpose | Status | Notes |\n|------|---------|--------|-------|\n| `scripts/package_pilot_release.sh` | Copies latest pilot checkpoint/config/logs into `artifacts/pilot_release/` and updates metadata | ✅ | Use after every significant checkpoint (e.g., 1k-step milestones) so collaborators can download a coherent bundle. |\n| `scripts/eval/run_pilot_suite.sh` | Runs zero-shot, NIAH (up to 64k), and continual harnesses (plus optional TITAN baseline) with memorization flags enabled | ✅ | Set `HOPE_CHECKPOINT`, `TITAN_*`, etc., to reuse for each ablation. Outputs land under `eval/`. |\n\n## 5. Evaluation checklist per ablation\n1. Run zero-shot suite (`scripts/eval/zeroshot.py --tasks all --memorize ...`).\n2. Run extended NIAH (`--context-lengths 2048 --context-lengths 4096 --context-lengths 8192 --context-lengths 16384 --context-lengths 32768 --context-lengths 65536`).\n3. Run continual-learning harness with memorization toggles (`--memorize --memorize-steps 2 --memorize-no-reset` and baseline run without memorization).\n4. Record metrics in `artifacts/pilot_release/` (JSON/CSV) and summarize deltas here.\n5. Long-context extras: `scripts/eval/passkey.py` (default 64 prompts, memorize on) and `scripts/eval/pg19_perplexity.py` (streaming PG-19, 2 048-token truncation). Archive outputs alongside zero-shot/NIAH JSONs.\n6. Continual plot: for multi-checkpoint evals, run `scripts/eval/plot_forgetting.py` and stash PNGs under `reports/plots/`.\n\n_Status legend:_ ✅ complete, ⏳ pending, 🔄 running, ⚠️ blocked.\n\n## 6. Reference snapshot – Pilot step 230k (HOPE)\n| Eval | Command | Output | Notes |\n|------|---------|--------|-------|\n| Zero-shot (full suite, memorize on) | `UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/zeroshot.py --config configs/pilot.yaml --checkpoint artifacts/checkpoints/pilot/step_230000.pt --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model --tasks all --max-samples 256 --device cuda:1 --output eval/zeroshot_pilot_step230000.json --memorize --memorize-steps 2 --memorize-use-correct-answer` | `eval/zeroshot_pilot_step230000.json` | PIQA 0.496, HellaSwag 0.297, Winogrande 0.473, ARC-E/C 0.285/0.234, BoolQ 0.367, SIQA 0.316, CSQA 0.180, OpenBookQA 0.113. |\n| NIAH (2k→65k, memorize on) | `UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/niah.py --config configs/pilot.yaml --checkpoint artifacts/checkpoints/pilot/step_230000.pt --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model --context-lengths 2048 --context-lengths 4096 --context-lengths 8192 --context-lengths 16384 --context-lengths 32768 --context-lengths 65536 --samples-per-length 8 --device cuda:1 --output eval/niah_pilot_step230000.json --memorize --memorize-steps 2 --memorize-use-correct-answer` | `eval/niah_pilot_step230000.json` | Accuracies 0.625 / 0.50 / 0.375 / 0.50 / 0.75 / 0.50 (2k→65k contexts). |\n| Continual segments (memorize 1 step) | `UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/continual.py --config configs/pilot.yaml --checkpoints artifacts/checkpoints/pilot/step_230000.pt --segments-yaml configs/data/continual_segments_sample.yaml --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model --batch-size 4 --max-batches 20 --device cuda:1 --output eval/continual_pilot_step230000.json --memorize --memorize-steps 1` | `eval/continual_pilot_step230000.json` | CE ≈8.06 / 7.79 / 7.68 / 7.95 for RefinedWeb / Wikipedia / C4 / RedPajama sample segments. |\n\nAll outputs are copied to `artifacts/pilot_release/` via `scripts/package_pilot_release.sh` for reproducibility.\n\n### Teach-scale=0.05 short-run notes\n- **Run:** `uv run python train.py --config-name pilot model.teach_scale=0.05 train.steps=2000 ...` on GPU 0, checkpoints in `artifacts/checkpoints/pilot_teach05/`.\n- **Training log:** `logs/pilot-teach05-20251114010549.json` (40 records, final loss 10.49, teach_signal_norm 7.8e‑3 at step 1950).\n- **Zeroshot (128 samples, memorize on):** `eval/zeroshot_pilot_teach05_step2000.json` → PIQA 0.453, HellaSwag 0.273, Winogrande 0.508, ARC-E 0.250, ARC-C 0.227, BoolQ 0.664, SIQA 0.289, CSQA 0.188, OBQA 0.180.\n- **NIAH:** `eval/niah_pilot_teach05_step2000.json` → 0.50 / 0.75 / 1.00 / 0.75 / 0.25 / 1.00 at 2k→65k.\n- **Continual:** `eval/continual_pilot_teach05_step2000.json` → CE ≈37.4 / 33.2 / 35.9 / 32.9 on RefinedWeb/Wiki/C4/RedPajama segments (expectedly high because the run saw only 2 k steps).\n\n### Teach-scale=0.15 short-run notes\n- **Run:** `uv run python train.py --config-name pilot model.teach_scale=0.15 train.steps=2000 ...` on GPU 1, checkpoints at `artifacts/checkpoints/pilot_teach15/step_{001000,002000}.pt`.\n- **Training log:** `logs/pilot-teach15-20251114012109.json` (40 records, final loss 8.70, ppl ≈6.0e3, teach_signal_norm ≈7.3e‑3).\n- **Zeroshot (128 samples, memorize on):** `eval/zeroshot_pilot_teach15_step2000.json` → PIQA 0.484, HellaSwag 0.258, Winogrande 0.461, ARC-E 0.203, ARC-C 0.219, BoolQ 0.336, SIQA 0.344, CSQA 0.211, OBQA 0.148.\n- **NIAH:** `eval/niah_pilot_teach15_step2000.json` → 0.75 / 0.75 / 0.75 / 0.50 / 0.25 / 0.50 (2k→65k).\n- **Continual:** `eval/continual_pilot_teach15_step2000.json` → CE ≈69.4 / 66.6 / 66.5 / 68.6 (substantially higher than baseline because the run barely saw data; will normalize once longer steps are run).\n\n### Teach-scale=0.05 long-run notes\n- **Run:** 25 k-step job (`pilot-teach05-long-20251114155521`) on GPU 1. Checkpoints under `artifacts/checkpoints/pilot_teach05_long/step_*.pt`.\n- **Training log:** `logs/pilot-teach05-long-20251114155521.json` (500 records; final loss ≈7.76, teach_sig_norm ≈7.3e‑3).\n- **Zeroshot:** `eval/zeroshot_pilot_teach05_long_step25000.json` → PIQA 0.508, HellaSwag 0.285, Winogrande 0.477, ARC-E 0.320, ARC-C 0.238, BoolQ 0.367, SIQA 0.328, CSQA 0.199, OBQA 0.145.\n- **NIAH:** `eval/niah_pilot_teach05_long_step25000.json` → 0.25 / 0.50 / 0.375 / 0.75 / 0.75 / 0.75.\n- **Continual:** `eval/continual_pilot_teach05_long_step25000.json` → CE ≈52.1 / 49.4 / 48.9 / 50.9 (much higher than baseline despite the long run).\n\n### Teach-scale=0.15 long-run notes\n- **Run:** 25 k-step job (`pilot-teach15-long-20251114185448`) on GPU 1. Checkpoints under `artifacts/checkpoints/pilot_teach15_long/step_*.pt`.\n- **Training log:** `logs/pilot-teach15-long-20251114185448.json` (500 records; final loss ≈7.76, teach_sig_norm ≈7.3e‑3).\n- **Zeroshot:** `eval/zeroshot_pilot_teach15_long_step25000.json` → PIQA 0.496, HellaSwag 0.305, Winogrande 0.500, ARC-E 0.301, ARC-C 0.238, BoolQ 0.367, SIQA 0.316, CSQA 0.176, OBQA 0.125.\n- **NIAH:** `eval/niah_pilot_teach15_long_step25000.json` → 0.75 / 0.625 / 0.375 / 0.75 / 0.50 / 0.75.\n- **Continual:** `eval/continual_pilot_teach15_long_step25000.json` → CE ≈7.91 / 7.63 / 7.56 / 7.79 (comparable to the HOPE baseline).\n\n### CMS chunk ablation – update_period=1\n- **Run:** `pilot-cms-nochunk-20251114232720` on GPU 1 (5 k steps) with all CMS levels set to `update_period=1`.\n- **Training log:** `logs/pilot-cms-nochunk-20251114232720.json` (100 records, final loss 8.65, teach_signal_norm ≈7.3e‑3).\n- **Zeroshot:** `eval/zeroshot_pilot_cms_nochunk_step5000.json` → PIQA 0.520, HellaSwag 0.277, Winogrande 0.473, ARC-E 0.320, ARC-C 0.242, BoolQ 0.633, SIQA 0.301, CSQA 0.191, OpenBookQA 0.148.\n- **NIAH:** `eval/niah_pilot_cms_nochunk_step5000.json` → 0.75 / 0.25 / 0.25 / 0.25 / 0.75 / 0.50.\n- **Continual:** `eval/continual_pilot_cms_nochunk_step5000.json` → CE ≈46.6 / 47.8 / 49.7 / 52.1 (forgetting worsens without chunk accumulation).\n\n### Self-modifier off (self_mod_lr=0)\n- **Run:** `pilot-selfmod-off-20251115132848` on GPU 1 (5 k steps) with `model.self_mod_lr=0`.\n- **Training log:** `logs/pilot-selfmod-off-20251115132848.json` (100 records, final loss 8.14, teach_signal_norm ≈7.3e‑3).\n- **Zeroshot:** `eval/zeroshot_pilot_selfmod_off_step5000.json` → PIQA 0.516, HellaSwag 0.266, Winogrande 0.465, ARC-E 0.289, ARC-C 0.207, BoolQ 0.633, SIQA 0.332, CSQA 0.164, OpenBookQA 0.164.\n- **NIAH:** `eval/niah_pilot_selfmod_off_step5000.json` → 0.75 / 0.75 / 0.50 / 0.75 / 0.25 / 0.75.\n- **Continual:** `eval/continual_pilot_selfmod_off_step5000.json` → CE ≈45.7 / 44.9 / 44.4 / 45.5 (self-modifier appears critical for continual learning even at short horizons).\n\n## 7. Upcoming experiments queue\n| ID | Variant | Command seed | Notes |\n|----|---------|--------------|-------|\n| Q1 | TITAN baseline (9k steps) | `uv run python train.py --config-name mid_titan_baseline ... train.steps=9000` | ✅ W&B `titan-short-20251112195149`; metrics stored as `eval/*_titan.json`. |\n| Q2 | Pilot long run (3 B tokens) | `tmux new -s pilot_full \"... train.steps=246667 train.checkpoint.save_interval=1000\"` | 🔄 Paused at step 246 667; release bundle now tracks step 230 000 (`artifacts/pilot_release/`). Resume after TITAN catches up. |\n| Q3 | Teach-scale ablation | `+model.teach_scale=0.05/0.15` (pilot config) | Run 2 k-step jobs to quantify stability vs accuracy. |\n| Q4 | CMS chunk toggle | `+model.cms_levels[].update_period=1` (Transformer-like) | Compare zero-shot/NIAH vs default chunking. |\n| Q5 | Muon vs AdamW | `optim.type=muon` vs `adamw` | Use 5 k-step runs, document speed/quality in `docs/experiments_report.md`. |\n| Q6 | TITAN long run (25 k steps) | `TMPDIR=/mnt/drive_4/tmp_titan UV_CACHE_DIR=/tmp/uv-cache uv run python train.py --config-name mid_titan_baseline ... train.steps=25000 train.checkpoint.save_interval=1000` | 🔄 Running on `cuda:0` (W&B `titan-long-20251113192738`); monitor `logs/titan_long.log` + wandb for checkpoints every 1 000 steps. |\n\nMark each queue item ✅/⏳/⚠️ as it progresses so we know which ablations have data ready for reporting.\n\n## 9. Phase 3 – Self-modifying Titans (paper HOPE scaffold)\n\nThe paper-defined `hope_selfmod` scaffold (`Self-modifying Titans → CMS`) has its own knobs and\nablation-ready configs. These are intended for **implementation validation** and small-scale\nexperiments (not paper-scale reproduction).\n\n| Config | Variant | Key override(s) |\n|--------|---------|-----------------|\n| `configs/hope/pilot_selfmod.yaml` | Pilot defaults | `model.block_variant=hope_selfmod`, `self_mod_chunk_size=8`, `self_mod_chunk_size_memory=64` |\n| `configs/ablations/selfmod_rank1_precond_off.yaml` | No DGD preconditioner | `model.self_mod_use_rank1_precond=false` |\n| `configs/ablations/selfmod_no_alpha.yaml` | No alpha/decay | `model.self_mod_use_alpha=false` |\n| `configs/ablations/selfmod_chunked_8_64.yaml` | Explicit chunking | `model.self_mod_chunk_size=8`, `model.self_mod_chunk_size_memory=64` |\n| `configs/ablations/selfmod_no_cms.yaml` | Selfmod-only | `model.cms_levels=[]` |\n| `configs/ablations/selfmod_momentum_on.yaml` | Momentum on | `model.self_mod_momentum=0.9` |\n| `configs/ablations/selfmod_momentum_off.yaml` | Momentum off | `model.self_mod_momentum=0.0` |\n\nThese configs require `src/nested_learning/training.py:build_model_from_cfg()` to plumb the self-mod\nfields through `ModelConfig`; this is covered by `tests/test_build_model_from_cfg_selfmod.py`.\n\n## 8. Baseline comparison (HOPE step 230k vs TITAN step 25k)\n| Metric | HOPE | TITAN | Notes |\n|--------|------|-------|-------|\n| PIQA / HellaSwag / Winogrande | 0.496 / 0.297 / 0.473 | 0.484 / 0.293 / 0.480 | `eval/zeroshot_pilot_step230000.json` vs `eval/zeroshot_titan_step25000.json`. |\n| ARC-E / ARC-C / BoolQ / SIQA / CSQA / OpenBookQA | 0.285 / 0.234 / 0.367 / 0.316 / 0.180 / 0.113 | 0.281 / 0.250 / 0.398 / 0.293 / 0.188 / 0.145 | Same zero-shot outputs as above. |\n| NIAH (2k / 4k / 8k / 16k / 32k / 65k) | 0.625 / 0.50 / 0.375 / 0.50 / 0.75 / 0.50 | 0.50 / 0.625 / 0.125 / 0.75 / 0.50 / 0.125 | `eval/niah_pilot_step230000.json` vs `eval/niah_titan_step25000.json`. |\n| Continual CE (RefinedWeb / Wiki / C4 / RedPajama) | 8.06 / 7.79 / 7.68 / 7.95 | 8.36 / 8.12 / 7.85 / 8.11 | `eval/continual_pilot_step230000.json` vs `eval/continual_titan_step25000.json`. |\n\nUse these values as the reference when logging ablations; refresh the table whenever a new HOPE or TITAN checkpoint is evaluated.\n\n### Additional snapshots (surprise-gated relaunch)\n- **HOPE pilot relaunch (step 477k):** `reports/checkpoints/pilot_relaunch_step477000.md` with eval outputs `eval/*_pilot.json`.\n- **TITAN long relaunch (step 32k):** `reports/checkpoints/titan_long_step32000.md` with eval outputs `eval/*_titan.json`.\n\nThese runs use `model.surprise_threshold=0.02` and therefore show ≈0 memorization deltas on short eval prompts (the gate did not trigger update events). Treat them as “pipeline / plumbing” validations rather than evidence of long-context advantages.\n"
  },
  {
    "path": "reports/cadence_mechanism_audit_smoke.json",
    "content": "{\n  \"ok\": true,\n  \"metric_prefix\": \"layer0.cms.cms_mid\",\n  \"log_path\": \"logs/mechanism_audit_smoke.json\",\n  \"flush_partial\": false,\n  \"total_tokens\": 8,\n  \"update_period\": 4,\n  \"expected\": {\n    \"updates_applied\": 2.0,\n    \"chunk_tokens\": 8.0,\n    \"tokens_flushed\": 0.0,\n    \"pending_tokens\": 0.0,\n    \"remainder_tokens\": 0.0\n  },\n  \"observed\": {\n    \"updates_applied\": 2.0,\n    \"chunk_tokens\": 8.0,\n    \"tokens_flushed\": 0.0,\n    \"pending_tokens\": 0.0\n  },\n  \"checks\": {\n    \"updates_applied\": true,\n    \"chunk_tokens\": true,\n    \"tokens_flushed\": true,\n    \"pending_tokens\": true\n  }\n}"
  },
  {
    "path": "reports/compliance_mechanism_audit_smoke.json",
    "content": "{\n  \"config\": \"configs/pilot.yaml\",\n  \"overall_ok\": true,\n  \"checks\": [\n    {\n      \"name\": \"strict_streaming_contract_observed\",\n      \"ok\": true,\n      \"detail\": \"train.strict_streaming_contract=False\"\n    },\n    {\n      \"name\": \"algorithm_mode_supported\",\n      \"ok\": true,\n      \"detail\": \"train.algorithm_mode='two_pass_stopgrad_updates'; allowed=['boundary_state_grad_through_write', 'two_pass_stopgrad_updates']\"\n    },\n    {\n      \"name\": \"variant_recorded\",\n      \"ok\": true,\n      \"detail\": \"model.block_variant='hope_hybrid'\"\n    },\n    {\n      \"name\": \"lm_head_embed_tied\",\n      \"ok\": true,\n      \"detail\": \"lm_head.weight aliases embed.weight\"\n    },\n    {\n      \"name\": \"boundary_requires_online_updates\",\n      \"ok\": true,\n      \"detail\": \"online_updates=True, online_boundary_targets=False\"\n    },\n    {\n      \"name\": \"carry_attention_requires_boundary\",\n      \"ok\": true,\n      \"detail\": \"online_carry_attention_cache requires online_boundary_targets (carry=False, boundary=False)\"\n    },\n    {\n      \"name\": \"fast_state_batch_semantics\",\n      \"ok\": true,\n      \"detail\": \"use_fast_state=False, data.batch_size=6\"\n    },\n    {\n      \"name\": \"cadence_report_ok\",\n      \"ok\": true,\n      \"detail\": \"cadence_report=reports/cadence_mechanism_audit_smoke.json\"\n    }\n  ],\n  \"cadence_report\": {\n    \"ok\": true,\n    \"metric_prefix\": \"layer0.cms.cms_mid\",\n    \"log_path\": \"logs/mechanism_audit_smoke.json\",\n    \"flush_partial\": false,\n    \"total_tokens\": 8,\n    \"update_period\": 4,\n    \"expected\": {\n      \"updates_applied\": 2.0,\n      \"chunk_tokens\": 8.0,\n      \"tokens_flushed\": 0.0,\n      \"pending_tokens\": 0.0,\n      \"remainder_tokens\": 0.0\n    },\n    \"observed\": {\n      \"updates_applied\": 2.0,\n      \"chunk_tokens\": 8.0,\n      \"tokens_flushed\": 0.0,\n      \"pending_tokens\": 0.0\n    },\n    \"checks\": {\n      \"updates_applied\": true,\n      \"chunk_tokens\": true,\n      \"tokens_flushed\": true,\n      \"pending_tokens\": true\n    }\n  }\n}"
  },
  {
    "path": "reports/compliance_summary_pilot.json",
    "content": "{\n  \"config\": \"configs/pilot.yaml\",\n  \"overall_ok\": true,\n  \"checks\": [\n    {\n      \"name\": \"strict_streaming_contract_observed\",\n      \"ok\": true,\n      \"detail\": \"train.strict_streaming_contract=False\"\n    },\n    {\n      \"name\": \"algorithm_mode_supported\",\n      \"ok\": true,\n      \"detail\": \"train.algorithm_mode='two_pass_stopgrad_updates'; allowed=['boundary_state_grad_through_write', 'two_pass_stopgrad_updates']\"\n    },\n    {\n      \"name\": \"variant_recorded\",\n      \"ok\": true,\n      \"detail\": \"model.block_variant='hope_hybrid'\"\n    },\n    {\n      \"name\": \"lm_head_embed_tied\",\n      \"ok\": true,\n      \"detail\": \"lm_head.weight aliases embed.weight\"\n    },\n    {\n      \"name\": \"boundary_requires_online_updates\",\n      \"ok\": true,\n      \"detail\": \"online_updates=True, online_boundary_targets=False\"\n    },\n    {\n      \"name\": \"carry_attention_requires_boundary\",\n      \"ok\": true,\n      \"detail\": \"online_carry_attention_cache requires online_boundary_targets (carry=False, boundary=False)\"\n    },\n    {\n      \"name\": \"fast_state_batch_semantics\",\n      \"ok\": true,\n      \"detail\": \"use_fast_state=False, data.batch_size=6\"\n    },\n    {\n      \"name\": \"cadence_report_ok\",\n      \"ok\": true,\n      \"detail\": \"cadence_report=reports/cadence_mechanism_audit_smoke.json\"\n    }\n  ],\n  \"cadence_report\": {\n    \"ok\": true,\n    \"metric_prefix\": \"layer0.cms.cms_mid\",\n    \"log_path\": \"logs/mechanism_audit_smoke.json\",\n    \"flush_partial\": false,\n    \"total_tokens\": 8,\n    \"update_period\": 4,\n    \"expected\": {\n      \"updates_applied\": 2.0,\n      \"chunk_tokens\": 8.0,\n      \"tokens_flushed\": 0.0,\n      \"pending_tokens\": 0.0,\n      \"remainder_tokens\": 0.0\n    },\n    \"observed\": {\n      \"updates_applied\": 2.0,\n      \"chunk_tokens\": 8.0,\n      \"tokens_flushed\": 0.0,\n      \"pending_tokens\": 0.0\n    },\n    \"checks\": {\n      \"updates_applied\": true,\n      \"chunk_tokens\": true,\n      \"tokens_flushed\": true,\n      \"pending_tokens\": true\n    }\n  }\n}"
  },
  {
    "path": "reports/compliance_summary_pilot_paper_faithful.json",
    "content": "{\n  \"config\": \"configs/pilot_paper_faithful.yaml\",\n  \"overall_ok\": true,\n  \"checks\": [\n    {\n      \"name\": \"strict_streaming_contract_observed\",\n      \"ok\": true,\n      \"detail\": \"train.strict_streaming_contract=True\"\n    },\n    {\n      \"name\": \"algorithm_mode_supported\",\n      \"ok\": true,\n      \"detail\": \"train.algorithm_mode='two_pass_stopgrad_updates'; allowed=['boundary_state_grad_through_write', 'two_pass_stopgrad_updates']\"\n    },\n    {\n      \"name\": \"strict_variant_is_paper_defined\",\n      \"ok\": true,\n      \"detail\": \"model.block_variant='hope_attention'; allowed=['hope_attention', 'hope_selfmod']\"\n    },\n    {\n      \"name\": \"lm_head_embed_tied\",\n      \"ok\": true,\n      \"detail\": \"lm_head.weight aliases embed.weight\"\n    },\n    {\n      \"name\": \"boundary_requires_online_updates\",\n      \"ok\": true,\n      \"detail\": \"online_updates=True, online_boundary_targets=True\"\n    },\n    {\n      \"name\": \"carry_attention_requires_boundary\",\n      \"ok\": true,\n      \"detail\": \"online_carry_attention_cache requires online_boundary_targets (carry=True, boundary=True)\"\n    },\n    {\n      \"name\": \"fast_state_batch_semantics\",\n      \"ok\": true,\n      \"detail\": \"use_fast_state=True, data.batch_size=1\"\n    }\n  ],\n  \"cadence_report\": null\n}"
  },
  {
    "path": "reports/next_backlog_scoped.md",
    "content": "# Next Backlog (Scoped, Non-Feature-Creep)\n\n1. Stabilize boundary-state mode for longer single-GPU runs (memory profiling + guardrail docs).\n2. Add optional `require_strict` compliance gate job in CI for paper-faithful config only.\n3. Expand packaging tests to assert required sidecars for both HOPE and TITAN checkpoints.\n4. Add deterministic mini-run harness that stores and compares two short metric traces.\n5. Add doc automation that validates CLI flags mentioned in README against `--help` output snapshots.\n6. Keep result-scale work explicitly optional (no default 100B+ reproduction commitments).\n"
  },
  {
    "path": "reports/security_release_gate.md",
    "content": "# Security / Release Gate Log\n\nExecuted at: `2026-02-24T00:40:32Z` (UTC)\n\n## Commands Run\n- `rg -n --hidden --glob '!.git' --glob '!*.pt' --glob '!*.bin' \"(AKIA[0-9A-Z]{16}|ghp_[A-Za-z0-9]{30,}|hf_[A-Za-z0-9]{20,}|BEGIN PRIVATE KEY|SECRET_KEY|API_KEY|PASSWORD=|token=)\" .`\n- `git ls-files | rg -n \"(\\\\.pt$|\\\\.ckpt$|\\\\.safetensors$|\\\\.npy$|\\\\.zip$|git\\\\.env|docs_tmp|artifacts/|data/raw/)\"`\n- `git ls-files -s | awk '{print $4}' | xargs -r du -h | sort -h | tail -n 30`\n- `git check-ignore -v artifacts/checkpoints/pilot/step_000001.pt logs/mechanism_audit_smoke.json docs_tmp/placeholder.txt data/raw/example.txt git.env docs/POSTS.md`\n- `bash scripts/checks/check_git_tracked_sizes.sh`\n\n## Findings\n- Secret-pattern scan: no credentials/tokens detected in tracked content.\n- Tracked artifact scan: no forbidden checkpoint/artifact extensions are tracked.\n- `.gitignore` coverage confirmed for:\n  - `artifacts/`\n  - `logs/`\n  - `data/`\n  - `docs_tmp/`\n  - `git.env`\n  - `docs/POSTS.md`\n- Largest tracked files are paper/reference artifacts and lockfile; all below size gate threshold (`5 MiB`) enforced by `scripts/checks/check_git_tracked_sizes.sh`.\n\n## Remediation / Guardrails Added\n- Added `scripts/checks/check_git_tracked_sizes.sh`:\n  - fails CI for forbidden tracked artifact extensions (`.pt`, `.ckpt`, `.safetensors`, `.npy`, `.npz`, `.zip`)\n  - fails CI for tracked files above `MAX_TRACKED_FILE_BYTES` (default `5 MiB`)\n- Added CI step in `.github/workflows/ci.yml` to run size/artifact gate.\n- Added test `tests/test_git_tracked_sizes_check.py`.\n- Added package-script test `tests/test_package_release_script.py` validating manifest train flags and that raw data is not included in release bundle.\n\n## Status\n- Gate result: `PASS`\n"
  },
  {
    "path": "reports/sprint_completion_report.md",
    "content": "# Sprint Completion Report (Mechanism Fidelity Focus)\n\nDate: `2026-02-24`\n\n## What Closed This Sprint\n\n- Boundary-state mechanism path now has:\n  - training-loop coverage\n  - explicit startup warning\n  - fail-fast constraints\n  - config-level assertions\n- Paper-faithful config behavior is explicit (`online_updates=true`) and tested.\n- Compliance/traceability improved:\n  - algorithm/online flags persisted in checkpoint metadata\n  - release manifest includes train flags\n  - compliance summaries generated for pilot configs\n- Docs and usability hardening:\n  - markdown link+anchor validation\n  - data split fallback deterministic order (`train -> validation -> test -> first available`)\n  - data script `--help` checks automated in CI\n- Security/release hygiene:\n  - tracked-file size and forbidden-extension gate in CI\n  - release packaging exclusion behavior tested\n  - explicit security gate log recorded\n\n## Residual Risks\n\n1. Boundary-state mode remains experimental and single-process only.\n2. Distributed mechanism-auditing parity (online + per-layer + boundary/cache semantics) remains deferred.\n3. Full paper-scale result reproduction remains compute-limited and out of current scope.\n4. Some warnings in CPU tests come from upstream `torch` pin-memory deprecations (non-blocking for correctness).\n\n## Sprint Definition of Done\n\n- Mechanism-level compliance claims are aligned with code/tests.\n- CI catches key regressions in docs references, data-script usability, and tracked artifact hygiene.\n- Reproducibility path is explicit (commands, configs, compliance outputs).\n"
  },
  {
    "path": "reports/stage2_smoke.md",
    "content": "# Stage 2 Smoke Artifact Summary\n\n## Hardware\n- 2× NVIDIA RTX 6000 Ada (49 GB VRAM each)\n- PyTorch 2.9.0 (LTS), CUDA 12.4\n- Python 3.12 via `uv`\n\n## Data Prep\n```bash\nuv run bash scripts/data/run_sample.sh\n# full pipeline (tmux recommended)\nRW_LIMIT=20000 WIKI_LIMIT=10000 C4_LIMIT=8000 RPJ_LIMIT=8000 CODE_LIMIT=8000 \\\n  tmux new -s data_full 'uv run bash scripts/data/run_full.sh'\n```\n\nKey artifacts:\n- Tokenizer: `artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model`\n- Sample shards: `data/shards/*_filtered`\n- Full shards: `data/shards/*_full`\n- Stats: `data/mixtures/refinedweb_mix_filtered_shards.json`, `data/mixtures/refinedweb_mix_full_shards.json`\n\n## Dual-GPU Smoke Run\n```bash\nuv run torchrun --nproc_per_node=2 train_dist.py --config-name mid_stage2_smoke\n```\n- Checkpoint: `artifacts/checkpoints/mid_stage2_smoke/step_000060.pt`\n- Log: `logs/mid_stage2_smoke.json`\n\n### Evaluations\n```bash\nuv run python scripts/eval/zeroshot.py --config configs/mid_stage2_smoke.yaml \\\n  --checkpoint artifacts/checkpoints/mid_stage2_smoke/step_000060.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --tasks piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,siqa \\\n  --max-samples 64 --device cuda:1\n\nuv run python scripts/eval/niah.py --config configs/mid_stage2_smoke.yaml \\\n  --checkpoint artifacts/checkpoints/mid_stage2_smoke/step_000060.pt \\\n  --tokenizer-path artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model \\\n  --context-lengths 2048 --samples-per-length 5 --device cuda:1\n\nuv run python scripts/eval/continual.py --config configs/mid_stage2_smoke.yaml \\\n  --checkpoints artifacts/checkpoints/mid_stage2_smoke/step_000060.pt \\\n  --segments-yaml configs/data/continual_segments_sample.yaml \\\n  --batch-size 4 --max-batches 5 --device cuda:1\n```\n- Zero-shot metrics: `eval/zeroshot_mid_stage2_smoke.json`\n- NIAH: `eval/niah_mid_stage2_smoke.json`\n- Continual: `eval/continual_mid_stage2_smoke.json`\n\n## Mid-Scale Reference Run\n```bash\ntmux new -s mid_stage2_run \"uv run torchrun --nproc_per_node=2 train_dist.py --config-name mid_stage2\"\n```\n- Checkpoint: `artifacts/checkpoints/mid_stage2/step_000100.pt`\n- Log: `logs/mid_stage2.json`\n- Eval summaries: `eval/zeroshot_mid_stage2.json`, `eval/niah_mid_stage2.json`, `eval/continual_mid_stage2.json`\n\n## Teach-scale Sweep (single GPU, batch=4)\nReference runs for teach_scale ∈ {0.05, 0.10, 0.20}:\n```bash\nuv run python train.py --config-name mid_stage2 \\\n  model.teach_scale=0.10 model.teach_clip=5.0 \\\n  data.batch_size=4 train.steps=40 train.device=cuda:1 \\\n  logging.path=logs/mid_stage2_single_ts10.json \\\n  train.checkpoint.dir=artifacts/checkpoints/mid_stage2_single_ts10\n```\nLogs and checkpoints are stored under `logs/mid_stage2_single_ts*.json`, `artifacts/checkpoints/mid_stage2_single_ts*/`.\n\n## Extended Single-GPU Run (teach_scale=0.10)\n```bash\ntmux new -s mid_stage2_ts10_single \"uv run python train.py --config-name mid_stage2 \\\n  model.teach_scale=0.10 model.teach_clip=4.0 \\\n  model.teach_schedule.warmup_steps=60 \\\n  model.teach_schedule.decay_start=140 \\\n  model.teach_schedule.decay_duration=80 \\\n  data.batch_size=4 optim.lr=1e-5 train.device=cuda:1 \\\n  train.steps=220 train.log_interval=20 \\\n  logging.path=logs/mid_stage2_ts10_single220_schedD.json \\\n  train.checkpoint.dir=artifacts/checkpoints/mid_stage2_ts10_single220_schedD\"\n```\n- Checkpoint: `artifacts/checkpoints/mid_stage2_ts10_single220_schedD/step_000220.pt`\n- Log: `logs/mid_stage2_ts10_single220_schedD.json`\n- Evaluations: `eval/zeroshot_mid_stage2_ts10_single220_schedD.json`, `eval/niah_mid_stage2_ts10_single220_schedD.json`, `eval/continual_mid_stage2_ts10_single220_schedD.json`\n\n## TITAN Baseline (single GPU)\n```bash\nuv run python train.py --config-name mid_titan_baseline\n```\n- Checkpoint: `artifacts/checkpoints/mid_titan_baseline/step_000200.pt`\n- Log: `logs/mid_titan_baseline.json`\n- Evaluations: `eval/zeroshot_mid_titan_baseline.json`, `eval/niah_mid_titan_baseline.json`, `eval/continual_mid_titan_baseline.json`\n\nThis document should accompany the release tag so others can reproduce the exact smoke workflow in a few commands.\n"
  },
  {
    "path": "scripts/__init__.py",
    "content": "# Makes `scripts` a package for intra-eval imports.\n"
  },
  {
    "path": "scripts/checkpoint/verify.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport typer\n\nfrom nested_learning.training import verify_checkpoint_integrity\n\napp = typer.Typer(help=\"Verify checkpoint metadata hashes, config, and RNG sidecars.\")\n\n\n@app.command()\ndef main(\n    checkpoint: Path = typer.Option(..., help=\"Path to checkpoint .pt file.\"),\n) -> None:\n    metadata = verify_checkpoint_integrity(checkpoint)\n    typer.echo(f\"[verify] {checkpoint} OK (step {metadata.get('step')})\")\n    typer.echo(json.dumps(metadata, indent=2))\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/checks/check_data_script_help.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nROOT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")/../..\" && pwd)\"\ncd \"${ROOT_DIR}\"\n\ntmp_out=\"$(mktemp)\"\ntrap 'rm -f \"${tmp_out}\"' EXIT\n\nfail=0\n\nfor script in scripts/data/*.py; do\n  base=\"$(basename \"${script}\")\"\n  if [[ \"${base}\" == \"__init__.py\" ]]; then\n    continue\n  fi\n  if ! uv run python \"${script}\" --help >\"${tmp_out}\" 2>&1; then\n    echo \"[help-check] FAIL ${script}\"\n    cat \"${tmp_out}\"\n    fail=1\n  fi\ndone\n\nfor script in scripts/data/*.sh; do\n  if ! bash \"${script}\" --help >\"${tmp_out}\" 2>&1; then\n    echo \"[help-check] FAIL ${script}\"\n    cat \"${tmp_out}\"\n    fail=1\n  fi\ndone\n\nif [[ \"${fail}\" -ne 0 ]]; then\n  exit 1\nfi\n\necho \"[help-check] OK scripts/data --help\"\n"
  },
  {
    "path": "scripts/checks/check_git_tracked_sizes.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nMAX_BYTES=\"${MAX_TRACKED_FILE_BYTES:-5242880}\"  # 5 MiB default\nFORBIDDEN_EXT_REGEX='(\\.pt|\\.ckpt|\\.safetensors|\\.npy|\\.npz|\\.zip)$'\n\nfail=0\n\nwhile IFS= read -r path; do\n  [[ -f \"${path}\" ]] || continue\n  if [[ \"${path}\" =~ ${FORBIDDEN_EXT_REGEX} ]]; then\n    echo \"[size-check] forbidden tracked artifact extension: ${path}\"\n    fail=1\n  fi\n  size=$(wc -c < \"${path}\")\n  if (( size > MAX_BYTES )); then\n    echo \"[size-check] tracked file exceeds ${MAX_BYTES} bytes: ${path} (${size})\"\n    fail=1\n  fi\ndone < <(git ls-files)\n\nif [[ \"${fail}\" -ne 0 ]]; then\n  exit 1\nfi\n\necho \"[size-check] OK (max_bytes=${MAX_BYTES})\"\n"
  },
  {
    "path": "scripts/checks/check_readme_commands.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\n# Keep README's core CLI guidance executable in CI.\nuv run nl --help >/dev/null\nuv run nl doctor --json >/dev/null\nuv run nl smoke --help >/dev/null\nuv run python -m nested_learning --help >/dev/null\n\necho \"README command smoke checks passed.\"\n\n"
  },
  {
    "path": "scripts/checks/compliance_report.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom dataclasses import asdict, dataclass\nfrom pathlib import Path\nfrom typing import Any\n\nimport typer\nfrom hydra import compose, initialize_config_dir\nfrom hydra.core.global_hydra import GlobalHydra\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(\n    add_completion=False,\n    help=\"Aggregate implementation compliance checks into a machine-readable report.\",\n)\n\n\ndef _load_resolved_config(config_path: Path):\n    cfg = OmegaConf.load(config_path)\n    cfg = unwrap_config(cfg)\n    model_cfg = cfg.get(\"model\")\n    needs_compose = bool(\n        cfg.get(\"defaults\") is not None\n        and model_cfg is not None\n        and model_cfg.get(\"titan_level\") is None\n    )\n    if not needs_compose:\n        return cfg\n    config_dir = config_path.resolve().parent\n    config_name = config_path.stem\n    GlobalHydra.instance().clear()\n    with initialize_config_dir(version_base=None, config_dir=str(config_dir)):\n        composed = compose(config_name=config_name)\n    return unwrap_config(composed)\n\n\n@dataclass\nclass CheckResult:\n    name: str\n    ok: bool\n    detail: str\n\n\ndef _append(results: list[CheckResult], name: str, ok: bool, detail: str) -> None:\n    results.append(CheckResult(name=name, ok=ok, detail=detail))\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra config path.\"),\n    output: Path = typer.Option(\n        Path(\"eval/compliance_report.json\"),\n        help=\"Path to write the compliance JSON.\",\n    ),\n    require_strict: bool = typer.Option(\n        False,\n        help=\"Require train.strict_streaming_contract=true.\",\n    ),\n    cadence_report: Path | None = typer.Option(\n        None,\n        help=\"Optional JSON emitted by scripts/checks/verify_update_cadence.py.\",\n    ),\n) -> None:\n    cfg = _load_resolved_config(config)\n    results: list[CheckResult] = []\n\n    strict = bool(cfg.train.get(\"strict_streaming_contract\", False))\n    algorithm_mode = str(cfg.train.get(\"algorithm_mode\", \"two_pass_stopgrad_updates\")).strip()\n    allowed_algorithm_modes = {\"two_pass_stopgrad_updates\", \"boundary_state_grad_through_write\"}\n    if require_strict:\n        _append(\n            results,\n            \"strict_streaming_contract_enabled\",\n            strict,\n            f\"train.strict_streaming_contract={strict}\",\n        )\n    else:\n        _append(\n            results,\n            \"strict_streaming_contract_observed\",\n            True,\n            f\"train.strict_streaming_contract={strict}\",\n        )\n    _append(\n        results,\n        \"algorithm_mode_supported\",\n        algorithm_mode in allowed_algorithm_modes,\n        f\"train.algorithm_mode={algorithm_mode!r}; allowed={sorted(allowed_algorithm_modes)}\",\n    )\n\n    block_variant = str(cfg.model.get(\"block_variant\", \"hope_hybrid\")).strip().lower()\n    allowed_variants = {\"hope_attention\", \"hope_selfmod\"}\n    if strict:\n        _append(\n            results,\n            \"strict_variant_is_paper_defined\",\n            block_variant in allowed_variants,\n            f\"model.block_variant={block_variant!r}; allowed={sorted(allowed_variants)}\",\n        )\n    else:\n        _append(\n            results,\n            \"variant_recorded\",\n            True,\n            f\"model.block_variant={block_variant!r}\",\n        )\n\n    model = build_model_from_cfg(cfg.model)\n    lm_head = getattr(model, \"lm_head\", None)\n    embed = getattr(model, \"embed\", None)\n    tied = False\n    if lm_head is not None and embed is not None:\n        lm_weight = getattr(lm_head, \"weight\", None)\n        emb_weight = getattr(embed, \"weight\", None)\n        if lm_weight is not None and emb_weight is not None:\n            tied = lm_weight.data_ptr() == emb_weight.data_ptr()\n    _append(\n        results,\n        \"lm_head_embed_tied\",\n        tied,\n        \"lm_head.weight aliases embed.weight\" if tied else \"weights are not tied\",\n    )\n\n    online_updates = bool(cfg.train.get(\"online_updates\", False))\n    boundary = bool(cfg.train.get(\"online_boundary_targets\", False))\n    carry_attn = bool(cfg.train.get(\"online_carry_attention_cache\", False))\n    per_layer_teach = bool(cfg.train.get(\"per_layer_teach_signal\", False))\n    use_fast_state = bool(cfg.train.get(\"use_fast_state\", False))\n    if algorithm_mode == \"boundary_state_grad_through_write\":\n        boundary_mode_ok = online_updates and per_layer_teach and use_fast_state\n        _append(\n            results,\n            \"boundary_algorithm_mode_constraints\",\n            boundary_mode_ok,\n            (\n                \"boundary mode requires online_updates, per_layer_teach_signal, and use_fast_state \"\n                f\"(online_updates={online_updates}, per_layer_teach_signal={per_layer_teach}, \"\n                f\"use_fast_state={use_fast_state})\"\n            ),\n        )\n    _append(\n        results,\n        \"boundary_requires_online_updates\",\n        (not boundary) or online_updates,\n        f\"online_updates={online_updates}, online_boundary_targets={boundary}\",\n    )\n    _append(\n        results,\n        \"carry_attention_requires_boundary\",\n        (not carry_attn) or boundary,\n        (\n            \"online_carry_attention_cache requires online_boundary_targets\"\n            f\" (carry={carry_attn}, boundary={boundary})\"\n        ),\n    )\n\n    batch_size = int(cfg.data.get(\"batch_size\", 1))\n    fast_state_semantics_ok = (not use_fast_state) or batch_size <= 1\n    _append(\n        results,\n        \"fast_state_batch_semantics\",\n        fast_state_semantics_ok,\n        f\"use_fast_state={use_fast_state}, data.batch_size={batch_size}\",\n    )\n\n    cadence_payload: dict[str, Any] | None = None\n    if cadence_report is not None:\n        cadence_payload = json.loads(cadence_report.read_text())\n        cadence_ok = bool(cadence_payload.get(\"ok\", False))\n        _append(\n            results,\n            \"cadence_report_ok\",\n            cadence_ok,\n            f\"cadence_report={cadence_report}\",\n        )\n\n    summary = {\n        \"config\": str(config),\n        \"overall_ok\": all(item.ok for item in results),\n        \"checks\": [asdict(item) for item in results],\n        \"cadence_report\": cadence_payload,\n    }\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(summary, indent=2))\n    print(json.dumps(summary, indent=2))\n    if not summary[\"overall_ok\"]:\n        raise typer.Exit(code=1)\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/checks/run_fidelity_ci_subset.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nexport UV_LINK_MODE=\"${UV_LINK_MODE:-copy}\"\nexport UV_CACHE_DIR=\"${UV_CACHE_DIR:-/tmp/uv-cache}\"\n\nuv run python scripts/checks/verify_docs_refs.py\nbash scripts/checks/check_git_tracked_sizes.sh\nbash scripts/checks/check_data_script_help.sh\n\nuv run pytest \\\n  tests/test_algorithm_mode_grad.py \\\n  tests/test_boundary_state_mode.py \\\n  tests/test_attention_cache.py \\\n  tests/test_teach_signal.py \\\n  tests/test_cms.py \\\n  tests/test_cms_cross_call.py \\\n  tests/test_cms_flush_partial.py \\\n  tests/test_online_chunking.py \\\n  tests/test_surprise_override.py \\\n  tests/test_model_streaming_cadence.py \\\n  tests/test_verify_update_cadence.py \\\n  tests/test_eval_state.py \\\n  tests/test_optim.py \\\n  tests/test_distributed_fail_fast.py \\\n  tests/test_fast_state_batch_semantics.py \\\n  tests/test_strict_streaming_contract.py \\\n  tests/test_tied_weight_guard.py \\\n  tests/test_verify_docs_refs.py \\\n  tests/test_paper_faithful_configs.py \\\n  tests/test_compliance_report.py \\\n  tests/test_compile_toggle.py\n\nuv run python scripts/checks/compliance_report.py \\\n  --config configs/pilot.yaml \\\n  --output /tmp/compliance_report_ci.json\n\nbash scripts/run_mechanism_audit_smoke.sh\n"
  },
  {
    "path": "scripts/checks/tokenizer_coverage_guard.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\nfrom nested_learning.tokenizer_coverage import compute_tokenizer_coverage_stats\n\napp = typer.Typer(\n    add_completion=False,\n    help=\"Regress coverage stats against a recorded baseline to catch tokenizer drift.\",\n)\n\n\n@app.command()\ndef main(\n    baseline: Path = typer.Option(\n        ...,\n        help=\"Reference JSON produced by scripts/data/check_tokenizer_coverage.py.\",\n    ),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece tokenizer to evaluate.\"),\n    sample_file: Path = typer.Option(..., help=\"Representative text sample.\"),\n    max_lines: int = typer.Option(10_000, help=\"Maximum lines to consume.\"),\n    avg_tokens_tolerance: float = typer.Option(\n        0.05,\n        help=\"Allowed increase in avg tokens per word before failing.\",\n    ),\n    single_token_drop_tolerance: float = typer.Option(\n        0.02,\n        help=\"Allowed decrease in pct_single_token_words before failing.\",\n    ),\n    two_token_drop_tolerance: float = typer.Option(\n        0.02,\n        help=\"Allowed decrease in pct_two_or_less_tokens_words before failing.\",\n    ),\n    output: Optional[Path] = typer.Option(\n        None,\n        help=\"Optional path to write the freshly computed coverage JSON.\",\n    ),\n) -> None:\n    if not baseline.exists():\n        raise typer.BadParameter(f\"Baseline JSON {baseline} was not found.\")\n    baseline_stats = json.loads(baseline.read_text())\n    current_stats = compute_tokenizer_coverage_stats(\n        tokenizer_path, sample_file, max_lines=max_lines\n    )\n    violations: list[str] = []\n\n    delta_avg = current_stats[\"avg_tokens_per_word\"] - baseline_stats[\"avg_tokens_per_word\"]\n    if delta_avg > avg_tokens_tolerance:\n        violations.append(\n            f\"avg_tokens_per_word regressed by {delta_avg:.4f} (limit {avg_tokens_tolerance:.4f}).\"\n        )\n\n    delta_single = (\n        baseline_stats[\"pct_single_token_words\"] - current_stats[\"pct_single_token_words\"]\n    )\n    if delta_single > single_token_drop_tolerance:\n        violations.append(\n            f\"pct_single_token_words dropped by {delta_single:.4f} \"\n            f\"(limit {single_token_drop_tolerance:.4f}).\"\n        )\n\n    delta_two = (\n        baseline_stats[\"pct_two_or_less_tokens_words\"]\n        - current_stats[\"pct_two_or_less_tokens_words\"]\n    )\n    if delta_two > two_token_drop_tolerance:\n        violations.append(\n            f\"pct_two_or_less_tokens_words dropped by {delta_two:.4f} \"\n            f\"(limit {two_token_drop_tolerance:.4f}).\"\n        )\n\n    payload = json.dumps(current_stats, indent=2)\n    typer.echo(\"# Tokenizer coverage guard\")\n    typer.echo(f\"- Baseline: {baseline}\")\n    typer.echo(f\"- Tokenizer: {tokenizer_path}\")\n    typer.echo(f\"- Sample: {sample_file}\")\n    typer.echo(payload)\n\n    if output:\n        output.parent.mkdir(parents=True, exist_ok=True)\n        output.write_text(payload)\n\n    if violations:\n        typer.echo(\"Guard failed:\")\n        for violation in violations:\n            typer.echo(f\"  - {violation}\")\n        raise typer.Exit(code=1)\n\n    typer.echo(\"Guard passed: tokenizer coverage within tolerance.\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/checks/verify_docs_refs.py",
    "content": "#!/usr/bin/env python3\nfrom __future__ import annotations\n\nimport argparse\nimport json\nimport re\nfrom pathlib import Path\nfrom typing import Iterable\n\nDEFAULT_DOCS = (\n    \"README.md\",\n    \"docs/PAPER_COMPLIANCE.md\",\n    \"docs/STREAMING_CONTRACT.md\",\n    \"docs/IMPLEMENTATION_STATUS.md\",\n    \"docs/release_checklist.md\",\n    \"docs/BUG_REPORT_CHECKLIST.md\",\n)\n\nPATH_PREFIXES = (\n    \"src/\",\n    \"scripts/\",\n    \"tests/\",\n    \"configs/\",\n    \"docs/\",\n    \".github/\",\n)\n\nSCHEME_PREFIXES = (\"http://\", \"https://\", \"mailto:\")\n\n\ndef _iter_code_spans(text: str) -> Iterable[str]:\n    for match in re.finditer(r\"`([^`\\n]+)`\", text):\n        yield match.group(1)\n\n\ndef _iter_link_targets(text: str) -> Iterable[str]:\n    for match in re.finditer(r\"\\[[^\\]]+\\]\\(([^)]+)\\)\", text):\n        yield match.group(1)\n\n\ndef _normalize_path_candidate(token: str) -> str | None:\n    token = token.strip().strip(\"`\")\n    token = token.strip(\"()[]{}'\\\".,;:\")\n    if not token:\n        return None\n    if token.startswith(SCHEME_PREFIXES):\n        return None\n    if token.startswith((\"-\", \"--\")):\n        return None\n    if any(ch in token for ch in (\"<\", \">\", \"{\", \"}\", \"*\", \"|\", \"$\")):\n        return None\n    token = re.sub(r\":\\d+(?::\\d+)?$\", \"\", token)\n    if \"#\" in token and not token.startswith(\"#\"):\n        token = token.split(\"#\", 1)[0]\n    if token.startswith(\"./\"):\n        token = token[2:]\n    if token.startswith(\"../\"):\n        return None\n    if token == \"README.md\":\n        return token\n    if \"/\" not in token:\n        return None\n    if not token.startswith(PATH_PREFIXES):\n        return None\n    return token\n\n\ndef parse_referenced_paths(doc_text: str) -> set[str]:\n    refs: set[str] = set()\n    for span in _iter_code_spans(doc_text):\n        for piece in span.split():\n            normalized = _normalize_path_candidate(piece)\n            if normalized is not None:\n                refs.add(normalized)\n    for target in _iter_link_targets(doc_text):\n        normalized = _normalize_path_candidate(target)\n        if normalized is not None:\n            refs.add(normalized)\n    return refs\n\n\ndef _slugify_heading(heading: str) -> str:\n    slug = heading.strip().lower()\n    slug = re.sub(r\"[`*_~]\", \"\", slug)\n    slug = re.sub(r\"[^\\w\\s-]\", \"\", slug)\n    slug = re.sub(r\"\\s+\", \"-\", slug)\n    slug = re.sub(r\"-{2,}\", \"-\", slug)\n    return slug.strip(\"-\")\n\n\ndef _extract_markdown_anchors(path: Path) -> set[str]:\n    anchors: set[str] = set()\n    counts: dict[str, int] = {}\n    for line in path.read_text(encoding=\"utf-8\").splitlines():\n        match = re.match(r\"^\\s{0,3}#{1,6}\\s+(.*)\\s*$\", line)\n        if not match:\n            continue\n        base = _slugify_heading(match.group(1))\n        if not base:\n            continue\n        idx = counts.get(base, 0)\n        counts[base] = idx + 1\n        anchor = base if idx == 0 else f\"{base}-{idx}\"\n        anchors.add(anchor)\n    return anchors\n\n\ndef parse_anchor_references(doc_text: str) -> list[tuple[str, str]]:\n    refs: list[tuple[str, str]] = []\n    for target in _iter_link_targets(doc_text):\n        cleaned = target.strip()\n        if cleaned.startswith(SCHEME_PREFIXES) or cleaned.startswith(\"#\"):\n            continue\n        if \"#\" not in cleaned:\n            continue\n        path_part, anchor = cleaned.split(\"#\", 1)\n        if not anchor:\n            continue\n        normalized = _normalize_path_candidate(path_part)\n        if normalized is None:\n            continue\n        refs.append((normalized, anchor.strip()))\n    return refs\n\n\ndef verify_docs_refs(\n    *,\n    repo_root: Path,\n    docs: list[Path],\n) -> tuple[dict[str, list[str]], dict[str, list[str]]]:\n    missing: dict[str, list[str]] = {}\n    missing_anchors: dict[str, list[str]] = {}\n    anchor_cache: dict[Path, set[str]] = {}\n    for doc in docs:\n        text = doc.read_text(encoding=\"utf-8\")\n        refs = sorted(parse_referenced_paths(text))\n        missing_for_doc = [ref for ref in refs if not (repo_root / ref).exists()]\n        if missing_for_doc:\n            missing[str(doc)] = missing_for_doc\n        bad_anchors: list[str] = []\n        for ref_path, anchor in parse_anchor_references(text):\n            candidate = repo_root / ref_path\n            if not candidate.exists() or candidate.suffix.lower() != \".md\":\n                continue\n            anchors = anchor_cache.get(candidate)\n            if anchors is None:\n                anchors = _extract_markdown_anchors(candidate)\n                anchor_cache[candidate] = anchors\n            if anchor not in anchors:\n                bad_anchors.append(f\"{ref_path}#{anchor}\")\n        if bad_anchors:\n            missing_anchors[str(doc)] = sorted(set(bad_anchors))\n    return missing, missing_anchors\n\n\ndef main() -> int:\n    parser = argparse.ArgumentParser(\n        description=\"Verify that code/documentation references in docs resolve to existing files.\"\n    )\n    parser.add_argument(\n        \"--docs\",\n        nargs=\"+\",\n        default=list(DEFAULT_DOCS),\n        help=\"Docs to scan (repo-relative paths).\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=Path,\n        default=None,\n        help=\"Optional JSON report output path.\",\n    )\n    args = parser.parse_args()\n\n    repo_root = Path(__file__).resolve().parents[2]\n    docs = [repo_root / Path(p) for p in args.docs]\n    missing, missing_anchors = verify_docs_refs(repo_root=repo_root, docs=docs)\n    payload = {\n        \"ok\": len(missing) == 0 and len(missing_anchors) == 0,\n        \"docs_checked\": [str(d.relative_to(repo_root)) for d in docs],\n        \"missing\": missing,\n        \"missing_anchors\": missing_anchors,\n    }\n\n    if args.output is not None:\n        args.output.parent.mkdir(parents=True, exist_ok=True)\n        args.output.write_text(json.dumps(payload, indent=2), encoding=\"utf-8\")\n\n    if missing or missing_anchors:\n        print(json.dumps(payload, indent=2))\n        return 1\n\n    print(\n        json.dumps(\n            {\n                \"ok\": True,\n                \"docs_checked\": payload[\"docs_checked\"],\n                \"message\": \"all referenced repo paths and markdown anchors exist\",\n            },\n            indent=2,\n        )\n    )\n    return 0\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "scripts/checks/verify_update_cadence.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport argparse\nimport json\nfrom pathlib import Path\nfrom typing import Any\n\n\ndef _expected_counts(\n    *,\n    total_tokens: int,\n    update_period: int,\n    flush_partial: bool,\n) -> dict[str, float]:\n    if update_period <= 0:\n        raise ValueError(\"update_period must be > 0\")\n    if total_tokens < 0:\n        raise ValueError(\"total_tokens must be >= 0\")\n    full_updates = total_tokens // update_period\n    remainder = total_tokens % update_period\n    updates = full_updates + (1 if flush_partial and remainder > 0 else 0)\n    chunk_tokens = float(total_tokens - (0 if flush_partial else remainder))\n    return {\n        \"updates_applied\": float(updates),\n        \"chunk_tokens\": chunk_tokens,\n        \"tokens_flushed\": float(remainder if flush_partial else 0),\n        \"pending_tokens\": float(0 if flush_partial else remainder),\n        \"remainder_tokens\": float(remainder),\n    }\n\n\ndef _load_records(path: Path) -> list[dict[str, Any]]:\n    records = json.loads(path.read_text())\n    if not isinstance(records, list):\n        raise ValueError(\"JSON log must be a list of records\")\n    return [rec for rec in records if isinstance(rec, dict)]\n\n\ndef _find_last_with_prefix(records: list[dict[str, Any]], prefix: str) -> dict[str, Any]:\n    suffixes = (\"updates_applied\", \"chunk_tokens\", \"tokens_flushed\", \"pending_tokens\", \"gate_hits\")\n    for record in reversed(records):\n        for suffix in suffixes:\n            if f\"{prefix}.{suffix}\" in record:\n                return record\n    raise ValueError(f\"No record found for metric prefix {prefix!r}\")\n\n\ndef verify_cadence(\n    *,\n    log_path: Path,\n    metric_prefix: str,\n    total_tokens: int,\n    update_period: int,\n    flush_partial: bool,\n    atol: float = 1e-6,\n) -> dict[str, Any]:\n    records = _load_records(log_path)\n    record = _find_last_with_prefix(records, metric_prefix)\n    expected = _expected_counts(\n        total_tokens=total_tokens,\n        update_period=update_period,\n        flush_partial=flush_partial,\n    )\n    observed = {\n        \"updates_applied\": float(record.get(f\"{metric_prefix}.updates_applied\", 0.0)),\n        \"chunk_tokens\": float(record.get(f\"{metric_prefix}.chunk_tokens\", 0.0)),\n        \"tokens_flushed\": float(record.get(f\"{metric_prefix}.tokens_flushed\", 0.0)),\n        \"pending_tokens\": float(record.get(f\"{metric_prefix}.pending_tokens\", 0.0)),\n    }\n    checks = {\n        key: abs(observed[key] - expected[key]) <= atol\n        for key in (\"updates_applied\", \"chunk_tokens\", \"tokens_flushed\", \"pending_tokens\")\n    }\n    ok = all(checks.values())\n    return {\n        \"ok\": ok,\n        \"metric_prefix\": metric_prefix,\n        \"log_path\": str(log_path),\n        \"flush_partial\": flush_partial,\n        \"total_tokens\": total_tokens,\n        \"update_period\": update_period,\n        \"expected\": expected,\n        \"observed\": observed,\n        \"checks\": checks,\n    }\n\n\ndef _build_parser() -> argparse.ArgumentParser:\n    parser = argparse.ArgumentParser(description=\"Verify CMS update cadence from JSON logs.\")\n    parser.add_argument(\"--log-path\", required=True, type=Path, help=\"Path to JSON metrics log.\")\n    parser.add_argument(\n        \"--metric-prefix\",\n        required=True,\n        help=\"Metric prefix, e.g. layer0.cms.cms_fast\",\n    )\n    parser.add_argument(\"--total-tokens\", required=True, type=int)\n    parser.add_argument(\"--update-period\", required=True, type=int)\n    parser.add_argument(\n        \"--flush-partial\",\n        action=\"store_true\",\n        help=\"Use ceil(T/C) expectation and zero pending tokens.\",\n    )\n    parser.add_argument(\n        \"--atol\",\n        type=float,\n        default=1e-6,\n        help=\"Absolute tolerance used for float comparisons.\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=Path,\n        default=None,\n        help=\"Optional output path for JSON report.\",\n    )\n    return parser\n\n\ndef main() -> int:\n    parser = _build_parser()\n    args = parser.parse_args()\n    report = verify_cadence(\n        log_path=args.log_path,\n        metric_prefix=args.metric_prefix,\n        total_tokens=args.total_tokens,\n        update_period=args.update_period,\n        flush_partial=args.flush_partial,\n        atol=args.atol,\n    )\n    if args.output is not None:\n        args.output.parent.mkdir(parents=True, exist_ok=True)\n        args.output.write_text(json.dumps(report, indent=2))\n    print(json.dumps(report, indent=2))\n    return 0 if report[\"ok\"] else 1\n\n\nif __name__ == \"__main__\":\n    raise SystemExit(main())\n"
  },
  {
    "path": "scripts/compute/create_reservations.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\n# Example Slurm reservations for Stage 2 (edit dates/times as needed).\n\nPARTITION=\"gpu-a6000\"\nACCOUNT=\"${ACCOUNT:-research}\"\n\nfunction reserve() {\n  local name=\"$1\"\n  local start=\"$2\"\n  local duration=\"$3\"\n  local nodes=\"$4\"\n  scontrol create reservation=\"Name=${name},StartTime=${start},Duration=${duration},Nodes=${nodes},PartitionName=${PARTITION},Users=${USER},Accounts=${ACCOUNT}\"\n}\n\n# Pilot run (1 node, 2 GPUs)\nreserve \"NL_Pilot\" \"2025-02-10T08:00:00\" \"3-00:00:00\" 1\n\n# Ablations (1 node)\nreserve \"NL_Ablations\" \"2025-02-13T08:00:00\" \"2-00:00:00\" 1\n\n# Mid-scale (2 nodes)\nreserve \"NL_Mid\" \"2025-02-17T08:00:00\" \"10-00:00:00\" 2\n\n# Mid evals (1 node)\nreserve \"NL_MidEval\" \"2025-02-27T08:00:00\" \"2-00:00:00\" 1\n\n# Target warmup (2 nodes)\nreserve \"NL_TargetWarmup\" \"2025-03-03T08:00:00\" \"3-00:00:00\" 2\n\n# Target full run (2 nodes)\nreserve \"NL_TargetFull\" \"2025-03-06T08:00:00\" \"14-00:00:00\" 2\n\n# Final evals (1 node)\nreserve \"NL_FinalEval\" \"2025-03-20T08:00:00\" \"3-00:00:00\" 1\n\necho \"Submitted reservations for Stage 2 (check with scontrol show reservation).\"\n"
  },
  {
    "path": "scripts/data/__init__.py",
    "content": "\"\"\"Data preparation scripts (tokenizer/filtering/sharding).\"\"\"\n\n"
  },
  {
    "path": "scripts/data/check_tokenizer.py",
    "content": "#!/usr/bin/env python3\n\"\"\"Utility to record and verify tokenizer artifact checksums.\"\"\"\n\nimport argparse\nimport hashlib\nimport json\nfrom pathlib import Path\nfrom typing import Optional\n\n\ndef compute_sha256(path: Path) -> str:\n    hasher = hashlib.sha256()\n    with path.open(\"rb\") as f:\n        for chunk in iter(lambda: f.read(1024 * 1024), b\"\"):\n            hasher.update(chunk)\n    return hasher.hexdigest()\n\n\ndef dump_metadata(path: Path, sha256: str, output: Optional[Path]) -> None:\n    if not output:\n        return\n    payload = {\n        \"tokenizer_path\": str(path),\n        \"sha256\": sha256,\n    }\n    output.write_text(json.dumps(payload, indent=2) + \"\\n\")\n\n\ndef parse_args() -> argparse.Namespace:\n    parser = argparse.ArgumentParser(description=__doc__)\n    parser.add_argument(\n        \"--tokenizer-path\",\n        type=Path,\n        required=True,\n        help=\"Path to the SentencePiece tokenizer model (.model).\",\n    )\n    parser.add_argument(\n        \"--expected-sha256\",\n        type=str,\n        default=None,\n        help=(\n            \"Optional expected checksum; if provided and mismatch occurs, \"\n            \"exits with non-zero status.\"\n        ),\n    )\n    parser.add_argument(\n        \"--metadata-json\",\n        type=Path,\n        default=None,\n        help=\"Optional path to write checksum metadata as JSON.\",\n    )\n    parser.add_argument(\n        \"--quiet\",\n        action=\"store_true\",\n        help=\"Only emit errors; suppress the default stdout line.\",\n    )\n    return parser.parse_args()\n\n\ndef main() -> None:\n    args = parse_args()\n    tokenizer_path = args.tokenizer_path\n    if not tokenizer_path.exists():\n        raise SystemExit(f\"Tokenizer file not found: {tokenizer_path}\")\n\n    sha256 = compute_sha256(tokenizer_path)\n    if not args.quiet:\n        print(f\"{sha256}  {tokenizer_path}\")\n\n    dump_metadata(tokenizer_path, sha256, args.metadata_json)\n\n    expected = args.expected_sha256\n    if expected and expected.lower() != sha256.lower():\n        raise SystemExit(\n            f\"Checksum mismatch for {tokenizer_path} (expected {expected}, got {sha256})\"\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/data/check_tokenizer_coverage.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\nfrom nested_learning.tokenizer_coverage import compute_tokenizer_coverage_stats\n\napp = typer.Typer(add_completion=False, help=\"Compute tokenizer coverage stats on a text sample.\")\n\n\n@app.command()\ndef main(\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece model path.\"),\n    sample_file: Path = typer.Option(..., help=\"Text file with representative lines.\"),\n    max_lines: int = typer.Option(10000, help=\"Maximum lines to process.\"),\n    output: Optional[Path] = typer.Option(None, help=\"Optional JSON output path.\"),\n) -> None:\n    try:\n        result = compute_tokenizer_coverage_stats(tokenizer_path, sample_file, max_lines=max_lines)\n    except ValueError as exc:\n        raise typer.BadParameter(str(exc)) from exc\n    payload = json.dumps(result, indent=2)\n    typer.echo(payload)\n    if output:\n        output.parent.mkdir(parents=True, exist_ok=True)\n        output.write_text(payload)\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/data/filter_corpus.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport os\nfrom collections import deque\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\nfrom datasets import load_dataset\nfrom langdetect import DetectorFactory, LangDetectException, detect_langs\nfrom tqdm import tqdm\n\nDetectorFactory.seed = 0\n\napp = typer.Typer(\n    add_completion=False, help=\"Filter datasets by language/length and deduplicate lines.\"\n)\n\n\ndef _select_fallback_split(available: list[str]) -> str:\n    for candidate in (\"train\", \"validation\", \"test\"):\n        if candidate in available:\n            return candidate\n    return available[0]\n\n\ndef normalize_text(text: str) -> str:\n    return \" \".join(text.strip().split())\n\n\ndef is_target_language(text: str, target_lang: str, threshold: float) -> bool:\n    try:\n        langs = detect_langs(text)\n    except LangDetectException:\n        return False\n    return any(lang.lang == target_lang and lang.prob >= threshold for lang in langs)\n\n\n@app.command()\ndef main(\n    dataset: str = typer.Option(..., help=\"HF dataset name, e.g. HuggingFaceFW/fineweb\"),\n    subset: Optional[str] = typer.Option(None, help=\"Optional dataset subset/config name.\"),\n    split: str = typer.Option(\"train\", help=\"Dataset split.\"),\n    text_column: str = typer.Option(\"text\", help=\"Column containing text.\"),\n    target_lang: str = typer.Option(\"en\", help=\"Language code to keep.\"),\n    lang_threshold: float = typer.Option(0.80, help=\"Minimum probability for language detection.\"),\n    min_chars: int = typer.Option(200, help=\"Minimum character count.\"),\n    max_chars: int = typer.Option(10000, help=\"Maximum character count.\"),\n    output_path: Path = typer.Option(\n        Path(\"data/filtered/output.jsonl\"), help=\"Destination JSONL file.\"\n    ),\n    dedup_window: int = typer.Option(\n        50000, help=\"Number of recent hashes to retain for deduplication.\"\n    ),\n    limit: Optional[int] = typer.Option(None, help=\"Optional limit on records processed.\"),\n    streaming: bool = typer.Option(True, help=\"Use HF streaming mode.\"),\n    data_files: Optional[str] = typer.Option(\n        None, help=\"Optional data_files argument (e.g., local text file).\"\n    ),\n    force_exit: bool = typer.Option(\n        False, help=\"Force os._exit(0) to avoid async finalization issues.\"\n    ),\n) -> None:\n    output_path.parent.mkdir(parents=True, exist_ok=True)\n    load_kwargs = {}\n    if data_files is not None:\n        # Ensure the requested split exists for local files (HF `text` dataset defaults can be odd).\n        load_kwargs[\"data_files\"] = {split: data_files}\n    try:\n        dataset_obj = load_dataset(\n            dataset, subset, split=split, streaming=streaming, **load_kwargs\n        )\n    except ValueError as err:\n        msg = str(err)\n        if \"Bad split\" not in msg:\n            raise\n        ds_dict = load_dataset(dataset, subset, streaming=streaming, **load_kwargs)\n        if not hasattr(ds_dict, \"keys\"):\n            raise\n        available = list(ds_dict.keys())\n        if not available:\n            raise\n        fallback = _select_fallback_split(available)\n        typer.echo(\n            f\"[Filter] Requested split '{split}' unavailable; \"\n            f\"using '{fallback}' (available={available})\"\n        )\n        dataset_obj = ds_dict[fallback]\n    iterator = dataset_obj if streaming else iter(dataset_obj)\n    seen_hashes = set()\n    hash_queue = deque()\n    kept = 0\n    total = 0\n    with output_path.open(\"w\", encoding=\"utf-8\") as writer:\n        for row in tqdm(iterator, desc=\"Filtering dataset\"):\n            total += 1\n            text = row.get(text_column)\n            if not isinstance(text, str):\n                continue\n            normalized = normalize_text(text)\n            if len(normalized) < min_chars or len(normalized) > max_chars:\n                continue\n            if not is_target_language(normalized, target_lang, lang_threshold):\n                continue\n            hashed = hash(normalized)\n            if hashed in seen_hashes:\n                continue\n            writer.write(normalized + \"\\n\")\n            kept += 1\n            seen_hashes.add(hashed)\n            hash_queue.append(hashed)\n            if len(hash_queue) > dedup_window:\n                old_hash = hash_queue.popleft()\n                seen_hashes.discard(old_hash)\n            if limit and kept >= limit:\n                break\n    typer.echo(f\"[Filter] Processed={total} kept={kept} -> {output_path}\")\n    if force_exit:\n        os._exit(0)\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/data/process_mixture.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import Any, Dict, List\n\nimport typer\nimport yaml\nfrom shard_corpus import ShardConfig, shard_dataset\n\napp = typer.Typer(\n    add_completion=False, help=\"Process a dataset manifest to shard multiple corpora.\"\n)\n\n\n@app.command()\ndef main(\n    manifest: Path = typer.Argument(..., help=\"YAML manifest describing datasets.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece model to tokenize with.\"),\n    log_file: Path = typer.Option(\n        Path(\"data/mixtures/mixture_stats.json\"), help=\"Output stats JSON.\"\n    ),\n) -> None:\n    data = yaml.safe_load(manifest.read_text())\n    datasets = data.get(\"datasets\", data)\n    stats: List[Dict[str, Any]] = []\n    for entry in datasets:\n        name = entry[\"name\"]\n        config = ShardConfig(\n            name=name,\n            dataset=entry[\"dataset\"],\n            split=entry.get(\"split\", \"train\"),\n            subset=entry.get(\"subset\"),\n            text_column=entry.get(\"text_column\", \"text\"),\n            tokenizer_path=tokenizer_path,\n            seq_len=entry.get(\"seq_len\", 2048),\n            sequences_per_shard=entry.get(\"sequences_per_shard\", 1024),\n            output_dir=Path(entry.get(\"output_dir\", f\"data/shards/{name}\")),\n            eos_id=entry.get(\"eos_id\", -1),\n            max_records=entry.get(\"max_records\"),\n            data_files=entry.get(\"data_files\"),\n        )\n        stats.append(shard_dataset(config))\n    log_file.parent.mkdir(parents=True, exist_ok=True)\n    log_file.write_text(json.dumps({\"manifest\": str(manifest), \"stats\": stats}, indent=2))\n    typer.echo(f\"[Mixture] Logged stats for {len(stats)} datasets -> {log_file}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/data/run_full.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nusage() {\n  cat <<'EOF'\nUsage: scripts/data/run_full.sh\n\nRuns the full data pipeline:\n  1) filter configured datasets\n  2) train tokenizer (unless present)\n  3) shard filtered corpora\n\nControls are provided via environment variables (see script header defaults).\nEOF\n}\n\nif [[ \"${1:-}\" == \"--help\" || \"${1:-}\" == \"-h\" ]]; then\n  usage\n  exit 0\nfi\n\nif [[ $# -gt 0 ]]; then\n  usage\n  exit 2\nfi\n\n# General controls\nTOKENIZER_MANIFEST=${TOKENIZER_MANIFEST:-configs/data/refinedweb_mixture.yaml}\nTOKENIZER_OUTPUT_DIR=${TOKENIZER_OUTPUT_DIR:-artifacts/tokenizer/refinedweb_mix}\nTOKENIZER_MODEL=${TOKENIZER_MODEL:-${TOKENIZER_OUTPUT_DIR}/spm_32000_unigram.model}\nVOCAB_SIZE=${VOCAB_SIZE:-32000}\nTOKENIZER_LOG=${TOKENIZER_LOG:-data/mixtures/refinedweb_mix_tokenizer_full.json}\nMIXTURE_CONFIG=${MIXTURE_CONFIG:-configs/data/refinedweb_mixture_full.yaml}\nSHARD_LOG=${SHARD_LOG:-data/mixtures/refinedweb_mix_full_shards.json}\nFORCE_FILTER=${FORCE_FILTER:-0}\nRETRAIN_TOKENIZER=${RETRAIN_TOKENIZER:-0}\nFALLBACK_SPLIT=${FALLBACK_SPLIT:-test}\n\nmkdir -p data/filtered data/shards artifacts/tokenizer data/mixtures\n\nfilter_dataset() {\n  local name=$1\n  local dataset=$2\n  local subset=$3\n  local split=$4\n  local text_column=$5\n  local limit=$6\n  local output=$7\n  local target_lang=${8:-en}\n  local lang_threshold=${9:-0.85}\n  local min_chars=${10:-200}\n  local max_chars=${11:-12000}\n\n  if [[ \"${FORCE_FILTER}\" != \"1\" && -f \"${output}\" ]]; then\n    echo \"[Data][${name}] Found existing ${output}, skipping filter step (set FORCE_FILTER=1 to rebuild)\"\n    return\n  fi\n\n  echo \"[Data][${name}] Filtering ${dataset}${subset:+/${subset}} -> ${output}\"\n  run_filter() {\n    local split_value=$1\n    cmd=(uv run python scripts/data/filter_corpus.py\n      --dataset \"${dataset}\"\n      --split \"${split_value}\"\n      --text-column \"${text_column}\"\n      --target-lang \"${target_lang}\"\n      --lang-threshold \"${lang_threshold}\"\n      --min-chars \"${min_chars}\"\n      --max-chars \"${max_chars}\"\n      --output-path \"${output}\"\n      --force-exit)\n    if [[ -n \"${subset}\" ]]; then\n      cmd+=(--subset \"${subset}\")\n    fi\n    if [[ -n \"${limit}\" ]]; then\n      cmd+=(--limit \"${limit}\")\n    fi\n    \"${cmd[@]}\"\n  }\n\n  if ! run_filter \"${split}\"; then\n    if [[ -n \"${FALLBACK_SPLIT}\" && \"${FALLBACK_SPLIT}\" != \"${split}\" ]]; then\n      echo \"[Data][${name}] Primary split '${split}' failed; retrying with fallback '${FALLBACK_SPLIT}'\"\n      run_filter \"${FALLBACK_SPLIT}\"\n    else\n      exit 1\n    fi\n  fi\n}\n\necho \"[Data] === Stage 1: Filtering corpora ===\"\nfilter_dataset \"refinedweb\" \\\n  \"${RW_DATASET:-HuggingFaceFW/fineweb}\" \\\n  \"${RW_SUBSET:-sample-10BT}\" \\\n  \"${RW_SPLIT:-train}\" \\\n  \"${RW_TEXT_COLUMN:-text}\" \\\n  \"${RW_LIMIT:-100000}\" \\\n  \"${RW_OUTPUT:-data/filtered/refinedweb_en_full.txt}\" \\\n  \"${RW_LANG:-en}\" \\\n  \"${RW_LANG_THRESHOLD:-0.85}\" \\\n  \"${RW_MIN_CHARS:-200}\" \\\n  \"${RW_MAX_CHARS:-8000}\"\n\nfilter_dataset \"wikipedia\" \\\n  \"${WIKI_DATASET:-wikimedia/wikipedia}\" \\\n  \"${WIKI_SUBSET:-20231101.en}\" \\\n  \"${WIKI_SPLIT:-train}\" \\\n  \"${WIKI_TEXT_COLUMN:-text}\" \\\n  \"${WIKI_LIMIT:-50000}\" \\\n  \"${WIKI_OUTPUT:-data/filtered/wikipedia_en_full.txt}\" \\\n  \"${WIKI_LANG:-en}\" \\\n  \"${WIKI_LANG_THRESHOLD:-0.85}\" \\\n  \"${WIKI_MIN_CHARS:-200}\" \\\n  \"${WIKI_MAX_CHARS:-8000}\"\n\nfilter_dataset \"c4\" \\\n  \"${C4_DATASET:-allenai/c4}\" \\\n  \"${C4_SUBSET:-en}\" \\\n  \"${C4_SPLIT:-train}\" \\\n  \"${C4_TEXT_COLUMN:-text}\" \\\n  \"${C4_LIMIT:-50000}\" \\\n  \"${C4_OUTPUT:-data/filtered/c4_en_full.txt}\" \\\n  \"${C4_LANG:-en}\" \\\n  \"${C4_LANG_THRESHOLD:-0.85}\" \\\n  \"${C4_MIN_CHARS:-200}\" \\\n  \"${C4_MAX_CHARS:-8000}\"\n\nfilter_dataset \"redpajama\" \\\n  \"${RPJ_DATASET:-cerebras/SlimPajama-627B}\" \\\n  \"${RPJ_SUBSET:-}\" \\\n  \"${RPJ_SPLIT:-train}\" \\\n  \"${RPJ_TEXT_COLUMN:-text}\" \\\n  \"${RPJ_LIMIT:-50000}\" \\\n  \"${RPJ_OUTPUT:-data/filtered/redpajama_en_full.txt}\" \\\n  \"${RPJ_LANG:-en}\" \\\n  \"${RPJ_LANG_THRESHOLD:-0.85}\" \\\n  \"${RPJ_MIN_CHARS:-200}\" \\\n  \"${RPJ_MAX_CHARS:-8000}\"\n\nfilter_dataset \"code\" \\\n  \"${CODE_DATASET:-codeparrot/codeparrot-clean-train}\" \\\n  \"${CODE_SUBSET:-}\" \\\n  \"${CODE_SPLIT:-train}\" \\\n  \"${CODE_TEXT_COLUMN:-content}\" \\\n  \"${CODE_LIMIT:-50000}\" \\\n  \"${CODE_OUTPUT:-data/filtered/code_en_full.txt}\" \\\n  \"${CODE_LANG:-en}\" \\\n  \"${CODE_LANG_THRESHOLD:-0.50}\" \\\n  \"${CODE_MIN_CHARS:-200}\" \\\n  \"${CODE_MAX_CHARS:-16000}\"\n\necho \"[Data] === Stage 2: Tokenizer training ===\"\nif [[ ! -f \"${TOKENIZER_MODEL}\" || \"${RETRAIN_TOKENIZER}\" == \"1\" ]]; then\n  uv run python scripts/data/train_tokenizer.py \\\n    --manifest \"${TOKENIZER_MANIFEST}\" \\\n    --vocab-size \"${VOCAB_SIZE}\" \\\n    --output-dir \"${TOKENIZER_OUTPUT_DIR}\" \\\n    --log-file \"${TOKENIZER_LOG}\"\nelse\n  echo \"[Data] Tokenizer already exists at ${TOKENIZER_MODEL}; set RETRAIN_TOKENIZER=1 to rebuild.\"\nfi\n\necho \"[Data] === Stage 3: Sharding filtered corpora ===\"\nuv run python scripts/data/process_mixture.py \\\n  \"${MIXTURE_CONFIG}\" \\\n  --tokenizer-path \"${TOKENIZER_MODEL}\" \\\n  --log-file \"${SHARD_LOG}\"\n\necho \"[Data] Full pipeline complete.\"\n"
  },
  {
    "path": "scripts/data/run_sample.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nusage() {\n  cat <<'EOF'\nUsage: scripts/data/run_sample.sh [TOKENIZER_MODEL_PATH]\n\nBuilds a small filtered corpus sample, trains a tokenizer if missing, and shards it.\n\nArgs:\n  TOKENIZER_MODEL_PATH  Optional tokenizer model path.\n                        Default: artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model\nEOF\n}\n\nif [[ \"${1:-}\" == \"--help\" || \"${1:-}\" == \"-h\" ]]; then\n  usage\n  exit 0\nfi\n\nif [[ $# -gt 1 ]]; then\n  usage\n  exit 2\nfi\n\nTOKENIZER_MODEL=${1:-artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model}\nTOKENIZER_DIR=\"$(dirname -- \"${TOKENIZER_MODEL}\")\"\n\nif [[ ! -f \"data/filtered/refinedweb_en_sample.txt\" ]]; then\n  echo \"[Data] Creating filtered RefinedWeb sample\"\n  uv run python scripts/data/filter_corpus.py \\\n    --dataset HuggingFaceFW/fineweb \\\n    \"--subset=sample-10BT\" \\\n    --split train \\\n    --text-column text \\\n    --target-lang en \\\n    --lang-threshold 0.85 \\\n    --min-chars 200 \\\n    --max-chars 8000 \\\n    --limit 2000 \\\n    --output-path data/filtered/refinedweb_en_sample.txt \\\n    --force-exit\nfi\n\nif [[ ! -f \"data/filtered/wikipedia_en_sample.txt\" ]]; then\n  echo \"[Data] Creating filtered Wikipedia sample\"\n  uv run python scripts/data/filter_corpus.py \\\n    --dataset wikimedia/wikipedia \\\n    \"--subset=20231101.en\" \\\n    --split train \\\n    --text-column text \\\n    --target-lang en \\\n    --lang-threshold 0.85 \\\n    --min-chars 200 \\\n    --max-chars 8000 \\\n    --limit 1000 \\\n    --output-path data/filtered/wikipedia_en_sample.txt \\\n    --force-exit\nfi\n\nif [[ ! -f \"data/filtered/c4_en_sample.txt\" ]]; then\n  echo \"[Data] Creating filtered C4 sample\"\n  uv run python scripts/data/filter_corpus.py \\\n    --dataset allenai/c4 --subset en --split train \\\n    --text-column text --target-lang en --lang-threshold 0.85 \\\n    --min-chars 200 --max-chars 8000 --limit 1000 \\\n    --output-path data/filtered/c4_en_sample.txt --force-exit\nfi\n\nif [[ ! -f \"data/filtered/redpajama_en_sample.txt\" ]]; then\n  echo \"[Data] Creating filtered SlimPajama sample\"\n  uv run python scripts/data/filter_corpus.py \\\n    \"--dataset=cerebras/SlimPajama-627B\" \\\n    --split train \\\n    --text-column text \\\n    --target-lang en \\\n    --lang-threshold 0.85 \\\n    --min-chars 200 \\\n    --max-chars 8000 \\\n    --limit 1000 \\\n    --output-path data/filtered/redpajama_en_sample.txt \\\n    --force-exit\nfi\n\nif [[ ! -f \"data/filtered/code_en_sample.txt\" ]]; then\n  echo \"[Data] Creating filtered code sample\"\n  uv run python scripts/data/filter_corpus.py \\\n    --dataset codeparrot/codeparrot-clean-train --split train \\\n    --text-column content --target-lang en --lang-threshold 0.5 \\\n    --min-chars 200 --max-chars 12000 --limit 1000 \\\n    --output-path data/filtered/code_en_sample.txt --force-exit\nfi\n\nif [[ ! -f \"${TOKENIZER_MODEL}\" ]]; then\n  echo \"[Data] Training tokenizer (sample) -> ${TOKENIZER_DIR}\"\n  uv run python scripts/data/train_tokenizer.py \\\n    --manifest configs/data/refinedweb_mixture_filtered.yaml \\\n    --vocab-size 32000 \\\n    --no-hard-vocab-limit \\\n    --output-dir \"${TOKENIZER_DIR}\" \\\n    --log-file data/mixtures/refinedweb_mix_tokenizer_sample.json\nfi\n\necho \"[Data] Sharding filtered samples\"\nuv run python scripts/data/process_mixture.py \\\n  configs/data/refinedweb_mixture_filtered.yaml \\\n  --tokenizer-path ${TOKENIZER_MODEL} \\\n  --log-file data/mixtures/refinedweb_mix_filtered_shards.json\n\necho \"[Data] Sample pipeline complete\"\n"
  },
  {
    "path": "scripts/data/shard_corpus.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport numpy as np\nimport sentencepiece as spm\nimport typer\nfrom datasets import load_dataset\nfrom tqdm import tqdm\n\napp = typer.Typer(add_completion=False, help=\"Shard datasets into tokenized numpy binaries.\")\n\n\n@dataclass\nclass ShardConfig:\n    name: str\n    dataset: str\n    split: str = \"train\"\n    subset: str | None = None\n    text_column: str = \"text\"\n    tokenizer_path: Path = Path()\n    seq_len: int = 2048\n    sequences_per_shard: int = 1024\n    output_dir: Path = Path(\"data/shards\")\n    eos_id: int = -1\n    max_records: Optional[int] = None\n    data_files: Optional[str] = None\n\n\ndef _select_fallback_split(available: list[str]) -> str:\n    for candidate in (\"train\", \"validation\", \"test\"):\n        if candidate in available:\n            return candidate\n    return available[0]\n\n\ndef shard_dataset(config: ShardConfig) -> dict:\n    config.output_dir.mkdir(parents=True, exist_ok=True)\n    processor = spm.SentencePieceProcessor(model_file=str(config.tokenizer_path))\n    eos = config.eos_id if config.eos_id >= 0 else processor.eos_id()\n    load_kwargs = {}\n    if config.data_files is not None:\n        load_kwargs[\"data_files\"] = {config.split: config.data_files}\n    try:\n        ds = load_dataset(\n            config.dataset, config.subset, split=config.split, streaming=True, **load_kwargs\n        )\n    except ValueError as err:\n        msg = str(err)\n        if \"Bad split\" not in msg:\n            raise\n        ds_dict = load_dataset(config.dataset, config.subset, streaming=True, **load_kwargs)\n        available = list(ds_dict.keys())\n        if not available:\n            raise\n        fallback = _select_fallback_split(available)\n        typer.echo(\n            f\"[Shard] Requested split '{config.split}' unavailable; \"\n            f\"using '{fallback}' (available={available})\"\n        )\n        ds = ds_dict[fallback]\n\n    buffer: List[int] = []\n    sequences: List[List[int]] = []\n    shard_idx = 0\n    records = 0\n    sequences_total = 0\n    tokens_total = 0\n\n    for row in tqdm(ds, desc=f\"Sharding {config.name}\", unit=\"record\"):\n        text = row.get(config.text_column)\n        if not isinstance(text, str):\n            continue\n        tokens = processor.encode(text)\n        tokens.append(eos)\n        tokens_total += len(tokens)\n        buffer.extend(tokens)\n        records += 1\n        while len(buffer) >= config.seq_len:\n            seq = buffer[: config.seq_len]\n            buffer = buffer[config.seq_len :]\n            sequences.append(seq)\n            sequences_total += 1\n            if len(sequences) >= config.sequences_per_shard:\n                _write_shard(sequences, config.output_dir, shard_idx)\n                shard_idx += 1\n                sequences = []\n        if config.max_records and records >= config.max_records:\n            break\n    if sequences:\n        _write_shard(sequences, config.output_dir, shard_idx)\n        shard_idx += 1\n\n    stats = {\n        \"name\": config.name,\n        \"dataset\": config.dataset,\n        \"subset\": config.subset,\n        \"records\": records,\n        \"sequences\": sequences_total,\n        \"tokens\": tokens_total,\n        \"shards\": shard_idx,\n        \"output_dir\": str(config.output_dir),\n    }\n    typer.echo(\n        f\"[Shard] {config.name}: records={records} sequences={sequences_total} \"\n        f\"shards={shard_idx} -> {config.output_dir}\"\n    )\n    return stats\n\n\n@app.command()\ndef main(\n    dataset: str = typer.Option(\"roneneldan/TinyStories\", help=\"HF dataset name.\"),\n    split: str = typer.Option(\"train\", help=\"Dataset split.\"),\n    subset: Optional[str] = typer.Option(None, help=\"Optional dataset subset/config.\"),\n    text_column: str = typer.Option(\"text\", help=\"Text column.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"Path to SentencePiece model.\"),\n    seq_len: int = typer.Option(2048, help=\"Sequence length (tokens per sample).\"),\n    sequences_per_shard: int = typer.Option(1024, help=\"Number of sequences per shard.\"),\n    output_dir: Path = typer.Option(Path(\"data/shards\"), help=\"Directory for shard files.\"),\n    eos_id: int = typer.Option(-1, help=\"EOS token id (defaults to tokenizer default).\"),\n    max_records: Optional[int] = typer.Option(None, help=\"Optional max records to process.\"),\n    name: Optional[str] = typer.Option(None, help=\"Friendly name for logging.\"),\n    log_file: Optional[Path] = typer.Option(\n        Path(\"data/mixtures/shard_stats.json\"), help=\"Where to save shard stats JSON.\"\n    ),\n    data_files: Optional[str] = typer.Option(None, help=\"Optional data_files argument.\"),\n) -> None:\n    config = ShardConfig(\n        name=name or dataset.split(\"/\")[-1],\n        dataset=dataset,\n        split=split,\n        subset=subset,\n        text_column=text_column,\n        tokenizer_path=tokenizer_path,\n        seq_len=seq_len,\n        sequences_per_shard=sequences_per_shard,\n        output_dir=output_dir,\n        eos_id=eos_id,\n        max_records=max_records,\n        data_files=data_files,\n    )\n    stats = shard_dataset(config)\n    if log_file is not None:\n        log_file.parent.mkdir(parents=True, exist_ok=True)\n        log_file.write_text(json.dumps(stats, indent=2))\n        typer.echo(f\"[Shard] Stats logged to {log_file}\")\n\n\ndef _write_shard(sequences: List[List[int]], output_dir: Path, shard_idx: int) -> None:\n    array = np.asarray(sequences, dtype=np.int32)\n    target = output_dir / f\"shard_{shard_idx:05d}.npy\"\n    np.save(target, array)\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/data/train_tokenizer.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nimport tempfile\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport sentencepiece as spm\nimport typer\nimport yaml\nfrom datasets import load_dataset\n\napp = typer.Typer(add_completion=False, help=\"Train a SentencePiece tokenizer from HF datasets.\")\n\n\n@dataclass\nclass DatasetSpec:\n    name: str\n    dataset: str\n    split: str = \"train\"\n    subset: str | None = None\n    text_column: str = \"text\"\n    sample_limit: int = 100_000\n    data_files: str | None = None\n\n\ndef _select_fallback_split(available: list[str]) -> str:\n    for candidate in (\"train\", \"validation\", \"test\"):\n        if candidate in available:\n            return candidate\n    return available[0]\n\n\ndef _load_specs_from_manifest(manifest: Path) -> List[DatasetSpec]:\n    data = yaml.safe_load(manifest.read_text())\n    entries = data.get(\"datasets\", data)\n    specs = []\n    for entry in entries:\n        specs.append(\n            DatasetSpec(\n                name=entry.get(\"name\") or entry[\"dataset\"].split(\"/\")[-1],\n                dataset=entry[\"dataset\"],\n                split=entry.get(\"split\", \"train\"),\n                subset=entry.get(\"subset\"),\n                text_column=entry.get(\"text_column\", \"text\"),\n                sample_limit=entry.get(\"sample_limit\", 100_000),\n                data_files=entry.get(\"data_files\"),\n            )\n        )\n    return specs\n\n\ndef _write_samples(spec: DatasetSpec, handle) -> int:\n    load_kwargs = {}\n    if spec.data_files is not None:\n        load_kwargs[\"data_files\"] = {spec.split: spec.data_files}\n    try:\n        ds = load_dataset(\n            spec.dataset, spec.subset, split=spec.split, streaming=True, **load_kwargs\n        )\n    except ValueError as err:\n        msg = str(err)\n        if \"Bad split\" not in msg:\n            raise\n        ds_dict = load_dataset(spec.dataset, spec.subset, streaming=True, **load_kwargs)\n        available = list(ds_dict.keys())\n        if not available:\n            raise\n        fallback = _select_fallback_split(available)\n        typer.echo(\n            f\"[Tokenizer] Requested split '{spec.split}' unavailable for {spec.dataset}; \"\n            f\"using '{fallback}' (available={available})\"\n        )\n        ds = ds_dict[fallback]\n    count = 0\n    for row in ds:\n        text = row.get(spec.text_column)\n        if not isinstance(text, str):\n            continue\n        handle.write(text.replace(\"\\n\", \" \") + \"\\n\")\n        count += 1\n        if spec.sample_limit > 0 and count >= spec.sample_limit:\n            break\n    return count\n\n\n@app.command()\ndef main(\n    dataset: str = typer.Option(\n        \"roneneldan/TinyStories\", help=\"HF dataset name (ignored if manifest set).\"\n    ),\n    split: str = typer.Option(\"train\", help=\"Dataset split (ignored if manifest set).\"),\n    text_column: str = typer.Option(\"text\", help=\"Text column (ignored if manifest set).\"),\n    sample_limit: int = typer.Option(\n        100_000, help=\"Sample limit per dataset (ignored if manifest set).\"\n    ),\n    vocab_size: int = typer.Option(32_000, help=\"SentencePiece vocabulary size.\"),\n    model_type: str = typer.Option(\"unigram\", help=\"SentencePiece model type.\"),\n    character_coverage: float = typer.Option(0.9995, help=\"Character coverage target.\"),\n    hard_vocab_limit: bool = typer.Option(\n        True,\n        help=(\n            \"Require the trained vocab to match vocab_size exactly. \"\n            \"Disable for tiny sample corpora where vocab_size is unattainable.\"\n        ),\n    ),\n    output_dir: Path = typer.Option(\n        Path(\"artifacts/tokenizer\"), help=\"Directory for tokenizer artifacts.\"\n    ),\n    manifest: Optional[Path] = typer.Option(\n        None, help=\"YAML manifest describing multiple datasets.\"\n    ),\n    log_file: Optional[Path] = typer.Option(\n        Path(\"data/mixtures/tokenizer_samples.json\"), help=\"Where to log dataset sample stats.\"\n    ),\n) -> None:\n    output_dir.mkdir(parents=True, exist_ok=True)\n    specs = (\n        _load_specs_from_manifest(manifest)\n        if manifest is not None\n        else [\n            DatasetSpec(\n                name=dataset.split(\"/\")[-1],\n                dataset=dataset,\n                split=split,\n                text_column=text_column,\n                sample_limit=sample_limit,\n            )\n        ]\n    )\n    model_prefix = output_dir / f\"spm_{vocab_size}_{model_type}\"\n    stats = []\n    with tempfile.NamedTemporaryFile(\"w+\", encoding=\"utf-8\", delete=False) as tmp:\n        tmp_path = Path(tmp.name)\n        typer.echo(f\"[Tokenizer] Writing samples to {tmp_path}\")\n        with tmp_path.open(\"w\", encoding=\"utf-8\") as handle:\n            for spec in specs:\n                typer.echo(\n                    f\"[Tokenizer] Streaming {spec.name} ({spec.dataset}) limit={spec.sample_limit}\"\n                )\n                count = _write_samples(spec, handle)\n                stats.append({\"name\": spec.name, \"dataset\": spec.dataset, \"samples\": count})\n    typer.echo(f\"[Tokenizer] Training SentencePiece -> {model_prefix}\")\n    total_samples = sum(s[\"samples\"] for s in stats)\n    spm.SentencePieceTrainer.train(\n        input=str(tmp_path),\n        model_prefix=str(model_prefix),\n        vocab_size=vocab_size,\n        model_type=model_type,\n        character_coverage=character_coverage,\n        hard_vocab_limit=hard_vocab_limit,\n        # SentencePiece requires input_sentence_size <= 0 or > 100.\n        input_sentence_size=(total_samples if total_samples > 100 else 0),\n        shuffle_input_sentence=True,\n        train_extremely_large_corpus=True,\n    )\n    typer.echo(f\"[Tokenizer] Saved model to {model_prefix}.model\")\n    if log_file is not None:\n        log_file.parent.mkdir(parents=True, exist_ok=True)\n        log_payload = {\"model\": str(model_prefix), \"datasets\": stats}\n        log_file.write_text(json.dumps(log_payload, indent=2))\n        typer.echo(f\"[Tokenizer] Logged sample stats to {log_file}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/data/validate_mixture.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom itertools import combinations\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\napp = typer.Typer(add_completion=False, help=\"Validate mixture manifests and shard inventories.\")\n\n\ndef _dir_stats(path: Path, sample_limit: int = 2000) -> tuple[dict[str, float], set[str]]:\n    total_bytes = 0\n    file_count = 0\n    sampled_names: set[str] = set()\n    for entry in sorted(path.rglob(\"*.npy\")):\n        total_bytes += entry.stat().st_size\n        file_count += 1\n        if len(sampled_names) < sample_limit:\n            sampled_names.add(f\"{path.name}/{entry.relative_to(path).as_posix()}\")\n    return {\"files\": file_count, \"bytes\": total_bytes}, sampled_names\n\n\n@app.command()\ndef main(\n    manifest: Path = typer.Option(..., help=\"Path to data/manifest/*.json file.\"),\n    output: Optional[Path] = typer.Option(None, help=\"Optional JSON output path for the report.\"),\n    overlap_threshold: float = typer.Option(\n        0.05, help=\"Warn when filename overlap exceeds this Jaccard.\"\n    ),\n) -> None:\n    spec = json.loads(manifest.read_text())\n    report = {\"manifest\": spec.get(\"name\"), \"sources\": []}\n    sampled_sets: dict[str, set[str]] = {}\n    for entry in spec.get(\"sources\", []):\n        shards_dir = Path(entry[\"shards_dir\"])\n        source_report = dict(entry)\n        source_report[\"exists\"] = shards_dir.exists()\n        if shards_dir.exists():\n            stats, sampled = _dir_stats(shards_dir)\n            source_report.update(stats)\n            sampled_sets[entry[\"name\"]] = sampled\n        stats_file = entry.get(\"stats_file\")\n        if stats_file and Path(stats_file).exists():\n            try:\n                stats_payload = json.loads(Path(stats_file).read_text())\n                source_report[\"stats_snapshot\"] = stats_payload.get(entry[\"name\"])\n            except json.JSONDecodeError:\n                source_report[\"stats_snapshot\"] = \"unreadable\"\n        report[\"sources\"].append(source_report)\n    overlaps = []\n    for a, b in combinations(sampled_sets.keys(), 2):\n        set_a = sampled_sets[a]\n        set_b = sampled_sets[b]\n        if not set_a or not set_b:\n            continue\n        jaccard = len(set_a & set_b) / len(set_a | set_b)\n        entry = {\"pair\": [a, b], \"jaccard\": jaccard}\n        if jaccard >= overlap_threshold:\n            entry[\"warning\"] = True\n        overlaps.append(entry)\n    report[\"filename_overlap\"] = overlaps\n    summary = json.dumps(report, indent=2)\n    typer.echo(summary)\n    if output:\n        output.parent.mkdir(parents=True, exist_ok=True)\n        output.write_text(summary)\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/__init__.py",
    "content": "# Eval utilities package marker.\n"
  },
  {
    "path": "scripts/eval/compare_variants.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nimport random\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, Dict, List\n\nimport torch\nimport typer\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.memorize import (\n    MemorizeConfig,\n    memorize_sequence,\n    memorize_tokens,\n    restore_state_dict,\n    snapshot_state_dict,\n)\nfrom nested_learning.tokenizer import SentencePieceTokenizer\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(\n    add_completion=False, help=\"Compare long-context metrics across two model variants.\"\n)\n\n\n@dataclass(frozen=True)\nclass ModelSpec:\n    name: str\n    config: Path\n    checkpoint: Path\n\n\ndef _load_model(spec: ModelSpec, device: torch.device) -> torch.nn.Module:\n    cfg = OmegaConf.load(spec.config)\n    cfg = unwrap_config(cfg)\n    model = build_model_from_cfg(cfg.model)\n    state = torch.load(spec.checkpoint, map_location=\"cpu\", weights_only=False)\n    state_dict = state[\"model\"] if \"model\" in state else state\n    missing, unexpected = model.load_state_dict(state_dict, strict=False)\n    if missing or unexpected:\n        print(\n            f\"[compare] {spec.name}: state_dict mismatch \"\n            f\"(missing={len(missing)} unexpected={len(unexpected)}) – continuing.\"\n        )\n    return model.to(device).eval()\n\n\ndef _logprob_answer(\n    model: torch.nn.Module,\n    tokenizer: SentencePieceTokenizer,\n    prompt: str,\n    answer: str,\n    device: torch.device,\n    *,\n    fast_state=None,\n) -> float:\n    prompt_ids = tokenizer.encode(prompt, add_bos=True)\n    answer_ids = tokenizer.encode(\" \" + answer, add_bos=False)\n    inputs = torch.cat([prompt_ids, answer_ids], dim=0).to(device)\n    with torch.no_grad():\n        logits = (\n            model(inputs.unsqueeze(0), fast_state=fast_state)\n            if fast_state is not None\n            else model(inputs.unsqueeze(0))\n        )\n        log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)\n        target = inputs.unsqueeze(0)[:, 1:]\n        gathered = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)\n        prompt_len = prompt_ids.numel()\n        answer_logprob = gathered[0, prompt_len - 1 :].sum().item()\n    return float(answer_logprob)\n\n\ndef _memorize_prompt_answer_only(\n    model: torch.nn.Module,\n    tokenizer: SentencePieceTokenizer,\n    prompt: str,\n    answer: str,\n    device: torch.device,\n    memorize_cfg: MemorizeConfig,\n    *,\n    fast_state=None,\n) -> Dict[str, float]:\n    \"\"\"\n    Memorize using gradients for the answer tokens only.\n\n    This avoids updating on long filler/haystack tokens when `use_correct_answer=True`,\n    which otherwise makes the comparison noisy for randomly-initialized checkpoints.\n    \"\"\"\n    prompt_ids = tokenizer.encode(prompt, add_bos=True)\n    answer_ids = tokenizer.encode(\" \" + answer, add_bos=False)\n    inputs = torch.cat([prompt_ids, answer_ids], dim=0).to(device)\n    batch = inputs.unsqueeze(0)\n    teach_mask = torch.zeros((1, batch.size(1)), device=device)\n    start = max(0, prompt_ids.numel() - 1)\n    end = min(batch.size(1), start + answer_ids.numel())\n    teach_mask[:, start:end] = 1.0\n    return memorize_tokens(\n        model,\n        batch,\n        memorize_cfg,\n        fast_state=fast_state,\n        teach_mask=teach_mask,\n    )\n\n\ndef _make_passkey_prompt(*, filler_sentences: int, key: str) -> str:\n    sentences = [f\"This is filler sentence number {idx}.\" for idx in range(filler_sentences)]\n    random.shuffle(sentences)\n    filler = \" \".join(sentences)\n    return (\n        f\"{filler}\\nRemember that the passkey for this document is {key}. \"\n        \"Later we will ask about it.\\nQuestion: What is the passkey?\\nAnswer:\"\n    )\n\n\ndef _run_passkey(\n    model: torch.nn.Module,\n    tokenizer: SentencePieceTokenizer,\n    device: torch.device,\n    *,\n    samples: int,\n    filler_sentences: int,\n    memorize_cfg: MemorizeConfig,\n) -> Dict[str, Any]:\n    base_state: Dict[str, torch.Tensor] | None = None\n    fast_state = None\n    if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset:\n        base_state = snapshot_state_dict(model)\n\n    correct_base = 0\n    correct_mem = 0\n    true_lp_base_sum = 0.0\n    false_lp_base_sum = 0.0\n    margin_base_sum = 0.0\n    true_lp_mem_sum = 0.0\n    false_lp_mem_sum = 0.0\n    margin_mem_sum = 0.0\n    path_stats: Dict[str, float] = {}\n    for _ in tqdm(range(samples), desc=\"passkey\"):\n        key = f\"PASSKEY-{random.randint(1000, 9999)}\"\n        prompt = _make_passkey_prompt(filler_sentences=filler_sentences, key=key)\n        distractor = f\"PASSKEY-{random.randint(1000, 9999)}\"\n\n        lp_true = _logprob_answer(model, tokenizer, prompt, key, device, fast_state=fast_state)\n        lp_false = _logprob_answer(\n            model, tokenizer, prompt, distractor, device, fast_state=fast_state\n        )\n        correct_base += int(lp_true > lp_false)\n        true_lp_base_sum += lp_true\n        false_lp_base_sum += lp_false\n        margin_base_sum += lp_true - lp_false\n\n        if memorize_cfg.enabled:\n            if memorize_cfg.use_fast_state:\n                if fast_state is None or memorize_cfg.reset:\n                    if not hasattr(model, \"init_fast_state\"):\n                        raise RuntimeError(\"Model does not support fast state memorization\")\n                    fast_state = model.init_fast_state()\n                if memorize_cfg.use_correct_answer:\n                    stats = _memorize_prompt_answer_only(\n                        model,\n                        tokenizer,\n                        prompt,\n                        key,\n                        device,\n                        memorize_cfg,\n                        fast_state=fast_state,\n                    )\n                else:\n                    stats = memorize_sequence(\n                        model,\n                        tokenizer,\n                        prompt,\n                        device,\n                        memorize_cfg,\n                        fast_state=fast_state,\n                    )\n                for k, v in stats.items():\n                    path_stats[k] = path_stats.get(k, 0.0) + v\n                lp_true_mem = _logprob_answer(\n                    model, tokenizer, prompt, key, device, fast_state=fast_state\n                )\n                lp_false_mem = _logprob_answer(\n                    model, tokenizer, prompt, distractor, device, fast_state=fast_state\n                )\n                correct_mem += int(lp_true_mem > lp_false_mem)\n                true_lp_mem_sum += lp_true_mem\n                false_lp_mem_sum += lp_false_mem\n                margin_mem_sum += lp_true_mem - lp_false_mem\n            else:\n                memorize_text = prompt if not memorize_cfg.use_correct_answer else f\"{prompt} {key}\"\n                stats = memorize_sequence(model, tokenizer, memorize_text, device, memorize_cfg)\n                for k, v in stats.items():\n                    path_stats[k] = path_stats.get(k, 0.0) + v\n                lp_true_mem = _logprob_answer(model, tokenizer, prompt, key, device)\n                lp_false_mem = _logprob_answer(model, tokenizer, prompt, distractor, device)\n                correct_mem += int(lp_true_mem > lp_false_mem)\n                true_lp_mem_sum += lp_true_mem\n                false_lp_mem_sum += lp_false_mem\n                margin_mem_sum += lp_true_mem - lp_false_mem\n                if memorize_cfg.reset and base_state is not None:\n                    restore_state_dict(model, base_state)\n        else:\n            correct_mem += int(lp_true > lp_false)\n            true_lp_mem_sum += lp_true\n            false_lp_mem_sum += lp_false\n            margin_mem_sum += lp_true - lp_false\n\n    denom = float(samples) if samples else 1.0\n    base_acc = correct_base / denom\n    mem_acc = correct_mem / denom\n    payload: Dict[str, Any] = {\n        \"samples\": samples,\n        \"filler_sentences\": filler_sentences,\n        \"accuracy_base\": base_acc,\n        \"accuracy_memorize\": mem_acc,\n        \"accuracy_delta\": mem_acc - base_acc,\n        \"mean_logprob_true_base\": true_lp_base_sum / denom,\n        \"mean_logprob_true_memorize\": true_lp_mem_sum / denom,\n        \"mean_logprob_true_delta\": (true_lp_mem_sum - true_lp_base_sum) / denom,\n        \"mean_logprob_false_base\": false_lp_base_sum / denom,\n        \"mean_logprob_false_memorize\": false_lp_mem_sum / denom,\n        \"mean_logprob_false_delta\": (false_lp_mem_sum - false_lp_base_sum) / denom,\n        \"mean_margin_base\": margin_base_sum / denom,\n        \"mean_margin_memorize\": margin_mem_sum / denom,\n        \"mean_margin_delta\": (margin_mem_sum - margin_base_sum) / denom,\n    }\n    if memorize_cfg.enabled:\n        payload[\"memorize_paths\"] = (\n            \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths)\n        )\n        if memorize_cfg.surprise_threshold is not None:\n            payload[\"memorize_surprise_threshold\"] = memorize_cfg.surprise_threshold\n        payload[\"memorize_use_correct_answer\"] = bool(memorize_cfg.use_correct_answer)\n        if path_stats:\n            payload[\"memorize_stats\"] = path_stats\n    return payload\n\n\ndef _make_niah_prompt(*, needle: str, filler_tokens: int) -> str:\n    filler_chunks = [\"This is filler sentence number {}.\".format(i) for i in range(filler_tokens)]\n    random.shuffle(filler_chunks)\n    haystack = \" \".join(filler_chunks)\n    prompt = (\n        f\"{haystack} Remember that the secret key is {needle}. Later you might be asked about it. \"\n    )\n    prompt += \"Now answer the question truthfully. What is the secret key? Answer:\"\n    return prompt\n\n\ndef _run_niah(\n    model: torch.nn.Module,\n    tokenizer: SentencePieceTokenizer,\n    device: torch.device,\n    *,\n    context_lengths: List[int],\n    samples_per_length: int,\n    memorize_cfg: MemorizeConfig,\n) -> Dict[str, Any]:\n    base_state: Dict[str, torch.Tensor] | None = None\n    fast_state = None\n    if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset:\n        base_state = snapshot_state_dict(model)\n\n    results: Dict[str, Any] = {}\n    path_stats: Dict[str, float] = {}\n    for length in context_lengths:\n        correct_base = 0\n        correct_mem = 0\n        true_lp_base_sum = 0.0\n        false_lp_base_sum = 0.0\n        margin_base_sum = 0.0\n        true_lp_mem_sum = 0.0\n        false_lp_mem_sum = 0.0\n        margin_mem_sum = 0.0\n        for _ in tqdm(range(samples_per_length), desc=f\"niah@{length}\"):\n            needle = f\"KEY-{random.randint(1000, 9999)}\"\n            prompt = _make_niah_prompt(needle=needle, filler_tokens=max(1, length // 128))\n            distractor = f\"KEY-{random.randint(1000, 9999)}\"\n\n            lp_true_base = _logprob_answer(\n                model, tokenizer, prompt, needle, device, fast_state=fast_state\n            )\n            lp_false_base = _logprob_answer(\n                model, tokenizer, prompt, distractor, device, fast_state=fast_state\n            )\n            correct_base += int(lp_true_base > lp_false_base)\n            true_lp_base_sum += lp_true_base\n            false_lp_base_sum += lp_false_base\n            margin_base_sum += lp_true_base - lp_false_base\n\n            if memorize_cfg.enabled:\n                if memorize_cfg.use_fast_state:\n                    if fast_state is None or memorize_cfg.reset:\n                        if not hasattr(model, \"init_fast_state\"):\n                            raise RuntimeError(\"Model does not support fast state memorization\")\n                        fast_state = model.init_fast_state()\n                    if memorize_cfg.use_correct_answer:\n                        stats = _memorize_prompt_answer_only(\n                            model,\n                            tokenizer,\n                            prompt,\n                            needle,\n                            device,\n                            memorize_cfg,\n                            fast_state=fast_state,\n                        )\n                    else:\n                        stats = memorize_sequence(\n                            model,\n                            tokenizer,\n                            prompt,\n                            device,\n                            memorize_cfg,\n                            fast_state=fast_state,\n                        )\n                    for k, v in stats.items():\n                        path_stats[k] = path_stats.get(k, 0.0) + v\n                    lp_true_mem = _logprob_answer(\n                        model, tokenizer, prompt, needle, device, fast_state=fast_state\n                    )\n                    lp_false_mem = _logprob_answer(\n                        model, tokenizer, prompt, distractor, device, fast_state=fast_state\n                    )\n                    correct_mem += int(lp_true_mem > lp_false_mem)\n                    true_lp_mem_sum += lp_true_mem\n                    false_lp_mem_sum += lp_false_mem\n                    margin_mem_sum += lp_true_mem - lp_false_mem\n                else:\n                    memorize_text = (\n                        prompt if not memorize_cfg.use_correct_answer else f\"{prompt} {needle}\"\n                    )\n                    stats = memorize_sequence(model, tokenizer, memorize_text, device, memorize_cfg)\n                    for k, v in stats.items():\n                        path_stats[k] = path_stats.get(k, 0.0) + v\n                    lp_true_mem = _logprob_answer(model, tokenizer, prompt, needle, device)\n                    lp_false_mem = _logprob_answer(model, tokenizer, prompt, distractor, device)\n                    correct_mem += int(lp_true_mem > lp_false_mem)\n                    true_lp_mem_sum += lp_true_mem\n                    false_lp_mem_sum += lp_false_mem\n                    margin_mem_sum += lp_true_mem - lp_false_mem\n                    if memorize_cfg.reset and base_state is not None:\n                        restore_state_dict(model, base_state)\n            else:\n                correct_mem += int(lp_true_base > lp_false_base)\n                true_lp_mem_sum += lp_true_base\n                false_lp_mem_sum += lp_false_base\n                margin_mem_sum += lp_true_base - lp_false_base\n\n        base_acc = correct_base / samples_per_length if samples_per_length else 0.0\n        mem_acc = correct_mem / samples_per_length if samples_per_length else 0.0\n        results[f\"niah_{length}_baseline_accuracy\"] = base_acc\n        results[f\"niah_{length}_memorize_accuracy\"] = mem_acc\n        results[f\"niah_{length}_memorize_delta\"] = mem_acc - base_acc\n        denom = float(samples_per_length) if samples_per_length else 1.0\n        results[f\"niah_{length}_mean_logprob_true_base\"] = true_lp_base_sum / denom\n        results[f\"niah_{length}_mean_logprob_true_memorize\"] = true_lp_mem_sum / denom\n        results[f\"niah_{length}_mean_logprob_true_delta\"] = (\n            true_lp_mem_sum - true_lp_base_sum\n        ) / denom\n        results[f\"niah_{length}_mean_margin_base\"] = margin_base_sum / denom\n        results[f\"niah_{length}_mean_margin_memorize\"] = margin_mem_sum / denom\n        results[f\"niah_{length}_mean_margin_delta\"] = (margin_mem_sum - margin_base_sum) / denom\n\n    if memorize_cfg.enabled:\n        results[\"memorize_paths\"] = (\n            \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths)\n        )\n        if memorize_cfg.surprise_threshold is not None:\n            results[\"memorize_surprise_threshold\"] = memorize_cfg.surprise_threshold\n        results[\"memorize_use_correct_answer\"] = bool(memorize_cfg.use_correct_answer)\n        if path_stats:\n            results[\"memorize_stats\"] = path_stats\n    return results\n\n\n@app.command()\ndef main(\n    a_config: Path = typer.Option(..., help=\"Hydra config for model A.\"),\n    a_checkpoint: Path = typer.Option(..., help=\"Checkpoint for model A.\"),\n    b_config: Path = typer.Option(..., help=\"Hydra config for model B.\"),\n    b_checkpoint: Path = typer.Option(..., help=\"Checkpoint for model B.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece tokenizer path.\"),\n    device: str = typer.Option(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n    output: Path = typer.Option(Path(\"eval/compare_variants.json\")),\n    seed: int = typer.Option(0, help=\"PRNG seed for prompt generation.\"),\n    smoke: bool = typer.Option(False, help=\"Use tiny settings for quick sanity checks.\"),\n    passkey_samples: int = typer.Option(64, help=\"Passkey prompts per model.\"),\n    passkey_filler_sentences: int = typer.Option(200, help=\"Filler sentences for passkey.\"),\n    niah_context_lengths: List[int] = typer.Option(\n        [2048, 4096, 8192], help=\"Context lengths for NIAH.\"\n    ),\n    niah_samples_per_length: int = typer.Option(50, help=\"Samples per NIAH length.\"),\n    memorize: bool = typer.Option(False, help=\"Enable test-time memorization for both models.\"),\n    memorize_steps: int = typer.Option(1, help=\"Memorization passes per prompt.\"),\n    memorize_use_correct_answer: bool = typer.Option(\n        False, help=\"Append ground truth during memorization.\"\n    ),\n    memorize_no_reset: bool = typer.Option(False, help=\"Retain memory between samples.\"),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm required to trigger memorization.\"\n    ),\n    memorize_layers: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated layer indices to update during memorization \"\n            \"(e.g., '11' or '0,11'), or 'last', or 'all'.\"\n        ),\n    ),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated memory paths to update (e.g., 'titan,cms_fast'); \"\n            \"use 'all' for no restriction.\"\n        ),\n    ),\n) -> None:\n    random.seed(seed)\n    torch_device = resolve_device(device)\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n\n    if smoke:\n        passkey_samples = min(passkey_samples, 8)\n        passkey_filler_sentences = min(passkey_filler_sentences, 20)\n        niah_context_lengths = [256]\n        niah_samples_per_length = min(niah_samples_per_length, 8)\n\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n\n    layers_raw = memorize_layers.strip().lower()\n    if layers_raw == \"all\":\n        allowed_layers = None\n    elif layers_raw == \"last\":\n        allowed_layers = (-1,)\n    else:\n        parsed: list[int] = []\n        for part in memorize_layers.split(\",\"):\n            part = part.strip()\n            if not part:\n                continue\n            parsed.append(int(part))\n        allowed_layers = tuple(parsed) if parsed else None\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=memorize_use_correct_answer,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n        layers=allowed_layers,\n    )\n\n    spec_a = ModelSpec(name=\"A\", config=a_config, checkpoint=a_checkpoint)\n    spec_b = ModelSpec(name=\"B\", config=b_config, checkpoint=b_checkpoint)\n    model_a = _load_model(spec_a, torch_device)\n    model_b = _load_model(spec_b, torch_device)\n\n    payload: Dict[str, Any] = {\n        \"seed\": seed,\n        \"device\": str(torch_device),\n        \"tokenizer_path\": str(tokenizer_path),\n        \"a\": {\"config\": str(a_config), \"checkpoint\": str(a_checkpoint)},\n        \"b\": {\"config\": str(b_config), \"checkpoint\": str(b_checkpoint)},\n        \"memorize\": {\n            \"enabled\": memorize_cfg.enabled,\n            \"steps\": memorize_cfg.steps,\n            \"reset\": memorize_cfg.reset,\n            \"use_correct_answer\": bool(memorize_cfg.use_correct_answer),\n            \"paths\": \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths),\n            \"surprise_threshold\": memorize_cfg.surprise_threshold,\n        },\n    }\n\n    payload[\"a\"][\"passkey\"] = _run_passkey(\n        model_a,\n        tokenizer,\n        torch_device,\n        samples=passkey_samples,\n        filler_sentences=passkey_filler_sentences,\n        memorize_cfg=memorize_cfg,\n    )\n    payload[\"b\"][\"passkey\"] = _run_passkey(\n        model_b,\n        tokenizer,\n        torch_device,\n        samples=passkey_samples,\n        filler_sentences=passkey_filler_sentences,\n        memorize_cfg=memorize_cfg,\n    )\n\n    payload[\"a\"][\"niah\"] = _run_niah(\n        model_a,\n        tokenizer,\n        torch_device,\n        context_lengths=niah_context_lengths,\n        samples_per_length=niah_samples_per_length,\n        memorize_cfg=memorize_cfg,\n    )\n    payload[\"b\"][\"niah\"] = _run_niah(\n        model_b,\n        tokenizer,\n        torch_device,\n        context_lengths=niah_context_lengths,\n        samples_per_length=niah_samples_per_length,\n        memorize_cfg=memorize_cfg,\n    )\n\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(payload, indent=2))\n    typer.echo(f\"[compare] Saved comparison to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/continual.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Dict, List\n\nimport torch\nimport typer\nimport yaml\nfrom omegaconf import OmegaConf\nfrom torch.utils.data import DataLoader\n\nfrom nested_learning.data import TokenShardDataset, collate_batch\nfrom nested_learning.device import resolve_device\nfrom nested_learning.eval_state import (\n    EvalStreamingState,\n    forward_with_eval_state,\n    init_eval_streaming_state,\n    parse_eval_state_mode,\n)\nfrom nested_learning.memorize import (\n    MemorizeConfig,\n    memorize_tokens,\n    restore_state_dict,\n    snapshot_state_dict,\n)\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(add_completion=False, help=\"Continual learning evaluation harness.\")\n\n\ndef load_segments(yaml_path: Path) -> List[Dict[str, str]]:\n    payload = yaml.safe_load(yaml_path.read_text())\n    return payload.get(\"segments\", [])\n\n\ndef evaluate_segment(\n    model,\n    dataloader: DataLoader,\n    device: torch.device,\n    max_batches: int | None,\n    memorize_cfg: MemorizeConfig,\n    *,\n    eval_state_mode: str,\n    eval_use_fast_state: bool,\n    eval_use_attention_cache: bool,\n) -> tuple[float, float, Dict[str, float]]:\n    model.eval()\n    total_loss_base = 0.0\n    total_loss_mem = 0.0\n    total_tokens = 0\n    batches = 0\n    path_stats: Dict[str, float] = defaultdict(float)\n    base_state: Dict[str, torch.Tensor] | None = None\n    fast_state = None\n    carry_eval_state = parse_eval_state_mode(eval_state_mode)\n    eval_state_base: EvalStreamingState | None = None\n    eval_state_mem: EvalStreamingState | None = None\n    if eval_use_fast_state or eval_use_attention_cache:\n        eval_state_base = init_eval_streaming_state(\n            model,\n            use_fast_state=eval_use_fast_state,\n            use_attention_cache=eval_use_attention_cache,\n        )\n        eval_state_mem = init_eval_streaming_state(\n            model,\n            use_fast_state=eval_use_fast_state,\n            use_attention_cache=eval_use_attention_cache,\n        )\n    if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset:\n        base_state = snapshot_state_dict(model)\n    for batch in dataloader:\n        tokens = batch.to(device)\n        if (eval_use_fast_state or eval_use_attention_cache) and not carry_eval_state:\n            eval_state_base = init_eval_streaming_state(\n                model,\n                use_fast_state=eval_use_fast_state,\n                use_attention_cache=eval_use_attention_cache,\n            )\n            eval_state_mem = init_eval_streaming_state(\n                model,\n                use_fast_state=eval_use_fast_state,\n                use_attention_cache=eval_use_attention_cache,\n            )\n        with torch.no_grad():\n            logits, eval_state_base = forward_with_eval_state(\n                model,\n                tokens,\n                state=eval_state_base,\n            )\n            loss = torch.nn.functional.cross_entropy(\n                logits[:, :-1].reshape(-1, logits.size(-1)),\n                tokens[:, 1:].reshape(-1),\n                reduction=\"sum\",\n            )\n        total_loss_base += loss.item()\n        if memorize_cfg.enabled:\n            if memorize_cfg.use_fast_state:\n                if fast_state is None or memorize_cfg.reset:\n                    if not hasattr(model, \"init_fast_state\"):\n                        raise RuntimeError(\"Model does not support fast state memorization\")\n                    fast_state = model.init_fast_state()\n                stats = memorize_tokens(model, tokens, memorize_cfg, fast_state=fast_state)\n            else:\n                stats = memorize_tokens(model, tokens, memorize_cfg)\n            for key, value in stats.items():\n                path_stats[key] += value\n            with torch.no_grad():\n                if eval_use_fast_state or eval_use_attention_cache:\n                    logits_mem, eval_state_mem = forward_with_eval_state(\n                        model,\n                        tokens,\n                        state=eval_state_mem,\n                    )\n                else:\n                    logits_mem = (\n                        model(tokens, fast_state=fast_state)\n                        if memorize_cfg.use_fast_state\n                        else model(tokens)\n                    )\n                loss_mem = torch.nn.functional.cross_entropy(\n                    logits_mem[:, :-1].reshape(-1, logits_mem.size(-1)),\n                    tokens[:, 1:].reshape(-1),\n                    reduction=\"sum\",\n                )\n            total_loss_mem += loss_mem.item()\n            if (not memorize_cfg.use_fast_state) and memorize_cfg.reset and base_state is not None:\n                restore_state_dict(model, base_state)\n        else:\n            total_loss_mem += loss.item()\n        total_tokens += tokens[:, 1:].numel()\n        batches += 1\n        if max_batches and batches >= max_batches:\n            break\n    base_ce = total_loss_base / total_tokens if total_tokens > 0 else float(\"nan\")\n    mem_ce = total_loss_mem / total_tokens if total_tokens > 0 else float(\"nan\")\n    return base_ce, mem_ce, path_stats\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra model config for HOPE.\"),\n    checkpoints: List[Path] = typer.Option(\n        ..., help=\"Ordered list of checkpoints (chronological).\"\n    ),\n    segments_yaml: Path = typer.Option(..., help=\"YAML describing shard directories per segment.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece model path (unused for now).\"),\n    batch_size: int = typer.Option(4, help=\"Batch size for evaluation.\"),\n    max_batches: int = typer.Option(50, help=\"Max batches per segment (0 = entire dataset).\"),\n    device: str = typer.Option(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n    output: Path = typer.Option(Path(\"eval/continual_results.json\")),\n    memorize: bool = typer.Option(False, help=\"Enable memorization while evaluating segments.\"),\n    memorize_steps: int = typer.Option(1, help=\"Memorization passes per batch.\"),\n    memorize_no_reset: bool = typer.Option(True, help=\"Keep memory between segments by default.\"),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm needed to memorize a batch.\"\n    ),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated memory paths to update (e.g., 'titan,cms_fast'); \"\n            \"use 'all' for default behavior.\"\n        ),\n    ),\n    eval_state_mode: str = typer.Option(\n        \"reset_per_sample\",\n        help=\"Streaming eval state mode: 'reset_per_sample' or 'carry_across_samples'.\",\n    ),\n    eval_use_fast_state: bool = typer.Option(\n        False,\n        help=\"Use model fast state during inference scoring (independent from memorization state).\",\n    ),\n    eval_use_attention_cache: bool = typer.Option(\n        False,\n        help=\"Use attention KV cache during inference scoring.\",\n    ),\n) -> None:\n    segments = load_segments(segments_yaml)\n    if not segments:\n        raise typer.BadParameter(\"No segments found in YAML.\")\n\n    cfg = OmegaConf.load(config)\n    cfg = unwrap_config(cfg)\n    device_obj = resolve_device(device)\n    results = []\n\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=False,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n    )\n\n    for step_idx, ckpt_path in enumerate(checkpoints):\n        state = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n        model = build_model_from_cfg(cfg.model)\n        state_dict = state[\"model\"] if \"model\" in state else state\n        missing, unexpected = model.load_state_dict(state_dict, strict=False)\n        if missing or unexpected:\n            print(\n                \"[continual] Warning: state_dict mismatch \"\n                f\"(missing={len(missing)} unexpected={len(unexpected)}) – continuing.\"\n            )\n        model = model.to(device_obj)\n\n        segment_losses = {}\n        baseline_losses = {}\n        segment_stats = {}\n        for segment in segments:\n            name = segment[\"name\"]\n            shards_dir = Path(segment[\"shards_dir\"])\n            dataset = TokenShardDataset(shards_dir)\n            loader = DataLoader(\n                dataset,\n                batch_size=batch_size,\n                shuffle=False,\n                num_workers=0,\n                collate_fn=collate_batch,\n            )\n            base_loss, mem_loss, stats = evaluate_segment(\n                model,\n                loader,\n                device_obj,\n                None if max_batches <= 0 else max_batches,\n                memorize_cfg,\n                eval_state_mode=eval_state_mode,\n                eval_use_fast_state=eval_use_fast_state,\n                eval_use_attention_cache=eval_use_attention_cache,\n            )\n            baseline_losses[name] = base_loss\n            segment_losses[name] = mem_loss\n            if stats:\n                segment_stats[name] = stats\n\n        entry = {\"checkpoint\": str(ckpt_path), \"segment_losses\": segment_losses}\n        if memorize_cfg.enabled:\n            entry[\"segment_baseline_losses\"] = baseline_losses\n            entry[\"segment_memorize_delta\"] = {\n                name: baseline_losses[name] - segment_losses[name] for name in segment_losses\n            }\n            if segment_stats:\n                entry[\"memorize_stats\"] = segment_stats\n        entry[\"eval_state_mode\"] = eval_state_mode\n        entry[\"eval_use_fast_state\"] = bool(eval_use_fast_state)\n        entry[\"eval_use_attention_cache\"] = bool(eval_use_attention_cache)\n        results.append(entry)\n\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(results, indent=2))\n    typer.echo(f\"[Continual] Saved results to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/continual_classification.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import List\n\nimport torch\nimport typer\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.continual_classification import (\n    ClassificationExample,\n    load_banking77,\n    load_clinc_oos,\n    load_dbpedia14,\n)\nfrom nested_learning.continual_streaming import (\n    ContinualEvalConfig,\n    build_streaming_tasks,\n    evaluate_continual_classification,\n)\nfrom nested_learning.device import resolve_device\nfrom nested_learning.memorize import MemorizeConfig\nfrom nested_learning.tokenizer import SentencePieceTokenizer\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(\n    add_completion=False,\n    help=\"Class-incremental continual-learning harness (CLINC/Banking/DBpedia).\",\n)\n\n\ndef _load_local_jsonl(path: Path) -> List[ClassificationExample]:\n    examples: List[ClassificationExample] = []\n    for line in path.read_text().splitlines():\n        if not line.strip():\n            continue\n        row = json.loads(line)\n        examples.append(ClassificationExample(text=str(row[\"text\"]), label=str(row[\"label\"])))\n    return examples\n\n\ndef _load_examples(\n    dataset: str, *, split: str, max_samples: int | None\n) -> List[ClassificationExample]:\n    dataset = dataset.strip().lower()\n    if dataset == \"clinc\":\n        return load_clinc_oos(split=split, max_samples=max_samples).examples\n    if dataset == \"banking77\":\n        return load_banking77(split=split, max_samples=max_samples).examples\n    if dataset == \"dbpedia14\":\n        return load_dbpedia14(split=split, max_samples=max_samples).examples\n    raise typer.BadParameter(\"dataset must be one of: clinc, banking77, dbpedia14\")\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra model config path.\"),\n    checkpoint: Path = typer.Option(..., help=\"Checkpoint path.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece tokenizer path.\"),\n    dataset: str = typer.Option(\"clinc\", help=\"Dataset: clinc | banking77 | dbpedia14.\"),\n    split: str = typer.Option(\"test\", help=\"HF split to load.\"),\n    local_jsonl: Path = typer.Option(\n        None,\n        help=\"Optional local JSONL (each line: {'text':..., 'label':...}); bypasses HF datasets.\",\n    ),\n    task_size: int = typer.Option(10, help=\"Number of classes per task.\"),\n    train_per_label: int = typer.Option(25, help=\"Streaming examples per label.\"),\n    eval_per_label: int = typer.Option(25, help=\"Eval examples per label.\"),\n    seed: int = typer.Option(0, help=\"Label/task shuffle seed.\"),\n    task_aware: bool = typer.Option(True, help=\"Restrict candidates to current task labels.\"),\n    max_samples: int = typer.Option(\n        0, help=\"Max dataset samples (0 = no limit); recommended for smoke runs.\"\n    ),\n    device: str = typer.Option(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n    output: Path = typer.Option(Path(\"eval/continual_classification.json\")),\n    smoke: bool = typer.Option(False, help=\"Tiny settings for quick sanity checks.\"),\n    memorize: bool = typer.Option(False, help=\"Enable test-time memorization during streaming.\"),\n    memorize_steps: int = typer.Option(1, help=\"Memorization passes per example.\"),\n    memorize_no_reset: bool = typer.Option(\n        True, help=\"Keep memory across examples/tasks by default.\"\n    ),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm required to trigger memorization.\"\n    ),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated memory paths to update (e.g., 'titan,cms_fast'); \"\n            \"use 'all' for no restriction.\"\n        ),\n    ),\n) -> None:\n    torch_device = resolve_device(device)\n    cfg = OmegaConf.load(config)\n    cfg = unwrap_config(cfg)\n    model = build_model_from_cfg(cfg.model)\n    state = torch.load(checkpoint, map_location=\"cpu\", weights_only=False)\n    state_dict = state[\"model\"] if \"model\" in state else state\n    missing, unexpected = model.load_state_dict(state_dict, strict=False)\n    if missing or unexpected:\n        print(\n            \"[continual_cls] Warning: state_dict mismatch \"\n            f\"(missing={len(missing)} unexpected={len(unexpected)}) – continuing.\"\n        )\n    model = model.to(torch_device).eval()\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n\n    resolved_max = None if max_samples <= 0 else int(max_samples)\n    if smoke:\n        resolved_max = 500\n        task_size = min(task_size, 3)\n        train_per_label = min(train_per_label, 2)\n        eval_per_label = min(eval_per_label, 2)\n\n    if local_jsonl is not None:\n        examples = _load_local_jsonl(local_jsonl)\n    else:\n        examples = _load_examples(dataset, split=split, max_samples=resolved_max)\n\n    eval_cfg = ContinualEvalConfig(\n        task_size=task_size,\n        seed=seed,\n        train_per_label=train_per_label,\n        eval_per_label=eval_per_label,\n        task_aware=task_aware,\n    )\n    tasks = build_streaming_tasks(examples, cfg=eval_cfg)\n\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=True,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n    )\n\n    result, meta = evaluate_continual_classification(\n        model,\n        tokenizer,\n        tasks,\n        torch_device,\n        cfg=eval_cfg,\n        memorize_cfg=memorize_cfg,\n    )\n    payload = {\n        \"dataset\": dataset if local_jsonl is None else str(local_jsonl),\n        \"split\": split,\n        \"config\": str(config),\n        \"checkpoint\": str(checkpoint),\n        \"tokenizer_path\": str(tokenizer_path),\n        \"device\": str(torch_device),\n        \"tasks\": [\n            {\"task_id\": t.task_id, \"labels\": t.labels, \"train\": len(t.train), \"eval\": len(t.eval)}\n            for t in tasks\n        ],\n        \"result\": {\n            \"avg_accuracy_final\": result.avg_accuracy_final,\n            \"avg_forgetting\": result.avg_forgetting,\n            \"per_task_forgetting\": result.per_task_forgetting,\n            \"task_accuracy_matrix\": result.task_accuracy_matrix,\n        },\n        \"meta\": meta,\n    }\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(payload, indent=2))\n    typer.echo(f\"[continual_cls] Saved results to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/niah.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nimport random\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Dict, List\n\nimport torch\nimport typer\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.memorize import (\n    MemorizeConfig,\n    memorize_sequence,\n    restore_state_dict,\n    snapshot_state_dict,\n)\nfrom nested_learning.model import HOPEModel\nfrom nested_learning.tokenizer import SentencePieceTokenizer\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(add_completion=False, help=\"Needle-in-a-haystack evaluation scaffolding.\")\n\n\ndef load_model(config_path: Path, checkpoint: Path, device: torch.device) -> HOPEModel:\n    cfg = OmegaConf.load(config_path)\n    cfg = unwrap_config(cfg)\n    model = build_model_from_cfg(cfg.model)\n    state = torch.load(checkpoint, map_location=\"cpu\", weights_only=False)\n    state_dict = state[\"model\"] if \"model\" in state else state\n    missing, unexpected = model.load_state_dict(state_dict, strict=False)\n    if missing or unexpected:\n        print(\n            \"[eval] Warning: state_dict mismatch \"\n            f\"(missing={len(missing)} unexpected={len(unexpected)}) – continuing.\"\n        )\n    return model.to(device).eval()\n\n\ndef make_prompt(needle: str, filler_tokens: int) -> str:\n    filler_chunks = [\"This is filler sentence number {}.\".format(i) for i in range(filler_tokens)]\n    random.shuffle(filler_chunks)\n    haystack = \" \".join(filler_chunks)\n    prompt = (\n        f\"{haystack} Remember that the secret key is {needle}. Later you might be asked about it. \"\n    )\n    prompt += \"Now answer the question truthfully. What is the secret key? Answer:\"\n    return prompt\n\n\ndef logprob_answer(\n    model: HOPEModel,\n    tokenizer: SentencePieceTokenizer,\n    prompt: str,\n    answer: str,\n    device: torch.device,\n    *,\n    fast_state=None,\n) -> float:\n    prompt_ids = tokenizer.encode(prompt, add_bos=True)\n    answer_ids = tokenizer.encode(\" \" + answer, add_bos=False)\n    inputs = torch.cat([prompt_ids, answer_ids], dim=0).to(device)\n    with torch.no_grad():\n        logits = (\n            model(inputs.unsqueeze(0), fast_state=fast_state)\n            if fast_state is not None\n            else model(inputs.unsqueeze(0))\n        )\n        log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)\n        target = inputs.unsqueeze(0)[:, 1:]\n        gathered = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)\n        prompt_len = prompt_ids.numel()\n        answer_logprob = gathered[0, prompt_len - 1 :].sum().item()\n    return answer_logprob\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra config path.\"),\n    checkpoint: Path = typer.Option(..., help=\"Checkpoint to evaluate.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece tokenizer path.\"),\n    context_lengths: List[int] = typer.Option([2048, 4096, 8192], help=\"Context lengths to probe.\"),\n    samples_per_length: int = typer.Option(50, help=\"Samples per context length.\"),\n    device: str = typer.Option(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n    output: Path = typer.Option(Path(\"eval/niah_results.json\")),\n    memorize: bool = typer.Option(False, help=\"Enable test-time memorization for each prompt.\"),\n    memorize_steps: int = typer.Option(1, help=\"Memorization passes per prompt.\"),\n    memorize_use_correct_answer: bool = typer.Option(\n        False, help=\"Include correct key when memorizing.\"\n    ),\n    memorize_no_reset: bool = typer.Option(False, help=\"Retain memory between samples.\"),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm required to trigger memorization.\"\n    ),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated memory paths to update (e.g., 'titan,cms_fast'); \"\n            \"use 'all' for no restriction.\"\n        ),\n    ),\n    eval_state_mode: str = typer.Option(\n        \"reset_per_sample\",\n        help=\"Streaming eval state mode. Currently only 'reset_per_sample' is supported here.\",\n    ),\n) -> None:\n    if eval_state_mode.strip().lower() not in {\"reset\", \"isolated\", \"reset_per_sample\"}:\n        raise typer.BadParameter(\n            \"NIAH binary-choice scoring currently supports eval_state_mode=reset_per_sample only.\"\n        )\n    torch_device = resolve_device(device)\n    model = load_model(config, checkpoint, torch_device)\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=memorize_use_correct_answer,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n    )\n    base_state: Dict[str, torch.Tensor] | None = None\n    fast_state = None\n    results = {}\n    path_stats: Dict[str, float] = defaultdict(float)\n    for length in context_lengths:\n        correct_base = 0\n        correct_mem = 0\n        for _ in tqdm(range(samples_per_length), desc=f\"NIAH@{length}\"):\n            needle = f\"KEY-{random.randint(1000, 9999)}\"\n            prompt = make_prompt(needle, filler_tokens=max(1, length // 128))\n            distractor = f\"KEY-{random.randint(1000, 9999)}\"\n            logprob_true_base = logprob_answer(model, tokenizer, prompt, needle, torch_device)\n            logprob_false_base = logprob_answer(model, tokenizer, prompt, distractor, torch_device)\n            correct_base += int(logprob_true_base > logprob_false_base)\n            if memorize_cfg.enabled:\n                memorize_text = prompt\n                if memorize_cfg.use_correct_answer:\n                    memorize_text = f\"{prompt} {needle}\"\n                if memorize_cfg.use_fast_state:\n                    if fast_state is None or memorize_cfg.reset:\n                        if not hasattr(model, \"init_fast_state\"):\n                            raise RuntimeError(\"Model does not support fast state memorization\")\n                        fast_state = model.init_fast_state()\n                    stats = memorize_sequence(\n                        model,\n                        tokenizer,\n                        memorize_text,\n                        torch_device,\n                        memorize_cfg,\n                        fast_state=fast_state,\n                    )\n                    for key, value in stats.items():\n                        path_stats[key] += value\n                    logprob_true_mem = logprob_answer(\n                        model, tokenizer, prompt, needle, torch_device, fast_state=fast_state\n                    )\n                    logprob_false_mem = logprob_answer(\n                        model, tokenizer, prompt, distractor, torch_device, fast_state=fast_state\n                    )\n                    correct_mem += int(logprob_true_mem > logprob_false_mem)\n                else:\n                    if memorize_cfg.reset and base_state is None:\n                        base_state = snapshot_state_dict(model)\n                    stats = memorize_sequence(\n                        model, tokenizer, memorize_text, torch_device, memorize_cfg\n                    )\n                    for key, value in stats.items():\n                        path_stats[key] += value\n                    logprob_true_mem = logprob_answer(\n                        model, tokenizer, prompt, needle, torch_device\n                    )\n                    logprob_false_mem = logprob_answer(\n                        model, tokenizer, prompt, distractor, torch_device\n                    )\n                    correct_mem += int(logprob_true_mem > logprob_false_mem)\n                    if memorize_cfg.reset and base_state is not None:\n                        restore_state_dict(model, base_state)\n            else:\n                correct_mem += int(logprob_true_base > logprob_false_base)\n        base_acc = correct_base / samples_per_length if samples_per_length else 0.0\n        mem_acc = correct_mem / samples_per_length if samples_per_length else 0.0\n        results[f\"niah_{length}\"] = mem_acc\n        if memorize_cfg.enabled:\n            results[f\"niah_{length}_baseline_accuracy\"] = base_acc\n            results[f\"niah_{length}_memorize_accuracy\"] = mem_acc\n            results[f\"niah_{length}_memorize_delta\"] = mem_acc - base_acc\n    if memorize_cfg.enabled:\n        for key, value in path_stats.items():\n            results[f\"niah_{key}\"] = value\n        results[\"niah_memorize_paths\"] = (\n            \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths)\n        )\n        if memorize_cfg.surprise_threshold is not None:\n            results[\"niah_memorize_surprise_threshold\"] = memorize_cfg.surprise_threshold\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(results, indent=2))\n    typer.echo(f\"[Eval] Saved NIAH metrics to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/niah_suite.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nimport random\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, Dict, List\n\nimport torch\nimport typer\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.memorize import (\n    MemorizeConfig,\n    memorize_sequence,\n    restore_state_dict,\n    snapshot_state_dict,\n)\nfrom nested_learning.tokenizer import SentencePieceTokenizer\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(add_completion=False, help=\"RULER-ish NIAH suite (multiple retrieval variants).\")\n\n\ndef load_model(config_path: Path, checkpoint: Path, device: torch.device):\n    cfg = OmegaConf.load(config_path)\n    cfg = unwrap_config(cfg)\n    model = build_model_from_cfg(cfg.model)\n    state = torch.load(checkpoint, map_location=\"cpu\", weights_only=False)\n    state_dict = state[\"model\"] if \"model\" in state else state\n    missing, unexpected = model.load_state_dict(state_dict, strict=False)\n    if missing or unexpected:\n        print(\n            \"[niah_suite] Warning: state_dict mismatch \"\n            f\"(missing={len(missing)} unexpected={len(unexpected)}) – continuing.\"\n        )\n    return model.to(device).eval()\n\n\ndef _logprob_answer(\n    model,\n    tokenizer: SentencePieceTokenizer,\n    prompt: str,\n    answer: str,\n    device: torch.device,\n    *,\n    fast_state=None,\n) -> float:\n    prompt_ids = tokenizer.encode(prompt, add_bos=True)\n    answer_ids = tokenizer.encode(\" \" + answer, add_bos=False)\n    inputs = torch.cat([prompt_ids, answer_ids], dim=0).to(device)\n    with torch.no_grad():\n        logits = (\n            model(inputs.unsqueeze(0), fast_state=fast_state)\n            if fast_state is not None\n            else model(inputs.unsqueeze(0))\n        )\n        log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)\n        target = inputs.unsqueeze(0)[:, 1:]\n        gathered = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)\n        prompt_len = prompt_ids.numel()\n        return float(gathered[0, prompt_len - 1 :].sum().item())\n\n\ndef _filler_sentences(count: int) -> List[str]:\n    return [f\"This is filler sentence number {idx}.\" for idx in range(count)]\n\n\ndef _ensure_prompt_length(\n    tokenizer: SentencePieceTokenizer,\n    *,\n    base_lines: List[str],\n    target_tokens: int,\n    rng: random.Random,\n    max_filler: int = 50_000,\n) -> str:\n    filler = []\n    filler_count = max(1, target_tokens // 32)\n    while True:\n        filler = _filler_sentences(filler_count)\n        rng.shuffle(filler)\n        prompt = \"\\n\".join([*filler, *base_lines])\n        token_len = int(tokenizer.encode(prompt, add_bos=True).numel())\n        if token_len >= target_tokens:\n            return prompt\n        if filler_count >= max_filler:\n            return prompt\n        missing = target_tokens - token_len\n        filler_count += max(1, missing // 16)\n\n\n@dataclass(frozen=True)\nclass VariantCase:\n    prompt: str\n    answer: str\n    distractor: str\n\n\ndef _case_single_needle(rng: random.Random) -> VariantCase:\n    needle = f\"KEY-{rng.randint(1000, 9999)}\"\n    prompt_lines = [\n        f\"Remember that the secret key is {needle}.\",\n        \"Later you might be asked about it.\",\n        \"Question: What is the secret key?\",\n        \"Answer:\",\n    ]\n    distractor = f\"KEY-{rng.randint(1000, 9999)}\"\n    return VariantCase(prompt=\"\\n\".join(prompt_lines), answer=needle, distractor=distractor)\n\n\ndef _case_multi_needle(rng: random.Random, *, needles: int) -> VariantCase:\n    keys = [f\"KEY-{rng.randint(1000, 9999)}\" for _ in range(max(2, needles))]\n    query_idx = rng.randrange(len(keys))\n    prompt_lines = [\"Memorize the following secret keys:\"]\n    for idx, key in enumerate(keys, start=1):\n        prompt_lines.append(f\"Key {idx}: {key}.\")\n    prompt_lines.extend(\n        [\n            f\"Question: What is Key {query_idx + 1}?\",\n            \"Answer:\",\n        ]\n    )\n    distractor = f\"KEY-{rng.randint(1000, 9999)}\"\n    return VariantCase(\n        prompt=\"\\n\".join(prompt_lines), answer=keys[query_idx], distractor=distractor\n    )\n\n\ndef _case_kv_single(rng: random.Random) -> VariantCase:\n    key = f\"ITEM-{rng.randint(100, 999)}\"\n    value = f\"VALUE-{rng.randint(1000, 9999)}\"\n    prompt_lines = [\n        \"Memorize this key-value pair:\",\n        f\"{key} -> {value}.\",\n        f\"Question: What is the value for {key}?\",\n        \"Answer:\",\n    ]\n    distractor = f\"VALUE-{rng.randint(1000, 9999)}\"\n    return VariantCase(prompt=\"\\n\".join(prompt_lines), answer=value, distractor=distractor)\n\n\ndef _case_kv_multi(rng: random.Random, *, pairs: int) -> VariantCase:\n    pairs = max(2, pairs)\n    keys = [f\"ITEM-{rng.randint(100, 999)}\" for _ in range(pairs)]\n    values = [f\"VALUE-{rng.randint(1000, 9999)}\" for _ in range(pairs)]\n    query_idx = rng.randrange(pairs)\n    prompt_lines = [\"Memorize the following key-value pairs:\"]\n    for k, v in zip(keys, values, strict=True):\n        prompt_lines.append(f\"{k} -> {v}.\")\n    prompt_lines.extend(\n        [\n            f\"Question: What is the value for {keys[query_idx]}?\",\n            \"Answer:\",\n        ]\n    )\n    distractor = f\"VALUE-{rng.randint(1000, 9999)}\"\n    return VariantCase(\n        prompt=\"\\n\".join(prompt_lines), answer=values[query_idx], distractor=distractor\n    )\n\n\ndef _case_positioned_needle(rng: random.Random, *, position: str) -> VariantCase:\n    needle = f\"KEY-{rng.randint(1000, 9999)}\"\n    prompt_lines = [\n        f\"Remember that the secret key is {needle}.\",\n        \"Question: What is the secret key?\",\n        \"Answer:\",\n    ]\n    distractor = f\"KEY-{rng.randint(1000, 9999)}\"\n    return VariantCase(prompt=\"\\n\".join(prompt_lines), answer=needle, distractor=distractor)\n\n\ndef _variant_cases(rng: random.Random, *, variant: str) -> VariantCase:\n    if variant == \"single_needle\":\n        return _case_single_needle(rng)\n    if variant == \"multi_needle\":\n        return _case_multi_needle(rng, needles=4)\n    if variant == \"kv_single\":\n        return _case_kv_single(rng)\n    if variant == \"kv_multi\":\n        return _case_kv_multi(rng, pairs=6)\n    if variant in {\"needle_early\", \"needle_mid\", \"needle_late\"}:\n        pos = variant.split(\"_\", 1)[1]\n        return _case_positioned_needle(rng, position=pos)\n    raise ValueError(f\"Unknown variant: {variant}\")\n\n\ndef _evaluate_variant(\n    model,\n    tokenizer: SentencePieceTokenizer,\n    device: torch.device,\n    *,\n    variant: str,\n    context_tokens: int,\n    samples: int,\n    rng: random.Random,\n    memorize_cfg: MemorizeConfig,\n) -> Dict[str, Any]:\n    base_state: Dict[str, torch.Tensor] | None = None\n    fast_state = None\n    if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset:\n        base_state = snapshot_state_dict(model)\n\n    correct_base = 0\n    correct_mem = 0\n    path_stats: Dict[str, float] = {}\n    for _ in tqdm(range(samples), desc=f\"{variant}@{context_tokens}\"):\n        case = _variant_cases(rng, variant=variant)\n        if variant in {\"needle_early\", \"needle_mid\", \"needle_late\"}:\n            memory_line, question_line, answer_line = case.prompt.split(\"\\n\", 2)\n            if variant == \"needle_early\":\n                ratio = 0.1\n            elif variant == \"needle_late\":\n                ratio = 0.9\n            else:\n                ratio = 0.5\n            filler_count = max(1, context_tokens // 32)\n            while True:\n                filler = _filler_sentences(filler_count)\n                rng.shuffle(filler)\n                insert_at = int(ratio * max(1, len(filler)))\n                insert_at = max(0, min(insert_at, len(filler)))\n                with_memory = filler[:insert_at] + [memory_line] + filler[insert_at:]\n                prompt = \"\\n\".join([*with_memory, question_line, answer_line])\n                token_len = int(tokenizer.encode(prompt, add_bos=True).numel())\n                if token_len >= context_tokens:\n                    break\n                filler_count += max(1, (context_tokens - token_len) // 16)\n        else:\n            prompt = _ensure_prompt_length(\n                tokenizer,\n                base_lines=[case.prompt],\n                target_tokens=context_tokens,\n                rng=rng,\n            )\n        lp_true_base = _logprob_answer(\n            model, tokenizer, prompt, case.answer, device, fast_state=fast_state\n        )\n        lp_false_base = _logprob_answer(\n            model, tokenizer, prompt, case.distractor, device, fast_state=fast_state\n        )\n        correct_base += int(lp_true_base > lp_false_base)\n        if memorize_cfg.enabled:\n            memorize_text = (\n                prompt if not memorize_cfg.use_correct_answer else f\"{prompt} {case.answer}\"\n            )\n            if memorize_cfg.use_fast_state:\n                if fast_state is None or memorize_cfg.reset:\n                    if not hasattr(model, \"init_fast_state\"):\n                        raise RuntimeError(\"Model does not support fast state memorization\")\n                    fast_state = model.init_fast_state()\n                stats = memorize_sequence(\n                    model, tokenizer, memorize_text, device, memorize_cfg, fast_state=fast_state\n                )\n                for k, v in stats.items():\n                    path_stats[k] = path_stats.get(k, 0.0) + v\n                lp_true_mem = _logprob_answer(\n                    model, tokenizer, prompt, case.answer, device, fast_state=fast_state\n                )\n                lp_false_mem = _logprob_answer(\n                    model, tokenizer, prompt, case.distractor, device, fast_state=fast_state\n                )\n                correct_mem += int(lp_true_mem > lp_false_mem)\n            else:\n                stats = memorize_sequence(model, tokenizer, memorize_text, device, memorize_cfg)\n                for k, v in stats.items():\n                    path_stats[k] = path_stats.get(k, 0.0) + v\n                lp_true_mem = _logprob_answer(model, tokenizer, prompt, case.answer, device)\n                lp_false_mem = _logprob_answer(model, tokenizer, prompt, case.distractor, device)\n                correct_mem += int(lp_true_mem > lp_false_mem)\n                if memorize_cfg.reset and base_state is not None:\n                    restore_state_dict(model, base_state)\n        else:\n            correct_mem += int(lp_true_base > lp_false_base)\n\n    base_acc = correct_base / samples if samples else 0.0\n    mem_acc = correct_mem / samples if samples else 0.0\n    payload: Dict[str, Any] = {\n        \"variant\": variant,\n        \"context_tokens\": context_tokens,\n        \"samples\": samples,\n        \"baseline_accuracy\": base_acc,\n        \"memorize_accuracy\": mem_acc,\n        \"memorize_delta\": mem_acc - base_acc,\n    }\n    if memorize_cfg.enabled:\n        payload[\"memorize_paths\"] = (\n            \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths)\n        )\n        payload[\"memorize_use_correct_answer\"] = bool(memorize_cfg.use_correct_answer)\n        if memorize_cfg.surprise_threshold is not None:\n            payload[\"memorize_surprise_threshold\"] = memorize_cfg.surprise_threshold\n        if path_stats:\n            payload[\"memorize_stats\"] = path_stats\n    return payload\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra config path.\"),\n    checkpoint: Path = typer.Option(..., help=\"Checkpoint to evaluate.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece tokenizer path.\"),\n    context_tokens: List[int] = typer.Option(\n        [2048, 4096, 8192], help=\"Target prompt token lengths.\"\n    ),\n    samples_per_length: int = typer.Option(50, help=\"Samples per (variant, length).\"),\n    variants: List[str] = typer.Option(\n        [\n            \"single_needle\",\n            \"multi_needle\",\n            \"kv_single\",\n            \"kv_multi\",\n            \"needle_early\",\n            \"needle_mid\",\n            \"needle_late\",\n        ],\n        help=\"Variant names to run.\",\n    ),\n    seed: int = typer.Option(0, help=\"Random seed.\"),\n    device: str = typer.Option(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n    output: Path = typer.Option(Path(\"eval/niah_suite_results.json\")),\n    smoke: bool = typer.Option(False, help=\"Tiny settings for quick sanity checks.\"),\n    memorize: bool = typer.Option(False, help=\"Enable test-time memorization for each prompt.\"),\n    memorize_steps: int = typer.Option(1, help=\"Memorization passes per prompt.\"),\n    memorize_use_correct_answer: bool = typer.Option(\n        False, help=\"Append ground truth during memorization.\"\n    ),\n    memorize_no_reset: bool = typer.Option(False, help=\"Retain memory between samples.\"),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm required to trigger memorization.\"\n    ),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated memory paths to update (e.g., 'titan,cms_fast'); \"\n            \"use 'all' for no restriction.\"\n        ),\n    ),\n) -> None:\n    rng = random.Random(seed)\n    torch_device = resolve_device(device)\n    model = load_model(config, checkpoint, torch_device)\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n\n    if smoke:\n        context_tokens = [256]\n        samples_per_length = min(samples_per_length, 8)\n        variants = [\"single_needle\", \"kv_single\"]\n\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=memorize_use_correct_answer,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n    )\n\n    results: List[Dict[str, Any]] = []\n    for variant in variants:\n        for length in context_tokens:\n            results.append(\n                _evaluate_variant(\n                    model,\n                    tokenizer,\n                    torch_device,\n                    variant=variant,\n                    context_tokens=length,\n                    samples=samples_per_length,\n                    rng=rng,\n                    memorize_cfg=memorize_cfg,\n                )\n            )\n\n    payload = {\n        \"seed\": seed,\n        \"device\": str(torch_device),\n        \"config\": str(config),\n        \"checkpoint\": str(checkpoint),\n        \"tokenizer_path\": str(tokenizer_path),\n        \"variants\": variants,\n        \"context_tokens\": context_tokens,\n        \"samples_per_length\": samples_per_length,\n        \"memorize\": {\n            \"enabled\": memorize_cfg.enabled,\n            \"steps\": memorize_cfg.steps,\n            \"reset\": memorize_cfg.reset,\n            \"use_correct_answer\": bool(memorize_cfg.use_correct_answer),\n            \"paths\": \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths),\n            \"surprise_threshold\": memorize_cfg.surprise_threshold,\n        },\n        \"results\": results,\n    }\n\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(payload, indent=2))\n    typer.echo(f\"[niah_suite] Saved results to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/passkey.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nimport random\nfrom pathlib import Path\n\nimport torch\nimport typer\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.memorize import (\n    MemorizeConfig,\n    memorize_sequence,\n    restore_state_dict,\n    snapshot_state_dict,\n)\nfrom nested_learning.tokenizer import SentencePieceTokenizer\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(add_completion=False, help=\"Synthetic passkey evaluation (LongBench-style).\")\n\nPROMPT_TEMPLATE = (\n    \"{filler}\\nRemember that the passkey for this document is {key}. \"\n    \"Later we will ask about it.\\nQuestion: What is the passkey?\\nAnswer:\"\n)\n\n\ndef load_model(config: Path, checkpoint: Path, device: torch.device):\n    cfg = OmegaConf.load(config)\n    cfg = unwrap_config(cfg)\n    model = build_model_from_cfg(cfg.model)\n    state = torch.load(checkpoint, map_location=\"cpu\", weights_only=False)\n    state_dict = state[\"model\"] if \"model\" in state else state\n    missing, unexpected = model.load_state_dict(state_dict, strict=False)\n    if missing or unexpected:\n        print(f\"[passkey] Warning: mismatch missing={len(missing)} unexpected={len(unexpected)}\")\n    return model.to(device).eval()\n\n\ndef make_prompt(context_tokens: int, key: str) -> str:\n    sentences = [f\"This is filler sentence number {idx}.\" for idx in range(context_tokens)]\n    random.shuffle(sentences)\n    filler = \" \".join(sentences)\n    return PROMPT_TEMPLATE.format(filler=filler, key=key)\n\n\ndef logprob(\n    model, tokenizer, prompt: str, answer: str, device: torch.device, *, fast_state=None\n) -> float:\n    prompt_ids = tokenizer.encode(prompt, add_bos=True)\n    answer_ids = tokenizer.encode(\" \" + answer, add_bos=False, add_eos=True)\n    tokens = torch.cat([prompt_ids, answer_ids], dim=0).unsqueeze(0).to(device)\n    with torch.no_grad():\n        logits = model(tokens, fast_state=fast_state) if fast_state is not None else model(tokens)\n        log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)\n        targets = tokens[:, 1:]\n        gathered = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)\n        prompt_len = prompt_ids.numel()\n        return gathered[:, prompt_len - 1 :].sum().item()\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra model config.\"),\n    checkpoint: Path = typer.Option(..., help=\"Checkpoint path.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece tokenizer.\"),\n    samples: int = typer.Option(64, help=\"Number of synthetic prompts.\"),\n    filler_sentences: int = typer.Option(200, help=\"Number of filler sentences (controls length).\"),\n    device: str = typer.Option(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n    output: Path = typer.Option(Path(\"eval/passkey_results.json\")),\n    memorize: bool = typer.Option(False, help=\"Enable memorization before answering.\"),\n    memorize_steps: int = typer.Option(1, help=\"Memorization iterations.\"),\n    memorize_no_reset: bool = typer.Option(False, help=\"Retain memory between prompts.\"),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm required before memorizing a prompt.\"\n    ),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated memory paths to update (e.g., 'titan,cms_fast'); \"\n            \"use 'all' for unrestricted paths.\"\n        ),\n    ),\n    eval_state_mode: str = typer.Option(\n        \"reset_per_sample\",\n        help=\"Streaming eval state mode. Currently only 'reset_per_sample' is supported here.\",\n    ),\n) -> None:\n    if eval_state_mode.strip().lower() not in {\"reset\", \"isolated\", \"reset_per_sample\"}:\n        raise typer.BadParameter(\n            \"Passkey binary-choice scoring currently supports \"\n            \"eval_state_mode=reset_per_sample only.\"\n        )\n    torch_device = resolve_device(device)\n    model = load_model(config, checkpoint, torch_device)\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=True,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n    )\n    base_state = (\n        snapshot_state_dict(model)\n        if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset\n        else None\n    )\n    fast_state = None\n    correct_base = 0\n    correct_mem = 0\n    path_stats: dict[str, float] = {}\n    for _ in range(samples):\n        key = f\"PASSKEY-{random.randint(1000, 9999)}\"\n        prompt = make_prompt(filler_sentences, key)\n        distractor = f\"PASSKEY-{random.randint(1000, 9999)}\"\n        lp_true = logprob(model, tokenizer, prompt, key, torch_device)\n        lp_false = logprob(model, tokenizer, prompt, distractor, torch_device)\n        correct_base += int(lp_true > lp_false)\n        if memorize_cfg.enabled:\n            if memorize_cfg.use_fast_state:\n                if fast_state is None or memorize_cfg.reset:\n                    if not hasattr(model, \"init_fast_state\"):\n                        raise RuntimeError(\"Model does not support fast state memorization\")\n                    fast_state = model.init_fast_state()\n                stats = memorize_sequence(\n                    model, tokenizer, prompt, torch_device, memorize_cfg, fast_state=fast_state\n                )\n                for k, v in stats.items():\n                    path_stats[k] = path_stats.get(k, 0.0) + v\n                lp_true_mem = logprob(\n                    model, tokenizer, prompt, key, torch_device, fast_state=fast_state\n                )\n                lp_false_mem = logprob(\n                    model, tokenizer, prompt, distractor, torch_device, fast_state=fast_state\n                )\n                correct_mem += int(lp_true_mem > lp_false_mem)\n            else:\n                stats = memorize_sequence(model, tokenizer, prompt, torch_device, memorize_cfg)\n                for k, v in stats.items():\n                    path_stats[k] = path_stats.get(k, 0.0) + v\n                lp_true_mem = logprob(model, tokenizer, prompt, key, torch_device)\n                lp_false_mem = logprob(model, tokenizer, prompt, distractor, torch_device)\n                correct_mem += int(lp_true_mem > lp_false_mem)\n                if memorize_cfg.reset and base_state is not None:\n                    restore_state_dict(model, base_state)\n        else:\n            correct_mem += int(lp_true > lp_false)\n    base_acc = correct_base / samples if samples else 0.0\n    mem_acc = correct_mem / samples if samples else 0.0\n    result = {\n        \"samples\": samples,\n        \"filler_sentences\": filler_sentences,\n        \"accuracy_base\": base_acc,\n        \"accuracy_memorize\": mem_acc,\n        \"accuracy_delta\": mem_acc - base_acc,\n        \"path_stats\": path_stats,\n    }\n    if memorize_cfg.enabled:\n        result[\"memorize_paths\"] = (\n            \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths)\n        )\n        if memorize_cfg.surprise_threshold is not None:\n            result[\"memorize_surprise_threshold\"] = memorize_cfg.surprise_threshold\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(result, indent=2))\n    typer.echo(f\"[passkey] Saved results to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/pg19_perplexity.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nimport typer\nfrom datasets import load_dataset\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.memorize import (\n    MemorizeConfig,\n    memorize_sequence,\n    restore_state_dict,\n    snapshot_state_dict,\n)\nfrom nested_learning.tokenizer import SentencePieceTokenizer\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(add_completion=False, help=\"Compute PG-19 perplexity for a checkpoint.\")\n\n\ndef load_model(config: Path, checkpoint: Path, device: torch.device):\n    cfg = OmegaConf.load(config)\n    cfg = unwrap_config(cfg)\n    model = build_model_from_cfg(cfg.model)\n    state = torch.load(checkpoint, map_location=\"cpu\", weights_only=False)\n    state_dict = state[\"model\"] if \"model\" in state else state\n    missing, unexpected = model.load_state_dict(state_dict, strict=False)\n    if missing or unexpected:\n        print(f\"[pg19] Warning: mismatch missing={len(missing)} unexpected={len(unexpected)}\")\n    return model.to(device).eval()\n\n\ndef _nll_for_text(\n    model,\n    tokenizer,\n    text: str,\n    device: torch.device,\n    max_seq: int,\n    *,\n    fast_state=None,\n) -> tuple[float, int] | None:\n    tokens = tokenizer.encode(text, add_bos=True, add_eos=True)\n    if tokens.size(0) < 2:\n        return None\n    if tokens.size(0) > max_seq:\n        tokens = tokens[:max_seq]\n    tokens = tokens.to(device).unsqueeze(0)\n    with torch.no_grad():\n        logits = model(tokens, fast_state=fast_state) if fast_state is not None else model(tokens)\n        log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)\n        targets = tokens[:, 1:]\n        gathered = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)\n        return -gathered.sum().item(), targets.numel()\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra model config.\"),\n    checkpoint: Path = typer.Option(..., help=\"Checkpoint path.\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece model path.\"),\n    max_samples: int = typer.Option(64, help=\"Number of PG-19 samples.\"),\n    device: str = typer.Option(\"cuda:0\" if torch.cuda.is_available() else \"cpu\"),\n    output: Path = typer.Option(Path(\"eval/pg19_perplexity.json\")),\n    context_tokens: int = typer.Option(\n        2048, help=\"Truncate text to this many tokens before scoring.\"\n    ),\n    memorize: bool = typer.Option(False, help=\"Apply test-time memorization to each excerpt.\"),\n    memorize_steps: int = typer.Option(1, help=\"Memorization passes per excerpt.\"),\n    memorize_no_reset: bool = typer.Option(False, help=\"Retain memory between excerpts.\"),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=\"Comma-separated memory paths to update during memorization (e.g., 'titan,cms_fast').\",\n    ),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm required before memorizing an excerpt.\"\n    ),\n    eval_state_mode: str = typer.Option(\n        \"reset_per_sample\",\n        help=\"Streaming eval state mode. Currently only 'reset_per_sample' is supported here.\",\n    ),\n) -> None:\n    if eval_state_mode.strip().lower() not in {\"reset\", \"isolated\", \"reset_per_sample\"}:\n        raise typer.BadParameter(\n            \"PG-19 script currently supports eval_state_mode=reset_per_sample only.\"\n        )\n    torch_device = resolve_device(device)\n    model = load_model(config, checkpoint, torch_device)\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n    dataset = load_dataset(\"pg19\", split=\"test\", streaming=True, trust_remote_code=True).shuffle(\n        seed=42\n    )\n    total_tokens = 0\n    total_nll_base = 0.0\n    total_nll_mem = 0.0\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=False,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n    )\n    base_state = (\n        snapshot_state_dict(model)\n        if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset\n        else None\n    )\n    fast_state = None\n    processed = 0\n    for idx, sample in enumerate(dataset):\n        if idx >= max_samples:\n            break\n        text = sample.get(\"text\") or sample.get(\"passage\")\n        if not text:\n            continue\n        nll_tot = _nll_for_text(model, tokenizer, text, torch_device, context_tokens)\n        if nll_tot is None:\n            continue\n        nll_base, tokens_seen = nll_tot\n        total_nll_base += nll_base\n        total_tokens += tokens_seen\n        if memorize_cfg.enabled:\n            if memorize_cfg.use_fast_state:\n                if fast_state is None or memorize_cfg.reset:\n                    if not hasattr(model, \"init_fast_state\"):\n                        raise RuntimeError(\"Model does not support fast state memorization\")\n                    fast_state = model.init_fast_state()\n                memorize_sequence(\n                    model, tokenizer, text[:1024], torch_device, memorize_cfg, fast_state=fast_state\n                )\n                nll_mem = _nll_for_text(\n                    model, tokenizer, text, torch_device, context_tokens, fast_state=fast_state\n                )\n                if nll_mem is not None:\n                    total_nll_mem += nll_mem[0]\n            else:\n                memorize_sequence(model, tokenizer, text[:1024], torch_device, memorize_cfg)\n                nll_mem = _nll_for_text(model, tokenizer, text, torch_device, context_tokens)\n                if nll_mem is not None:\n                    total_nll_mem += nll_mem[0]\n                if memorize_cfg.reset and base_state is not None:\n                    restore_state_dict(model, base_state)\n        else:\n            total_nll_mem += nll_base\n        processed += 1\n    ppl_base = float(torch.exp(torch.tensor(total_nll_base / max(1, total_tokens))))\n    ppl_mem = float(torch.exp(torch.tensor(total_nll_mem / max(1, total_tokens))))\n    payload = {\n        \"samples\": processed,\n        \"tokens\": total_tokens,\n        \"ppl_base\": ppl_base,\n        \"ppl_memorize\": ppl_mem,\n        \"ppl_delta\": ppl_base - ppl_mem,\n    }\n    if memorize_cfg.enabled:\n        payload[\"memorize_paths\"] = (\n            \"all\" if memorize_cfg.paths is None else \",\".join(memorize_cfg.paths)\n        )\n        payload[\"memorize_surprise_threshold\"] = memorize_cfg.surprise_threshold\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(payload, indent=2))\n    typer.echo(f\"[pg19] Saved perplexity to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/phase2_memorization_delta_smoke.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport torch\nimport typer\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.memorize import MemorizeConfig, memorize_tokens\nfrom nested_learning.model import HOPEModel, ModelConfig\n\napp = typer.Typer(\n    add_completion=False,\n    help=(\n        \"CPU-friendly smoke: show HOPE-Attention adapts via CMS updates while Transformer does not.\"\n    ),\n)\n\n\ndef _build_model(*, variant: str, vocab_size: int, dim: int, layers: int, heads: int) -> HOPEModel:\n    titan = LevelSpec(name=\"titan\", update_period=1, optimizer_key=\"titan_opt\")\n    cms = (LevelSpec(name=\"cms_fast\", update_period=1, optimizer_key=\"cms_opt\"),)\n    cfg = ModelConfig(\n        vocab_size=vocab_size,\n        dim=dim,\n        num_layers=layers,\n        heads=heads,\n        titan_level=titan,\n        cms_levels=cms,\n        optimizers=None,\n        teach_scale=0.1,\n        block_variant=variant,\n    )\n    return HOPEModel(cfg).eval()\n\n\ndef _run_once(\n    *,\n    variant: str,\n    tokens: torch.Tensor,\n    seed: int,\n) -> dict:\n    torch.manual_seed(seed)\n    model = _build_model(\n        variant=variant,\n        vocab_size=int(tokens.max().item() + 1),\n        dim=16,\n        layers=1,\n        heads=4,\n    ).to(tokens.device)\n    fast_state = model.init_fast_state()\n    with torch.no_grad():\n        before = model(tokens, fast_state=fast_state).detach()\n    cfg = MemorizeConfig(enabled=True, steps=1, use_fast_state=True, paths=(\"cms_fast\",))\n    stats = memorize_tokens(model, tokens, cfg, fast_state=fast_state)\n    with torch.no_grad():\n        after = model(tokens, fast_state=fast_state).detach()\n    return {\n        \"delta_mean_abs\": float((after - before).abs().mean().item()),\n        \"outputs_identical\": bool(torch.allclose(before, after, atol=0.0, rtol=0.0)),\n        \"cms_fast_update_events\": float(stats.get(\"cms_fast_update_events\", 0.0)),\n        \"cms_fast_updates\": float(stats.get(\"cms_fast_updates\", 0.0)),\n        \"titan_update_events\": float(stats.get(\"titan_update_events\", 0.0)),\n    }\n\n\n@app.command()\ndef main(\n    seed: int = typer.Option(0, help=\"Torch RNG seed (affects weights).\"),\n    vocab_size: int = typer.Option(32, help=\"Synthetic vocab size.\"),\n    seq_len: int = typer.Option(16, help=\"Token sequence length.\"),\n    batch_size: int = typer.Option(1, help=\"Batch size.\"),\n    device: str = typer.Option(\"cpu\", help=\"cpu or cuda:<idx>.\"),\n    output: Path = typer.Option(\n        Path(\"eval/phase2_memorization_delta_smoke.json\"), help=\"Where to write results.\"\n    ),\n) -> None:\n    torch_device = resolve_device(device)\n    token_gen = torch.Generator(device=\"cpu\").manual_seed(1337)\n    tokens = torch.randint(0, vocab_size, (batch_size, seq_len), generator=token_gen).to(\n        torch_device\n    )\n    results = {\n        \"seed\": int(seed),\n        \"vocab_size\": int(vocab_size),\n        \"seq_len\": int(seq_len),\n        \"batch_size\": int(batch_size),\n        \"hope_attention\": _run_once(variant=\"hope_attention\", tokens=tokens, seed=seed),\n        \"transformer\": _run_once(variant=\"transformer\", tokens=tokens, seed=seed),\n    }\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(results, indent=2))\n    typer.echo(f\"[phase2] wrote {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/plot_continual_classification.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import List\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport typer\n\napp = typer.Typer(\n    add_completion=False, help=\"Plot continual classification task matrix + forgetting bars.\"\n)\n\n\n@app.command()\ndef main(\n    continual_json: Path = typer.Option(\n        ..., help=\"Output JSON from scripts/eval/continual_classification.py\"\n    ),\n    output: Path = typer.Option(Path(\"reports/plots/continual_classification.png\")),\n    title: str = typer.Option(\"Continual Classification\", help=\"Plot title\"),\n) -> None:\n    payload = json.loads(continual_json.read_text())\n    tasks = payload.get(\"tasks\", [])\n    matrix = payload.get(\"result\", {}).get(\"task_accuracy_matrix\", [])\n    forgetting = payload.get(\"result\", {}).get(\"per_task_forgetting\", [])\n\n    task_ids: List[str] = [str(t.get(\"task_id\", idx)) for idx, t in enumerate(tasks)]\n    data = np.array(matrix, dtype=np.float32)\n    mask = np.isnan(data)\n    masked = np.ma.array(data, mask=mask)\n\n    fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12, 4), gridspec_kw={\"width_ratios\": [2, 1]})\n    im = ax0.imshow(masked, vmin=0.0, vmax=1.0, cmap=\"viridis\")\n    ax0.set_title(f\"{title} – Task Accuracy Matrix\")\n    ax0.set_xlabel(\"After Task\")\n    ax0.set_ylabel(\"Eval Task\")\n    ax0.set_xticks(range(len(task_ids)))\n    ax0.set_yticks(range(len(task_ids)))\n    ax0.set_xticklabels(task_ids, rotation=90)\n    ax0.set_yticklabels(task_ids)\n    fig.colorbar(im, ax=ax0, fraction=0.046, pad=0.04, label=\"Accuracy\")\n\n    f = (\n        np.array(forgetting, dtype=np.float32)\n        if forgetting\n        else np.zeros((len(task_ids),), dtype=np.float32)\n    )\n    ax1.bar(range(len(task_ids)), f)\n    ax1.set_title(\"Forgetting per Task\")\n    ax1.set_xlabel(\"Task\")\n    ax1.set_ylabel(\"Max - Final Acc\")\n    ax1.set_xticks(range(len(task_ids)))\n    ax1.set_xticklabels(task_ids, rotation=90)\n\n    fig.tight_layout()\n    output.parent.mkdir(parents=True, exist_ok=True)\n    fig.savefig(output, dpi=160)\n    plt.close(fig)\n    typer.echo(f\"[plot] Wrote {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/plot_forgetting.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\n\nimport matplotlib.pyplot as plt\nimport typer\n\napp = typer.Typer(add_completion=False, help=\"Plot continual-learning forgetting curves.\")\n\n\n@app.command()\ndef main(\n    continual_json: Path = typer.Option(..., help=\"Path to eval/continual_*.json output.\"),\n    output: Path = typer.Option(Path(\"reports/plots/continual_forgetting.png\")),\n    segment: str = typer.Option(None, help=\"Specific segment to plot (default: all).\"),\n) -> None:\n    data = json.loads(continual_json.read_text())\n    checkpoints = []\n    baseline = []\n    memorize = []\n    for entry in data:\n        checkpoints.append(entry.get(\"checkpoint\"))\n        seg_losses = entry.get(\"segment_losses\", {})\n        base_losses = entry.get(\"segment_baseline_losses\", seg_losses)\n        key = segment or next(iter(seg_losses))\n        baseline.append(base_losses.get(key))\n        memorize.append(seg_losses.get(key))\n    plt.figure(figsize=(8, 4))\n    plt.plot(checkpoints, baseline, label=\"baseline CE\", marker=\"o\")\n    plt.plot(checkpoints, memorize, label=\"memorize CE\", marker=\"o\")\n    plt.xticks(rotation=45, ha=\"right\")\n    plt.ylabel(\"Cross-entropy\")\n    plt.title(f\"Continual forgetting ({segment or 'default segment'})\")\n    plt.legend()\n    output.parent.mkdir(parents=True, exist_ok=True)\n    plt.tight_layout()\n    plt.savefig(output)\n    typer.echo(f\"[plot] Saved plot to {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/plot_niah_suite.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Dict, List, Tuple\n\nimport matplotlib.pyplot as plt\nimport typer\n\napp = typer.Typer(add_completion=False, help=\"Plot NIAH suite accuracy vs context length.\")\n\n\n@app.command()\ndef main(\n    niah_suite_json: Path = typer.Option(..., help=\"Output JSON from scripts/eval/niah_suite.py\"),\n    output: Path = typer.Option(Path(\"reports/plots/niah_suite.png\")),\n    title: str = typer.Option(\"NIAH Suite\", help=\"Plot title\"),\n) -> None:\n    payload = json.loads(niah_suite_json.read_text())\n    results = payload.get(\"results\", [])\n    grouped: Dict[str, List[Tuple[int, float, float]]] = defaultdict(list)\n    for row in results:\n        variant = str(row.get(\"variant\", \"unknown\"))\n        length = int(row.get(\"context_tokens\", 0))\n        base = float(row.get(\"baseline_accuracy\", 0.0))\n        mem = float(row.get(\"memorize_accuracy\", base))\n        grouped[variant].append((length, base, mem))\n\n    variants = sorted(grouped.keys())\n    ncols = 2\n    nrows = (len(variants) + ncols - 1) // ncols\n    fig, axes = plt.subplots(nrows, ncols, figsize=(12, max(3, 3 * nrows)), squeeze=False)\n    axes_flat = axes.flatten()\n    for ax, variant in zip(axes_flat, variants, strict=False):\n        series = sorted(grouped[variant], key=lambda t: t[0])\n        xs = [t[0] for t in series]\n        base = [t[1] for t in series]\n        mem = [t[2] for t in series]\n        ax.plot(xs, base, label=\"baseline\")\n        ax.plot(xs, mem, label=\"memorize\")\n        ax.set_title(variant)\n        ax.set_xlabel(\"context_tokens\")\n        ax.set_ylabel(\"accuracy\")\n        ax.set_ylim(0.0, 1.0)\n        ax.legend()\n\n    for ax in axes_flat[len(variants) :]:\n        ax.axis(\"off\")\n\n    fig.suptitle(title)\n    fig.tight_layout()\n    output.parent.mkdir(parents=True, exist_ok=True)\n    fig.savefig(output, dpi=160)\n    plt.close(fig)\n    typer.echo(f\"[plot] Wrote {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/run_pilot_suite.sh",
    "content": "#!/usr/bin/env bash\n#\n# Convenience wrapper to run the Stage 2 evaluation suite (zero-shot, NIAH, continual)\n# on the pilot HOPE checkpoint and optional TITAN baseline.\n#\n# Environment variables (override as needed):\n#   HOPE_CONFIG          (default configs/pilot.yaml)\n#   HOPE_CHECKPOINT      (default artifacts/checkpoints/pilot/step_latest.pt)\n#   TITAN_CONFIG         (optional)\n#   TITAN_CHECKPOINT     (optional)\n#   TOKENIZER_PATH       (default artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model)\n#   DEVICE               (default cuda:1)\n#   MAX_SAMPLES          (default 256 for zero-shot)\n#   NIAH_CONTEXTS        (space-separated list, default \"2048 4096 8192 16384 32768 65536\")\n#   NIAH_SAMPLES         (default 8 per context)\n#   CONT_BATCH           (default 4)\n#   CONT_MAX_BATCHES     (default 20)\n\nset -euo pipefail\n\nHOPE_CONFIG=${HOPE_CONFIG:-configs/pilot.yaml}\nTOKENIZER_PATH=${TOKENIZER_PATH:-artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model}\nDEVICE=${DEVICE:-cuda:1}\nMAX_SAMPLES=${MAX_SAMPLES:-256}\nNIAH_CONTEXTS=${NIAH_CONTEXTS:-\"2048 4096 8192 16384 32768 65536\"}\nNIAH_SAMPLES=${NIAH_SAMPLES:-8}\nCONT_BATCH=${CONT_BATCH:-4}\nCONT_MAX_BATCHES=${CONT_MAX_BATCHES:-20}\nSEGMENTS_YAML=${SEGMENTS_YAML:-configs/data/continual_segments_sample.yaml}\nHOPE_CONT_CHECKPOINTS=${HOPE_CONT_CHECKPOINTS:-}\nPASSKEY_SAMPLES=${PASSKEY_SAMPLES:-64}\nPASSKEY_FILLER=${PASSKEY_FILLER:-256}\nPG19_SAMPLES=${PG19_SAMPLES:-32}\nCONT_PLOT_SEGMENT=${CONT_PLOT_SEGMENT:-refinedweb_2018}\nMEMORIZE_PATHS=${MEMORIZE_PATHS:-titan,cms_fast}\nHOPE_MEMORIZE_PATHS=${HOPE_MEMORIZE_PATHS:-${MEMORIZE_PATHS}}\nTITAN_MEMORIZE_PATHS=${TITAN_MEMORIZE_PATHS:-titan}\nMEMORIZE_SURPRISE_THRESHOLD=${MEMORIZE_SURPRISE_THRESHOLD:-0.02}\n\nresolve_checkpoint() {\n  local path=\"$1\"\n  if [[ -n \"${path}\" ]]; then\n    echo \"${path}\"\n    return\n  fi\n  local latest\n  latest=$(ls -1t artifacts/checkpoints/pilot/step_*.pt 2>/dev/null | head -n 1 || true)\n  if [[ -z \"${latest}\" ]]; then\n    echo \"\"\n  else\n    echo \"${latest}\"\n  fi\n}\n\nHOPE_CHECKPOINT=${HOPE_CHECKPOINT:-$(resolve_checkpoint \"\")}\nif [[ -z \"${HOPE_CHECKPOINT}\" ]]; then\n  echo \"[eval] No HOPE checkpoint supplied and none found under artifacts/checkpoints/pilot.\"\n  exit 1\nfi\nif [[ -z \"${HOPE_CONT_CHECKPOINTS}\" ]]; then\n  HOPE_CONT_CHECKPOINTS=\"${HOPE_CHECKPOINT}\"\nfi\n\nmkdir -p eval\nIFS=' ' read -r -a HOPE_CONT_LIST <<< \"${HOPE_CONT_CHECKPOINTS}\"\n\nrun_zero_shot() {\n  local config=$1\n  local ckpt=$2\n  local tag=$3\n  local memorize_paths=$4\n  UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/zeroshot.py \\\n    --config \"${config}\" \\\n    --checkpoint \"${ckpt}\" \\\n    --tokenizer-path \"${TOKENIZER_PATH}\" \\\n    --tasks all \\\n    --max-samples \"${MAX_SAMPLES}\" \\\n    --device \"${DEVICE}\" \\\n    --output \"eval/zeroshot_${tag}.json\" \\\n    --memorize \\\n    --memorize-steps 2 \\\n    --memorize-use-correct-answer \\\n    --memorize-paths \"${memorize_paths}\" \\\n    --memorize-surprise-threshold \"${MEMORIZE_SURPRISE_THRESHOLD}\"\n}\n\nrun_niah() {\n  local config=$1\n  local ckpt=$2\n  local tag=$3\n  local memorize_paths=$4\n  local args=()\n  for ctx in ${NIAH_CONTEXTS}; do\n    args+=(--context-lengths \"${ctx}\")\n  done\n  UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/niah.py \\\n    --config \"${config}\" \\\n    --checkpoint \"${ckpt}\" \\\n    --tokenizer-path \"${TOKENIZER_PATH}\" \\\n    \"${args[@]}\" \\\n    --samples-per-length \"${NIAH_SAMPLES}\" \\\n    --device \"${DEVICE}\" \\\n    --output \"eval/niah_${tag}.json\" \\\n    --memorize \\\n    --memorize-steps 2 \\\n    --memorize-use-correct-answer \\\n    --memorize-paths \"${memorize_paths}\" \\\n    --memorize-surprise-threshold \"${MEMORIZE_SURPRISE_THRESHOLD}\"\n}\n\nrun_continual() {\n  local config=$1\n  local tag=$2\n  local memorize_paths=$3\n  shift 3\n  local ckpts=(\"$@\")\n  if [[ ${#ckpts[@]} -eq 0 ]]; then\n    echo \"[eval] No checkpoints provided for continual eval (${tag}); skipping.\"\n    return\n  fi\n  UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/continual.py \\\n    --config \"${config}\" \\\n    --checkpoints \"${ckpts[@]}\" \\\n    --segments-yaml \"${SEGMENTS_YAML}\" \\\n    --tokenizer-path \"${TOKENIZER_PATH}\" \\\n    --batch-size \"${CONT_BATCH}\" \\\n    --max-batches \"${CONT_MAX_BATCHES}\" \\\n    --device \"${DEVICE}\" \\\n    --output \"eval/continual_${tag}.json\" \\\n    --memorize \\\n    --memorize-steps 1 \\\n    --memorize-paths \"${memorize_paths}\" \\\n    --memorize-surprise-threshold \"${MEMORIZE_SURPRISE_THRESHOLD}\"\n  if [[ ${#ckpts[@]} -gt 1 ]]; then\n    local plot_target=\"reports/plots/continual_${tag}_${CONT_PLOT_SEGMENT}.png\"\n    mkdir -p reports/plots\n    UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/plot_forgetting.py \\\n      --continual-json \"eval/continual_${tag}.json\" \\\n      --segment \"${CONT_PLOT_SEGMENT}\" \\\n      --output \"${plot_target}\"\n    echo \"[eval] Forgetting plot saved to ${plot_target}\"\n  fi\n}\n\nrun_passkey() {\n  local config=$1\n  local ckpt=$2\n  local tag=$3\n  local memorize_paths=$4\n  UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/passkey.py \\\n    --config \"${config}\" \\\n    --checkpoint \"${ckpt}\" \\\n    --tokenizer-path \"${TOKENIZER_PATH}\" \\\n    --samples \"${PASSKEY_SAMPLES}\" \\\n    --filler-sentences \"${PASSKEY_FILLER}\" \\\n    --device \"${DEVICE}\" \\\n    --output \"eval/passkey_${tag}.json\" \\\n    --memorize \\\n    --memorize-steps 2 \\\n    --memorize-paths \"${memorize_paths}\" \\\n    --memorize-surprise-threshold \"${MEMORIZE_SURPRISE_THRESHOLD}\"\n}\n\nrun_pg19() {\n  local config=$1\n  local ckpt=$2\n  local tag=$3\n  local memorize_paths=$4\n  UV_CACHE_DIR=/tmp/uv-cache UV_LINK_MODE=copy uv run python scripts/eval/pg19_perplexity.py \\\n    --config \"${config}\" \\\n    --checkpoint \"${ckpt}\" \\\n    --tokenizer-path \"${TOKENIZER_PATH}\" \\\n    --max-samples \"${PG19_SAMPLES}\" \\\n    --device \"${DEVICE}\" \\\n    --output \"eval/pg19_${tag}.json\" \\\n    --memorize \\\n    --memorize-paths \"${memorize_paths}\" \\\n    --memorize-surprise-threshold \"${MEMORIZE_SURPRISE_THRESHOLD}\"\n}\n\necho \"[eval] Running suite for HOPE (${HOPE_CHECKPOINT})\"\nrun_zero_shot \"${HOPE_CONFIG}\" \"${HOPE_CHECKPOINT}\" \"pilot\" \"${HOPE_MEMORIZE_PATHS}\"\nrun_niah \"${HOPE_CONFIG}\" \"${HOPE_CHECKPOINT}\" \"pilot\" \"${HOPE_MEMORIZE_PATHS}\"\nrun_continual \"${HOPE_CONFIG}\" \"pilot\" \"${HOPE_MEMORIZE_PATHS}\" \"${HOPE_CONT_LIST[@]}\"\nrun_passkey \"${HOPE_CONFIG}\" \"${HOPE_CHECKPOINT}\" \"pilot\" \"${HOPE_MEMORIZE_PATHS}\"\nrun_pg19 \"${HOPE_CONFIG}\" \"${HOPE_CHECKPOINT}\" \"pilot\" \"${HOPE_MEMORIZE_PATHS}\"\n\nif [[ -n \"${TITAN_CONFIG:-}\" && -n \"${TITAN_CHECKPOINT:-}\" ]]; then\n  echo \"[eval] Running suite for TITAN baseline (${TITAN_CHECKPOINT})\"\n  run_zero_shot \"${TITAN_CONFIG}\" \"${TITAN_CHECKPOINT}\" \"titan\" \"${TITAN_MEMORIZE_PATHS}\"\n  run_niah \"${TITAN_CONFIG}\" \"${TITAN_CHECKPOINT}\" \"titan\" \"${TITAN_MEMORIZE_PATHS}\"\n  IFS=' ' read -r -a TITAN_CONT_LIST <<< \"${TITAN_CHECKPOINTS:-$TITAN_CHECKPOINT}\"\n  run_continual \"${TITAN_CONFIG}\" \"titan\" \"${TITAN_MEMORIZE_PATHS}\" \"${TITAN_CONT_LIST[@]}\"\n  run_passkey \"${TITAN_CONFIG}\" \"${TITAN_CHECKPOINT}\" \"titan\" \"${TITAN_MEMORIZE_PATHS}\"\n  run_pg19 \"${TITAN_CONFIG}\" \"${TITAN_CHECKPOINT}\" \"titan\" \"${TITAN_MEMORIZE_PATHS}\"\nelse\n  echo \"[eval] TITAN baseline skipped (set TITAN_CONFIG and TITAN_CHECKPOINT to enable).\"\nfi\n\necho \"[eval] Pilot suite complete. Outputs saved under eval/.\"\n"
  },
  {
    "path": "scripts/eval/summarize_eval.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import Any, Dict, Iterable, List, Tuple\n\nimport typer\n\napp = typer.Typer(add_completion=False, help=\"Summarize eval JSONs into a small markdown table.\")\n\n\ndef _flatten_numeric(obj: Any, *, prefix: str = \"\") -> Dict[str, float]:\n    out: Dict[str, float] = {}\n    if isinstance(obj, dict):\n        for k, v in obj.items():\n            key = f\"{prefix}.{k}\" if prefix else str(k)\n            out.update(_flatten_numeric(v, prefix=key))\n        return out\n    if isinstance(obj, list):\n        # Avoid exploding large lists; only summarize scalar numeric lists.\n        if obj and all(isinstance(v, (int, float)) for v in obj):\n            out[prefix] = float(sum(float(v) for v in obj) / len(obj))\n        return out\n    if isinstance(obj, (int, float)):\n        out[prefix] = float(obj)\n    return out\n\n\ndef _expand_keys(flat: Dict[str, float], keys: Iterable[str]) -> List[str]:\n    resolved: List[str] = []\n    for key in keys:\n        key = key.strip()\n        if not key:\n            continue\n        if key.endswith(\"*\"):\n            prefix = key[:-1]\n            matches = sorted(k for k in flat.keys() if k.startswith(prefix))\n            resolved.extend(matches)\n        else:\n            resolved.append(key)\n    # De-duplicate while preserving order.\n    seen = set()\n    ordered: List[str] = []\n    for k in resolved:\n        if k in seen:\n            continue\n        seen.add(k)\n        ordered.append(k)\n    return ordered\n\n\ndef _render_table(rows: List[Tuple[str, Dict[str, float]]], keys: List[str]) -> str:\n    header = [\"file\", *keys]\n    lines = [\"| \" + \" | \".join(header) + \" |\", \"| \" + \" | \".join([\"---\"] * len(header)) + \" |\"]\n    for name, flat in rows:\n        cells = [name]\n        for key in keys:\n            value = flat.get(key)\n            if value is None:\n                cells.append(\"\")\n            else:\n                cells.append(f\"{value:.6g}\")\n        lines.append(\"| \" + \" | \".join(cells) + \" |\")\n    return \"\\n\".join(lines) + \"\\n\"\n\n\n@app.command()\ndef main(\n    inputs: List[Path] = typer.Option(..., help=\"Eval JSON files to summarize.\"),\n    keys: List[str] = typer.Option(\n        [],\n        help=(\n            \"Dotted numeric keys to include (supports '*' suffix prefix expansion). \"\n            \"If omitted, uses a small default set.\"\n        ),\n    ),\n    output: Path = typer.Option(Path(\"eval/summary.md\"), help=\"Markdown output path.\"),\n) -> None:\n    rows: List[Tuple[str, Dict[str, float]]] = []\n    for path in inputs:\n        payload = json.loads(path.read_text())\n        flat = _flatten_numeric(payload)\n        rows.append((path.name, flat))\n\n    if not rows:\n        raise typer.BadParameter(\"No input files provided.\")\n\n    if not keys:\n        # Reasonable defaults across our eval scripts.\n        keys = [\n            \"accuracy\",\n            \"accuracy_base\",\n            \"accuracy_memorize\",\n            \"accuracy_delta\",\n            \"avg_accuracy_final\",\n            \"avg_forgetting\",\n        ]\n\n    expanded = _expand_keys(rows[0][1], keys)\n    for _name, flat in rows[1:]:\n        expanded = sorted(set(expanded) | set(_expand_keys(flat, keys)))\n\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(_render_table(rows, expanded))\n    typer.echo(f\"[summary] Wrote {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/eval/zeroshot.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import annotations\n\nimport json\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Callable, Dict, Iterable, List, Tuple\n\nimport torch\nimport typer\nfrom datasets import load_dataset\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.memorize import (\n    MemorizeConfig,\n    memorize_sequence,\n    restore_state_dict,\n    snapshot_state_dict,\n)\nfrom nested_learning.tokenizer import SentencePieceTokenizer\nfrom nested_learning.training import build_model_from_cfg, unwrap_config\n\napp = typer.Typer(add_completion=False, help=\"Zero-shot evaluation harness for HOPE.\")\nHF_DATASET_KWARGS = {\"trust_remote_code\": True}\n\n\ndef load_model(config_path: Path, checkpoint: Path, device: torch.device):\n    cfg = OmegaConf.load(config_path)\n    cfg = unwrap_config(cfg)\n    model = build_model_from_cfg(cfg.model)\n    state = torch.load(checkpoint, map_location=\"cpu\", weights_only=False)\n    state_dict = state[\"model\"] if \"model\" in state else state\n    missing, unexpected = model.load_state_dict(state_dict, strict=False)\n    if missing or unexpected:\n        print(\n            \"[eval] Warning: state_dict mismatch \"\n            f\"(missing={len(missing)} unexpected={len(unexpected)}) – continuing.\"\n        )\n    return model.to(device).eval()\n\n\ndef score_text(\n    model, tokenizer: SentencePieceTokenizer, text: str, device: torch.device, *, fast_state=None\n) -> float:\n    tokens = tokenizer.encode(text)\n    tokens = tokens.to(device)\n    with torch.no_grad():\n        logits = (\n            model(tokens.unsqueeze(0), fast_state=fast_state)\n            if fast_state is not None\n            else model(tokens.unsqueeze(0))\n        )\n        log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)\n        target = tokens.unsqueeze(0)[:, 1:]\n        gathered = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)\n        return gathered.sum().item()\n\n\ndef evaluate_multiple_choice(\n    task_name: str,\n    dataset_iter: Iterable[dict],\n    build_texts_fn: Callable[[dict], Tuple[str, List[str], int]],\n    tokenizer: SentencePieceTokenizer,\n    model,\n    device: torch.device,\n    max_samples: int | None,\n    memorize_cfg: MemorizeConfig,\n) -> Dict[str, float]:\n    correct_mem = 0\n    correct_base = 0\n    total = 0\n    base_state: Dict[str, torch.Tensor] | None = None\n    fast_state = None\n    path_stats: Dict[str, float] = defaultdict(float)\n    for sample in tqdm(dataset_iter, desc=task_name.upper()):\n        prompt, texts, answer_idx = build_texts_fn(sample)\n        scores_base = [score_text(model, tokenizer, t, device) for t in texts]\n        pred_base = int(max(range(len(scores_base)), key=lambda i: scores_base[i]))\n        correct_base += int(pred_base == answer_idx)\n        if memorize_cfg.enabled:\n            memorize_text = prompt\n            if memorize_cfg.use_correct_answer:\n                memorize_text = f\"{prompt} {texts[answer_idx]}\".strip()\n            if memorize_cfg.use_fast_state:\n                if fast_state is None or memorize_cfg.reset:\n                    if not hasattr(model, \"init_fast_state\"):\n                        raise RuntimeError(\"Model does not support fast state memorization\")\n                    fast_state = model.init_fast_state()\n                stats = memorize_sequence(\n                    model, tokenizer, memorize_text, device, memorize_cfg, fast_state=fast_state\n                )\n                for key, value in stats.items():\n                    path_stats[key] += value\n                scores_eval = [\n                    score_text(model, tokenizer, t, device, fast_state=fast_state) for t in texts\n                ]\n                pred_eval = int(max(range(len(scores_eval)), key=lambda i: scores_eval[i]))\n                correct_mem += int(pred_eval == answer_idx)\n            else:\n                if memorize_cfg.reset and base_state is None:\n                    base_state = snapshot_state_dict(model)\n                stats = memorize_sequence(model, tokenizer, memorize_text, device, memorize_cfg)\n                for key, value in stats.items():\n                    path_stats[key] += value\n                scores_eval = [score_text(model, tokenizer, t, device) for t in texts]\n                pred_eval = int(max(range(len(scores_eval)), key=lambda i: scores_eval[i]))\n                correct_mem += int(pred_eval == answer_idx)\n        else:\n            correct_mem += int(pred_base == answer_idx)\n        total += 1\n        if (\n            memorize_cfg.enabled\n            and (not memorize_cfg.use_fast_state)\n            and memorize_cfg.reset\n            and base_state is not None\n        ):\n            restore_state_dict(model, base_state)\n        if max_samples and total >= max_samples:\n            break\n    accuracy = correct_mem / total if total else 0.0\n    result: Dict[str, float] = {f\"{task_name}_accuracy\": accuracy, f\"{task_name}_samples\": total}\n    if memorize_cfg.enabled:\n        baseline_acc = correct_base / total if total else 0.0\n        result[f\"{task_name}_baseline_accuracy\"] = baseline_acc\n        result[f\"{task_name}_memorize_accuracy\"] = accuracy\n        result[f\"{task_name}_memorize_delta\"] = accuracy - baseline_acc\n        if memorize_cfg.paths is None:\n            result[f\"{task_name}_memorize_paths\"] = \"all\"\n        else:\n            result[f\"{task_name}_memorize_paths\"] = \",\".join(memorize_cfg.paths)\n        if memorize_cfg.surprise_threshold is not None:\n            result[f\"{task_name}_memorize_surprise_threshold\"] = memorize_cfg.surprise_threshold\n        for key, value in path_stats.items():\n            result[f\"{task_name}_{key}\"] = value\n    return result\n\n\ndef build_piqa_texts(sample: dict) -> Tuple[str, List[str], int]:\n    prompt = sample[\"goal\"].strip()\n    options = [sample[\"sol1\"].strip(), sample[\"sol2\"].strip()]\n    texts = [f\"{prompt} {opt}\" for opt in options]\n    target = sample[\"label\"]\n    return prompt, texts, target\n\n\ndef eval_piqa(model, tokenizer, device, max_samples, memorize_cfg):\n    dataset = load_dataset(\"piqa\", split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        \"piqa\", dataset, build_piqa_texts, tokenizer, model, device, max_samples, memorize_cfg\n    )\n\n\ndef build_hellaswag_texts(sample: dict) -> Tuple[str, List[str], int]:\n    prompt = f\"{sample['ctx_a'].strip()} {sample['ctx_b'].strip()}\".strip()\n    endings = [ending.strip() for ending in sample[\"endings\"]]\n    texts = [f\"{prompt} {ending}\" for ending in endings]\n    label = sample[\"label\"]\n    target = int(label) if not isinstance(label, int) else label\n    return prompt, texts, target\n\n\ndef eval_hellaswag(model, tokenizer, device, max_samples, memorize_cfg):\n    dataset = load_dataset(\"hellaswag\", split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        \"hellaswag\",\n        dataset,\n        build_hellaswag_texts,\n        tokenizer,\n        model,\n        device,\n        max_samples,\n        memorize_cfg,\n    )\n\n\ndef build_winogrande_texts(sample: dict) -> Tuple[str, List[str], int]:\n    sentence = sample[\"sentence\"]\n    options = [sample[\"option1\"].strip(), sample[\"option2\"].strip()]\n    texts = [sentence.replace(\"_\", opt) for opt in options]\n    target = int(sample[\"answer\"]) - 1\n    return sentence, texts, target\n\n\ndef eval_winogrande(model, tokenizer, device, max_samples, memorize_cfg):\n    dataset = load_dataset(\"winogrande\", \"winogrande_xl\", split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        \"winogrande\",\n        dataset,\n        build_winogrande_texts,\n        tokenizer,\n        model,\n        device,\n        max_samples,\n        memorize_cfg,\n    )\n\n\ndef build_arc_texts(sample: dict) -> Tuple[str, List[str], int]:\n    prompt = sample[\"question\"].strip()\n    choice_texts = sample[\"choices\"][\"text\"]\n    labels = sample[\"choices\"][\"label\"]\n    texts = [f\"{prompt} {choice.strip()}\" for choice in choice_texts]\n    target = labels.index(sample[\"answerKey\"])\n    return prompt, texts, target\n\n\ndef eval_arc(\n    model, tokenizer, device, max_samples, difficulty: str, memorize_cfg: MemorizeConfig\n) -> Dict[str, float]:\n    dataset = load_dataset(\"ai2_arc\", difficulty, split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        f\"arc_{difficulty.lower()}\",\n        dataset,\n        build_arc_texts,\n        tokenizer,\n        model,\n        device,\n        max_samples,\n        memorize_cfg,\n    )\n\n\ndef build_boolq_texts(sample: dict) -> Tuple[str, List[str], int]:\n    prompt = f\"{sample['passage'].strip()}\\nQuestion: {sample['question'].strip()}\\nAnswer:\"\n    texts = [f\"{prompt} yes\", f\"{prompt} no\"]\n    target = 0 if sample[\"answer\"] else 1\n    return prompt, texts, target\n\n\ndef eval_boolq(model, tokenizer, device, max_samples, memorize_cfg):\n    dataset = load_dataset(\"boolq\", split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        \"boolq\", dataset, build_boolq_texts, tokenizer, model, device, max_samples, memorize_cfg\n    )\n\n\ndef build_siqa_texts(sample: dict) -> Tuple[str, List[str], int]:\n    prompt = f\"Context: {sample['context'].strip()} Question: {sample['question'].strip()} Answer:\"\n    options = [sample[\"answerA\"].strip(), sample[\"answerB\"].strip(), sample[\"answerC\"].strip()]\n    texts = [f\"{prompt} {opt}\" for opt in options]\n    target = int(sample[\"label\"]) - 1\n    return prompt, texts, target\n\n\ndef eval_siqa(model, tokenizer, device, max_samples, memorize_cfg):\n    dataset = load_dataset(\"social_i_qa\", split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        \"siqa\", dataset, build_siqa_texts, tokenizer, model, device, max_samples, memorize_cfg\n    )\n\n\ndef build_commonsenseqa_texts(sample: dict) -> Tuple[str, List[str], int]:\n    prompt = sample[\"question\"].strip()\n    choice_texts = sample[\"choices\"][\"text\"]\n    labels = sample[\"choices\"][\"label\"]\n    texts = [f\"{prompt} {choice.strip()}\" for choice in choice_texts]\n    target = labels.index(sample[\"answerKey\"])\n    return prompt, texts, target\n\n\ndef eval_commonsenseqa(model, tokenizer, device, max_samples, memorize_cfg):\n    dataset = load_dataset(\"commonsense_qa\", split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        \"commonsenseqa\",\n        dataset,\n        build_commonsenseqa_texts,\n        tokenizer,\n        model,\n        device,\n        max_samples,\n        memorize_cfg,\n    )\n\n\ndef build_openbookqa_texts(sample: dict) -> Tuple[str, List[str], int]:\n    prompt = sample[\"question_stem\"].strip()\n    choice_texts = sample[\"choices\"][\"text\"]\n    labels = sample[\"choices\"][\"label\"]\n    texts = [f\"{prompt} {choice.strip()}\" for choice in choice_texts]\n    target = labels.index(sample[\"answerKey\"])\n    return prompt, texts, target\n\n\ndef eval_openbookqa(model, tokenizer, device, max_samples, memorize_cfg):\n    dataset = load_dataset(\"openbookqa\", \"main\", split=\"validation\", **HF_DATASET_KWARGS)\n    return evaluate_multiple_choice(\n        \"openbookqa\",\n        dataset,\n        build_openbookqa_texts,\n        tokenizer,\n        model,\n        device,\n        max_samples,\n        memorize_cfg,\n    )\n\n\nTASK_EVALUATORS = {\n    \"piqa\": eval_piqa,\n    \"hellaswag\": eval_hellaswag,\n    \"winogrande\": eval_winogrande,\n    \"arc_easy\": lambda model, tok, dev, n, mem: eval_arc(model, tok, dev, n, \"ARC-Easy\", mem),\n    \"arc_challenge\": lambda model, tok, dev, n, mem: eval_arc(\n        model, tok, dev, n, \"ARC-Challenge\", mem\n    ),\n    \"boolq\": eval_boolq,\n    \"siqa\": eval_siqa,\n    \"commonsenseqa\": eval_commonsenseqa,\n    \"openbookqa\": eval_openbookqa,\n}\n\n\n@app.command()\ndef main(\n    config: Path = typer.Option(..., help=\"Hydra model config path.\"),\n    checkpoint: Path = typer.Option(..., help=\"Checkpoint file (state dict).\"),\n    tokenizer_path: Path = typer.Option(..., help=\"SentencePiece model path.\"),\n    tasks: str = typer.Option(\"piqa\", help=\"Comma-separated list of tasks or 'all'.\"),\n    max_samples: int = typer.Option(500, help=\"Max samples per task (0 = entire split).\"),\n    output: Path = typer.Option(Path(\"eval/zeroshot_results.json\"), help=\"Output JSON file.\"),\n    device: str = typer.Option(\n        \"cuda:0\" if torch.cuda.is_available() else \"cpu\", help=\"Device to run eval on.\"\n    ),\n    list_tasks: bool = typer.Option(False, \"--list-tasks\", help=\"List available tasks and exit.\"),\n    memorize: bool = typer.Option(False, help=\"Enable test-time memorization updates.\"),\n    memorize_steps: int = typer.Option(1, help=\"Number of memorize passes per sample.\"),\n    memorize_use_correct_answer: bool = typer.Option(\n        False, help=\"When memorizing, include the correct answer text (for ablations).\"\n    ),\n    memorize_no_reset: bool = typer.Option(\n        False, help=\"If set, retain memorization across samples.\"\n    ),\n    memorize_surprise_threshold: float = typer.Option(\n        None, help=\"Minimum teach-signal norm required before applying memorization.\"\n    ),\n    memorize_paths: str = typer.Option(\n        \"all\",\n        help=(\n            \"Comma-separated memory paths to update (e.g., 'titan,cms_fast'); \"\n            \"use 'all' to allow every path.\"\n        ),\n    ),\n    eval_state_mode: str = typer.Option(\n        \"reset_per_sample\",\n        help=\"Streaming eval state mode. Currently only 'reset_per_sample' is supported here.\",\n    ),\n) -> None:\n    available = list(TASK_EVALUATORS.keys())\n    if eval_state_mode.strip().lower() not in {\"reset\", \"isolated\", \"reset_per_sample\"}:\n        raise typer.BadParameter(\n            \"zeroshot multiple-choice scoring only supports eval_state_mode=reset_per_sample \"\n            \"in this implementation.\"\n        )\n    if list_tasks:\n        typer.echo(\"Available tasks: \" + \", \".join(available))\n        raise typer.Exit(0)\n\n    selected_tasks = (\n        available if tasks.lower() == \"all\" else [t.strip().lower() for t in tasks.split(\",\")]\n    )\n    torch_device = resolve_device(device)\n    model = load_model(config, checkpoint, torch_device)\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n    if memorize_paths.lower() == \"all\":\n        allowed_paths = None\n    else:\n        allowed_paths = tuple(path.strip() for path in memorize_paths.split(\",\") if path.strip())\n    memorize_cfg = MemorizeConfig(\n        enabled=memorize,\n        steps=max(1, memorize_steps),\n        reset=not memorize_no_reset,\n        use_correct_answer=memorize_use_correct_answer,\n        surprise_threshold=memorize_surprise_threshold,\n        paths=allowed_paths,\n    )\n\n    results: Dict[str, float] = {}\n    for task in selected_tasks:\n        evaluator = TASK_EVALUATORS.get(task)\n        if evaluator is None:\n            raise ValueError(f\"Unsupported task '{task}'. Valid tasks: {available}\")\n        metrics = evaluator(\n            model,\n            tokenizer,\n            torch_device,\n            None if max_samples <= 0 else max_samples,\n            memorize_cfg,\n        )\n        results.update(metrics)\n\n    output.parent.mkdir(parents=True, exist_ok=True)\n    output.write_text(json.dumps(results, indent=2))\n    typer.echo(f\"[Eval] Saved metrics for tasks {selected_tasks} -> {output}\")\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "scripts/package_pilot_release.sh",
    "content": "#!/usr/bin/env bash\n#\n# Bundle the latest pilot checkpoint + metadata into artifacts/pilot_release/.\n# Usage:\n#   scripts/package_pilot_release.sh [hope_checkpoint_path] [titan_checkpoint_path]\n# If no path is provided, the newest file under artifacts/checkpoints/pilot is used.\n\nset -euo pipefail\n\nRELEASE_DIR=\"artifacts/pilot_release\"\nCHECKPOINT_DIR=\"artifacts/checkpoints/pilot\"\nCONFIG_PATH=\"configs/pilot.yaml\"\nLOG_PATTERNS=( \"logs/pilot_train*.log\" \"logs/pilot_train*.json\" \"logs/pilot_relaunch*.log\" \"logs/pilot_relaunch*.json\" )\nMETADATA_PATH=\"${RELEASE_DIR}/metadata.json\"\nMANIFEST_PATH=\"${RELEASE_DIR}/MANIFEST.txt\"\nEVAL_PATTERNS=( \"eval/*_pilot.json\" \"eval/*_titan.json\" )\nPLOT_PATTERNS=( \"reports/plots/continual_pilot_*.png\" \"reports/plots/continual_titan_*.png\" )\n\nmkdir -p \"${RELEASE_DIR}\"\n\ncopy_sidecars() {\n  local ckpt_path=\"$1\"\n  local dest_prefix=\"$2\"\n  local src_prefix=\"${ckpt_path%.pt}\"\n  local exts=(\"sha256\" \"meta.json\" \"yaml\")\n  for ext in \"${exts[@]}\"; do\n    local src=\"${src_prefix}.${ext}\"\n    if [[ -f \"${src}\" ]]; then\n      cp \"${src}\" \"${dest_prefix}.${ext}\"\n    fi\n  done\n}\n\ncopy_patterns() {\n  local dest_dir=\"$1\"\n  shift\n  mkdir -p \"${dest_dir}\"\n  shopt -s nullglob\n  for pattern in \"$@\"; do\n    for path in ${pattern}; do\n      cp \"${path}\" \"${dest_dir}/\"\n    done\n  done\n  shopt -u nullglob\n}\n\nif [[ $# -ge 1 ]]; then\n  HOPE_CHECKPOINT=\"$1\"\nelse\n  HOPE_CHECKPOINT=$(ls -1t ${CHECKPOINT_DIR}/step_*.pt 2>/dev/null | head -n 1 || true)\nfi\n\nTITAN_CHECKPOINT=\"${2:-}\"\n\nif [[ -z \"${HOPE_CHECKPOINT}\" ]]; then\n  echo \"[package] No checkpoint found. Pass the path explicitly or ensure ${CHECKPOINT_DIR}/step_*.pt exists.\"\n  exit 1\nfi\n\nHOPE_CHECKPOINT_BASENAME=$(basename \"${HOPE_CHECKPOINT}\")\nDEST_CKPT=\"${RELEASE_DIR}/checkpoint.pt\"\ncp \"${HOPE_CHECKPOINT}\" \"${DEST_CKPT}\"\ncopy_sidecars \"${HOPE_CHECKPOINT}\" \"${RELEASE_DIR}/checkpoint\"\n\n# Copy config snapshot\ncp \"${CONFIG_PATH}\" \"${RELEASE_DIR}/config.yaml\"\n\n# Copy relevant logs (if they exist)\nLOG_DEST=\"${RELEASE_DIR}/logs\"\nmkdir -p \"${LOG_DEST}\"\nshopt -s nullglob\nfor pattern in \"${LOG_PATTERNS[@]}\"; do\n  for log_path in ${pattern}; do\n    cp \"${log_path}\" \"${LOG_DEST}/\"\n  done\ndone\nshopt -u nullglob\n\n# Copy latest eval outputs / plots if present.\ncopy_patterns \"${RELEASE_DIR}\" \"${EVAL_PATTERNS[@]}\"\ncopy_patterns \"${RELEASE_DIR}/plots\" \"${PLOT_PATTERNS[@]}\"\n\nTITAN_RELEASE_BASENAME=\"\"\nif [[ -n \"${TITAN_CHECKPOINT}\" ]]; then\n  if [[ ! -f \"${TITAN_CHECKPOINT}\" ]]; then\n    echo \"[package] TITAN checkpoint not found: ${TITAN_CHECKPOINT}\"\n    exit 1\n  fi\n  TITAN_BASENAME=$(basename \"${TITAN_CHECKPOINT}\")\n  TITAN_RELEASE_BASENAME=\"titan_${TITAN_BASENAME}\"\n  cp \"${TITAN_CHECKPOINT}\" \"${RELEASE_DIR}/${TITAN_RELEASE_BASENAME}\"\n  copy_sidecars \"${TITAN_CHECKPOINT}\" \"${RELEASE_DIR}/titan_${TITAN_BASENAME%.pt}\"\nfi\n\nsummarize_train_flags() {\n  local ckpt_path=\"$1\"\n  local meta_path=\"${ckpt_path%.pt}.meta.json\"\n  if [[ ! -f \"${meta_path}\" ]]; then\n    echo \"n/a\"\n    return\n  fi\n  python - \"$meta_path\" <<'PY'\nimport json, pathlib, sys\nmeta = json.loads(pathlib.Path(sys.argv[1]).read_text())\nkeys = [\n    (\"algorithm_mode\", \"algorithm_mode\"),\n    (\"online_updates\", \"online_updates\"),\n    (\"online_boundary_targets\", \"online_boundary_targets\"),\n    (\"online_carry_attention_cache\", \"online_carry_attention_cache\"),\n    (\"use_fast_state\", \"use_fast_state\"),\n]\nparts = [f\"{label}={meta.get(key)!r}\" for key, label in keys]\nprint(\", \".join(parts))\nPY\n}\n\nHOPE_TRAIN_FLAGS=$(summarize_train_flags \"${HOPE_CHECKPOINT}\")\nTITAN_TRAIN_FLAGS=\"\"\nif [[ -n \"${TITAN_CHECKPOINT}\" ]]; then\n  TITAN_TRAIN_FLAGS=$(summarize_train_flags \"${TITAN_CHECKPOINT}\")\nfi\n\n# Update metadata stub with checkpoint information if present\nif [[ -f \"${METADATA_PATH}\" ]]; then\n  python - \"$HOPE_CHECKPOINT_BASENAME\" \"$TITAN_RELEASE_BASENAME\" \"$METADATA_PATH\" <<'PY' || true\nimport json, sys, pathlib\nckpt = sys.argv[1]\ntitan = sys.argv[2]\npath = pathlib.Path(sys.argv[3])\nmeta = json.loads(path.read_text())\nmeta[\"checkpoint_step\"] = ckpt\nif titan:\n    meta[\"titan_checkpoint_step\"] = titan\npath.write_text(json.dumps(meta, indent=2))\nPY\nfi\n\n# Emit manifest with quick reference info\n{\n  echo \"Pilot Release Manifest\"\n  echo \"======================\"\n  echo \"HOPE Checkpoint: ${HOPE_CHECKPOINT_BASENAME}\"\n  echo \"HOPE Train Flags: ${HOPE_TRAIN_FLAGS}\"\n  if [[ -n \"${TITAN_RELEASE_BASENAME}\" ]]; then\n    echo \"TITAN Checkpoint: ${TITAN_RELEASE_BASENAME}\"\n    echo \"TITAN Train Flags: ${TITAN_TRAIN_FLAGS}\"\n  fi\n  echo \"Config: ${CONFIG_PATH}\"\n  echo \"Logs copied from patterns: ${LOG_PATTERNS[*]}\"\n  echo \"Eval copied from patterns: ${EVAL_PATTERNS[*]}\"\n  echo \"Plots copied from patterns: ${PLOT_PATTERNS[*]}\"\n  date \"+Packaged at: %Y-%m-%d %H:%M:%S\"\n} > \"${MANIFEST_PATH}\"\n\necho \"[package] Release bundle updated:\"\necho \"  - ${DEST_CKPT}\"\necho \"  - ${RELEASE_DIR}/checkpoint.* (sidecars, when available)\"\nif [[ -n \"${TITAN_RELEASE_BASENAME}\" ]]; then\n  echo \"  - ${RELEASE_DIR}/${TITAN_RELEASE_BASENAME}\"\n  echo \"  - ${RELEASE_DIR}/titan_${TITAN_BASENAME%.pt}.* (sidecars, when available)\"\nfi\necho \"  - ${RELEASE_DIR}/config.yaml\"\necho \"  - ${LOG_DEST}/\"\necho \"  - ${RELEASE_DIR}/*_pilot.json and ${RELEASE_DIR}/*_titan.json (when available)\"\necho \"  - ${RELEASE_DIR}/plots/ (when available)\"\necho \"  - ${METADATA_PATH} (if present)\"\n"
  },
  {
    "path": "scripts/run_cpu_ddp_smoke.sh",
    "content": "#!/usr/bin/env bash\n\nset -euo pipefail\n\n# Force CPU execution so torchrun selects the gloo backend.\nexport CUDA_VISIBLE_DEVICES=\"\"\n\nuv run torchrun --standalone --nproc_per_node=2 train_dist.py --config-name pilot_smoke \"$@\"\n\n# Strict mechanism-auditing guardrails should fail fast under DDP.\nif uv run torchrun --standalone --nproc_per_node=2 train_dist.py \\\n  --config-name pilot_smoke \\\n  train.strict_streaming_contract=true \\\n  train.fail_if_paper_faithful_disabled=true \\\n  >/tmp/ddp_strict_expected_fail.log 2>&1; then\n  echo \"[cpu-ddp-smoke] expected strict mode failure did not occur\"\n  exit 1\nfi\n"
  },
  {
    "path": "scripts/run_e2e_smoke.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nDEVICE=${DEVICE:-cpu}\nTRAIN_CONFIG=${TRAIN_CONFIG:-pilot_smoke}\nDEFAULT_MODEL_CONFIG=\"configs/${TRAIN_CONFIG}.yaml\"\nif [[ ! -f \"${DEFAULT_MODEL_CONFIG}\" ]]; then\n  DEFAULT_MODEL_CONFIG=\"configs/hope/pilot.yaml\"\nfi\nMODEL_CONFIG=${MODEL_CONFIG:-${DEFAULT_MODEL_CONFIG}}\nTOKENIZER_PATH=${TOKENIZER_PATH:-artifacts/tokenizer/refinedweb_mix/spm_32000_unigram.model}\nCHECKPOINT_DIR=${CHECKPOINT_DIR:-artifacts/checkpoints/${TRAIN_CONFIG}}\nLOG_PATH=${LOG_PATH:-logs/${TRAIN_CONFIG}_release.json}\nEVAL_OUTPUT=${EVAL_OUTPUT:-eval/zeroshot_smoke.json}\nTASKS=${TASKS:-piqa}\nMAX_SAMPLES=${MAX_SAMPLES:-32}\n\nstep() {\n  echo\n  echo \"[$(date +%H:%M:%S)] $1\"\n  echo \"------------------------------------------------------------\"\n}\n\nstep \"1/4: Syncing environment (uv sync --all-extras)\"\nuv sync --all-extras\n\nstep \"2/4: Preparing filtered sample data\"\nuv run bash scripts/data/run_sample.sh\n\nstep \"3/4: Running ${TRAIN_CONFIG} smoke training on device=${DEVICE}\"\nmkdir -p \"$(dirname \"${LOG_PATH}\")\"\nuv run python train.py \\\n  --config-name \"${TRAIN_CONFIG}\" \\\n  train.device=\"${DEVICE}\" \\\n  logging.enabled=true \\\n  logging.backend=json \\\n  logging.path=\"${LOG_PATH}\" \\\n  train.checkpoint.enable=true \\\n  train.checkpoint.dir=\"${CHECKPOINT_DIR}\" \\\n  train.checkpoint.save_interval=999999 \\\n  train.checkpoint.save_last=true\n\nif ! ls \"${CHECKPOINT_DIR}\"/step_*.pt >/dev/null 2>&1; then\n  echo \"No checkpoints found in ${CHECKPOINT_DIR}. Training may have failed.\"\n  exit 1\nfi\nLATEST_CKPT=$(ls -1 \"${CHECKPOINT_DIR}\"/step_*.pt | sort | tail -n 1)\necho \"[Info] Using checkpoint ${LATEST_CKPT}\"\n\nstep \"4/4: Running zero-shot eval (${TASKS})\"\nmkdir -p \"$(dirname \"${EVAL_OUTPUT}\")\"\nuv run python scripts/eval/zeroshot.py \\\n  --config \"${MODEL_CONFIG}\" \\\n  --checkpoint \"${LATEST_CKPT}\" \\\n  --tokenizer-path \"${TOKENIZER_PATH}\" \\\n  --tasks \"${TASKS}\" \\\n  --max-samples \"${MAX_SAMPLES}\" \\\n  --output \"${EVAL_OUTPUT}\" \\\n  --device \"${DEVICE}\"\n\necho\necho \"[Done] Logs -> ${LOG_PATH}\"\necho \"[Done] Checkpoint -> ${LATEST_CKPT}\"\necho \"[Done] Eval metrics -> ${EVAL_OUTPUT}\"\n"
  },
  {
    "path": "scripts/run_mechanism_audit_smoke.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nLOG_PATH=\"${LOG_PATH:-logs/mechanism_audit_smoke.json}\"\nCADENCE_OUT=\"${CADENCE_OUT:-reports/cadence_mechanism_audit_smoke.json}\"\nCOMPLIANCE_OUT=\"${COMPLIANCE_OUT:-reports/compliance_mechanism_audit_smoke.json}\"\n\nexport UV_LINK_MODE=\"${UV_LINK_MODE:-copy}\"\nexport UV_CACHE_DIR=\"${UV_CACHE_DIR:-/tmp/uv-cache}\"\n\nuv run python train.py --config-name pilot_paper_faithful \\\n  train.steps=1 \\\n  train.device=cpu \\\n  train.online_chunk_size=9 \\\n  model.dim=128 \\\n  model.num_layers=2 \\\n  model.heads=4 \\\n  data.source=synthetic \\\n  +data.vocab_size=32000 \\\n  data.seq_len=9 \\\n  +data.dataset_size=8 \\\n  data.batch_size=1 \\\n  data.num_workers=0 \\\n  train.mixed_precision.enabled=false \\\n  train.compile.enable=false \\\n  logging.enabled=true \\\n  logging.backend=json \\\n  logging.path=\"${LOG_PATH}\"\n\nuv run python scripts/checks/verify_update_cadence.py \\\n  --log-path \"${LOG_PATH}\" \\\n  --metric-prefix layer0.cms.cms_mid \\\n  --total-tokens 8 \\\n  --update-period 4 \\\n  --output \"${CADENCE_OUT}\"\n\nuv run python scripts/checks/compliance_report.py \\\n  --config configs/pilot.yaml \\\n  --cadence-report \"${CADENCE_OUT}\" \\\n  --output \"${COMPLIANCE_OUT}\"\n\necho \"[mechanism-audit-smoke] completed: ${LOG_PATH} + ${CADENCE_OUT} + ${COMPLIANCE_OUT}\"\n"
  },
  {
    "path": "scripts/run_smoke.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nMODE=${1:-pilot}\n\nif [[ \"${MODE}\" == \"pilot\" ]]; then\n  echo \"[Smoke] Running pilot config on CPU\"\n  uv run python train.py --config-name pilot_smoke\nelif [[ \"${MODE}\" == \"mid\" ]]; then\n  echo \"[Smoke] Ensuring filtered shards exist\"\n  if [[ ! -d \"data/shards/refinedweb_filtered\" ]]; then\n    echo \"Filtered shards missing. Generate them first via configs/data/refinedweb_mixture_filtered.yaml\"\n    exit 1\n  fi\n  echo \"[Smoke] Running mid mixture config on CPU\"\n  uv run python train.py --config-name mid_smoke\nelse\n  echo \"Usage: scripts/run_smoke.sh [pilot|mid]\"\n  exit 1\nfi\n"
  },
  {
    "path": "scripts/tests/run_passkey_smoke.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nCHECKPOINT_DIR=\"artifacts/checkpoints/pilot_smoke\"\nCHECKPOINT_PATH=\"${CHECKPOINT_DIR}/step_000010.pt\"\nTOKENIZER=\"tests/data/tiny_tokenizer.model\"\nOUTPUT_JSON=\"eval/passkey_ci.json\"\n\nrm -rf \"${CHECKPOINT_DIR}\"\n\necho \"[passkey-ci] training pilot_smoke for 10 steps\"\nuv run python train.py --config-name pilot_smoke\n\necho \"[passkey-ci] running synthetic passkey eval with memorization\"\nuv run python scripts/eval/passkey.py \\\n  --config configs/pilot_smoke.yaml \\\n  --checkpoint \"${CHECKPOINT_PATH}\" \\\n  --tokenizer-path \"${TOKENIZER}\" \\\n  --samples 8 \\\n  --filler-sentences 32 \\\n  --device cpu \\\n  --output \"${OUTPUT_JSON}\" \\\n  --memorize \\\n  --memorize-steps 1\n\nuv run python - <<'PY'\nimport json\nfrom pathlib import Path\n\ndata = json.loads(Path(\"eval/passkey_ci.json\").read_text())\ndelta = data.get(\"accuracy_delta\", 0.0)\nif delta < 0:\n    raise SystemExit(f\"Memorization delta negative: {delta}\")\nprint(f\"[passkey-ci] Memorization delta OK ({delta:.3f})\")\nPY\n"
  },
  {
    "path": "src/nested_learning/__init__.py",
    "content": "\"\"\"Nested Learning (HOPE) reproduction package.\"\"\"\n\nfrom importlib.metadata import PackageNotFoundError, version\n\nfrom .levels import LevelClock, LevelSpec  # noqa: F401\n\ntry:\n    __version__ = version(\"nested-learning\")\nexcept PackageNotFoundError:  # pragma: no cover - editable/local source tree\n    __version__ = \"0.2.0\"\n\n__all__ = [\"LevelClock\", \"LevelSpec\", \"__version__\"]\n"
  },
  {
    "path": "src/nested_learning/__main__.py",
    "content": "from __future__ import annotations\n\nfrom .cli import app\n\n\ndef main() -> None:\n    app()\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "src/nested_learning/assoc_memory.py",
    "content": "from __future__ import annotations\n\nfrom typing import Protocol\n\nimport torch\nimport torch.nn as nn\n\n\nclass AssocMemory(nn.Module):\n    \"\"\"Base class for associative memories with explicit update hooks.\"\"\"\n\n    def forward(self, query: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        raise NotImplementedError\n\n    @torch.no_grad()\n    def update(\n        self, *, key: torch.Tensor, value: torch.Tensor, error_signal: torch.Tensor | None = None\n    ) -> None:\n        raise NotImplementedError\n\n\nclass SupportsReset(Protocol):\n    def reset_state(self) -> None: ...\n"
  },
  {
    "path": "src/nested_learning/backbones.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .fast_state import AttentionKVCache\n\n\n@dataclass\nclass AttentionConfig:\n    dim: int\n    heads: int\n    dropout: float = 0.0\n    use_flash: bool = True\n    causal: bool = True\n    qk_l2_norm: bool = False\n    qk_norm_eps: float = 1e-6\n    local_conv_window: int | None = None\n\n\nclass SelfAttention(nn.Module):\n    def __init__(self, config: AttentionConfig):\n        super().__init__()\n        if config.dim % config.heads != 0:\n            msg = f\"dim must be divisible by heads (got dim={config.dim}, heads={config.heads})\"\n            raise ValueError(msg)\n        self.config = config\n        self.heads = config.heads\n        self.head_dim = config.dim // config.heads\n        self.qkv = nn.Linear(config.dim, config.dim * 3, bias=False)\n        self.out_proj = nn.Linear(config.dim, config.dim, bias=False)\n        self.resid_dropout = nn.Dropout(config.dropout)\n        self.norm = nn.LayerNorm(config.dim)\n        self.local_conv: nn.Conv1d | None = None\n        if config.local_conv_window is not None:\n            window = int(config.local_conv_window)\n            if window <= 0:\n                raise ValueError(\"local_conv_window must be positive\")\n            self.local_conv = nn.Conv1d(\n                config.dim,\n                config.dim,\n                kernel_size=window,\n                groups=config.dim,\n                padding=0,\n                bias=False,\n            )\n\n    def forward(  # type: ignore[override]\n        self,\n        x: torch.Tensor,\n        *,\n        kv_cache: AttentionKVCache | None = None,\n        return_kv_cache: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:\n        residual = x\n        attn_inp = x\n        if kv_cache is not None and self.local_conv is not None:\n            raise RuntimeError(\n                \"kv_cache with local_conv_window is not supported in this implementation.\"\n            )\n        if self.local_conv is not None:\n            kernel = self.local_conv.kernel_size[0]\n            attn_inp = attn_inp.transpose(1, 2)\n            # Causal depthwise conv: only attends to past tokens.\n            attn_inp = F.pad(attn_inp, (kernel - 1, 0))\n            attn_inp = self.local_conv(attn_inp).transpose(1, 2)\n        q, k, v = self._compute_qkv(attn_inp)\n        past_len = 0\n        k_all = k\n        v_all = v\n        if kv_cache is not None:\n            if kv_cache.key.size(0) != k.size(0):\n                raise ValueError(\"kv_cache batch dimension must match input batch dimension\")\n            if kv_cache.key.size(1) != k.size(1) or kv_cache.key.size(-1) != k.size(-1):\n                raise ValueError(\"kv_cache shape is incompatible with attention heads/head_dim\")\n            past_len = int(kv_cache.key.size(2))\n            k_all = torch.cat([kv_cache.key, k], dim=2)\n            v_all = torch.cat([kv_cache.value, v], dim=2)\n        attn_output = self._scaled_dot_product_attn(q, k_all, v_all, past_len=past_len)\n        attn_output = attn_output.transpose(1, 2).contiguous().view(x.size(0), x.size(1), -1)\n        attn_output = self.out_proj(attn_output)\n        attn_output = self.resid_dropout(attn_output)\n        out = self.norm(residual + attn_output)\n        if return_kv_cache:\n            return out, AttentionKVCache(key=k_all.detach(), value=v_all.detach())\n        return out\n\n    def _compute_qkv(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n        qkv = self.qkv(x)\n        q, k, v = qkv.chunk(3, dim=-1)\n        shape = (x.size(0), x.size(1), self.heads, self.head_dim)\n        q = q.view(*shape).transpose(1, 2)\n        k = k.view(*shape).transpose(1, 2)\n        v = v.view(*shape).transpose(1, 2)\n        if self.config.qk_l2_norm:\n            q = F.normalize(q, dim=-1, eps=self.config.qk_norm_eps)\n            k = F.normalize(k, dim=-1, eps=self.config.qk_norm_eps)\n        return q, k, v\n\n    def _scaled_dot_product_attn(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        v: torch.Tensor,\n        *,\n        past_len: int = 0,\n    ) -> torch.Tensor:\n        dropout_p = self.config.dropout if self.training else 0.0\n        attn_mask = None\n        if self.config.causal and past_len > 0:\n            query_len = int(q.size(-2))\n            key_len = int(k.size(-2))\n            key_positions = torch.arange(key_len, device=q.device)\n            query_positions = past_len + torch.arange(query_len, device=q.device)\n            attn_mask = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)\n        is_causal = self.config.causal and attn_mask is None\n        device_type = q.device.type\n        if (\n            device_type == \"cuda\"\n            and torch.cuda.is_available()\n            and hasattr(torch.backends, \"cuda\")\n            and hasattr(torch.backends.cuda, \"sdp_kernel\")\n        ):\n            with torch.backends.cuda.sdp_kernel(  # type: ignore[attr-defined]\n                enable_flash=self.config.use_flash,\n                enable_mem_efficient=True,\n                enable_math=not self.config.use_flash,\n            ):\n                return F.scaled_dot_product_attention(\n                    q,\n                    k,\n                    v,\n                    attn_mask=attn_mask,\n                    dropout_p=dropout_p,\n                    is_causal=is_causal,\n                )\n        return F.scaled_dot_product_attention(\n            q,\n            k,\n            v,\n            attn_mask=attn_mask,\n            dropout_p=dropout_p,\n            is_causal=is_causal,\n        )\n"
  },
  {
    "path": "src/nested_learning/capabilities.py",
    "content": "from __future__ import annotations\n\nimport platform\nimport sys\nfrom dataclasses import asdict, dataclass, field\nfrom typing import Any\n\nimport torch\n\n\n@dataclass\nclass RuntimeCapabilities:\n    python_version: str\n    platform: str\n    machine: str\n    torch_version: str\n    cuda_available: bool\n    cuda_device_count: int\n    cuda_devices: list[str] = field(default_factory=list)\n    mps_available: bool = False\n    mps_built: bool = False\n    distributed_available: bool = False\n    compile_available: bool = False\n    sdpa_flash_available: bool = False\n    sdpa_mem_efficient_available: bool = False\n    sdpa_math_available: bool = True\n    bf16_supported: bool = False\n    fp16_supported: bool = False\n    default_device: str = \"cpu\"\n    warnings: list[str] = field(default_factory=list)\n\n    def to_dict(self) -> dict[str, Any]:\n        return asdict(self)\n\n\ndef collect_runtime_capabilities() -> RuntimeCapabilities:\n    cuda_available = bool(torch.cuda.is_available())\n    cuda_device_count = int(torch.cuda.device_count() if cuda_available else 0)\n    cuda_devices: list[str] = []\n    warnings: list[str] = []\n\n    if cuda_available:\n        for idx in range(cuda_device_count):\n            try:\n                name = torch.cuda.get_device_name(idx)\n                cuda_devices.append(f\"cuda:{idx} {name}\")\n            except Exception as err:  # pragma: no cover - backend specific\n                warnings.append(f\"failed to query cuda:{idx}: {err}\")\n\n    mps_backend = getattr(torch.backends, \"mps\", None)\n    mps_available = bool(mps_backend and mps_backend.is_available())\n    mps_built = bool(mps_backend and mps_backend.is_built())\n\n    distributed_available = bool(torch.distributed.is_available())\n    compile_available = bool(hasattr(torch, \"compile\"))\n\n    flash_enabled = False\n    mem_eff_enabled = False\n    math_enabled = True\n    if hasattr(torch.backends, \"cuda\") and torch.backends.cuda.is_built():\n        try:\n            flash_enabled = bool(torch.backends.cuda.flash_sdp_enabled())\n            mem_eff_enabled = bool(torch.backends.cuda.mem_efficient_sdp_enabled())\n            math_enabled = bool(torch.backends.cuda.math_sdp_enabled())\n        except Exception as err:  # pragma: no cover - backend specific\n            warnings.append(f\"failed to query SDPA backend flags: {err}\")\n\n    bf16_supported = False\n    fp16_supported = False\n    if cuda_available:\n        try:\n            bf16_supported = bool(torch.cuda.is_bf16_supported())\n            fp16_supported = True\n        except Exception as err:  # pragma: no cover\n            warnings.append(f\"failed to query CUDA dtype support: {err}\")\n    elif mps_available:\n        fp16_supported = True\n\n    default_device = \"cpu\"\n    if cuda_available:\n        default_device = \"cuda:0\"\n    elif mps_available:\n        default_device = \"mps\"\n\n    return RuntimeCapabilities(\n        python_version=sys.version.split()[0],\n        platform=platform.platform(),\n        machine=platform.machine(),\n        torch_version=torch.__version__,\n        cuda_available=cuda_available,\n        cuda_device_count=cuda_device_count,\n        cuda_devices=cuda_devices,\n        mps_available=mps_available,\n        mps_built=mps_built,\n        distributed_available=distributed_available,\n        compile_available=compile_available,\n        sdpa_flash_available=flash_enabled,\n        sdpa_mem_efficient_available=mem_eff_enabled,\n        sdpa_math_available=math_enabled,\n        bf16_supported=bf16_supported,\n        fp16_supported=fp16_supported,\n        default_device=default_device,\n        warnings=warnings,\n    )\n"
  },
  {
    "path": "src/nested_learning/cli.py",
    "content": "from __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import Annotated\n\nimport torch\nimport typer\nfrom omegaconf import OmegaConf\n\nfrom .capabilities import collect_runtime_capabilities\nfrom .config_utils import compose_config\nfrom .device import resolve_device\nfrom .training import build_model_from_cfg\n\napp = typer.Typer(\n    add_completion=False,\n    no_args_is_help=True,\n    help=\"Nested Learning CLI (training, diagnostics, and smoke checks).\",\n)\n\n\ndef _resolve_cli_device(device: str) -> torch.device:\n    lowered = device.strip().lower()\n    if lowered == \"auto\":\n        caps = collect_runtime_capabilities()\n        return resolve_device(caps.default_device)\n    return resolve_device(device)\n\n\n@app.command(\"doctor\")\ndef doctor(\n    as_json: Annotated[\n        bool,\n        typer.Option(\"--json\", help=\"Emit machine-readable JSON only.\"),\n    ] = False,\n    output: Annotated[\n        Path | None,\n        typer.Option(\n            \"--output\",\n            \"-o\",\n            help=\"Optional path for writing doctor output JSON.\",\n            dir_okay=False,\n            writable=True,\n        ),\n    ] = None,\n) -> None:\n    \"\"\"Inspect runtime capabilities for backend/device compatibility.\"\"\"\n    payload = collect_runtime_capabilities().to_dict()\n    rendered = json.dumps(payload, indent=2, sort_keys=True)\n    if output is not None:\n        output.parent.mkdir(parents=True, exist_ok=True)\n        output.write_text(rendered + \"\\n\", encoding=\"utf-8\")\n    if as_json:\n        typer.echo(rendered)\n        return\n\n    typer.echo(\"Runtime Doctor\")\n    typer.echo(f\"python: {payload['python_version']}\")\n    typer.echo(f\"platform: {payload['platform']} ({payload['machine']})\")\n    typer.echo(f\"torch: {payload['torch_version']}\")\n    typer.echo(f\"default_device: {payload['default_device']}\")\n    typer.echo(\n        \"cuda_available: {available} ({count} device(s))\".format(\n            available=payload[\"cuda_available\"],\n            count=payload[\"cuda_device_count\"],\n        )\n    )\n    for name in payload[\"cuda_devices\"]:\n        typer.echo(f\"  - {name}\")\n    typer.echo(f\"mps_available: {payload['mps_available']} (built={payload['mps_built']})\")\n    typer.echo(f\"distributed_available: {payload['distributed_available']}\")\n    typer.echo(f\"compile_available: {payload['compile_available']}\")\n    typer.echo(\n        \"sdpa backends: flash={flash} mem_efficient={mem} math={math}\".format(\n            flash=payload[\"sdpa_flash_available\"],\n            mem=payload[\"sdpa_mem_efficient_available\"],\n            math=payload[\"sdpa_math_available\"],\n        )\n    )\n    typer.echo(f\"dtype support: bf16={payload['bf16_supported']} fp16={payload['fp16_supported']}\")\n    if payload[\"warnings\"]:\n        typer.echo(\"warnings:\")\n        for warning in payload[\"warnings\"]:\n            typer.echo(f\"  - {warning}\")\n\n\n@app.command(\"smoke\")\ndef smoke(\n    config_name: Annotated[\n        str,\n        typer.Option(\"--config-name\", \"-c\", help=\"Hydra config name (e.g. pilot, hope/mid).\"),\n    ] = \"pilot_smoke\",\n    override: Annotated[\n        list[str] | None,\n        typer.Option(\n            \"--override\",\n            \"-O\",\n            help=\"Hydra override(s), may be passed multiple times.\",\n        ),\n    ] = None,\n    config_dir: Annotated[\n        Path | None,\n        typer.Option(\n            \"--config-dir\",\n            help=\"Optional explicit config directory.\",\n            exists=True,\n            file_okay=False,\n            dir_okay=True,\n            readable=True,\n        ),\n    ] = None,\n    device: Annotated[\n        str,\n        typer.Option(\n            \"--device\",\n            help=\"Device string for smoke pass (cpu, cuda:0, mps, auto).\",\n        ),\n    ] = \"cpu\",\n    batch_size: Annotated[\n        int,\n        typer.Option(\"--batch-size\", help=\"Synthetic smoke batch size.\"),\n    ] = 1,\n    seq_len: Annotated[\n        int,\n        typer.Option(\"--seq-len\", help=\"Synthetic smoke sequence length.\"),\n    ] = 32,\n) -> None:\n    \"\"\"Run a lightweight forward-pass smoke test with composed config.\"\"\"\n    cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)\n    model_cfg = cfg.model\n    torch_device = _resolve_cli_device(device)\n    model = build_model_from_cfg(model_cfg).to(torch_device)\n    model.eval()\n\n    vocab_size = int(model_cfg.vocab_size)\n    tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch_device)\n    with torch.no_grad():\n        outputs = model(tokens)\n    if isinstance(outputs, tuple):\n        logits = outputs[0]\n    else:\n        logits = outputs\n    typer.echo(\n        json.dumps(\n            {\n                \"status\": \"ok\",\n                \"config_name\": config_name,\n                \"device\": str(torch_device),\n                \"batch_size\": batch_size,\n                \"seq_len\": seq_len,\n                \"logits_shape\": list(logits.shape),\n                \"dtype\": str(logits.dtype),\n            },\n            sort_keys=True,\n        )\n    )\n\n\n@app.command(\"train\")\ndef train(\n    config_name: Annotated[\n        str,\n        typer.Option(\"--config-name\", \"-c\", help=\"Hydra config name for training.\"),\n    ] = \"pilot\",\n    override: Annotated[\n        list[str] | None,\n        typer.Option(\"--override\", \"-O\", help=\"Hydra override(s), may be passed multiple times.\"),\n    ] = None,\n    config_dir: Annotated[\n        Path | None,\n        typer.Option(\n            \"--config-dir\",\n            help=\"Optional explicit config directory.\",\n            exists=True,\n            file_okay=False,\n            dir_okay=True,\n            readable=True,\n        ),\n    ] = None,\n    device: Annotated[\n        str | None,\n        typer.Option(\n            \"--device\",\n            help=\"Override cfg.train.device (e.g. cpu, cuda:1, auto).\",\n        ),\n    ] = None,\n    dry_run: Annotated[\n        bool,\n        typer.Option(\"--dry-run\", help=\"Print resolved config and exit.\"),\n    ] = False,\n) -> None:\n    \"\"\"Launch a local (single-process) training loop.\"\"\"\n    from .training import run_training_loop\n\n    cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)\n    if device is not None:\n        cfg.train.device = device\n    if dry_run:\n        typer.echo(OmegaConf.to_yaml(cfg))\n        return\n    train_device = _resolve_cli_device(str(cfg.train.device))\n    run_training_loop(cfg, device=train_device, distributed=False, dist_ctx=None)\n\n\n@app.command(\"audit\")\ndef audit(\n    config_name: Annotated[\n        str,\n        typer.Option(\"--config-name\", \"-c\", help=\"Hydra config name to audit.\"),\n    ] = \"pilot_paper_faithful\",\n    override: Annotated[\n        list[str] | None,\n        typer.Option(\"--override\", \"-O\", help=\"Hydra override(s), may be passed multiple times.\"),\n    ] = None,\n    config_dir: Annotated[\n        Path | None,\n        typer.Option(\n            \"--config-dir\",\n            help=\"Optional explicit config directory.\",\n            exists=True,\n            file_okay=False,\n            dir_okay=True,\n            readable=True,\n        ),\n    ] = None,\n) -> None:\n    \"\"\"Run static architecture checks on a composed config.\"\"\"\n    cfg = compose_config(config_name, overrides=override or [], config_dir=config_dir)\n    model = build_model_from_cfg(cfg.model)\n    has_embed = hasattr(model, \"embed\")\n    has_lm_head = hasattr(model, \"lm_head\")\n    tied_weights = False\n    if has_embed and has_lm_head:\n        embed = getattr(model, \"embed\")\n        lm_head = getattr(model, \"lm_head\")\n        tied_weights = bool(embed.weight.data_ptr() == lm_head.weight.data_ptr())\n\n    report = {\n        \"status\": \"ok\",\n        \"config_name\": config_name,\n        \"model_type\": str(cfg.model.get(\"type\", \"hope\")),\n        \"block_variant\": str(cfg.model.get(\"block_variant\", \"hope_hybrid\")),\n        \"surprise_metric\": str(cfg.model.get(\"surprise_metric\", \"l2\")),\n        \"surprise_threshold\": cfg.model.get(\"surprise_threshold\"),\n        \"teach_scale\": float(cfg.model.get(\"teach_scale\", 1.0)),\n        \"teach_clip\": float(cfg.model.get(\"teach_clip\", 0.0)),\n        \"freeze_backbone\": bool(cfg.model.get(\"freeze_backbone\", False)),\n        \"has_embed\": has_embed,\n        \"has_lm_head\": has_lm_head,\n        \"lm_tied_to_embedding\": tied_weights,\n    }\n    typer.echo(json.dumps(report, sort_keys=True))\n"
  },
  {
    "path": "src/nested_learning/cms.py",
    "content": "from __future__ import annotations\n\nfrom typing import Dict, Sequence\n\nimport torch\nimport torch.nn as nn\n\nfrom .levels import LevelSpec, ensure_level_specs\n\n\nclass CMSBlock(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        hidden_multiplier: int = 4,\n        activation: str = \"gelu\",\n        grad_clip: float = 1.0,\n        use_layernorm: bool = True,\n    ):\n        super().__init__()\n        hidden = dim * hidden_multiplier\n        act: nn.Module\n        if activation == \"relu\":\n            act = nn.ReLU()\n        elif activation == \"silu\":\n            act = nn.SiLU()\n        else:\n            act = nn.GELU()\n        norm: nn.Module = nn.LayerNorm(dim) if use_layernorm else nn.Identity()\n        self.net = nn.Sequential(\n            norm,\n            nn.Linear(dim, hidden),\n            act,\n            nn.Linear(hidden, dim),\n        )\n        self.grad_clip = grad_clip\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        delta = self.net(x)\n        if self.training and self.grad_clip > 0:\n            with torch.no_grad():\n                norm = delta.norm(dim=-1, keepdim=True)\n                scale = torch.clamp(norm / self.grad_clip, min=1.0)\n            delta = delta / scale\n        return x + delta\n\n\nclass CMS(nn.Module):\n    \"\"\"Continuum Memory System with multi-frequency updates.\"\"\"\n\n    def __init__(\n        self,\n        *,\n        dim: int,\n        levels: Sequence[LevelSpec],\n        hidden_multiplier: int = 4,\n        activation: str = \"gelu\",\n        use_layernorm: bool = True,\n    ) -> None:\n        super().__init__()\n        ordered = ensure_level_specs(levels)\n        self.level_specs: Sequence[LevelSpec] = tuple(ordered)\n        self.blocks = nn.ModuleDict(\n            {\n                spec.name: CMSBlock(\n                    dim,\n                    hidden_multiplier=hidden_multiplier,\n                    activation=activation,\n                    grad_clip=1.0,\n                    use_layernorm=use_layernorm,\n                )\n                for spec in self.level_specs\n            }\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        *,\n        return_intermediates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:\n        current = x\n        inputs: Dict[str, torch.Tensor] = {}\n        outputs: Dict[str, torch.Tensor] = {}\n        for spec in self.level_specs:\n            block = self.blocks[spec.name]\n            inputs[spec.name] = current\n            current = block(current)\n            outputs[spec.name] = current\n        if return_intermediates:\n            return current, inputs, outputs\n        return current\n"
  },
  {
    "path": "src/nested_learning/config_utils.py",
    "content": "from __future__ import annotations\n\nfrom contextlib import contextmanager\nfrom importlib.resources import as_file, files\nfrom pathlib import Path\nfrom typing import Iterator\n\nfrom hydra import compose, initialize_config_dir\nfrom hydra.core.global_hydra import GlobalHydra\nfrom omegaconf import DictConfig\n\nfrom .training import unwrap_config\n\n\ndef find_repo_root(start: Path | None = None) -> Path | None:\n    cursor = (start or Path.cwd()).resolve()\n    for candidate in (cursor, *cursor.parents):\n        if (candidate / \".git\").exists() and (candidate / \"configs\").exists():\n            return candidate\n    return None\n\n\n@contextmanager\ndef resolved_config_dir(config_dir: Path | None = None) -> Iterator[Path]:\n    if config_dir is not None:\n        yield config_dir.resolve()\n        return\n\n    module_path = Path(__file__).resolve()\n    repo_config_dir = module_path.parents[2] / \"configs\"\n    if repo_config_dir.exists():\n        yield repo_config_dir\n        return\n\n    package_configs = files(\"nested_learning\").joinpath(\"configs\")\n    with as_file(package_configs) as pkg_dir:\n        yield Path(pkg_dir)\n\n\ndef compose_config(\n    config_name: str,\n    *,\n    overrides: list[str] | None = None,\n    config_dir: Path | None = None,\n) -> DictConfig:\n    with resolved_config_dir(config_dir) as cfg_dir:\n        GlobalHydra.instance().clear()\n        with initialize_config_dir(version_base=None, config_dir=str(cfg_dir)):\n            cfg = compose(config_name=config_name, overrides=overrides or [])\n    return unwrap_config(cfg)\n"
  },
  {
    "path": "src/nested_learning/continual_classification.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Iterable, List, Sequence\n\n\n@dataclass(frozen=True)\nclass ClassificationExample:\n    text: str\n    label: str\n\n\n@dataclass(frozen=True)\nclass LoadedClassificationDataset:\n    name: str\n    split: str\n    examples: List[ClassificationExample]\n    label_names: List[str]\n\n\ndef load_hf_classification_dataset(\n    dataset: str,\n    *,\n    split: str,\n    text_field: str,\n    label_field: str,\n    name: str | None = None,\n    max_samples: int | None = None,\n) -> LoadedClassificationDataset:\n    \"\"\"\n    Load a HuggingFace `datasets` text classification dataset into a simple in-memory format.\n\n    This is used by the Phase 4 continual-learning harness (CLINC/Banking/DBpedia).\n    \"\"\"\n    try:\n        from datasets import load_dataset  # type: ignore[import-not-found]\n    except Exception as exc:  # pragma: no cover\n        raise RuntimeError(\n            \"`datasets` dependency is required for continual classification.\"\n        ) from exc\n\n    ds = load_dataset(dataset, name=name, split=split)\n    features = getattr(ds, \"features\", None)\n    label_names: List[str] = []\n    if features is not None and label_field in features:\n        feature = features[label_field]\n        if getattr(feature, \"names\", None) is not None:\n            label_names = list(feature.names)\n\n    examples: List[ClassificationExample] = []\n    count = 0\n    for row in ds:\n        if max_samples is not None and count >= max_samples:\n            break\n        text = str(row[text_field])\n        raw_label = row[label_field]\n        if isinstance(raw_label, int) and label_names:\n            label = label_names[raw_label]\n        else:\n            label = str(raw_label)\n        examples.append(ClassificationExample(text=text, label=label))\n        count += 1\n\n    if not label_names:\n        label_names = sorted({ex.label for ex in examples})\n\n    return LoadedClassificationDataset(\n        name=dataset if name is None else f\"{dataset}:{name}\",\n        split=split,\n        examples=examples,\n        label_names=label_names,\n    )\n\n\ndef load_clinc_oos(\n    *,\n    split: str = \"test\",\n    max_samples: int | None = None,\n) -> LoadedClassificationDataset:\n    # HF dataset: \"clinc_oos\" with fields {\"text\", \"intent\"}.\n    return load_hf_classification_dataset(\n        \"clinc_oos\",\n        split=split,\n        text_field=\"text\",\n        label_field=\"intent\",\n        max_samples=max_samples,\n    )\n\n\ndef load_banking77(\n    *,\n    split: str = \"test\",\n    max_samples: int | None = None,\n) -> LoadedClassificationDataset:\n    # HF dataset: \"banking77\" with fields {\"text\", \"label\"}.\n    return load_hf_classification_dataset(\n        \"banking77\",\n        split=split,\n        text_field=\"text\",\n        label_field=\"label\",\n        max_samples=max_samples,\n    )\n\n\ndef load_dbpedia14(\n    *,\n    split: str = \"test\",\n    max_samples: int | None = None,\n) -> LoadedClassificationDataset:\n    # HF dataset: \"dbpedia_14\" with fields {\"content\", \"label\"}.\n    return load_hf_classification_dataset(\n        \"dbpedia_14\",\n        split=split,\n        text_field=\"content\",\n        label_field=\"label\",\n        max_samples=max_samples,\n    )\n\n\ndef unique_labels(examples: Iterable[ClassificationExample]) -> List[str]:\n    seen = set()\n    ordered: List[str] = []\n    for ex in examples:\n        if ex.label in seen:\n            continue\n        seen.add(ex.label)\n        ordered.append(ex.label)\n    return ordered\n\n\ndef filter_examples_by_labels(\n    examples: Sequence[ClassificationExample],\n    *,\n    allowed: set[str],\n) -> List[ClassificationExample]:\n    return [ex for ex in examples if ex.label in allowed]\n"
  },
  {
    "path": "src/nested_learning/continual_streaming.py",
    "content": "from __future__ import annotations\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Sequence\n\nimport torch\n\nfrom .continual_classification import ClassificationExample, unique_labels\nfrom .memorize import MemorizeConfig, memorize_sequence\nfrom .tokenizer import SentencePieceTokenizer\n\n\n@dataclass(frozen=True)\nclass StreamingTask:\n    task_id: int\n    labels: List[str]\n    train: List[ClassificationExample]\n    eval: List[ClassificationExample]\n\n\n@dataclass(frozen=True)\nclass ContinualEvalConfig:\n    task_size: int = 10\n    seed: int = 0\n    train_per_label: int = 50\n    eval_per_label: int = 50\n    prompt_template: str = \"Text: {text}\\nLabel:\"\n    label_template: str = \"{label}\"\n    task_aware: bool = True\n\n\ndef _logprob_completion(\n    model,\n    tokenizer: SentencePieceTokenizer,\n    prompt: str,\n    completion: str,\n    device: torch.device,\n    *,\n    fast_state=None,\n) -> float:\n    prompt_ids = tokenizer.encode(prompt, add_bos=True)\n    completion_ids = tokenizer.encode(\" \" + completion, add_bos=False)\n    tokens = torch.cat([prompt_ids, completion_ids], dim=0).unsqueeze(0).to(device)\n    with torch.no_grad():\n        logits = model(tokens, fast_state=fast_state) if fast_state is not None else model(tokens)\n        log_probs = torch.log_softmax(logits[:, :-1, :], dim=-1)\n        target = tokens[:, 1:]\n        gathered = log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)\n        prompt_len = prompt_ids.numel()\n        return float(gathered[0, prompt_len - 1 :].sum().item())\n\n\ndef predict_label(\n    model,\n    tokenizer: SentencePieceTokenizer,\n    text: str,\n    candidates: Sequence[str],\n    device: torch.device,\n    *,\n    prompt_template: str,\n    label_template: str,\n    fast_state=None,\n) -> str:\n    if not candidates:\n        raise ValueError(\"predict_label requires at least one candidate label\")\n    prompt = prompt_template.format(text=text)\n    best_label = candidates[0]\n    best_score = -math.inf\n    for label in candidates:\n        label_str = label_template.format(label=label)\n        score = _logprob_completion(\n            model, tokenizer, prompt, label_str, device, fast_state=fast_state\n        )\n        if score > best_score:\n            best_score = score\n            best_label = label\n    return best_label\n\n\ndef _balanced_split(\n    examples: Sequence[ClassificationExample],\n    *,\n    labels: Sequence[str],\n    train_per_label: int,\n    eval_per_label: int,\n) -> tuple[List[ClassificationExample], List[ClassificationExample]]:\n    train: List[ClassificationExample] = []\n    eval_: List[ClassificationExample] = []\n    counts_train: Dict[str, int] = {lbl: 0 for lbl in labels}\n    counts_eval: Dict[str, int] = {lbl: 0 for lbl in labels}\n    for ex in examples:\n        lbl = ex.label\n        if lbl not in counts_train:\n            continue\n        if counts_train[lbl] < train_per_label:\n            train.append(ex)\n            counts_train[lbl] += 1\n        elif counts_eval[lbl] < eval_per_label:\n            eval_.append(ex)\n            counts_eval[lbl] += 1\n        if all(v >= train_per_label for v in counts_train.values()) and all(\n            v >= eval_per_label for v in counts_eval.values()\n        ):\n            break\n    return train, eval_\n\n\ndef build_streaming_tasks(\n    examples: Sequence[ClassificationExample],\n    *,\n    cfg: ContinualEvalConfig,\n    label_order: Sequence[str] | None = None,\n) -> List[StreamingTask]:\n    labels = list(label_order) if label_order is not None else unique_labels(examples)\n    if label_order is None:\n        import random\n\n        rng = random.Random(cfg.seed)\n        rng.shuffle(labels)\n    if cfg.task_size <= 0:\n        raise ValueError(\"task_size must be positive\")\n    tasks: List[StreamingTask] = []\n    for task_id, start in enumerate(range(0, len(labels), cfg.task_size)):\n        task_labels = labels[start : start + cfg.task_size]\n        if not task_labels:\n            break\n        task_examples = [ex for ex in examples if ex.label in set(task_labels)]\n        train, eval_ = _balanced_split(\n            task_examples,\n            labels=task_labels,\n            train_per_label=cfg.train_per_label,\n            eval_per_label=cfg.eval_per_label,\n        )\n        tasks.append(\n            StreamingTask(task_id=task_id, labels=list(task_labels), train=train, eval=eval_)\n        )\n    return tasks\n\n\n@dataclass(frozen=True)\nclass ContinualEvalResult:\n    task_accuracy_matrix: List[List[float]]\n    per_task_forgetting: List[float]\n    avg_accuracy_final: float\n    avg_forgetting: float\n\n\ndef evaluate_continual_classification(\n    model,\n    tokenizer: SentencePieceTokenizer,\n    tasks: Sequence[StreamingTask],\n    device: torch.device,\n    *,\n    cfg: ContinualEvalConfig,\n    memorize_cfg: MemorizeConfig,\n) -> tuple[ContinualEvalResult, Dict[str, Any]]:\n    \"\"\"\n    Streaming class-incremental evaluation using generative classification + optional\n    test-time memorization.\n\n    - If `memorize_cfg.enabled`, each training example is memorized by appending the correct\n      label string.\n    - Accuracy is computed after each task on each task's eval set, producing a task-accuracy\n      matrix.\n    \"\"\"\n    meta_snapshot: Dict[str, torch.Tensor] | None = None\n    if memorize_cfg.enabled and (not memorize_cfg.use_fast_state) and memorize_cfg.reset:\n        from .memorize import snapshot_state_dict  # local import to avoid cycles\n\n        meta_snapshot = snapshot_state_dict(model)\n\n    fast_state = None\n    if memorize_cfg.enabled and memorize_cfg.use_fast_state:\n        if not hasattr(model, \"init_fast_state\"):\n            raise RuntimeError(\"Model does not support fast state memorization\")\n        fast_state = model.init_fast_state()\n\n    task_acc: List[List[float]] = [[float(\"nan\") for _ in tasks] for _ in tasks]\n    best_acc: List[float] = [0.0 for _ in tasks]\n\n    memorize_stats_total: Dict[str, float] = {}\n\n    def _eval_task(task_idx: int) -> float:\n        task = tasks[task_idx]\n        candidates = (\n            task.labels\n            if cfg.task_aware\n            else [lbl for t in tasks[: current_task + 1] for lbl in t.labels]\n        )\n        if not task.eval:\n            return float(\"nan\")\n        correct = 0\n        for ex in task.eval:\n            pred = predict_label(\n                model,\n                tokenizer,\n                ex.text,\n                candidates,\n                device,\n                prompt_template=cfg.prompt_template,\n                label_template=cfg.label_template,\n                fast_state=fast_state,\n            )\n            correct += int(pred == ex.label)\n        return correct / len(task.eval) if task.eval else float(\"nan\")\n\n    for current_task, task in enumerate(tasks):\n        # Online \"training\" on this task's examples via optional memorization.\n        for ex in task.train:\n            candidates = (\n                task.labels\n                if cfg.task_aware\n                else [lbl for t in tasks[: current_task + 1] for lbl in t.labels]\n            )\n            _ = predict_label(\n                model,\n                tokenizer,\n                ex.text,\n                candidates,\n                device,\n                prompt_template=cfg.prompt_template,\n                label_template=cfg.label_template,\n                fast_state=fast_state,\n            )\n            if memorize_cfg.enabled:\n                prompt = cfg.prompt_template.format(text=ex.text)\n                target = cfg.label_template.format(label=ex.label)\n                memorize_text = f\"{prompt} {target}\"\n                if memorize_cfg.use_fast_state and memorize_cfg.reset:\n                    fast_state = model.init_fast_state()\n                stats = memorize_sequence(\n                    model, tokenizer, memorize_text, device, memorize_cfg, fast_state=fast_state\n                )\n                for k, v in stats.items():\n                    memorize_stats_total[k] = memorize_stats_total.get(k, 0.0) + v\n                if (\n                    (not memorize_cfg.use_fast_state)\n                    and memorize_cfg.reset\n                    and meta_snapshot is not None\n                ):\n                    from .memorize import restore_state_dict  # local import to avoid cycles\n\n                    restore_state_dict(model, meta_snapshot)\n\n        # Evaluate on all tasks seen so far.\n        for task_idx in range(current_task + 1):\n            acc = _eval_task(task_idx)\n            task_acc[task_idx][current_task] = acc\n            if not math.isnan(acc):\n                best_acc[task_idx] = max(best_acc[task_idx], acc)\n\n    final_accs = [task_acc[i][-1] for i in range(len(tasks)) if not math.isnan(task_acc[i][-1])]\n    avg_accuracy_final = sum(final_accs) / len(final_accs) if final_accs else float(\"nan\")\n\n    per_task_forgetting: List[float] = []\n    for i in range(len(tasks)):\n        last = task_acc[i][-1]\n        if math.isnan(last):\n            per_task_forgetting.append(float(\"nan\"))\n            continue\n        per_task_forgetting.append(best_acc[i] - last)\n    valid_forgetting = [f for f in per_task_forgetting if not math.isnan(f)]\n    avg_forgetting = (\n        sum(valid_forgetting) / len(valid_forgetting) if valid_forgetting else float(\"nan\")\n    )\n\n    result = ContinualEvalResult(\n        task_accuracy_matrix=task_acc,\n        per_task_forgetting=per_task_forgetting,\n        avg_accuracy_final=avg_accuracy_final,\n        avg_forgetting=avg_forgetting,\n    )\n    meta = {\n        \"task_size\": cfg.task_size,\n        \"train_per_label\": cfg.train_per_label,\n        \"eval_per_label\": cfg.eval_per_label,\n        \"task_aware\": cfg.task_aware,\n        \"prompt_template\": cfg.prompt_template,\n        \"label_template\": cfg.label_template,\n        \"memorize_stats\": memorize_stats_total,\n    }\n    return result, meta\n"
  },
  {
    "path": "src/nested_learning/data.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Iterator, List, Sequence\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset, IterableDataset, get_worker_info\n\n\n@dataclass\nclass SyntheticTextConfig:\n    vocab_size: int\n    seq_len: int\n    dataset_size: int\n\n\nclass SyntheticTextDataset(Dataset[torch.Tensor]):\n    def __init__(self, config: SyntheticTextConfig):\n        self.config = config\n\n    def __len__(self) -> int:\n        return self.config.dataset_size\n\n    def __getitem__(self, idx: int) -> torch.Tensor:\n        g = torch.Generator().manual_seed(idx)\n        return torch.randint(0, self.config.vocab_size, (self.config.seq_len,), generator=g)\n\n\nclass TokenShardDataset(Dataset[torch.Tensor]):\n    \"\"\"Memory-mapped dataset over NumPy shards produced by shard_corpus.py.\"\"\"\n\n    def __init__(self, shard_dir: str | Path):\n        self.shard_dir = Path(shard_dir)\n        if not self.shard_dir.exists():\n            msg = f\"Shard directory {self.shard_dir} does not exist\"\n            raise FileNotFoundError(msg)\n        self.paths = sorted(self.shard_dir.glob(\"*.npy\"))\n        if not self.paths:\n            msg = f\"No shard files found in {self.shard_dir}\"\n            raise ValueError(msg)\n        self.metadata: List[tuple[int, int]] = []\n        self._cache: dict[int, np.memmap] = {}\n        total = 0\n        for idx, path in enumerate(self.paths):\n            arr = np.load(path, mmap_mode=\"r\")\n            length = arr.shape[0]\n            self.metadata.append((total, length))\n            total += length\n        self.total_sequences = total\n\n    def __len__(self) -> int:\n        return self.total_sequences\n\n    def _load_array(self, shard_idx: int) -> np.memmap:\n        if shard_idx not in self._cache:\n            self._cache[shard_idx] = np.load(self.paths[shard_idx], mmap_mode=\"r\")\n        return self._cache[shard_idx]\n\n    def __getitem__(self, idx: int) -> torch.Tensor:\n        if idx < 0 or idx >= self.total_sequences:\n            raise IndexError(idx)\n        shard_idx = self._find_shard(idx)\n        start_offset = self.metadata[shard_idx][0]\n        arr = self._load_array(shard_idx)\n        local_idx = idx - start_offset\n        tokens = torch.from_numpy(arr[local_idx])\n        return tokens.long()\n\n    def _find_shard(self, idx: int) -> int:\n        lo, hi = 0, len(self.metadata) - 1\n        while lo <= hi:\n            mid = (lo + hi) // 2\n            start, length = self.metadata[mid]\n            if idx < start:\n                hi = mid - 1\n            elif idx >= start + length:\n                lo = mid + 1\n            else:\n                return mid\n        return len(self.metadata) - 1\n\n\n@dataclass\nclass ShardSourceConfig:\n    name: str\n    shards_dir: str\n    weight: float\n\n\nclass ShardSource:\n    def __init__(self, config: ShardSourceConfig):\n        self.name = config.name\n        self.weight = config.weight\n        self.dir = Path(config.shards_dir)\n        if not self.dir.exists():\n            msg = f\"Shard directory {self.dir} missing for source {self.name}\"\n            raise FileNotFoundError(msg)\n        self.paths = sorted(self.dir.glob(\"*.npy\"))\n        if not self.paths:\n            raise ValueError(f\"No shard files in {self.dir}\")\n        self._cache: dict[Path, np.memmap] = {}\n\n    def sample(self, rng: np.random.Generator) -> np.ndarray:\n        shard_path = self.paths[rng.integers(0, len(self.paths))]\n        if shard_path not in self._cache:\n            self._cache[shard_path] = np.load(shard_path, mmap_mode=\"r\")\n        shard = self._cache[shard_path]\n        idx = rng.integers(0, shard.shape[0])\n        return shard[idx]\n\n\nclass MixtureShardDataset(IterableDataset[torch.Tensor]):\n    def __init__(\n        self,\n        sources: Sequence[ShardSourceConfig],\n        *,\n        samples_per_epoch: int,\n        seed: int = 0,\n    ):\n        super().__init__()\n        self.sources = [ShardSource(cfg) for cfg in sources]\n        total_weight = sum(max(src.weight, 0.0) for src in self.sources)\n        if total_weight <= 0:\n            raise ValueError(\"Mixture weights must sum to > 0\")\n        self.weights = np.array([max(src.weight, 0.0) / total_weight for src in self.sources])\n        self.samples_per_epoch = samples_per_epoch\n        self.seed = seed\n\n    def __len__(self) -> int:\n        return self.samples_per_epoch\n\n    def __iter__(self) -> Iterator[torch.Tensor]:\n        worker = get_worker_info()\n        if worker is None:\n            start = 0\n            end = self.samples_per_epoch\n            worker_seed = self.seed\n        else:\n            per_worker = (self.samples_per_epoch + worker.num_workers - 1) // worker.num_workers\n            start = worker.id * per_worker\n            end = min(start + per_worker, self.samples_per_epoch)\n            worker_seed = self.seed + worker.id\n        rng = np.random.default_rng(worker_seed)\n        for _ in range(start, end):\n            idx = rng.choice(len(self.sources), p=self.weights)\n            sample = np.array(self.sources[idx].sample(rng), copy=True)\n            yield torch.from_numpy(sample).long()\n\n\ndef collate_batch(batch: list[torch.Tensor]) -> torch.Tensor:\n    return torch.stack(batch, dim=0)\n"
  },
  {
    "path": "src/nested_learning/device.py",
    "content": "from __future__ import annotations\n\nimport torch\n\n\ndef resolve_device(device_str: str) -> torch.device:\n    normalized = str(device_str).strip().lower()\n    if normalized.startswith(\"cuda\"):\n        if not torch.cuda.is_available():\n            return torch.device(\"cpu\")\n        parts = normalized.split(\":\")\n        idx = int(parts[1]) if len(parts) > 1 else 0\n        if idx >= torch.cuda.device_count():\n            idx = max(torch.cuda.device_count() - 1, 0)\n        return torch.device(f\"cuda:{idx}\")\n    if normalized.startswith(\"mps\"):\n        if not (hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available()):\n            return torch.device(\"cpu\")\n        return torch.device(\"mps\")\n    return torch.device(device_str)\n\n"
  },
  {
    "path": "src/nested_learning/eval_state.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\n\nimport torch\n\n\n@dataclass\nclass EvalStreamingState:\n    fast_state: object | None = None\n    attention_cache: object | None = None\n\n\ndef parse_eval_state_mode(mode: str) -> bool:\n    \"\"\"\n    Returns True when eval state should be carried across samples.\n    \"\"\"\n    normalized = str(mode).strip().lower()\n    if normalized in {\"reset\", \"reset_per_sample\", \"isolated\"}:\n        return False\n    if normalized in {\"carry\", \"carry_across_samples\", \"stream\"}:\n        return True\n    raise ValueError(\n        \"Unsupported eval_state_mode={!r}; expected one of \"\n        \"['reset_per_sample', 'carry_across_samples']\".format(mode)\n    )\n\n\ndef init_eval_streaming_state(\n    model,\n    *,\n    use_fast_state: bool,\n    use_attention_cache: bool,\n) -> EvalStreamingState:\n    state = EvalStreamingState()\n    if use_fast_state:\n        init_fast_state = getattr(model, \"init_fast_state\", None)\n        if not callable(init_fast_state):\n            raise RuntimeError(\n                \"Requested fast-state eval mode, but model.init_fast_state() is missing\"\n            )\n        state.fast_state = init_fast_state()\n    if use_attention_cache:\n        init_attention_cache = getattr(model, \"init_attention_cache\", None)\n        if not callable(init_attention_cache):\n            raise RuntimeError(\n                \"Requested attention-cache eval mode, but model.init_attention_cache() is missing\"\n            )\n        state.attention_cache = init_attention_cache()\n    return state\n\n\ndef forward_with_eval_state(\n    model,\n    tokens: torch.Tensor,\n    *,\n    state: EvalStreamingState | None = None,\n) -> tuple[torch.Tensor, EvalStreamingState | None]:\n    if state is None:\n        return model(tokens), None\n    if state.attention_cache is not None:\n        logits, next_cache = model(\n            tokens,\n            fast_state=state.fast_state,\n            attention_cache=state.attention_cache,\n            return_attention_cache=True,\n        )\n        state.attention_cache = next_cache\n        return logits, state\n    if state.fast_state is not None:\n        return model(tokens, fast_state=state.fast_state), state\n    return model(tokens), state\n"
  },
  {
    "path": "src/nested_learning/fast_state.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import Dict, cast\n\nimport torch\nfrom torch import nn\n\nfrom .optim.manager import LevelConfig, LevelOptimizerManager\nfrom .titan.self_modifying import SelfModifyingTitansState\n\nParamDict = Dict[str, torch.Tensor]\n\n\n@dataclass\nclass CMSChunkBuffer:\n    \"\"\"\n    Streaming CMS chunk buffer persisted across multiple model calls.\n\n    This is required to preserve update-period cadence when a logical sequence is\n    processed in several chunked forward/update calls.\n    \"\"\"\n\n    inputs: list[torch.Tensor] = field(default_factory=list)\n    teach: list[torch.Tensor] = field(default_factory=list)\n    active: list[torch.Tensor] = field(default_factory=list)\n    count: int = 0\n\n\ndef init_module_deltas(module: nn.Module) -> ParamDict:\n    \"\"\"\n    Initialize a per-parameter \"fast state\" delta dict for meta+delta fast state.\n\n    The fast state stores *deltas* (initialized to 0) rather than detached parameter clones so that\n    forward passes can use `meta_param + delta`, allowing outer gradients to flow to meta params\n    while keeping online updates as stop-grad writes into the delta tensors.\n    \"\"\"\n\n    return {name: torch.zeros_like(param).detach() for name, param in module.named_parameters()}\n\n\n@dataclass\nclass BlockFastState:\n    titan_params: ParamDict | None\n    cms_params: Dict[str, ParamDict]\n    cms_online_buffers: Dict[str, CMSChunkBuffer]\n    level_manager: LevelOptimizerManager\n    selfmod_state: SelfModifyingTitansState | None = None\n\n\ndef build_block_fast_state(\n    *,\n    titan_module: nn.Module | None,\n    cms_blocks: Dict[str, nn.Module],\n    selfmod_module: nn.Module | None = None,\n    specs,\n    optimizer_configs: Dict[str, dict],\n    default_lr: float,\n) -> BlockFastState:\n    titan_params = None\n    if titan_module is not None:\n        titan_params = init_module_deltas(titan_module)\n    cms_params = {name: init_module_deltas(block) for name, block in cms_blocks.items()}\n    cms_online_buffers = {name: CMSChunkBuffer() for name in cms_blocks}\n    level_cfg = LevelConfig(specs=specs, optimizer_configs=optimizer_configs, default_lr=default_lr)\n    level_manager = LevelOptimizerManager(level_cfg)\n    selfmod_state = None\n    if selfmod_module is not None:\n        init_fn = getattr(selfmod_module, \"init_fast_state\", None)\n        if callable(init_fn):\n            selfmod_state = cast(SelfModifyingTitansState, init_fn())\n    return BlockFastState(\n        titan_params=titan_params,\n        cms_params=cms_params,\n        cms_online_buffers=cms_online_buffers,\n        level_manager=level_manager,\n        selfmod_state=selfmod_state,\n    )\n\n\n@dataclass\nclass ModelFastState:\n    blocks: list[BlockFastState]\n\n\n@dataclass\nclass AttentionKVCache:\n    \"\"\"\n    Per-layer autoregressive attention cache.\n\n    Shapes:\n    - key:   [batch, heads, cached_tokens, head_dim]\n    - value: [batch, heads, cached_tokens, head_dim]\n    \"\"\"\n\n    key: torch.Tensor\n    value: torch.Tensor\n\n\n@dataclass\nclass ModelAttentionCache:\n    \"\"\"\n    Model-level container for per-block attention caches.\n\n    Blocks without attention store `None` entries.\n    \"\"\"\n\n    blocks: list[AttentionKVCache | None]\n"
  },
  {
    "path": "src/nested_learning/functional.py",
    "content": "from __future__ import annotations\n\nfrom typing import Any, Dict, Mapping, Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.func import functional_call\n\nParamDict = Dict[str, torch.Tensor]\n\n\ndef params_with_deltas(module: nn.Module, deltas: ParamDict) -> ParamDict:\n    params: ParamDict = {}\n    missing: list[str] = []\n    for name, param in module.named_parameters():\n        delta = deltas.get(name)\n        if delta is None:\n            missing.append(name)\n            continue\n        params[name] = param + delta\n    if missing:\n        raise KeyError(\n            f\"Missing fast-state delta(s) for {module.__class__.__name__}: {sorted(missing)[:10]}\"\n        )\n    return params\n\n\ndef module_buffers(module: nn.Module) -> ParamDict:\n    return {name: buf for name, buf in module.named_buffers()}\n\n\ndef call_with_params(\n    module: nn.Module,\n    params: ParamDict,\n    *args: Any,\n    **kwargs: Any,\n) -> Any:\n    buffers = module_buffers(module)\n    return functional_call(module, (params, buffers), args, kwargs, strict=True)\n\n\ndef call_with_deltas(\n    module: nn.Module,\n    deltas: ParamDict,\n    *args: Any,\n    **kwargs: Any,\n) -> Any:\n    return call_with_params(module, params_with_deltas(module, deltas), *args, **kwargs)\n\n\ndef require_grad_params(\n    params: Mapping[str, torch.Tensor], *, detach: bool = True\n) -> ParamDict:\n    out: ParamDict = {}\n    for name, value in params.items():\n        if detach:\n            out[name] = value.detach().requires_grad_(True)\n        else:\n            out[name] = value.requires_grad_(True)\n    return out\n\n\ndef grads_to_dict(params: ParamDict, grads: Tuple[torch.Tensor | None, ...]) -> ParamDict:\n    out: ParamDict = {}\n    for (name, _), grad in zip(params.items(), grads, strict=True):\n        if grad is None:\n            continue\n        out[name] = grad\n    return out\n"
  },
  {
    "path": "src/nested_learning/hope/__init__.py",
    "content": ""
  },
  {
    "path": "src/nested_learning/hope/block.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import Dict, Sequence, Set\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ..backbones import AttentionConfig, SelfAttention\nfrom ..cms import CMS\nfrom ..fast_state import AttentionKVCache, BlockFastState, CMSChunkBuffer\nfrom ..functional import (\n    call_with_deltas,\n    call_with_params,\n    grads_to_dict,\n    params_with_deltas,\n    require_grad_params,\n)\nfrom ..levels import LevelSpec\nfrom ..optim.manager import LevelConfig, LevelOptimizerManager\nfrom ..titan.memory import TitanMemory, TitanMemoryConfig\nfrom ..titan.self_modifying import SelfModifyingTitans, SelfModifyingTitansConfig\nfrom .self_mod import SelfModifier\n\n\ndef _chunk_loss(\n    prediction: torch.Tensor,\n    delta_target: torch.Tensor,\n    mask_f: torch.Tensor,\n    *,\n    reduction: str,\n    differentiable_target: bool = False,\n) -> torch.Tensor:\n    if differentiable_target:\n        target = prediction.detach() - delta_target\n    else:\n        target = (prediction.detach() - delta_target).detach()\n    diff_sq = (prediction - target).pow(2)\n    masked = diff_sq * mask_f\n    if reduction == \"mean\":\n        return masked.sum() / mask_f.sum().clamp(min=1.0)\n    if reduction == \"sum\":\n        return masked.sum()\n    raise ValueError(f\"Unsupported cms_chunk_reduction={reduction}\")\n\n\ndef _min_update_period(levels: Sequence[LevelSpec]) -> int:\n    periods = [int(spec.update_period) for spec in levels if int(spec.update_period) > 0]\n    return min(periods) if periods else 1\n\n\n@dataclass\nclass _CmsBuffer:\n    inputs: list[torch.Tensor]\n    teach: list[torch.Tensor]\n    active: list[torch.Tensor]\n    count: int = 0\n\n\ndef _clear_buffer(buffer: _CmsBuffer | CMSChunkBuffer) -> None:\n    buffer.inputs.clear()\n    buffer.teach.clear()\n    buffer.active.clear()\n    buffer.count = 0\n\n\ndef _fast_state_buffers(\n    fast_state: BlockFastState, levels: Sequence[LevelSpec]\n) -> dict[str, CMSChunkBuffer]:\n    buffers = fast_state.cms_online_buffers\n    for spec in levels:\n        if spec.name not in buffers:\n            buffers[spec.name] = CMSChunkBuffer()\n    return buffers\n\n\ndef _pop_buffer_chunk(\n    buffer: _CmsBuffer | CMSChunkBuffer,\n    count: int,\n) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    if count <= 0:\n        raise ValueError(\"count must be positive\")\n    result_inputs: list[torch.Tensor] = []\n    result_teach: list[torch.Tensor] = []\n    result_active: list[torch.Tensor] = []\n    remaining = count\n    while remaining > 0:\n        first = buffer.inputs[0]\n        chunk_len = first.size(1)\n        take = min(remaining, chunk_len)\n        src_inputs = buffer.inputs[0]\n        src_teach = buffer.teach[0]\n        src_active = buffer.active[0]\n        result_inputs.append(src_inputs[:, :take])\n        result_teach.append(src_teach[:, :take])\n        result_active.append(src_active[:, :take])\n        if take == chunk_len:\n            buffer.inputs.pop(0)\n            buffer.teach.pop(0)\n            buffer.active.pop(0)\n        else:\n            buffer.inputs[0] = src_inputs[:, take:]\n            buffer.teach[0] = src_teach[:, take:]\n            buffer.active[0] = src_active[:, take:]\n        remaining -= take\n    return (\n        torch.cat(result_inputs, dim=1),\n        torch.cat(result_teach, dim=1),\n        torch.cat(result_active, dim=1),\n    )\n\n\n@dataclass\nclass HOPEBlockConfig:\n    dim: int\n    heads: int\n    titan_level: LevelSpec\n    cms_levels: Sequence[LevelSpec]\n    titan_hidden_multiplier: int = 4\n    cms_hidden_multiplier: int = 4\n    cms_use_layernorm: bool = True\n    activation: str = \"gelu\"\n    qk_l2_norm: bool = False\n    local_conv_window: int | None = None\n    self_mod_hidden: int = 4\n    self_mod_lr: float = 1e-3\n    cms_chunk_reduction: str = \"sum\"\n    cms_online_updates: bool = True\n    cms_flush_partial_at_end: bool = False\n    optimizer_configs: Dict[str, dict] = field(default_factory=dict)\n\n\n@dataclass\nclass HOPEAttentionBlockConfig:\n    dim: int\n    heads: int\n    cms_levels: Sequence[LevelSpec]\n    cms_hidden_multiplier: int = 4\n    cms_use_layernorm: bool = True\n    activation: str = \"gelu\"\n    qk_l2_norm: bool = False\n    local_conv_window: int | None = None\n    self_mod_lr: float = 1e-3\n    cms_chunk_reduction: str = \"sum\"\n    cms_online_updates: bool = True\n    cms_flush_partial_at_end: bool = False\n    optimizer_configs: Dict[str, dict] = field(default_factory=dict)\n\n\nclass HOPEAttentionBlock(nn.Module):\n    \"\"\"\n    Paper-defined HOPE-Attention variant: softmax attention followed by CMS.\n\n    Reference: Nested Learning paper, HOPE-Attention note under Eqs. 94–97.\n    \"\"\"\n\n    def __init__(self, config: HOPEAttentionBlockConfig):\n        super().__init__()\n        self.config = config\n        self.last_update_stats: Dict[str, Dict[str, float]] = {}\n        self.surprise_threshold: float | None = None\n        self.surprise_metric: str = \"l2\"\n        self.allowed_levels: Set[str] | None = None\n        self.attn = SelfAttention(\n            AttentionConfig(\n                dim=config.dim,\n                heads=config.heads,\n                qk_l2_norm=config.qk_l2_norm,\n                local_conv_window=config.local_conv_window,\n            )\n        )\n        self.cms = CMS(\n            dim=config.dim,\n            levels=config.cms_levels,\n            hidden_multiplier=config.cms_hidden_multiplier,\n            activation=config.activation,\n            use_layernorm=config.cms_use_layernorm,\n        )\n        level_config = LevelConfig(\n            specs=config.cms_levels,\n            optimizer_configs=config.optimizer_configs,\n            default_lr=config.self_mod_lr,\n        )\n        self.level_manager = LevelOptimizerManager(level_config)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        surprise_value: float | None = None,\n        fast_state: BlockFastState | None = None,\n        finalize_updates: bool = True,\n        attention_cache: AttentionKVCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:\n        next_attn_cache: AttentionKVCache | None = None\n        if return_attention_cache:\n            attn_out, next_attn_cache = self.attn(\n                x,\n                kv_cache=attention_cache,\n                return_kv_cache=True,\n            )\n        else:\n            attn_out = self.attn(x, kv_cache=attention_cache)\n        if fast_state is None:\n            if teach_signal is not None and self.config.cms_online_updates:\n                cms_out = self._cms_forward_online(\n                    attn_out,\n                    teach_signal,\n                    surprise_value,\n                    finalize_updates=finalize_updates,\n                )\n            else:\n                cms_result = self.cms(attn_out, return_intermediates=True)\n                cms_out, cms_inputs, cms_outputs = cms_result\n                if teach_signal is not None:\n                    self._update_cms(cms_inputs, cms_outputs, teach_signal, surprise_value)\n            self.level_manager.tick()\n            return cms_out\n        if teach_signal is not None and self.config.cms_online_updates:\n            cms_out = self._cms_forward_online_fast(\n                attn_out,\n                fast_state,\n                teach_signal,\n                surprise_value,\n                finalize_updates=finalize_updates,\n                differentiable_updates=differentiable_updates,\n            )\n        else:\n            cms_out, cms_inputs = self._cms_forward_fast(attn_out, fast_state)\n            if teach_signal is not None:\n                self._update_cms_fast(\n                    fast_state,\n                    cms_inputs,\n                    teach_signal,\n                    surprise_value,\n                    differentiable_updates=differentiable_updates,\n                )\n        fast_state.level_manager.tick()\n        if return_attention_cache:\n            assert next_attn_cache is not None\n            return cms_out, next_attn_cache\n        return cms_out\n\n    def set_surprise_threshold(self, threshold: float | None) -> None:\n        self.surprise_threshold = threshold\n\n    def set_surprise_metric(self, metric: str) -> None:\n        self.surprise_metric = str(metric).strip().lower()\n\n    def set_allowed_levels(self, allowed: Set[str] | None) -> None:\n        self.allowed_levels = allowed.copy() if allowed is not None else None\n\n    def pop_update_stats(self) -> Dict[str, Dict[str, float]]:\n        stats = self.last_update_stats\n        self.last_update_stats = {}\n        return stats\n\n    def _cms_forward_fast(\n        self,\n        x: torch.Tensor,\n        fast_state: BlockFastState,\n    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        current = x\n        inputs: dict[str, torch.Tensor] = {}\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            inputs[level_name] = current\n            params = fast_state.cms_params[level_name]\n            current = call_with_deltas(self.cms.blocks[level_name], params, current)\n        return current, inputs\n\n    def _cms_forward_online(\n        self,\n        x: torch.Tensor,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        finalize_updates: bool = True,\n    ) -> torch.Tensor:\n        seq_len = x.shape[1]\n        base_chunk = _min_update_period(self.config.cms_levels)\n        active_mask = teach_signal.detach().abs().sum(dim=-1) > 0\n        outputs: list[torch.Tensor] = []\n        stats: dict[str, Dict[str, float]] = {}\n        buffers: dict[str, _CmsBuffer] = {}\n        for spec in self.config.cms_levels:\n            buffers[spec.name] = _CmsBuffer(inputs=[], teach=[], active=[], count=0)\n            stats[spec.name] = {\n                \"grad_norm\": 0.0,\n                \"chunk_tokens\": 0.0,\n                \"gate_hit\": 0.0,\n                \"gate_hits\": 0.0,\n                \"updates_applied\": 0.0,\n                \"tokens_flushed\": 0.0,\n                \"pending_tokens\": 0.0,\n            }\n\n        for start in range(0, seq_len, base_chunk):\n            end = min(start + base_chunk, seq_len)\n            chunk_in = x[:, start:end, :]\n            chunk_teach = teach_signal[:, start:end, :]\n            chunk_active = active_mask[:, start:end]\n\n            current = chunk_in\n            level_inputs: dict[str, torch.Tensor] = {}\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                level_inputs[level_name] = current\n                current = self.cms.blocks[level_name](current)\n            outputs.append(current)\n\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                buffer.inputs.append(level_inputs[level_name].detach())\n                buffer.teach.append(chunk_teach)\n                buffer.active.append(chunk_active)\n                buffer.count += end - start\n                update_period = int(spec.update_period)\n                while update_period > 0 and buffer.count >= update_period:\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(\n                        buffer, update_period\n                    )\n                    buffer.count -= update_period\n                    magnitude = self._update_cms_chunk(\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(update_period)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n        if self.config.cms_flush_partial_at_end and finalize_updates:\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                remaining = int(buffer.count)\n                if remaining <= 0:\n                    continue\n                chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)\n                buffer.count -= remaining\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk(\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude > 0:\n                    stats[level_name][\"grad_norm\"] += magnitude\n                    stats[level_name][\"chunk_tokens\"] += float(remaining)\n                    stats[level_name][\"gate_hit\"] += 1.0\n                    stats[level_name][\"gate_hits\"] += 1.0\n                    stats[level_name][\"updates_applied\"] += 1.0\n                    stats[level_name][\"tokens_flushed\"] += float(remaining)\n        for spec in self.config.cms_levels:\n            stats[spec.name][\"pending_tokens\"] = float(buffers[spec.name].count)\n        for level_name, payload in stats.items():\n            if (\n                payload[\"updates_applied\"] <= 0\n                and payload[\"pending_tokens\"] <= 0\n                and payload[\"tokens_flushed\"] <= 0\n            ):\n                continue\n            if surprise_value is not None:\n                payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = payload\n        return torch.cat(outputs, dim=1)\n\n    def _cms_forward_online_fast(\n        self,\n        x: torch.Tensor,\n        fast_state: BlockFastState,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        finalize_updates: bool = True,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor:\n        seq_len = x.shape[1]\n        base_chunk = _min_update_period(self.config.cms_levels)\n        active_mask = teach_signal.detach().abs().sum(dim=-1) > 0\n        outputs: list[torch.Tensor] = []\n        stats: dict[str, Dict[str, float]] = {}\n        buffers = _fast_state_buffers(fast_state, self.config.cms_levels)\n        for spec in self.config.cms_levels:\n            stats[spec.name] = {\n                \"grad_norm\": 0.0,\n                \"chunk_tokens\": 0.0,\n                \"gate_hit\": 0.0,\n                \"gate_hits\": 0.0,\n                \"updates_applied\": 0.0,\n                \"tokens_flushed\": 0.0,\n                \"pending_tokens\": 0.0,\n            }\n\n        for start in range(0, seq_len, base_chunk):\n            end = min(start + base_chunk, seq_len)\n            chunk_in = x[:, start:end, :]\n            chunk_teach = teach_signal[:, start:end, :]\n            chunk_active = active_mask[:, start:end]\n\n            current = chunk_in\n            level_inputs: dict[str, torch.Tensor] = {}\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                level_inputs[level_name] = current\n                params = fast_state.cms_params[level_name]\n                current = call_with_deltas(self.cms.blocks[level_name], params, current)\n            outputs.append(current)\n\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                if differentiable_updates:\n                    buffer.inputs.append(level_inputs[level_name])\n                else:\n                    buffer.inputs.append(level_inputs[level_name].detach())\n                buffer.teach.append(chunk_teach)\n                buffer.active.append(chunk_active)\n                buffer.count += end - start\n                update_period = int(spec.update_period)\n                while update_period > 0 and buffer.count >= update_period:\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(\n                        buffer, update_period\n                    )\n                    buffer.count -= update_period\n                    magnitude = self._update_cms_chunk_fast(\n                        fast_state,\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                        differentiable_updates=differentiable_updates,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(update_period)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n        if finalize_updates:\n            if self.config.cms_flush_partial_at_end:\n                for spec in self.config.cms_levels:\n                    level_name = spec.name\n                    buffer = buffers[level_name]\n                    remaining = int(buffer.count)\n                    if remaining <= 0:\n                        continue\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)\n                    buffer.count -= remaining\n                    if not bool(chunk_active.any()):\n                        continue\n                    magnitude = self._update_cms_chunk_fast(\n                        fast_state,\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                        differentiable_updates=differentiable_updates,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(remaining)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n                        stats[level_name][\"tokens_flushed\"] += float(remaining)\n            for spec in self.config.cms_levels:\n                _clear_buffer(buffers[spec.name])\n        for spec in self.config.cms_levels:\n            stats[spec.name][\"pending_tokens\"] = float(buffers[spec.name].count)\n        for level_name, payload in stats.items():\n            if (\n                payload[\"updates_applied\"] <= 0\n                and payload[\"pending_tokens\"] <= 0\n                and payload[\"tokens_flushed\"] <= 0\n            ):\n                continue\n            if surprise_value is not None:\n                payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = payload\n        return torch.cat(outputs, dim=1)\n\n    def _update_cms_fast(\n        self,\n        fast_state: BlockFastState,\n        cms_inputs: dict[str, torch.Tensor],\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        differentiable_updates: bool = False,\n    ) -> None:\n        teach = teach_signal if differentiable_updates else teach_signal.detach()\n        active_mask = teach.abs().sum(dim=-1) > 0\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            if not self._is_level_allowed(level_name):\n                continue\n            if not self._passes_surprise(surprise_value):\n                self._record_gate(level_name, hit=False)\n                continue\n            inputs = cms_inputs[level_name]\n            seq_len = inputs.shape[1]\n            chunk_size = int(spec.update_period)\n            if chunk_size <= 0:\n                continue\n            total_norm = 0.0\n            update_events = 0\n            token_events = 0\n            for start in range(0, seq_len, chunk_size):\n                end = min(start + chunk_size, seq_len)\n                chunk_len = end - start\n                chunk_inputs = (\n                    inputs[:, start:end, :]\n                    if differentiable_updates\n                    else inputs[:, start:end, :].detach()\n                )\n                chunk_teach = teach[:, start:end, :]\n                chunk_active = active_mask[:, start:end]\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk_fast(\n                    fast_state,\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                    differentiable_updates=differentiable_updates,\n                )\n                if magnitude <= 0:\n                    continue\n                total_norm += magnitude\n                token_events += chunk_len\n                update_events += 1\n            if update_events == 0:\n                continue\n            stats_payload: Dict[str, float] = {\n                \"grad_norm\": total_norm,\n                \"chunk_tokens\": float(token_events),\n                \"gate_hit\": float(update_events),\n            }\n            if surprise_value is not None:\n                stats_payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = stats_payload\n\n    def _is_level_allowed(self, level_name: str) -> bool:\n        if self.allowed_levels is None:\n            return True\n        return level_name in self.allowed_levels\n\n    def _passes_surprise(self, surprise_value: float | None) -> bool:\n        if self.surprise_threshold is None:\n            return True\n        if surprise_value is None:\n            return False\n        return surprise_value >= self.surprise_threshold\n\n    def _record_gate(self, level_name: str, *, hit: bool) -> None:\n        stats_key = f\"gate.{level_name}\"\n        self.last_update_stats.setdefault(stats_key, {})\n        self.last_update_stats[stats_key][\"gate_hit\"] = 1.0 if hit else 0.0\n\n    def _update_cms(\n        self,\n        cms_inputs: dict[str, torch.Tensor],\n        cms_outputs: dict[str, torch.Tensor],\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        teach = teach_signal.detach()\n        active_mask = teach.abs().sum(dim=-1) > 0\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            if not self._is_level_allowed(level_name):\n                continue\n            if not self._passes_surprise(surprise_value):\n                self._record_gate(level_name, hit=False)\n                continue\n            inputs = cms_inputs[level_name]\n            seq_len = inputs.shape[1]\n            chunk_size = int(spec.update_period)\n            if chunk_size <= 0:\n                continue\n            total_norm = 0.0\n            update_events = 0\n            token_events = 0\n            for start in range(0, seq_len, chunk_size):\n                end = min(start + chunk_size, seq_len)\n                chunk_len = end - start\n                chunk_inputs = inputs[:, start:end, :].detach()\n                chunk_teach = teach[:, start:end, :]\n                chunk_active = active_mask[:, start:end]\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk(\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude <= 0:\n                    continue\n                total_norm += magnitude\n                token_events += chunk_len\n                update_events += 1\n            if update_events == 0:\n                continue\n            stats_payload: Dict[str, float] = {\n                \"grad_norm\": total_norm,\n                \"chunk_tokens\": float(token_events),\n                \"gate_hit\": float(update_events),\n            }\n            if surprise_value is not None:\n                stats_payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = stats_payload\n\n    def _update_cms_chunk(\n        self,\n        level_name: str,\n        chunk_inputs: torch.Tensor,\n        chunk_teach: torch.Tensor,\n        chunk_active: torch.Tensor,\n        surprise_value: float | None,\n    ) -> float:\n        if not self._is_level_allowed(level_name):\n            return 0.0\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return 0.0\n        mask_f = chunk_active.unsqueeze(-1).float()\n        with torch.enable_grad():\n            prediction = self.cms.blocks[level_name](chunk_inputs)\n            loss = _chunk_loss(\n                prediction,\n                chunk_teach,\n                mask_f,\n                reduction=self.config.cms_chunk_reduction,\n                differentiable_target=False,\n            )\n        context_vec = chunk_inputs.mean(dim=(0, 1))\n        magnitude = self.level_manager.optimize(\n            level_name,\n            self.cms.blocks[level_name],\n            loss,\n            context=context_vec,\n            force=True,\n        )\n        self.level_manager.pop_last_metrics(level_name)\n        return magnitude\n\n    def _update_cms_chunk_fast(\n        self,\n        fast_state: BlockFastState,\n        level_name: str,\n        chunk_inputs: torch.Tensor,\n        chunk_teach: torch.Tensor,\n        chunk_active: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        differentiable_updates: bool = False,\n    ) -> float:\n        if not self._is_level_allowed(level_name):\n            return 0.0\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return 0.0\n        mask_f = chunk_active.unsqueeze(-1).float()\n        base_params = fast_state.cms_params[level_name]\n        forward_params = params_with_deltas(self.cms.blocks[level_name], base_params)\n        params_req = require_grad_params(forward_params, detach=not differentiable_updates)\n        with torch.enable_grad():\n            prediction = call_with_params(self.cms.blocks[level_name], params_req, chunk_inputs)\n            loss = _chunk_loss(\n                prediction,\n                chunk_teach,\n                mask_f,\n                reduction=self.config.cms_chunk_reduction,\n                differentiable_target=differentiable_updates,\n            )\n        grads = torch.autograd.grad(\n            loss,\n            tuple(params_req.values()),\n            retain_graph=differentiable_updates,\n            create_graph=differentiable_updates,\n            allow_unused=True,\n        )\n        grads_dict = grads_to_dict(params_req, grads)\n        context_vec = chunk_inputs.mean(dim=(0, 1))\n        updated, magnitude = fast_state.level_manager.apply_grads(\n            level_name,\n            base_params,\n            grads_dict,\n            context=context_vec,\n            force=True,\n            differentiable=differentiable_updates,\n        )\n        fast_state.cms_params[level_name] = updated\n        fast_state.level_manager.pop_last_metrics(level_name)\n        return magnitude\n\n\n@dataclass\nclass HOPESelfModBlockConfig:\n    dim: int\n    cms_levels: Sequence[LevelSpec]\n    cms_hidden_multiplier: int = 4\n    cms_use_layernorm: bool = True\n    activation: str = \"gelu\"\n    qk_l2_norm: bool = True\n    cms_flush_partial_at_end: bool = False\n    selfmod_adaptive_q: bool = False\n    selfmod_local_conv_window: int | None = 4\n    eta_scale: float = 1e-3\n    selfmod_chunk_size: int = 1\n    selfmod_chunk_size_memory: int | None = None\n    selfmod_objective: str = \"l2\"\n    selfmod_stopgrad_vhat: bool = True\n    selfmod_use_rank1_precond: bool = True\n    selfmod_use_alpha: bool = True\n    selfmod_use_skip: bool = True\n    selfmod_momentum: float = 0.0\n    selfmod_online_updates: bool = True\n    self_mod_lr: float = 1e-3\n    cms_chunk_reduction: str = \"sum\"\n    cms_online_updates: bool = True\n    optimizer_configs: Dict[str, dict] = field(default_factory=dict)\n\n\nclass HOPESelfModBlock(nn.Module):\n    \"\"\"\n    Paper-defined HOPE block (Eqs. 94–97): self-modifying Titans followed by CMS.\n\n    Fast-state is required for in-context self-mod updates.\n    \"\"\"\n\n    def __init__(self, config: HOPESelfModBlockConfig):\n        super().__init__()\n        self.config = config\n        self.last_update_stats: Dict[str, Dict[str, float]] = {}\n        self.surprise_threshold: float | None = None\n        self.surprise_metric: str = \"l2\"\n        self.allowed_levels: Set[str] | None = None\n        self.selfmod = SelfModifyingTitans(\n            SelfModifyingTitansConfig(\n                dim=config.dim,\n                eta_scale=config.eta_scale,\n                chunk_size_other=config.selfmod_chunk_size,\n                chunk_size_memory=config.selfmod_chunk_size_memory,\n                objective=config.selfmod_objective,\n                stopgrad_vhat=config.selfmod_stopgrad_vhat,\n                use_rank1_precond=config.selfmod_use_rank1_precond,\n                use_alpha=config.selfmod_use_alpha,\n                use_skip=config.selfmod_use_skip,\n                momentum=config.selfmod_momentum,\n                qk_l2_norm=config.qk_l2_norm,\n                adaptive_q=config.selfmod_adaptive_q,\n                local_conv_window=config.selfmod_local_conv_window,\n            )\n        )\n        self.cms = CMS(\n            dim=config.dim,\n            levels=config.cms_levels,\n            hidden_multiplier=config.cms_hidden_multiplier,\n            activation=config.activation,\n            use_layernorm=config.cms_use_layernorm,\n        )\n        level_config = LevelConfig(\n            specs=config.cms_levels,\n            optimizer_configs=config.optimizer_configs,\n            default_lr=config.self_mod_lr,\n        )\n        self.level_manager = LevelOptimizerManager(level_config)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        surprise_value: float | None = None,\n        fast_state: BlockFastState | None = None,\n        finalize_updates: bool = True,\n        attention_cache: AttentionKVCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache | None]:\n        _ = (attention_cache, differentiable_updates)\n        if fast_state is None:\n            # Differentiable read path (used for the outer loss).\n            o = self.selfmod(x)\n            # Explicit update pass (typically called under `torch.no_grad()` after backward).\n            if teach_signal is not None and self.config.selfmod_online_updates:\n                self.selfmod.apply_updates_inplace(x)\n            if teach_signal is not None and self.config.cms_online_updates:\n                cms_out = self._cms_forward_online(\n                    o,\n                    teach_signal,\n                    surprise_value,\n                    finalize_updates=finalize_updates,\n                )\n            else:\n                cms_out, cms_inputs, cms_outputs = self.cms(o, return_intermediates=True)\n                if teach_signal is not None:\n                    self._update_cms(cms_inputs, cms_outputs, teach_signal, surprise_value)\n            self.level_manager.tick()\n            return cms_out\n\n        if fast_state.selfmod_state is None:\n            raise ValueError(\"fast_state.selfmod_state is required for hope_selfmod variant\")\n        if self.config.selfmod_online_updates and teach_signal is not None:\n            o, updated = self.selfmod.forward_with_updates(x, fast_state.selfmod_state)\n            fast_state.selfmod_state = updated\n        else:\n            o = self.selfmod.forward_with_state(x, fast_state.selfmod_state)\n        if teach_signal is not None and self.config.cms_online_updates:\n            cms_out = self._cms_forward_online_fast(\n                o,\n                fast_state,\n                teach_signal,\n                surprise_value,\n                finalize_updates=finalize_updates,\n            )\n        else:\n            cms_out, cms_inputs = self._cms_forward_fast(o, fast_state)\n            if teach_signal is not None:\n                self._update_cms_fast(fast_state, cms_inputs, teach_signal, surprise_value)\n        fast_state.level_manager.tick()\n        if return_attention_cache:\n            return cms_out, None\n        return cms_out\n\n    def set_surprise_threshold(self, threshold: float | None) -> None:\n        self.surprise_threshold = threshold\n\n    def set_surprise_metric(self, metric: str) -> None:\n        self.surprise_metric = str(metric).strip().lower()\n\n    def set_allowed_levels(self, allowed: Set[str] | None) -> None:\n        self.allowed_levels = allowed.copy() if allowed is not None else None\n\n    def pop_update_stats(self) -> Dict[str, Dict[str, float]]:\n        stats = self.last_update_stats\n        self.last_update_stats = {}\n        return stats\n\n    def _cms_forward_fast(\n        self,\n        x: torch.Tensor,\n        fast_state: BlockFastState,\n    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        current = x\n        inputs: dict[str, torch.Tensor] = {}\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            inputs[level_name] = current\n            params = fast_state.cms_params[level_name]\n            current = call_with_deltas(self.cms.blocks[level_name], params, current)\n        return current, inputs\n\n    def _cms_forward_online(\n        self,\n        x: torch.Tensor,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        finalize_updates: bool = True,\n    ) -> torch.Tensor:\n        seq_len = x.shape[1]\n        base_chunk = _min_update_period(self.config.cms_levels)\n        active_mask = teach_signal.detach().abs().sum(dim=-1) > 0\n        outputs: list[torch.Tensor] = []\n        stats: dict[str, Dict[str, float]] = {}\n        buffers: dict[str, _CmsBuffer] = {}\n        for spec in self.config.cms_levels:\n            buffers[spec.name] = _CmsBuffer(inputs=[], teach=[], active=[], count=0)\n            stats[spec.name] = {\n                \"grad_norm\": 0.0,\n                \"chunk_tokens\": 0.0,\n                \"gate_hit\": 0.0,\n                \"gate_hits\": 0.0,\n                \"updates_applied\": 0.0,\n                \"tokens_flushed\": 0.0,\n                \"pending_tokens\": 0.0,\n            }\n\n        for start in range(0, seq_len, base_chunk):\n            end = min(start + base_chunk, seq_len)\n            chunk_in = x[:, start:end, :]\n            chunk_teach = teach_signal[:, start:end, :]\n            chunk_active = active_mask[:, start:end]\n\n            current = chunk_in\n            level_inputs: dict[str, torch.Tensor] = {}\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                level_inputs[level_name] = current\n                current = self.cms.blocks[level_name](current)\n            outputs.append(current)\n\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                buffer.inputs.append(level_inputs[level_name].detach())\n                buffer.teach.append(chunk_teach)\n                buffer.active.append(chunk_active)\n                buffer.count += end - start\n                update_period = int(spec.update_period)\n                while update_period > 0 and buffer.count >= update_period:\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(\n                        buffer, update_period\n                    )\n                    buffer.count -= update_period\n                    magnitude = self._update_cms_chunk(\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(update_period)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n        if self.config.cms_flush_partial_at_end and finalize_updates:\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                remaining = int(buffer.count)\n                if remaining <= 0:\n                    continue\n                chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)\n                buffer.count -= remaining\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk(\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude > 0:\n                    stats[level_name][\"grad_norm\"] += magnitude\n                    stats[level_name][\"chunk_tokens\"] += float(remaining)\n                    stats[level_name][\"gate_hit\"] += 1.0\n                    stats[level_name][\"gate_hits\"] += 1.0\n                    stats[level_name][\"updates_applied\"] += 1.0\n                    stats[level_name][\"tokens_flushed\"] += float(remaining)\n        for spec in self.config.cms_levels:\n            stats[spec.name][\"pending_tokens\"] = float(buffers[spec.name].count)\n        for level_name, payload in stats.items():\n            if (\n                payload[\"updates_applied\"] <= 0\n                and payload[\"pending_tokens\"] <= 0\n                and payload[\"tokens_flushed\"] <= 0\n            ):\n                continue\n            if surprise_value is not None:\n                payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = payload\n        return torch.cat(outputs, dim=1)\n\n    def _cms_forward_online_fast(\n        self,\n        x: torch.Tensor,\n        fast_state: BlockFastState,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        finalize_updates: bool = True,\n    ) -> torch.Tensor:\n        seq_len = x.shape[1]\n        base_chunk = _min_update_period(self.config.cms_levels)\n        active_mask = teach_signal.detach().abs().sum(dim=-1) > 0\n        outputs: list[torch.Tensor] = []\n        stats: dict[str, Dict[str, float]] = {}\n        buffers = _fast_state_buffers(fast_state, self.config.cms_levels)\n        for spec in self.config.cms_levels:\n            stats[spec.name] = {\n                \"grad_norm\": 0.0,\n                \"chunk_tokens\": 0.0,\n                \"gate_hit\": 0.0,\n                \"gate_hits\": 0.0,\n                \"updates_applied\": 0.0,\n                \"tokens_flushed\": 0.0,\n                \"pending_tokens\": 0.0,\n            }\n\n        for start in range(0, seq_len, base_chunk):\n            end = min(start + base_chunk, seq_len)\n            chunk_in = x[:, start:end, :]\n            chunk_teach = teach_signal[:, start:end, :]\n            chunk_active = active_mask[:, start:end]\n\n            current = chunk_in\n            level_inputs: dict[str, torch.Tensor] = {}\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                level_inputs[level_name] = current\n                params = fast_state.cms_params[level_name]\n                current = call_with_deltas(self.cms.blocks[level_name], params, current)\n            outputs.append(current)\n\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                buffer.inputs.append(level_inputs[level_name].detach())\n                buffer.teach.append(chunk_teach)\n                buffer.active.append(chunk_active)\n                buffer.count += end - start\n                update_period = int(spec.update_period)\n                while update_period > 0 and buffer.count >= update_period:\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(\n                        buffer, update_period\n                    )\n                    buffer.count -= update_period\n                    magnitude = self._update_cms_chunk_fast(\n                        fast_state,\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(update_period)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n        if finalize_updates:\n            if self.config.cms_flush_partial_at_end:\n                for spec in self.config.cms_levels:\n                    level_name = spec.name\n                    buffer = buffers[level_name]\n                    remaining = int(buffer.count)\n                    if remaining <= 0:\n                        continue\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)\n                    buffer.count -= remaining\n                    if not bool(chunk_active.any()):\n                        continue\n                    magnitude = self._update_cms_chunk_fast(\n                        fast_state,\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(remaining)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n                        stats[level_name][\"tokens_flushed\"] += float(remaining)\n            for spec in self.config.cms_levels:\n                _clear_buffer(buffers[spec.name])\n        for spec in self.config.cms_levels:\n            stats[spec.name][\"pending_tokens\"] = float(buffers[spec.name].count)\n        for level_name, payload in stats.items():\n            if (\n                payload[\"updates_applied\"] <= 0\n                and payload[\"pending_tokens\"] <= 0\n                and payload[\"tokens_flushed\"] <= 0\n            ):\n                continue\n            if surprise_value is not None:\n                payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = payload\n        return torch.cat(outputs, dim=1)\n\n    def _is_level_allowed(self, level_name: str) -> bool:\n        if self.allowed_levels is None:\n            return True\n        return level_name in self.allowed_levels\n\n    def _passes_surprise(self, surprise_value: float | None) -> bool:\n        if self.surprise_threshold is None:\n            return True\n        if surprise_value is None:\n            return False\n        return surprise_value >= self.surprise_threshold\n\n    def _record_gate(self, level_name: str, *, hit: bool) -> None:\n        stats_key = f\"gate.{level_name}\"\n        self.last_update_stats.setdefault(stats_key, {})\n        self.last_update_stats[stats_key][\"gate_hit\"] = 1.0 if hit else 0.0\n\n    def _update_cms(\n        self,\n        cms_inputs: dict[str, torch.Tensor],\n        cms_outputs: dict[str, torch.Tensor],\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        teach = teach_signal.detach()\n        active_mask = teach.abs().sum(dim=-1) > 0\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            if not self._is_level_allowed(level_name):\n                continue\n            if not self._passes_surprise(surprise_value):\n                self._record_gate(level_name, hit=False)\n                continue\n            inputs = cms_inputs[level_name]\n            seq_len = inputs.shape[1]\n            chunk_size = int(spec.update_period)\n            if chunk_size <= 0:\n                continue\n            total_norm = 0.0\n            update_events = 0\n            token_events = 0\n            for start in range(0, seq_len, chunk_size):\n                end = min(start + chunk_size, seq_len)\n                chunk_len = end - start\n                chunk_inputs = inputs[:, start:end, :].detach()\n                chunk_teach = teach[:, start:end, :]\n                chunk_active = active_mask[:, start:end]\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk(\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude <= 0:\n                    continue\n                total_norm += magnitude\n                token_events += chunk_len\n                update_events += 1\n            if update_events == 0:\n                continue\n            stats_payload: Dict[str, float] = {\n                \"grad_norm\": total_norm,\n                \"chunk_tokens\": float(token_events),\n                \"gate_hit\": float(update_events),\n            }\n            if surprise_value is not None:\n                stats_payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = stats_payload\n\n    def _update_cms_fast(\n        self,\n        fast_state: BlockFastState,\n        cms_inputs: dict[str, torch.Tensor],\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        teach = teach_signal.detach()\n        active_mask = teach.abs().sum(dim=-1) > 0\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            if not self._is_level_allowed(level_name):\n                continue\n            if not self._passes_surprise(surprise_value):\n                self._record_gate(level_name, hit=False)\n                continue\n            inputs = cms_inputs[level_name]\n            seq_len = inputs.shape[1]\n            chunk_size = int(spec.update_period)\n            if chunk_size <= 0:\n                continue\n            total_norm = 0.0\n            update_events = 0\n            token_events = 0\n            for start in range(0, seq_len, chunk_size):\n                end = min(start + chunk_size, seq_len)\n                chunk_len = end - start\n                chunk_inputs = inputs[:, start:end, :].detach()\n                chunk_teach = teach[:, start:end, :]\n                chunk_active = active_mask[:, start:end]\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk_fast(\n                    fast_state,\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude <= 0:\n                    continue\n                total_norm += magnitude\n                token_events += chunk_len\n                update_events += 1\n            if update_events == 0:\n                continue\n            stats_payload: Dict[str, float] = {\n                \"grad_norm\": total_norm,\n                \"chunk_tokens\": float(token_events),\n                \"gate_hit\": float(update_events),\n            }\n            if surprise_value is not None:\n                stats_payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = stats_payload\n\n    def _update_cms_chunk(\n        self,\n        level_name: str,\n        chunk_inputs: torch.Tensor,\n        chunk_teach: torch.Tensor,\n        chunk_active: torch.Tensor,\n        surprise_value: float | None,\n    ) -> float:\n        if not self._is_level_allowed(level_name):\n            return 0.0\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return 0.0\n        mask_f = chunk_active.unsqueeze(-1).float()\n        with torch.enable_grad():\n            prediction = self.cms.blocks[level_name](chunk_inputs)\n            loss = _chunk_loss(\n                prediction,\n                chunk_teach,\n                mask_f,\n                reduction=self.config.cms_chunk_reduction,\n            )\n        context_vec = chunk_inputs.mean(dim=(0, 1))\n        magnitude = self.level_manager.optimize(\n            level_name,\n            self.cms.blocks[level_name],\n            loss,\n            context=context_vec,\n            force=True,\n        )\n        self.level_manager.pop_last_metrics(level_name)\n        return magnitude\n\n    def _update_cms_chunk_fast(\n        self,\n        fast_state: BlockFastState,\n        level_name: str,\n        chunk_inputs: torch.Tensor,\n        chunk_teach: torch.Tensor,\n        chunk_active: torch.Tensor,\n        surprise_value: float | None,\n    ) -> float:\n        if not self._is_level_allowed(level_name):\n            return 0.0\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return 0.0\n        mask_f = chunk_active.unsqueeze(-1).float()\n        base_params = fast_state.cms_params[level_name]\n        forward_params = params_with_deltas(self.cms.blocks[level_name], base_params)\n        params_req = require_grad_params(forward_params)\n        with torch.enable_grad():\n            prediction = call_with_params(self.cms.blocks[level_name], params_req, chunk_inputs)\n            loss = _chunk_loss(\n                prediction,\n                chunk_teach,\n                mask_f,\n                reduction=self.config.cms_chunk_reduction,\n            )\n        grads = torch.autograd.grad(\n            loss,\n            tuple(params_req.values()),\n            retain_graph=False,\n            allow_unused=True,\n        )\n        grads_dict = grads_to_dict(params_req, grads)\n        context_vec = chunk_inputs.mean(dim=(0, 1))\n        updated, magnitude = fast_state.level_manager.apply_grads(\n            level_name,\n            base_params,\n            grads_dict,\n            context=context_vec,\n            force=True,\n        )\n        fast_state.cms_params[level_name] = updated\n        fast_state.level_manager.pop_last_metrics(level_name)\n        return magnitude\n\n\nclass HOPEBlock(nn.Module):\n    def __init__(self, config: HOPEBlockConfig):\n        super().__init__()\n        self.config = config\n        self.last_update_stats: Dict[str, Dict[str, float]] = {}\n        self.surprise_threshold: float | None = None\n        self.surprise_metric: str = \"l2\"\n        self.allowed_levels: Set[str] | None = None\n        self.attn = SelfAttention(\n            AttentionConfig(\n                dim=config.dim,\n                heads=config.heads,\n                qk_l2_norm=config.qk_l2_norm,\n                local_conv_window=config.local_conv_window,\n            )\n        )\n        titan_config = TitanMemoryConfig(\n            dim=config.dim,\n            hidden_multiplier=config.titan_hidden_multiplier,\n            activation=config.activation,\n        )\n        self.titan_memory = TitanMemory(titan_config)\n        self.cms = CMS(\n            dim=config.dim,\n            levels=config.cms_levels,\n            hidden_multiplier=config.cms_hidden_multiplier,\n            activation=config.activation,\n            use_layernorm=config.cms_use_layernorm,\n        )\n        self.self_modifier = SelfModifier(config.dim, hidden_multiplier=config.self_mod_hidden)\n        self.dropout = nn.Dropout(0.0)\n        specs = [config.titan_level, *config.cms_levels]\n        level_config = LevelConfig(\n            specs=specs,\n            optimizer_configs=config.optimizer_configs,\n            default_lr=config.self_mod_lr,\n        )\n        self.level_manager = LevelOptimizerManager(level_config)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        surprise_value: float | None = None,\n        fast_state: BlockFastState | None = None,\n        finalize_updates: bool = True,\n        attention_cache: AttentionKVCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:\n        _ = differentiable_updates\n        next_attn_cache: AttentionKVCache | None = None\n        if return_attention_cache:\n            attn_out, next_attn_cache = self.attn(\n                x,\n                kv_cache=attention_cache,\n                return_kv_cache=True,\n            )\n        else:\n            attn_out = self.attn(x, kv_cache=attention_cache)\n        if fast_state is None:\n            mem_out = self.titan_memory(attn_out)\n            combined = attn_out + mem_out\n            if teach_signal is not None and self.config.cms_online_updates:\n                cms_out = self._cms_forward_online(\n                    combined,\n                    teach_signal,\n                    surprise_value,\n                    finalize_updates=finalize_updates,\n                )\n                self._update_titan(attn_out, mem_out, teach_signal, surprise_value)\n            else:\n                cms_result = self.cms(combined, return_intermediates=True)\n                cms_out, cms_inputs, cms_outputs = cms_result\n                if teach_signal is not None:\n                    self._update_titan(attn_out, mem_out, teach_signal, surprise_value)\n                    self._update_cms(cms_inputs, cms_outputs, teach_signal, surprise_value)\n            self.level_manager.tick()\n            return cms_out\n\n        if fast_state.titan_params is None:\n            raise ValueError(\"fast_state.titan_params is required for HOPEBlock fast-state forward\")\n        mem_out = call_with_deltas(self.titan_memory, fast_state.titan_params, attn_out)\n        combined = attn_out + mem_out\n        if teach_signal is not None and self.config.cms_online_updates:\n            cms_out = self._cms_forward_online_fast(\n                combined,\n                fast_state,\n                teach_signal,\n                surprise_value,\n                finalize_updates=finalize_updates,\n            )\n            self._update_titan_fast(fast_state, attn_out, mem_out, teach_signal, surprise_value)\n        else:\n            cms_out, cms_inputs = self._cms_forward_fast(combined, fast_state)\n            if teach_signal is not None:\n                self._update_titan_fast(fast_state, attn_out, mem_out, teach_signal, surprise_value)\n                self._update_cms_fast(fast_state, cms_inputs, teach_signal, surprise_value)\n        fast_state.level_manager.tick()\n        if return_attention_cache:\n            assert next_attn_cache is not None\n            return cms_out, next_attn_cache\n        return cms_out\n\n    def set_surprise_threshold(self, threshold: float | None) -> None:\n        self.surprise_threshold = threshold\n\n    def set_surprise_metric(self, metric: str) -> None:\n        self.surprise_metric = str(metric).strip().lower()\n\n    def set_allowed_levels(self, allowed: Set[str] | None) -> None:\n        self.allowed_levels = allowed.copy() if allowed is not None else None\n\n    def _cms_forward_fast(\n        self,\n        x: torch.Tensor,\n        fast_state: BlockFastState,\n    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:\n        current = x\n        inputs: dict[str, torch.Tensor] = {}\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            inputs[level_name] = current\n            params = fast_state.cms_params[level_name]\n            current = call_with_deltas(self.cms.blocks[level_name], params, current)\n        return current, inputs\n\n\n    def _cms_forward_online(\n        self,\n        x: torch.Tensor,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        finalize_updates: bool = True,\n    ) -> torch.Tensor:\n        seq_len = x.shape[1]\n        base_chunk = _min_update_period(self.config.cms_levels)\n        active_mask = teach_signal.detach().abs().sum(dim=-1) > 0\n        outputs: list[torch.Tensor] = []\n        stats: dict[str, Dict[str, float]] = {}\n        buffers: dict[str, _CmsBuffer] = {}\n        for spec in self.config.cms_levels:\n            buffers[spec.name] = _CmsBuffer(inputs=[], teach=[], active=[], count=0)\n            stats[spec.name] = {\n                \"grad_norm\": 0.0,\n                \"chunk_tokens\": 0.0,\n                \"gate_hit\": 0.0,\n                \"gate_hits\": 0.0,\n                \"updates_applied\": 0.0,\n                \"tokens_flushed\": 0.0,\n                \"pending_tokens\": 0.0,\n            }\n\n        for start in range(0, seq_len, base_chunk):\n            end = min(start + base_chunk, seq_len)\n            chunk_in = x[:, start:end, :]\n            chunk_teach = teach_signal[:, start:end, :]\n            chunk_active = active_mask[:, start:end]\n\n            current = chunk_in\n            level_inputs: dict[str, torch.Tensor] = {}\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                level_inputs[level_name] = current\n                current = self.cms.blocks[level_name](current)\n            outputs.append(current)\n\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                buffer.inputs.append(level_inputs[level_name].detach())\n                buffer.teach.append(chunk_teach)\n                buffer.active.append(chunk_active)\n                buffer.count += end - start\n                update_period = int(spec.update_period)\n                while update_period > 0 and buffer.count >= update_period:\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(\n                        buffer, update_period\n                    )\n                    buffer.count -= update_period\n                    magnitude = self._update_cms_chunk(\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(update_period)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n        if self.config.cms_flush_partial_at_end and finalize_updates:\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                remaining = int(buffer.count)\n                if remaining <= 0:\n                    continue\n                chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)\n                buffer.count -= remaining\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk(\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude > 0:\n                    stats[level_name][\"grad_norm\"] += magnitude\n                    stats[level_name][\"chunk_tokens\"] += float(remaining)\n                    stats[level_name][\"gate_hit\"] += 1.0\n                    stats[level_name][\"gate_hits\"] += 1.0\n                    stats[level_name][\"updates_applied\"] += 1.0\n                    stats[level_name][\"tokens_flushed\"] += float(remaining)\n        for spec in self.config.cms_levels:\n            stats[spec.name][\"pending_tokens\"] = float(buffers[spec.name].count)\n        for level_name, payload in stats.items():\n            if (\n                payload[\"updates_applied\"] <= 0\n                and payload[\"pending_tokens\"] <= 0\n                and payload[\"tokens_flushed\"] <= 0\n            ):\n                continue\n            if surprise_value is not None:\n                payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = payload\n        return torch.cat(outputs, dim=1)\n\n    def _cms_forward_online_fast(\n        self,\n        x: torch.Tensor,\n        fast_state: BlockFastState,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n        *,\n        finalize_updates: bool = True,\n    ) -> torch.Tensor:\n        seq_len = x.shape[1]\n        base_chunk = _min_update_period(self.config.cms_levels)\n        active_mask = teach_signal.detach().abs().sum(dim=-1) > 0\n        outputs: list[torch.Tensor] = []\n        stats: dict[str, Dict[str, float]] = {}\n        buffers = _fast_state_buffers(fast_state, self.config.cms_levels)\n        for spec in self.config.cms_levels:\n            stats[spec.name] = {\n                \"grad_norm\": 0.0,\n                \"chunk_tokens\": 0.0,\n                \"gate_hit\": 0.0,\n                \"gate_hits\": 0.0,\n                \"updates_applied\": 0.0,\n                \"tokens_flushed\": 0.0,\n                \"pending_tokens\": 0.0,\n            }\n\n        for start in range(0, seq_len, base_chunk):\n            end = min(start + base_chunk, seq_len)\n            chunk_in = x[:, start:end, :]\n            chunk_teach = teach_signal[:, start:end, :]\n            chunk_active = active_mask[:, start:end]\n\n            current = chunk_in\n            level_inputs: dict[str, torch.Tensor] = {}\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                level_inputs[level_name] = current\n                params = fast_state.cms_params[level_name]\n                current = call_with_deltas(self.cms.blocks[level_name], params, current)\n            outputs.append(current)\n\n            for spec in self.config.cms_levels:\n                level_name = spec.name\n                buffer = buffers[level_name]\n                buffer.inputs.append(level_inputs[level_name].detach())\n                buffer.teach.append(chunk_teach)\n                buffer.active.append(chunk_active)\n                buffer.count += end - start\n                update_period = int(spec.update_period)\n                while update_period > 0 and buffer.count >= update_period:\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(\n                        buffer, update_period\n                    )\n                    buffer.count -= update_period\n                    magnitude = self._update_cms_chunk_fast(\n                        fast_state,\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(update_period)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n        if finalize_updates:\n            if self.config.cms_flush_partial_at_end:\n                for spec in self.config.cms_levels:\n                    level_name = spec.name\n                    buffer = buffers[level_name]\n                    remaining = int(buffer.count)\n                    if remaining <= 0:\n                        continue\n                    chunk_inputs, chunk_teach, chunk_active = _pop_buffer_chunk(buffer, remaining)\n                    buffer.count -= remaining\n                    if not bool(chunk_active.any()):\n                        continue\n                    magnitude = self._update_cms_chunk_fast(\n                        fast_state,\n                        level_name,\n                        chunk_inputs,\n                        chunk_teach,\n                        chunk_active,\n                        surprise_value,\n                    )\n                    if magnitude > 0:\n                        stats[level_name][\"grad_norm\"] += magnitude\n                        stats[level_name][\"chunk_tokens\"] += float(remaining)\n                        stats[level_name][\"gate_hit\"] += 1.0\n                        stats[level_name][\"gate_hits\"] += 1.0\n                        stats[level_name][\"updates_applied\"] += 1.0\n                        stats[level_name][\"tokens_flushed\"] += float(remaining)\n            for spec in self.config.cms_levels:\n                _clear_buffer(buffers[spec.name])\n        for spec in self.config.cms_levels:\n            stats[spec.name][\"pending_tokens\"] = float(buffers[spec.name].count)\n        for level_name, payload in stats.items():\n            if (\n                payload[\"updates_applied\"] <= 0\n                and payload[\"pending_tokens\"] <= 0\n                and payload[\"tokens_flushed\"] <= 0\n            ):\n                continue\n            if surprise_value is not None:\n                payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = payload\n        return torch.cat(outputs, dim=1)\n    def _update_titan(\n        self,\n        attn_out: torch.Tensor,\n        mem_out: torch.Tensor,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        level_name = self.config.titan_level.name\n        if not self._is_level_allowed(\"titan\"):\n            return\n        if not self.level_manager.should_update(level_name):\n            return\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return\n        # Use full sequence for granular updates (Critique P1)\n        # Note: We intentionally do not pool over dim=1 (sequence) here.\n        # teach_signal is (B, T, D), attn_out is (B, T, D)\n        modifier = self.self_modifier(\n            key=attn_out.detach(),\n            value=mem_out.detach(),\n            error_signal=teach_signal.detach(),\n        )\n        context_vec = attn_out.detach().mean(dim=(0, 1))\n\n        with torch.enable_grad():\n            query = attn_out.detach()\n            target = (modifier - teach_signal.detach()).detach()\n            base_params = {name: param for name, param in self.titan_memory.named_parameters()}\n            params_req = require_grad_params(base_params)\n            prediction = call_with_params(self.titan_memory, params_req, query)\n            loss_terms = F.mse_loss(prediction, target, reduction=\"none\")\n            active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0\n            mask = active.float()\n            if self.surprise_threshold is not None and self.surprise_metric == \"l2\":\n                norms = teach_signal.norm(dim=-1, keepdim=True)\n                mask = mask * (norms >= self.surprise_threshold).float()\n            loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)\n\n        grads = torch.autograd.grad(\n            loss,\n            tuple(params_req.values()),\n            retain_graph=False,\n            allow_unused=True,\n        )\n        grads_dict = grads_to_dict(params_req, grads)\n        magnitude = self.level_manager.apply_module_grads(\n            level_name,\n            self.titan_memory,\n            grads_dict,\n            context=context_vec,\n            force=True,\n        )\n        extra_metrics = self.level_manager.pop_last_metrics(level_name)\n        stats = {\"grad_norm\": magnitude, \"gate_hit\": 1.0}\n        if surprise_value is not None:\n            stats[\"surprise_value\"] = surprise_value\n        stats.update(extra_metrics)\n        self.last_update_stats[f\"titan.{level_name}\"] = stats\n\n    def _update_titan_fast(\n        self,\n        fast_state: BlockFastState,\n        attn_out: torch.Tensor,\n        mem_out: torch.Tensor,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        level_name = self.config.titan_level.name\n        if not self._is_level_allowed(\"titan\"):\n            return\n        if not fast_state.level_manager.should_update(level_name):\n            return\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return\n        if fast_state.titan_params is None:\n            return\n        modifier = self.self_modifier(\n            key=attn_out.detach(),\n            value=mem_out.detach(),\n            error_signal=teach_signal.detach(),\n        )\n        context_vec = attn_out.detach().mean(dim=(0, 1))\n        base_params = fast_state.titan_params\n        forward_params = params_with_deltas(self.titan_memory, base_params)\n        params_req = require_grad_params(forward_params)\n        with torch.enable_grad():\n            query = attn_out.detach()\n            target = (modifier - teach_signal.detach()).detach()\n            prediction = call_with_params(self.titan_memory, params_req, query)\n            loss_terms = F.mse_loss(prediction, target, reduction=\"none\")\n            active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0\n            mask = active.float()\n            if self.surprise_threshold is not None and self.surprise_metric == \"l2\":\n                norms = teach_signal.norm(dim=-1, keepdim=True)\n                mask = mask * (norms >= self.surprise_threshold).float()\n            loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)\n        grads = torch.autograd.grad(\n            loss,\n            tuple(params_req.values()),\n            retain_graph=False,\n            allow_unused=True,\n        )\n        grads_dict = grads_to_dict(params_req, grads)\n        updated, magnitude = fast_state.level_manager.apply_grads(\n            level_name,\n            base_params,\n            grads_dict,\n            context=context_vec,\n            force=False,\n        )\n        fast_state.titan_params = updated\n        extra_metrics = fast_state.level_manager.pop_last_metrics(level_name)\n        stats = {\"grad_norm\": magnitude, \"gate_hit\": 1.0}\n        if surprise_value is not None:\n            stats[\"surprise_value\"] = surprise_value\n        stats.update(extra_metrics)\n        self.last_update_stats[f\"titan.{level_name}\"] = stats\n\n    def _update_cms(\n        self,\n        cms_inputs: dict[str, torch.Tensor],\n        cms_outputs: dict[str, torch.Tensor],\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        teach = teach_signal.detach()\n        active_mask = teach.abs().sum(dim=-1) > 0\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            if not self._is_level_allowed(level_name):\n                continue\n            if not self._passes_surprise(surprise_value):\n                self._record_gate(level_name, hit=False)\n                continue\n            inputs = cms_inputs[level_name]\n            seq_len = inputs.shape[1]\n            chunk_size = int(spec.update_period)\n            if chunk_size <= 0:\n                continue\n            total_norm = 0.0\n            update_events = 0\n            token_events = 0\n            for start in range(0, seq_len, chunk_size):\n                end = min(start + chunk_size, seq_len)\n                chunk_len = end - start\n                chunk_inputs = inputs[:, start:end, :].detach()\n                chunk_teach = teach[:, start:end, :]\n                chunk_active = active_mask[:, start:end]\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk(\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude <= 0:\n                    continue\n                total_norm += magnitude\n                token_events += chunk_len\n                update_events += 1\n            if update_events == 0:\n                continue\n            stats_payload: Dict[str, float] = {\n                \"grad_norm\": total_norm,\n                \"chunk_tokens\": float(token_events),\n                \"gate_hit\": float(update_events),\n            }\n            if surprise_value is not None:\n                stats_payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = stats_payload\n\n    def _update_cms_fast(\n        self,\n        fast_state: BlockFastState,\n        cms_inputs: dict[str, torch.Tensor],\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        teach = teach_signal.detach()\n        active_mask = teach.abs().sum(dim=-1) > 0\n        for spec in self.config.cms_levels:\n            level_name = spec.name\n            if not self._is_level_allowed(level_name):\n                continue\n            if not self._passes_surprise(surprise_value):\n                self._record_gate(level_name, hit=False)\n                continue\n            inputs = cms_inputs[level_name]\n            seq_len = inputs.shape[1]\n            chunk_size = int(spec.update_period)\n            if chunk_size <= 0:\n                continue\n            total_norm = 0.0\n            update_events = 0\n            token_events = 0\n            for start in range(0, seq_len, chunk_size):\n                end = min(start + chunk_size, seq_len)\n                chunk_len = end - start\n                chunk_inputs = inputs[:, start:end, :].detach()\n                chunk_teach = teach[:, start:end, :]\n                chunk_active = active_mask[:, start:end]\n                if not bool(chunk_active.any()):\n                    continue\n                magnitude = self._update_cms_chunk_fast(\n                    fast_state,\n                    level_name,\n                    chunk_inputs,\n                    chunk_teach,\n                    chunk_active,\n                    surprise_value,\n                )\n                if magnitude <= 0:\n                    continue\n                total_norm += magnitude\n                token_events += chunk_len\n                update_events += 1\n            if update_events == 0:\n                continue\n            stats_payload: Dict[str, float] = {\n                \"grad_norm\": total_norm,\n                \"chunk_tokens\": float(token_events),\n                \"gate_hit\": float(update_events),\n            }\n            if surprise_value is not None:\n                stats_payload[\"surprise_value\"] = surprise_value\n            self.last_update_stats[f\"cms.{level_name}\"] = stats_payload\n\n    def _update_cms_chunk(\n        self,\n        level_name: str,\n        chunk_inputs: torch.Tensor,\n        chunk_teach: torch.Tensor,\n        chunk_active: torch.Tensor,\n        surprise_value: float | None,\n    ) -> float:\n        if not self._is_level_allowed(level_name):\n            return 0.0\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return 0.0\n        mask_f = chunk_active.unsqueeze(-1).float()\n        with torch.enable_grad():\n            prediction = self.cms.blocks[level_name](chunk_inputs)\n            loss = _chunk_loss(\n                prediction,\n                chunk_teach,\n                mask_f,\n                reduction=self.config.cms_chunk_reduction,\n            )\n        context_vec = chunk_inputs.mean(dim=(0, 1))\n        magnitude = self.level_manager.optimize(\n            level_name,\n            self.cms.blocks[level_name],\n            loss,\n            context=context_vec,\n            force=True,\n        )\n        self.level_manager.pop_last_metrics(level_name)\n        return magnitude\n\n    def _update_cms_chunk_fast(\n        self,\n        fast_state: BlockFastState,\n        level_name: str,\n        chunk_inputs: torch.Tensor,\n        chunk_teach: torch.Tensor,\n        chunk_active: torch.Tensor,\n        surprise_value: float | None,\n    ) -> float:\n        if not self._is_level_allowed(level_name):\n            return 0.0\n        if not self._passes_surprise(surprise_value):\n            self._record_gate(level_name, hit=False)\n            return 0.0\n        mask_f = chunk_active.unsqueeze(-1).float()\n        base_params = fast_state.cms_params[level_name]\n        forward_params = params_with_deltas(self.cms.blocks[level_name], base_params)\n        params_req = require_grad_params(forward_params)\n        with torch.enable_grad():\n            prediction = call_with_params(self.cms.blocks[level_name], params_req, chunk_inputs)\n            loss = _chunk_loss(\n                prediction,\n                chunk_teach,\n                mask_f,\n                reduction=self.config.cms_chunk_reduction,\n            )\n        grads = torch.autograd.grad(\n            loss,\n            tuple(params_req.values()),\n            retain_graph=False,\n            allow_unused=True,\n        )\n        grads_dict = grads_to_dict(params_req, grads)\n        context_vec = chunk_inputs.mean(dim=(0, 1))\n        updated, magnitude = fast_state.level_manager.apply_grads(\n            level_name,\n            base_params,\n            grads_dict,\n            context=context_vec,\n            force=True,\n        )\n        fast_state.cms_params[level_name] = updated\n        fast_state.level_manager.pop_last_metrics(level_name)\n        return magnitude\n\n    def pop_update_stats(self) -> Dict[str, Dict[str, float]]:\n        stats = self.last_update_stats\n        self.last_update_stats = {}\n        return stats\n\n    def _passes_surprise(self, surprise_value: float | None) -> bool:\n        if self.surprise_threshold is None:\n            return True\n        if surprise_value is None:\n            return False\n        return surprise_value >= self.surprise_threshold\n\n    def _is_level_allowed(self, level_name: str) -> bool:\n        if self.allowed_levels is None:\n            return True\n        return level_name in self.allowed_levels or (\n            level_name.startswith(\"titan\") and \"titan\" in self.allowed_levels\n        )\n\n    def _record_gate(self, level_name: str, *, hit: bool) -> None:\n        stats_key = f\"gate.{level_name}\"\n        self.last_update_stats.setdefault(stats_key, {})\n        self.last_update_stats[stats_key][\"gate_hit\"] = 1.0 if hit else 0.0\n"
  },
  {
    "path": "src/nested_learning/hope/self_mod.py",
    "content": "from __future__ import annotations\n\nimport torch\nimport torch.nn as nn\n\n\nclass SelfModifier(nn.Module):\n    \"\"\"\n    Learns parameter updates conditioned on key/value/error signals.\n\n    Note: In this implementation, we predict a 'target modification' (delta to the error signal)\n    rather than directly predicting weight deltas (Delta W). Mathematically, modifying the\n    target y to (y + delta) in the inner optimization step:\n        L = || f(x) - (y + delta) ||^2\n    results in a gradient update that is shifted by the gradient of delta.\n    This is functionally equivalent to a 'Learned Optimization Step' or 'Hypernetwork'\n    that modulates the update direction, but is more efficient to implement for\n    large memory modules than generating O(d^2) weight parameters directly.\n    \"\"\"\n\n    def __init__(self, dim: int, hidden_multiplier: int = 4):\n        super().__init__()\n        hidden = dim * hidden_multiplier\n        self.net = nn.Sequential(\n            nn.Linear(dim * 3, hidden),\n            nn.GELU(),\n            nn.Linear(hidden, hidden),\n            nn.GELU(),\n            nn.Linear(hidden, dim),\n        )\n\n    def forward(\n        self,\n        *,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        error_signal: torch.Tensor,\n    ) -> torch.Tensor:\n        concat = torch.cat([key, value, error_signal], dim=-1)\n        return self.net(concat)\n"
  },
  {
    "path": "src/nested_learning/instrumentation.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass, field\nfrom typing import Dict, List\n\n\n@dataclass\nclass UpdateEvent:\n    step: int\n    level: str\n    magnitude: float | None = None\n\n\n@dataclass\nclass UpdateLog:\n    \"\"\"Lightweight container for tracking update magnitudes per level.\"\"\"\n\n    events: List[UpdateEvent] = field(default_factory=list)\n\n    def record(self, *, step: int, level: str, magnitude: float | None = None) -> None:\n        self.events.append(UpdateEvent(step=step, level=level, magnitude=magnitude))\n\n    def summary(self) -> Dict[str, Dict[str, float]]:\n        counts: Dict[str, int] = {}\n        totals: Dict[str, float] = {}\n        for event in self.events:\n            counts[event.level] = counts.get(event.level, 0) + 1\n            if event.magnitude is not None:\n                totals[event.level] = totals.get(event.level, 0.0) + event.magnitude\n        return {\n            level: {\n                \"updates\": counts[level],\n                \"avg_magnitude\": (\n                    totals[level] / counts[level] if level in totals else float(\"nan\")\n                ),\n            }\n            for level in counts\n        }\n"
  },
  {
    "path": "src/nested_learning/levels.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Dict, Iterable, List, MutableMapping, Sequence\n\n\n@dataclass(frozen=True)\nclass LevelSpec:\n    \"\"\"Configuration for a nested-learning level.\"\"\"\n\n    name: str\n    update_period: int\n    warmup_steps: int = 0\n    jitter: int = 0\n    optimizer_key: str | None = None\n\n    def __post_init__(self) -> None:\n        if self.update_period <= 0:\n            msg = f\"update_period for level {self.name} must be positive\"\n            raise ValueError(msg)\n        if self.warmup_steps < 0:\n            msg = f\"warmup_steps for level {self.name} must be non-negative\"\n            raise ValueError(msg)\n        if self.jitter < 0:\n            msg = f\"jitter for level {self.name} must be non-negative\"\n            raise ValueError(msg)\n\n\n@dataclass\nclass LevelState:\n    last_step: int = -1\n    updates: int = 0\n\n\nclass LevelClock:\n    \"\"\"Deterministic scheduler for Nested Learning level updates.\"\"\"\n\n    def __init__(self, specs: Sequence[LevelSpec]):\n        self._specs: Dict[str, LevelSpec] = {spec.name: spec for spec in specs}\n        if len(self._specs) != len(specs):\n            raise ValueError(\"Duplicate level names provided to LevelClock\")\n        self._state: MutableMapping[str, LevelState] = {name: LevelState() for name in self._specs}\n        self._step: int = 0\n        self._timeline: List[dict] = []\n\n    @property\n    def step(self) -> int:\n        return self._step\n\n    def tick(self) -> None:\n        self._step += 1\n\n    def should_update(self, name: str) -> bool:\n        spec = self._specs[name]\n        state = self._state[name]\n        if self._step < spec.warmup_steps:\n            return False\n        delta = self._step - state.last_step\n        period = spec.update_period\n        if spec.jitter:\n            period = period + (self._step % (spec.jitter + 1))\n        return state.last_step < 0 or delta >= period\n\n    def record_update(self, name: str) -> None:\n        state = self._state[name]\n        state.last_step = self._step\n        state.updates += 1\n        self._timeline.append({\"step\": self._step, \"level\": name})\n\n    def levels_in_frequency_order(self) -> List[LevelSpec]:\n        return sorted(self._specs.values(), key=lambda spec: spec.update_period)\n\n    def stats(self) -> Dict[str, LevelState]:\n        return {\n            name: LevelState(state.last_step, state.updates) for name, state in self._state.items()\n        }\n\n    def timeline(self) -> List[dict]:\n        return list(self._timeline)\n\n\ndef ensure_level_specs(entries: Iterable[LevelSpec]) -> List[LevelSpec]:\n    \"\"\"Ensure deterministic ordering and validate duplicates.\"\"\"\n\n    specs = list(entries)\n    seen = set()\n    ordered: List[LevelSpec] = []\n    for spec in specs:\n        if spec.name in seen:\n            msg = f\"Duplicate level spec {spec.name}\"\n            raise ValueError(msg)\n        seen.add(spec.name)\n        ordered.append(spec)\n    return ordered\n"
  },
  {
    "path": "src/nested_learning/logging_utils.py",
    "content": "from __future__ import annotations\n\nimport json\nfrom pathlib import Path\nfrom typing import Any, Dict, cast\n\nfrom omegaconf import DictConfig, OmegaConf\n\n\nclass BaseLogger:\n    def log(self, metrics: Dict[str, Any], step: int) -> None:\n        raise NotImplementedError\n\n    def finish(self) -> None:\n        pass\n\n\nclass NullLogger(BaseLogger):\n    def log(self, metrics: Dict[str, Any], step: int) -> None:\n        return\n\n\nclass JSONLogger(BaseLogger):\n    def __init__(self, path: Path):\n        self.path = path\n        self.records: list[Dict[str, Any]] = []\n\n    def log(self, metrics: Dict[str, Any], step: int) -> None:\n        payload = {\"step\": step, **metrics}\n        self.records.append(payload)\n\n    def finish(self) -> None:\n        self.path.parent.mkdir(parents=True, exist_ok=True)\n        self.path.write_text(json.dumps(self.records, indent=2))\n\n\nclass WandbLogger(BaseLogger):\n    def __init__(self, cfg: DictConfig, full_cfg: DictConfig):\n        import wandb\n\n        project = cfg.get(\"project\", \"nested-learning\")\n        run_name = cfg.get(\"run_name\")\n        config_dict = cast(dict[str, Any], OmegaConf.to_container(full_cfg, resolve=True))\n        self.run = wandb.init(project=project, name=run_name, config=config_dict)\n\n    def log(self, metrics: Dict[str, Any], step: int) -> None:\n        if self.run is not None:\n            self.run.log(metrics, step=step)\n\n    def finish(self) -> None:\n        if self.run is not None:\n            self.run.finish()\n\n\ndef init_logger(logging_cfg: DictConfig | None, full_cfg: DictConfig) -> BaseLogger:\n    if logging_cfg is None or not logging_cfg.get(\"enabled\", False):\n        return NullLogger()\n    backend = logging_cfg.get(\"backend\", \"wandb\").lower()\n    if backend == \"wandb\":\n        return WandbLogger(logging_cfg, full_cfg)\n    if backend == \"json\":\n        path = Path(logging_cfg.get(\"path\", \"logs/train_metrics.json\"))\n        return JSONLogger(path)\n    return NullLogger()\n"
  },
  {
    "path": "src/nested_learning/memorize.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Dict\n\nimport torch\nimport torch.nn as nn\n\nfrom .tokenizer import SentencePieceTokenizer\nfrom .training import compute_teach_signal\n\n\n@dataclass\nclass MemorizeConfig:\n    enabled: bool = False\n    steps: int = 1\n    reset: bool = True\n    use_correct_answer: bool = False\n    use_fast_state: bool = True\n    surprise_threshold: float | None = None\n    paths: tuple[str, ...] | None = None\n    layers: tuple[int, ...] | None = None\n    online_chunk_size: int | None = None  # If set, use online chunked updates\n\n\ndef snapshot_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:\n    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n\n\ndef restore_state_dict(model: torch.nn.Module, state: Dict[str, torch.Tensor]) -> None:\n    model.load_state_dict(state, strict=False)\n\n\ndef _setup_memorization_context(model, cfg: MemorizeConfig):\n    \"\"\"Helper to setup model state for memorization.\"\"\"\n    prev_allowed = getattr(model, \"get_allowed_update_levels\", lambda: None)()\n    prev_threshold = getattr(model, \"get_surprise_threshold\", lambda: None)()\n    prev_layers = getattr(model, \"get_allowed_update_layers\", lambda: None)()\n\n    if hasattr(model, \"set_allowed_update_levels\"):\n        allowed = None\n        if cfg.paths is not None:\n            allowed = {path.strip() for path in cfg.paths if path.strip()}\n        getattr(model, \"set_allowed_update_levels\")(allowed)\n\n    if cfg.surprise_threshold is not None and hasattr(model, \"set_surprise_threshold\"):\n        getattr(model, \"set_surprise_threshold\")(cfg.surprise_threshold)\n\n    if hasattr(model, \"set_allowed_update_layers\"):\n        layers = None\n        if cfg.layers is not None:\n            layers = {int(idx) for idx in cfg.layers}\n        getattr(model, \"set_allowed_update_layers\")(layers)\n\n    return prev_allowed, prev_threshold, prev_layers\n\n\ndef _teardown_memorization_context(model, prev_allowed, prev_threshold, prev_layers):\n    \"\"\"Helper to restore model state after memorization.\"\"\"\n    if hasattr(model, \"set_allowed_update_levels\"):\n        getattr(model, \"set_allowed_update_levels\")(\n            prev_allowed if prev_allowed is None else set(prev_allowed)\n        )\n    if hasattr(model, \"set_surprise_threshold\"):\n        getattr(model, \"set_surprise_threshold\")(prev_threshold)\n    if hasattr(model, \"set_allowed_update_layers\"):\n        getattr(model, \"set_allowed_update_layers\")(\n            None if prev_layers is None else {int(idx) for idx in prev_layers}\n        )\n\n\ndef _collect_metrics(model, stats: dict[str, float]):\n    \"\"\"Helper to collect and aggregate update metrics.\"\"\"\n    if hasattr(model, \"pop_update_metrics\"):\n        metrics = model.pop_update_metrics()\n        titan_updates = sum(\n            value for key, value in metrics.items() if key.endswith(\"titan.titan.grad_norm\")\n        )\n        titan_hits = sum(\n            value for key, value in metrics.items() if key.endswith(\"titan.titan.gate_hit\")\n        )\n        stats[\"titan_mem_updates\"] += titan_updates\n        stats[\"titan_update_events\"] += titan_hits\n\n        # Aggregate CMS updates per level: keys look like \"layer{idx}.cms.<level>.<metric>\".\n        for key, value in metrics.items():\n            parts = key.split(\".\")\n            if len(parts) < 4:\n                continue\n            if parts[-3] != \"cms\":\n                continue\n            level = parts[-2]\n            metric = parts[-1]\n            if metric == \"grad_norm\":\n                stats_key = f\"{level}_updates\"\n                stats[stats_key] = stats.get(stats_key, 0.0) + float(value)\n            elif metric == \"gate_hit\":\n                stats_key = f\"{level}_update_events\"\n                stats[stats_key] = stats.get(stats_key, 0.0) + float(value)\n\n\ndef _layernorm_backward(\n    grad_out: torch.Tensor,\n    pre_norm: torch.Tensor,\n    norm: nn.LayerNorm,\n) -> torch.Tensor:\n    \"\"\"\n    Convert gradient w.r.t. LayerNorm output into gradient w.r.t. LayerNorm input.\n\n    This aligns the teach signal with the pre-norm hidden state that the blocks actually update.\n    \"\"\"\n    if grad_out.shape != pre_norm.shape:\n        raise ValueError(\"grad_out and pre_norm must have identical shapes\")\n    weight = norm.weight\n    if weight is None:\n        weight = torch.ones(pre_norm.shape[-1], device=pre_norm.device, dtype=pre_norm.dtype)\n    grad_hat = grad_out * weight.to(grad_out.dtype).view(1, 1, -1)\n    mean = pre_norm.mean(dim=-1, keepdim=True)\n    var = pre_norm.var(dim=-1, unbiased=False, keepdim=True)\n    inv_std = torch.rsqrt(var + norm.eps)\n    x_hat = (pre_norm - mean) * inv_std\n    grad_mean = grad_hat.mean(dim=-1, keepdim=True)\n    grad_proj = (grad_hat * x_hat).mean(dim=-1, keepdim=True)\n    return (grad_hat - grad_mean - x_hat * grad_proj) * inv_std\n\n\ndef _get_model_surprise_metric(model) -> str:\n    getter = getattr(model, \"get_surprise_metric\", None)\n    if callable(getter):\n        return str(getter()).strip().lower()\n    return \"l2\"\n\n\ndef _compute_surprise_value(\n    *,\n    model,\n    metric: str,\n    logits: torch.Tensor,\n    tokens: torch.Tensor,\n    teach_signal: torch.Tensor,\n) -> tuple[float, float | None]:\n    normalized = str(metric).strip().lower()\n    if normalized == \"l2\":\n        runtime_scale = float(getattr(model, \"_runtime_teach_scale\", 1.0))\n        runtime_clip = float(getattr(model, \"_runtime_teach_clip\", 0.0))\n        scaled = teach_signal * runtime_scale\n        if runtime_clip > 0:\n            norm = scaled.norm(dim=-1, keepdim=True)\n            scale = torch.clamp(norm / runtime_clip, min=1.0)\n            scaled = scaled / scale\n        value = float(scaled.norm(dim=-1).mean().item())\n        return value, None\n    if normalized == \"loss\":\n        loss = torch.nn.functional.cross_entropy(\n            logits[:, :-1].reshape(-1, logits.size(-1)),\n            tokens[:, 1:].reshape(-1),\n        )\n        value = float(loss.detach().item())\n        return value, value\n    if normalized == \"logit_entropy\":\n        logits_detached = logits[:, :-1].detach().float()\n        probs = torch.softmax(logits_detached, dim=-1)\n        entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()\n        value = float(entropy.item())\n        return value, value\n    raise ValueError(f\"Unsupported surprise_metric={metric!r}\")\n\n\ndef memorize_tokens(\n    model,\n    token_batch: torch.Tensor,\n    cfg: MemorizeConfig,\n    *,\n    fast_state=None,\n    teach_mask: torch.Tensor | None = None,\n) -> dict[str, float]:\n    if token_batch.size(1) < 2:\n        return {}\n\n    if cfg.use_fast_state and fast_state is None:\n        raise ValueError(\"cfg.use_fast_state=True requires passing fast_state\")\n\n    with torch.no_grad():\n        stats: dict[str, float] = {\n            \"titan_mem_updates\": 0.0,\n            \"titan_update_events\": 0.0,\n            \"cms_fast_updates\": 0.0,\n            \"cms_fast_update_events\": 0.0,\n            \"cms_mid_updates\": 0.0,\n            \"cms_mid_update_events\": 0.0,\n            \"cms_slow_updates\": 0.0,\n            \"cms_slow_update_events\": 0.0,\n            \"cms_ultra_updates\": 0.0,\n            \"cms_ultra_update_events\": 0.0,\n        }\n        prev_allowed, prev_threshold, prev_layers = _setup_memorization_context(model, cfg)\n\n        if cfg.online_chunk_size and cfg.online_chunk_size > 0:\n            # Online / Chunked Learning Mode\n            seq_len = token_batch.size(1)\n            chunk_size = cfg.online_chunk_size\n\n            # We process the sequence in increasing windows\n            # But to avoid O(N^2) cost for very long sequences, this is an approximation\n            # where we re-process the history. For faithful online learning, this is necessary\n            # without external KV cache management.\n\n            # Note: compute_teach_signal computes gradients for predicting tokens[1:]\n            # token_batch: [t0, t1, t2, t3]\n            # logits: [p1, p2, p3, p4] (aligned with t0..t3 input)\n            # teach_signal index i corresponds to error on token[i+1]\n\n            # We iterate over target token indices (1..seq_len-1) in chunks.\n            # For targets up to index K (exclusive end), feed tokens[:, :K] as context.\n            target_start = 1\n            while target_start < seq_len:\n                target_end = min(target_start + chunk_size, seq_len)\n                # We want to learn targets [target_start ... target_end]\n                # (python slice style end index).\n                # Range: target_start until target_end.\n\n                # To compute error for target at index K, we need input 0..K.\n                # So we need input up to target_end-1? No, up to target_end.\n                # Because compute_teach_signal aligns logits[:-1] with tokens[1:].\n                # If tokens is [A, B], logits[:-1] is preds for [B].\n                # So if we have input [A, B], we get error for B.\n                # If we have input [A, B, C], we get error for B, C.\n\n                # So to get error for targets up to target_end-1 (python slice),\n                # we need input tokens[:, :target_end].\n\n                context_tokens = token_batch[:, :target_end]\n\n                pre_norm = None\n                if hasattr(model, \"forward_with_pre_norm\"):\n                    forward_fn = getattr(model, \"forward_with_pre_norm\")\n                    logits, pre_norm = (\n                        forward_fn(context_tokens, fast_state=fast_state)\n                        if cfg.use_fast_state\n                        else forward_fn(context_tokens)\n                    )\n                else:\n                    logits = (\n                        model(context_tokens, fast_state=fast_state)\n                        if cfg.use_fast_state\n                        else model(context_tokens)\n                    )\n                full_signal = compute_teach_signal(model, logits, context_tokens)\n                if pre_norm is not None:\n                    norm = getattr(model, \"norm\", None)\n                    if isinstance(norm, nn.LayerNorm):\n                        full_signal = _layernorm_backward(full_signal, pre_norm, norm)\n\n                # full_signal length is target_end.\n                # indices correspond to errors for targets at 1 ... target_end.\n                # idx 0 -> target 1.\n                # idx k -> target k+1.\n\n                # We want to keep errors for targets [target_start ... target_end-1].\n                # These correspond to signal indices [target_start-1 ... target_end-2].\n\n                # Example: [A, B, C]. target_start=1 (B). target_end=2 (up to B).\n                # chunk=1.\n                # context [A, B].\n                # signal len 2. idx 0->B. idx 1->pad.\n                # We want B. idx 0.\n                # signal indices: target_start-1 (0) to target_end-1 (1)?\n                # Wait, if target_end is 2 (slice), we processed B.\n                # signal indices: 1-1=0. 2-2=0. Range 0:1.\n\n                mask = torch.zeros_like(full_signal)\n                mask_start = target_start - 1\n                mask_end = target_end - 1\n                mask[:, mask_start:mask_end, :] = 1.0\n\n                masked_signal = full_signal * mask\n                if teach_mask is not None:\n                    if teach_mask.ndim != 2:\n                        raise ValueError(\"teach_mask must have shape (B, T)\")\n                    if teach_mask.shape[0] != token_batch.shape[0]:\n                        raise ValueError(\"teach_mask batch size mismatch\")\n                    mask_slice = teach_mask[:, :target_end].to(masked_signal.device).float()\n                    masked_signal = masked_signal * mask_slice.unsqueeze(-1)\n                surprise_metric = _get_model_surprise_metric(model)\n                surprise_value, surprise_override = _compute_surprise_value(\n                    model=model,\n                    metric=surprise_metric,\n                    logits=logits,\n                    tokens=context_tokens,\n                    teach_signal=masked_signal,\n                )\n                if cfg.surprise_threshold is not None and surprise_value < cfg.surprise_threshold:\n                    target_start = target_end\n                    continue\n                if cfg.use_fast_state:\n                    model(\n                        context_tokens,\n                        teach_signal=masked_signal,\n                        surprise_value=surprise_override,\n                        fast_state=fast_state,\n                    )\n                else:\n                    model(\n                        context_tokens,\n                        teach_signal=masked_signal,\n                        surprise_value=surprise_override,\n                    )\n                _collect_metrics(model, stats)\n\n                target_start = target_end\n\n        else:\n            # Batch Mode (Default)\n            for _ in range(cfg.steps):\n                pre_norm = None\n                if hasattr(model, \"forward_with_pre_norm\"):\n                    forward_fn = getattr(model, \"forward_with_pre_norm\")\n                    logits, pre_norm = (\n                        forward_fn(token_batch, fast_state=fast_state)\n                        if cfg.use_fast_state\n                        else forward_fn(token_batch)\n                    )\n                else:\n                    logits = (\n                        model(token_batch, fast_state=fast_state)\n                        if cfg.use_fast_state\n                        else model(token_batch)\n                    )\n                teach_signal = compute_teach_signal(model, logits, token_batch)\n                if pre_norm is not None:\n                    norm = getattr(model, \"norm\", None)\n                    if isinstance(norm, nn.LayerNorm):\n                        teach_signal = _layernorm_backward(teach_signal, pre_norm, norm)\n                if teach_mask is not None:\n                    if teach_mask.ndim != 2:\n                        raise ValueError(\"teach_mask must have shape (B, T)\")\n                    if teach_mask.shape[:2] != teach_signal.shape[:2]:\n                        raise ValueError(\"teach_mask shape mismatch\")\n                    mask_f = teach_mask.to(teach_signal.device).float().unsqueeze(-1)\n                    teach_signal = teach_signal * mask_f\n                surprise_metric = _get_model_surprise_metric(model)\n                surprise_value, surprise_override = _compute_surprise_value(\n                    model=model,\n                    metric=surprise_metric,\n                    logits=logits,\n                    tokens=token_batch,\n                    teach_signal=teach_signal,\n                )\n                if cfg.surprise_threshold is not None and surprise_value < cfg.surprise_threshold:\n                    continue\n                if cfg.use_fast_state:\n                    model(\n                        token_batch,\n                        teach_signal=teach_signal,\n                        surprise_value=surprise_override,\n                        fast_state=fast_state,\n                    )\n                else:\n                    model(token_batch, teach_signal=teach_signal, surprise_value=surprise_override)\n                _collect_metrics(model, stats)\n\n        _teardown_memorization_context(model, prev_allowed, prev_threshold, prev_layers)\n        return stats\n\n\ndef memorize_sequence(\n    model,\n    tokenizer: SentencePieceTokenizer,\n    text: str,\n    device: torch.device,\n    cfg: MemorizeConfig,\n    *,\n    fast_state=None,\n    teach_mask: torch.Tensor | None = None,\n) -> dict[str, float]:\n    if not text:\n        return {}\n    tokens = tokenizer.encode(text)\n    if tokens.size(0) < 2:\n        return {}\n    batch = tokens.to(device).unsqueeze(0)\n    return memorize_tokens(model, batch, cfg, fast_state=fast_state, teach_mask=teach_mask)\n"
  },
  {
    "path": "src/nested_learning/model.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Dict, Protocol, Sequence, cast\n\nimport torch\nimport torch.nn as nn\nfrom torch.utils.checkpoint import checkpoint\n\nfrom .fast_state import (\n    AttentionKVCache,\n    ModelAttentionCache,\n    ModelFastState,\n    build_block_fast_state,\n)\nfrom .hope.block import (\n    HOPEAttentionBlock,\n    HOPEAttentionBlockConfig,\n    HOPEBlock,\n    HOPEBlockConfig,\n    HOPESelfModBlock,\n    HOPESelfModBlockConfig,\n)\nfrom .levels import LevelSpec\nfrom .transformer import TransformerBlock, TransformerBlockConfig\n\n\n@dataclass\nclass ModelConfig:\n    vocab_size: int\n    dim: int\n    num_layers: int\n    heads: int\n    titan_level: LevelSpec\n    cms_levels: Sequence[LevelSpec]\n    cms_flush_partial_at_end: bool = False\n    cms_use_layernorm: bool = True\n    optimizers: Dict[str, dict] | None = None\n    teach_scale: float = 1.0\n    teach_clip: float = 0.0\n    teach_schedule: Dict[str, float] | None = None\n    gradient_checkpointing: bool = False\n    surprise_threshold: float | None = None\n    surprise_metric: str = \"l2\"\n    freeze_backbone: bool = False\n    qk_l2_norm: bool = False\n    local_conv_window: int | None = None\n    self_mod_lr: float = 1e-3\n    self_mod_hidden: int = 4\n    self_mod_chunk_size: int = 1\n    self_mod_chunk_size_memory: int | None = None\n    self_mod_objective: str = \"l2\"\n    self_mod_stopgrad_vhat: bool = True\n    self_mod_use_rank1_precond: bool = True\n    self_mod_use_alpha: bool = True\n    self_mod_use_skip: bool = True\n    self_mod_momentum: float = 0.0\n    self_mod_adaptive_q: bool = False\n    self_mod_local_conv_window: int | None = 4\n    transformer_mlp_hidden_multiplier: int = 4\n    transformer_activation: str = \"gelu\"\n    block_variant: str = \"hope_hybrid\"\n\n\nclass HOPEModel(nn.Module):\n    def __init__(self, config: ModelConfig):\n        super().__init__()\n        self.config = config\n        self.embed = nn.Embedding(config.vocab_size, config.dim)\n        self.base_teach_scale = config.teach_scale\n        self.base_teach_clip = config.teach_clip\n        self._runtime_teach_scale = config.teach_scale\n        self._runtime_teach_clip = config.teach_clip\n        self.gradient_checkpointing = config.gradient_checkpointing\n        self._surprise_threshold = config.surprise_threshold\n        self._surprise_metric = \"l2\"\n        self._allowed_update_levels: set[str] | None = None\n        self._allowed_update_layers: set[int] | None = None\n        variant = str(config.block_variant).strip().lower()\n        if variant == \"hope_attention\":\n            attn_block_config = HOPEAttentionBlockConfig(\n                dim=config.dim,\n                heads=config.heads,\n                cms_levels=config.cms_levels,\n                cms_flush_partial_at_end=config.cms_flush_partial_at_end,\n                cms_use_layernorm=config.cms_use_layernorm,\n                qk_l2_norm=config.qk_l2_norm,\n                local_conv_window=config.local_conv_window,\n                self_mod_lr=config.self_mod_lr,\n                optimizer_configs=config.optimizers or {},\n            )\n            self.blocks = nn.ModuleList(\n                [HOPEAttentionBlock(attn_block_config) for _ in range(config.num_layers)]\n            )\n        elif variant == \"hope_hybrid\":\n            hybrid_block_config = HOPEBlockConfig(\n                dim=config.dim,\n                heads=config.heads,\n                titan_level=config.titan_level,\n                cms_levels=config.cms_levels,\n                cms_flush_partial_at_end=config.cms_flush_partial_at_end,\n                cms_use_layernorm=config.cms_use_layernorm,\n                qk_l2_norm=config.qk_l2_norm,\n                local_conv_window=config.local_conv_window,\n                self_mod_lr=config.self_mod_lr,\n                self_mod_hidden=config.self_mod_hidden,\n                optimizer_configs=config.optimizers or {},\n            )\n            self.blocks = nn.ModuleList(\n                [HOPEBlock(hybrid_block_config) for _ in range(config.num_layers)]\n            )\n        elif variant == \"hope_selfmod\":\n            selfmod_block_config = HOPESelfModBlockConfig(\n                dim=config.dim,\n                cms_levels=config.cms_levels,\n                cms_flush_partial_at_end=config.cms_flush_partial_at_end,\n                cms_use_layernorm=config.cms_use_layernorm,\n                qk_l2_norm=config.qk_l2_norm,\n                selfmod_adaptive_q=config.self_mod_adaptive_q,\n                selfmod_local_conv_window=config.self_mod_local_conv_window,\n                eta_scale=config.self_mod_lr,\n                selfmod_chunk_size=config.self_mod_chunk_size,\n                selfmod_chunk_size_memory=config.self_mod_chunk_size_memory,\n                selfmod_objective=config.self_mod_objective,\n                selfmod_stopgrad_vhat=config.self_mod_stopgrad_vhat,\n                selfmod_use_rank1_precond=config.self_mod_use_rank1_precond,\n                selfmod_use_alpha=config.self_mod_use_alpha,\n                selfmod_use_skip=config.self_mod_use_skip,\n                selfmod_momentum=config.self_mod_momentum,\n                self_mod_lr=config.self_mod_lr,\n                optimizer_configs=config.optimizers or {},\n            )\n            self.blocks = nn.ModuleList(\n                [HOPESelfModBlock(selfmod_block_config) for _ in range(config.num_layers)]\n            )\n        elif variant == \"transformer\":\n            transformer_block_config = TransformerBlockConfig(\n                dim=config.dim,\n                heads=config.heads,\n                mlp_hidden_multiplier=config.transformer_mlp_hidden_multiplier,\n                activation=config.transformer_activation,\n                qk_l2_norm=config.qk_l2_norm,\n                local_conv_window=config.local_conv_window,\n            )\n            self.blocks = nn.ModuleList(\n                [TransformerBlock(transformer_block_config) for _ in range(config.num_layers)]\n            )\n        else:\n            raise ValueError(\n                f\"Unsupported block_variant={config.block_variant!r}; expected one of \"\n                \"['hope_attention', 'hope_hybrid', 'hope_selfmod', 'transformer']\"\n            )\n        self.norm = nn.LayerNorm(config.dim)\n        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)\n        # Weight tying keeps the LM head gradient aligned with the embedding space.\n        self.lm_head.weight = self.embed.weight\n        self._latest_update_metrics: Dict[str, float] = {}\n        self.set_surprise_metric(config.surprise_metric)\n        self.set_surprise_threshold(self._surprise_threshold)\n        if config.freeze_backbone:\n            self.freeze_backbone()\n\n    def set_teach_runtime(self, *, scale: float | None = None, clip: float | None = None) -> None:\n        if scale is not None:\n            self._runtime_teach_scale = scale\n        if clip is not None:\n            self._runtime_teach_clip = clip\n\n    def set_surprise_threshold(self, threshold: float | None) -> None:\n        self._surprise_threshold = threshold\n        for block in self.blocks:\n            cast(_UpdateControlledBlock, block).set_surprise_threshold(threshold)\n\n    def get_surprise_threshold(self) -> float | None:\n        return self._surprise_threshold\n\n    def set_surprise_metric(self, metric: str) -> None:\n        normalized = str(metric).strip().lower()\n        allowed = {\"l2\", \"loss\", \"logit_entropy\"}\n        if normalized not in allowed:\n            raise ValueError(\n                f\"Unsupported surprise_metric={metric!r}; expected one of {sorted(allowed)}\"\n            )\n        self._surprise_metric = normalized\n        for block in self.blocks:\n            cast(_UpdateControlledBlock, block).set_surprise_metric(normalized)\n\n    def get_surprise_metric(self) -> str:\n        return self._surprise_metric\n\n    def set_allowed_update_levels(self, levels: set[str] | None) -> None:\n        self._allowed_update_levels = levels.copy() if levels is not None else None\n        for block in self.blocks:\n            cast(_UpdateControlledBlock, block).set_allowed_levels(self._allowed_update_levels)\n\n    def get_allowed_update_levels(self) -> set[str] | None:\n        return None if self._allowed_update_levels is None else self._allowed_update_levels.copy()\n\n    def set_allowed_update_layers(self, layers: set[int] | None) -> None:\n        if layers is None:\n            self._allowed_update_layers = None\n            return\n        normalized: set[int] = set()\n        total = len(self.blocks)\n        for idx in layers:\n            layer_idx = int(idx)\n            if layer_idx < 0:\n                layer_idx = total + layer_idx\n            if not (0 <= layer_idx < total):\n                raise ValueError(f\"Invalid layer index {idx} for model with {total} layers\")\n            normalized.add(layer_idx)\n        self._allowed_update_layers = normalized\n\n    def get_allowed_update_layers(self) -> set[int] | None:\n        return None if self._allowed_update_layers is None else self._allowed_update_layers.copy()\n\n    def forward(\n        self,\n        tokens: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        teach_signals: list[torch.Tensor] | None = None,\n        fast_state: ModelFastState | None = None,\n        surprise_value: float | None = None,\n        finalize_updates: bool = True,\n        attention_cache: ModelAttentionCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ModelAttentionCache]:\n        if return_attention_cache:\n            logits, _pre_norm, next_attention_cache = cast(\n                tuple[torch.Tensor, torch.Tensor, ModelAttentionCache],\n                self.forward_with_pre_norm(\n                    tokens,\n                    teach_signal=teach_signal,\n                    teach_signals=teach_signals,\n                    fast_state=fast_state,\n                    surprise_value=surprise_value,\n                    finalize_updates=finalize_updates,\n                    attention_cache=attention_cache,\n                    return_attention_cache=True,\n                    differentiable_updates=differentiable_updates,\n                ),\n            )\n            return logits, next_attention_cache\n        logits, _pre_norm = cast(\n            tuple[torch.Tensor, torch.Tensor],\n            self.forward_with_pre_norm(\n                tokens,\n                teach_signal=teach_signal,\n                teach_signals=teach_signals,\n                fast_state=fast_state,\n                surprise_value=surprise_value,\n                finalize_updates=finalize_updates,\n                attention_cache=attention_cache,\n                return_attention_cache=False,\n                differentiable_updates=differentiable_updates,\n            ),\n        )\n        return logits\n\n    def forward_with_pre_norm(\n        self,\n        tokens: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        teach_signals: list[torch.Tensor] | None = None,\n        fast_state: ModelFastState | None = None,\n        surprise_value: float | None = None,\n        finalize_updates: bool = True,\n        attention_cache: ModelAttentionCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> (\n        tuple[torch.Tensor, torch.Tensor]\n        | tuple[torch.Tensor, torch.Tensor, ModelAttentionCache]\n    ):\n        if return_attention_cache:\n            x, next_attention_caches = cast(\n                tuple[torch.Tensor, list[AttentionKVCache | None]],\n                self._run_blocks(\n                    tokens,\n                    teach_signal=teach_signal,\n                    teach_signals=teach_signals,\n                    fast_state=fast_state,\n                    surprise_value=surprise_value,\n                    finalize_updates=finalize_updates,\n                    attention_cache=attention_cache,\n                    return_attention_cache=True,\n                    differentiable_updates=differentiable_updates,\n                ),\n            )\n        else:\n            x = cast(\n                torch.Tensor,\n                self._run_blocks(\n                    tokens,\n                    teach_signal=teach_signal,\n                    teach_signals=teach_signals,\n                    fast_state=fast_state,\n                    surprise_value=surprise_value,\n                    finalize_updates=finalize_updates,\n                    attention_cache=attention_cache,\n                    return_attention_cache=False,\n                    differentiable_updates=differentiable_updates,\n                ),\n            )\n        pre_norm = cast(torch.Tensor, x)\n        x = self.norm(pre_norm)\n        logits = self.lm_head(x)\n        if teach_signal is not None or teach_signals is not None:\n            self._latest_update_metrics = self._gather_block_stats()\n        if return_attention_cache:\n            return logits, pre_norm, ModelAttentionCache(blocks=next_attention_caches)\n        return logits, pre_norm\n\n    def forward_with_block_outputs(\n        self,\n        tokens: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        teach_signals: list[torch.Tensor] | None = None,\n        fast_state: ModelFastState | None = None,\n        surprise_value: float | None = None,\n        finalize_updates: bool = True,\n        attention_cache: ModelAttentionCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> (\n        tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]\n        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], ModelAttentionCache]\n    ):\n        if return_attention_cache:\n            x, block_outputs, next_attention_caches = cast(\n                tuple[torch.Tensor, list[torch.Tensor], list[AttentionKVCache | None]],\n                self._run_blocks(\n                    tokens,\n                    teach_signal=teach_signal,\n                    teach_signals=teach_signals,\n                    fast_state=fast_state,\n                    surprise_value=surprise_value,\n                    finalize_updates=finalize_updates,\n                    attention_cache=attention_cache,\n                    return_attention_cache=True,\n                    collect_outputs=True,\n                    differentiable_updates=differentiable_updates,\n                ),\n            )\n        else:\n            x, block_outputs = cast(\n                tuple[torch.Tensor, list[torch.Tensor]],\n                self._run_blocks(\n                    tokens,\n                    teach_signal=teach_signal,\n                    teach_signals=teach_signals,\n                    fast_state=fast_state,\n                    surprise_value=surprise_value,\n                    finalize_updates=finalize_updates,\n                    attention_cache=attention_cache,\n                    return_attention_cache=False,\n                    collect_outputs=True,\n                    differentiable_updates=differentiable_updates,\n                ),\n            )\n        pre_norm = x\n        x = self.norm(x)\n        logits = self.lm_head(x)\n        if teach_signal is not None or teach_signals is not None:\n            self._latest_update_metrics = self._gather_block_stats()\n        if return_attention_cache:\n            return (\n                logits,\n                pre_norm,\n                block_outputs,\n                ModelAttentionCache(blocks=next_attention_caches),\n            )\n        return logits, pre_norm, block_outputs\n\n    def _run_blocks(\n        self,\n        tokens: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None,\n        fast_state: ModelFastState | None,\n        teach_signals: list[torch.Tensor] | None = None,\n        surprise_value: float | None = None,\n        finalize_updates: bool = True,\n        attention_cache: ModelAttentionCache | None = None,\n        return_attention_cache: bool = False,\n        collect_outputs: bool = False,\n        differentiable_updates: bool = False,\n    ) -> (\n        torch.Tensor\n        | tuple[torch.Tensor, list[torch.Tensor]]\n        | tuple[torch.Tensor, list[AttentionKVCache | None]]\n        | tuple[torch.Tensor, list[torch.Tensor], list[AttentionKVCache | None]]\n    ):\n        x = self.embed(tokens)\n        block_outputs: list[torch.Tensor] = []\n        next_attention_caches: list[AttentionKVCache | None] = []\n        runtime_scale = self._runtime_teach_scale\n        runtime_clip = self._runtime_teach_clip\n        if teach_signals is not None:\n            if len(teach_signals) != len(self.blocks):\n                raise ValueError(\n                    f\"teach_signals length {len(teach_signals)} \"\n                    f\"does not match blocks {len(self.blocks)}\"\n                )\n            if teach_signal is not None:\n                raise ValueError(\"Provide either teach_signal or teach_signals, not both.\")\n        if fast_state is not None and len(fast_state.blocks) != len(self.blocks):\n            raise ValueError(\"fast_state.blocks length does not match model.blocks\")\n        if attention_cache is not None and len(attention_cache.blocks) != len(self.blocks):\n            raise ValueError(\"attention_cache.blocks length does not match model.blocks\")\n\n        require_external = self._surprise_metric in {\"loss\", \"logit_entropy\"}\n        if require_external and self._surprise_threshold is not None:\n            if (teach_signal is not None or teach_signals is not None) and surprise_value is None:\n                raise ValueError(\n                    f\"surprise_metric={self._surprise_metric} requires passing surprise_value \"\n                    \"when model.surprise_threshold is set.\"\n                )\n\n        base_surprise = surprise_value\n        scaled_global_signal: torch.Tensor | None = None\n        if base_surprise is None and teach_signal is not None and self._surprise_metric == \"l2\":\n            scaled_global_signal = teach_signal * runtime_scale\n            if runtime_clip > 0:\n                norm = scaled_global_signal.norm(dim=-1, keepdim=True)\n                scale = torch.clamp(norm / runtime_clip, min=1.0)\n                scaled_global_signal = scaled_global_signal / scale\n            base_surprise = float(scaled_global_signal.norm(dim=-1).mean().item())\n\n        for idx, block in enumerate(self.blocks):\n            block_state = None if fast_state is None else fast_state.blocks[idx]\n            block_attention_cache = None if attention_cache is None else attention_cache.blocks[idx]\n            scaled_signal = None\n            block_surprise = base_surprise\n            if teach_signal is not None:\n                if scaled_global_signal is None:\n                    scaled_signal = teach_signal * runtime_scale\n                    if runtime_clip > 0:\n                        norm = scaled_signal.norm(dim=-1, keepdim=True)\n                        scale = torch.clamp(norm / runtime_clip, min=1.0)\n                        scaled_signal = scaled_signal / scale\n                else:\n                    scaled_signal = scaled_global_signal\n                if (\n                    self._allowed_update_layers is not None\n                    and idx not in self._allowed_update_layers\n                ):\n                    scaled_signal = None\n            if teach_signals is not None:\n                scaled_signal = teach_signals[idx] * self._runtime_teach_scale\n                if self._surprise_metric == \"l2\" and base_surprise is None:\n                    block_surprise = float(scaled_signal.norm(dim=-1).mean().item())\n                if self._runtime_teach_clip > 0:\n                    norm = scaled_signal.norm(dim=-1, keepdim=True)\n                    scale = torch.clamp(norm / self._runtime_teach_clip, min=1.0)\n                    scaled_signal = scaled_signal / scale\n                if (\n                    self._allowed_update_layers is not None\n                    and idx not in self._allowed_update_layers\n                ):\n                    scaled_signal = None\n\n            def block_call(\n                hidden: torch.Tensor,\n                *,\n                blk=block,\n                sig=scaled_signal,\n                st=block_state,\n                sv=block_surprise,\n                fin=finalize_updates,\n                ac=block_attention_cache,\n                du=differentiable_updates,\n            ) -> torch.Tensor:\n                return blk(\n                    hidden,\n                    teach_signal=sig,\n                    surprise_value=sv,\n                    fast_state=st,\n                    finalize_updates=fin,\n                    attention_cache=ac,\n                    differentiable_updates=du,\n                )\n\n            if return_attention_cache:\n                x, next_cache = block(  # type: ignore[assignment]\n                    x,\n                    teach_signal=scaled_signal,\n                    surprise_value=block_surprise,\n                    fast_state=block_state,\n                    finalize_updates=finalize_updates,\n                    attention_cache=block_attention_cache,\n                    return_attention_cache=True,\n                    differentiable_updates=differentiable_updates,\n                )\n                next_attention_caches.append(next_cache)\n            elif torch.is_grad_enabled() and self.training and self.gradient_checkpointing:\n                x = checkpoint(block_call, x, use_reentrant=False)\n            else:\n                x = block_call(x)\n            if collect_outputs:\n                block_outputs.append(x)\n        if collect_outputs and return_attention_cache:\n            return x, block_outputs, next_attention_caches\n        if collect_outputs:\n            return x, block_outputs\n        if return_attention_cache:\n            return x, next_attention_caches\n        return x\n\n    def _gather_block_stats(self) -> Dict[str, float]:\n        metrics: Dict[str, float] = {}\n        for idx, block in enumerate(self.blocks):\n            pop_fn = getattr(block, \"pop_update_stats\", None)\n            if callable(pop_fn):\n                stats = cast(Dict[str, Dict[str, float]], pop_fn())\n                for level_name, payload in stats.items():\n                    prefix = f\"layer{idx}.{level_name}\"\n                    for key, value in payload.items():\n                        metrics[f\"{prefix}.{key}\"] = value\n        return metrics\n\n    def pop_update_metrics(self) -> Dict[str, float]:\n        metrics = self._latest_update_metrics\n        self._latest_update_metrics = {}\n        return metrics\n\n    def init_fast_state(self) -> ModelFastState:\n        states = []\n        for block in self.blocks:\n            if isinstance(block, HOPEBlock):\n                specs = [block.config.titan_level, *block.config.cms_levels]\n                state = build_block_fast_state(\n                    titan_module=block.titan_memory,\n                    cms_blocks=dict(block.cms.blocks.items()),\n                    specs=specs,\n                    optimizer_configs=block.config.optimizer_configs,\n                    default_lr=block.config.self_mod_lr,\n                )\n                states.append(state)\n            elif isinstance(block, HOPEAttentionBlock):\n                specs = list(block.config.cms_levels)\n                state = build_block_fast_state(\n                    titan_module=None,\n                    cms_blocks=dict(block.cms.blocks.items()),\n                    specs=specs,\n                    optimizer_configs=block.config.optimizer_configs,\n                    default_lr=block.config.self_mod_lr,\n                )\n                states.append(state)\n            elif isinstance(block, HOPESelfModBlock):\n                specs = list(block.config.cms_levels)\n                state = build_block_fast_state(\n                    titan_module=None,\n                    cms_blocks=dict(block.cms.blocks.items()),\n                    selfmod_module=block.selfmod,\n                    specs=specs,\n                    optimizer_configs=block.config.optimizer_configs,\n                    default_lr=block.config.self_mod_lr,\n                )\n                states.append(state)\n            elif isinstance(block, TransformerBlock):\n                state = build_block_fast_state(\n                    titan_module=None,\n                    cms_blocks={},\n                    specs=(),\n                    optimizer_configs={},\n                    default_lr=0.0,\n                )\n                states.append(state)\n            else:\n                raise TypeError(f\"Unsupported block type for fast state: {type(block)}\")\n        return ModelFastState(blocks=states)\n\n    def init_attention_cache(self) -> ModelAttentionCache:\n        return ModelAttentionCache(blocks=[None for _ in self.blocks])\n\n    def freeze_backbone(self) -> None:\n        \"\"\"\n        Freeze the shared transformer spine (embeddings, attention blocks, norm, LM head).\n        HOPE/TITAN/CMS memories remain trainable for adapter-style finetuning.\n        \"\"\"\n        for p in self.embed.parameters():\n            p.requires_grad = False\n        for p in self.norm.parameters():\n            p.requires_grad = False\n        for p in self.lm_head.parameters():\n            p.requires_grad = False\n        for block in self.blocks:\n            attn = getattr(block, \"attn\", None)\n            if isinstance(attn, nn.Module):\n                for p in attn.parameters():\n                    p.requires_grad = False\n\n\nclass _UpdateControlledBlock(Protocol):\n    def set_surprise_threshold(self, threshold: float | None) -> None: ...\n\n    def set_surprise_metric(self, metric: str) -> None: ...\n\n    def set_allowed_levels(self, allowed: set[str] | None) -> None: ...\n"
  },
  {
    "path": "src/nested_learning/optim/__init__.py",
    "content": ""
  },
  {
    "path": "src/nested_learning/optim/deep.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\n\n\n@dataclass\nclass DeepMomentumState:\n    grad_avg: Optional[torch.Tensor] = None\n    sq_avg: Optional[torch.Tensor] = None\n\n\nclass DeepMomentum(nn.Module):\n    \"\"\"Implements momentum variants described in the NL paper.\"\"\"\n\n    def __init__(\n        self,\n        *,\n        beta: float = 0.9,\n        beta2: float = 0.999,\n        eps: float = 1e-8,\n        variant: str = \"preconditioned\",\n    ) -> None:\n        super().__init__()\n        self.beta = beta\n        self.beta2 = beta2\n        self.eps = eps\n        self.variant = variant\n        self.state: dict[str, DeepMomentumState] = {}\n        self.nonlinearity = nn.Tanh() if variant in {\"dmgd\", \"muon\"} else nn.Identity()\n        self.last_metrics: dict[str, float] = {}\n\n    def reset_state(self) -> None:\n        self.state.clear()\n\n    def _precondition(self, grad: torch.Tensor, state: DeepMomentumState) -> torch.Tensor:\n        if state.sq_avg is None or state.sq_avg.shape != grad.shape:\n            state.sq_avg = torch.zeros_like(grad)\n        state.sq_avg.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)\n        denom = state.sq_avg.sqrt().add_(self.eps)\n        return grad / denom\n\n    def _nl_precondition(\n        self,\n        grad: torch.Tensor,\n        context: torch.Tensor | None,\n    ) -> tuple[torch.Tensor, dict[str, float]]:\n        metrics: dict[str, float] = {\n            \"ctx_norm\": 0.0,\n            \"proj_norm\": 0.0,\n            \"proj_skipped\": 0.0,\n        }\n        if context is None:\n            return grad, metrics\n        ctx = context\n        if ctx.ndim > 1:\n            ctx = ctx.reshape(-1, ctx.shape[-1]).mean(dim=0)\n        ctx_norm = torch.norm(ctx)\n        metrics[\"ctx_norm\"] = ctx_norm.item()\n\n        if ctx_norm > 0:\n            if grad.ndim == 0 or grad.shape[-1] != ctx.shape[-1]:\n                metrics[\"proj_skipped\"] = 1.0\n                return grad, metrics\n            unit = ctx / (ctx_norm + self.eps)\n            # Project grad orthogonal to context (rank-1 projector).\n            projection = (grad * unit).sum(dim=-1, keepdim=True) * unit\n            update = grad - projection\n            metrics[\"proj_norm\"] = torch.norm(update).item()\n            return update, metrics\n        return grad, metrics\n\n    def forward(  # type: ignore[override]\n        self,\n        grad: torch.Tensor,\n        *,\n        context: torch.Tensor | None = None,\n        param_key: str | None = None,\n    ) -> torch.Tensor:\n        key = param_key or \"__default__\"\n        state = self.state.get(key)\n        if state is None:\n            state = DeepMomentumState()\n            self.state[key] = state\n        if state.grad_avg is None or state.grad_avg.shape != grad.shape:\n            state.grad_avg = torch.zeros_like(grad)\n        self.last_metrics = {}\n        update = grad\n        if self.variant in {\"preconditioned\", \"muon\"}:\n            update = self._precondition(grad, state)\n        if self.variant == \"l2_objective\":\n            update = grad + 0.1 * torch.mean(grad, dim=-1, keepdim=True)\n        if self.variant == \"nl_l2_precond\":\n            update, metrics = self._nl_precondition(grad, context)\n            self.last_metrics.update(metrics)\n        if self.variant in {\"dmgd\", \"muon\"}:\n            update = self.nonlinearity(update)\n        state.grad_avg.mul_(self.beta).add_(update, alpha=1 - self.beta)\n        return state.grad_avg\n"
  },
  {
    "path": "src/nested_learning/optim/factory.py",
    "content": "from __future__ import annotations\n\nfrom typing import Any, Dict\n\nfrom .deep import DeepMomentum\n\n\ndef build_optimizer(config: Dict[str, Any]) -> DeepMomentum:\n    opt_type = config.get(\"type\", \"deep_momentum\").lower()\n    if opt_type != \"deep_momentum\":\n        raise ValueError(f\"Unsupported optimizer type {opt_type}\")\n    params = config.get(\"params\", {})\n    return DeepMomentum(**params)\n"
  },
  {
    "path": "src/nested_learning/optim/m3.py",
    "content": "from __future__ import annotations\n\nfrom typing import Iterable\n\nimport torch\n\n\ndef _newton_schulz(matrix: torch.Tensor, steps: int, eps: float = 1e-6) -> torch.Tensor:\n    if matrix.ndim != 2:\n        raise ValueError(\"Newton-Schulz expects a 2D matrix\")\n    dtype = matrix.dtype\n    device = matrix.device\n    m, n = matrix.shape\n    x = matrix\n    norm = torch.linalg.norm(x)\n    x = x / (norm + eps)\n    eye = torch.eye(n, device=device, dtype=dtype)\n    for _ in range(steps):\n        x = 0.5 * x @ (3.0 * eye - x.T @ x)\n    return x\n\n\ndef _orthogonalize(tensor: torch.Tensor, steps: int, eps: float) -> torch.Tensor:\n    if tensor.ndim < 2:\n        return tensor\n    mat = tensor.reshape(tensor.shape[0], -1)\n    ortho = _newton_schulz(mat, steps=steps, eps=eps)\n    return ortho.reshape_as(tensor)\n\n\nclass M3(torch.optim.Optimizer):\n    \"\"\"\n    Multi-scale Momentum Muon (M3) optimizer (Nested Learning paper, Algorithm 1).\n\n    This is a paper-faithful implementation for 2D weight tensors:\n      - M1: fast momentum\n      - M2: slow momentum (updated every `slow_chunk` steps)\n      - V: second moment\n      - O1/O2: Newton-Schulz orthogonalized momenta\n    \"\"\"\n\n    def __init__(\n        self,\n        params: Iterable[torch.nn.Parameter],\n        *,\n        lr: float = 1e-3,\n        beta1: float = 0.9,\n        beta2: float = 0.999,\n        beta3: float = 0.9,\n        alpha: float = 1.0,\n        eps: float = 1e-8,\n        ns_steps: int = 3,\n        slow_chunk: int = 100,\n        weight_decay: float = 0.0,\n    ) -> None:\n        defaults = dict(\n            lr=lr,\n            beta1=beta1,\n            beta2=beta2,\n            beta3=beta3,\n            alpha=alpha,\n            eps=eps,\n            ns_steps=ns_steps,\n            slow_chunk=slow_chunk,\n            weight_decay=weight_decay,\n        )\n        super().__init__(params, defaults)\n\n    @torch.no_grad()\n    def step(self, closure=None):  # type: ignore[override]\n        loss = None\n        if closure is not None:\n            with torch.enable_grad():\n                loss = closure()\n        for group in self.param_groups:\n            lr = group[\"lr\"]\n            beta1 = group[\"beta1\"]\n            beta2 = group[\"beta2\"]\n            beta3 = group[\"beta3\"]\n            alpha = group[\"alpha\"]\n            eps = group[\"eps\"]\n            ns_steps = group[\"ns_steps\"]\n            slow_chunk = group[\"slow_chunk\"]\n            weight_decay = group[\"weight_decay\"]\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad\n                if weight_decay != 0.0:\n                    grad = grad.add(p, alpha=weight_decay)\n                state = self.state[p]\n                if not state:\n                    state[\"step\"] = 0\n                    state[\"m1\"] = torch.zeros_like(p)\n                    state[\"m2\"] = torch.zeros_like(p)\n                    state[\"v\"] = torch.zeros_like(p)\n                    state[\"slow_buffer\"] = torch.zeros_like(p)\n                    state[\"o2\"] = torch.zeros_like(p)\n                state[\"step\"] += 1\n                m1 = state[\"m1\"]\n                m2 = state[\"m2\"]\n                v = state[\"v\"]\n                slow_buffer = state[\"slow_buffer\"]\n\n                m1.add_(grad, alpha=beta1)\n                v.addcmul_(grad, grad, value=beta2)\n                slow_buffer.add_(grad)\n\n                o1 = _orthogonalize(m1, steps=ns_steps, eps=eps)\n                o2 = state[\"o2\"]\n                denom = v.sqrt().add_(eps)\n                update = (o1 + alpha * o2) / denom\n                p.add_(update, alpha=-lr)\n\n                if slow_chunk > 0 and state[\"step\"] % slow_chunk == 0:\n                    # Paper Algorithm 1 uses the updated slow momentum term in the *next* chunk.\n                    # Compute it after applying the current step update to avoid off-by-one usage.\n                    m2.add_(slow_buffer, alpha=beta3)\n                    slow_buffer.zero_()\n                    state[\"o2\"] = _orthogonalize(m2, steps=ns_steps, eps=eps)\n        return loss\n"
  },
  {
    "path": "src/nested_learning/optim/manager.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Dict, Sequence, Tuple\n\nimport torch\nfrom torch import nn\n\nfrom ..levels import LevelClock, LevelSpec\nfrom .factory import build_optimizer\n\n\n@dataclass\nclass LevelConfig:\n    specs: Sequence[LevelSpec]\n    optimizer_configs: Dict[str, dict]\n    default_lr: float\n\n\nclass LevelOptimizerManager:\n    def __init__(self, config: LevelConfig):\n        self.clock = LevelClock(config.specs)\n        self.learning_rates: Dict[str, float] = {}\n        self.optimizers = {}\n        self._last_metrics: Dict[str, Dict[str, float]] = {}\n        for spec in config.specs:\n            key = spec.optimizer_key or \"default\"\n            optim_cfg = config.optimizer_configs.get(key, {\"type\": \"deep_momentum\", \"params\": {}})\n            lr = optim_cfg.get(\"lr\", config.default_lr)\n            params_cfg = optim_cfg.get(\"params\", {})\n            optimizer = build_optimizer(\n                {\"type\": optim_cfg.get(\"type\", \"deep_momentum\"), \"params\": params_cfg}\n            )\n            self.optimizers[spec.name] = optimizer\n            self.learning_rates[spec.name] = lr\n\n    def should_update(self, level: str) -> bool:\n        return self.clock.should_update(level)\n\n    def optimize(\n        self,\n        level: str,\n        module: nn.Module,\n        loss: torch.Tensor,\n        *,\n        context: torch.Tensor | None = None,\n        force: bool = False,\n    ) -> float:\n        if (not force) and (not self.should_update(level)):\n            return 0.0\n        named_params: Tuple[Tuple[str, torch.nn.Parameter], ...] = tuple(\n            (name, param) for name, param in module.named_parameters() if param.requires_grad\n        )\n        if not named_params:\n            return 0.0\n        params = tuple(param for _, param in named_params)\n        grads = torch.autograd.grad(loss, params, retain_graph=False, allow_unused=True)\n        grads_dict: Dict[str, torch.Tensor] = {}\n        for (name, _), grad in zip(named_params, grads, strict=True):\n            if grad is None:\n                continue\n            grads_dict[name] = grad\n        return self.apply_module_grads(\n            level,\n            module,\n            grads_dict,\n            context=context,\n            force=True,\n        )\n\n    def apply_module_grads(\n        self,\n        level: str,\n        module: nn.Module,\n        grads: Dict[str, torch.Tensor],\n        *,\n        context: torch.Tensor | None = None,\n        force: bool = False,\n    ) -> float:\n        if (not force) and (not self.should_update(level)):\n            return 0.0\n        optimizer = self.optimizers[level]\n        lr = self.learning_rates[level]\n        total_norm = 0.0\n        with torch.no_grad():\n            for name, param in module.named_parameters():\n                if not param.requires_grad:\n                    continue\n                grad = grads.get(name)\n                if grad is None:\n                    continue\n                update = optimizer(grad, context=context, param_key=name)\n                param.add_(update, alpha=-lr)\n                total_norm += grad.norm().item()\n        self.clock.record_update(level)\n        metrics = getattr(optimizer, \"last_metrics\", None)\n        if metrics:\n            self._last_metrics[level] = dict(metrics)\n        else:\n            self._last_metrics[level] = {}\n        return total_norm\n\n    def tick(self) -> None:\n        self.clock.tick()\n\n    def pop_last_metrics(self, level: str) -> Dict[str, float]:\n        return self._last_metrics.pop(level, {})\n\n    def apply_grads(\n        self,\n        level: str,\n        params: Dict[str, torch.Tensor],\n        grads: Dict[str, torch.Tensor],\n        *,\n        context: torch.Tensor | None = None,\n        force: bool = False,\n        differentiable: bool = False,\n    ) -> tuple[Dict[str, torch.Tensor], float]:\n        if (not force) and (not self.should_update(level)):\n            return params, 0.0\n        optimizer = self.optimizers[level]\n        lr = self.learning_rates[level]\n        updated: Dict[str, torch.Tensor] = {}\n        total_norm = 0.0\n        if differentiable:\n            for name, param in params.items():\n                grad = grads.get(name)\n                if grad is None:\n                    updated[name] = param\n                    continue\n                updated[name] = param - lr * grad\n                total_norm += float(grad.detach().norm().item())\n            self.clock.record_update(level)\n            self._last_metrics[level] = {\"differentiable_updates\": 1.0}\n            return updated, total_norm\n        with torch.no_grad():\n            for name, param in params.items():\n                grad = grads.get(name)\n                if grad is None:\n                    updated[name] = param\n                    continue\n                update = optimizer(grad, context=context, param_key=name)\n                updated[name] = (param - lr * update).detach()\n                total_norm += grad.norm().item()\n        self.clock.record_update(level)\n        metrics = getattr(optimizer, \"last_metrics\", None)\n        if metrics:\n            self._last_metrics[level] = dict(metrics)\n        else:\n            self._last_metrics[level] = {}\n        return updated, total_norm\n"
  },
  {
    "path": "src/nested_learning/titan/__init__.py",
    "content": ""
  },
  {
    "path": "src/nested_learning/titan/memory.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Dict\n\nimport torch\nimport torch.nn as nn\n\nfrom ..assoc_memory import AssocMemory\n\n\n@dataclass\nclass TitanMemoryConfig:\n    dim: int\n    hidden_multiplier: int = 4\n    layers: int = 2\n    activation: str = \"gelu\"\n\n\ndef _activation(name: str) -> nn.Module:\n    if name.lower() == \"relu\":\n        return nn.ReLU()\n    if name.lower() == \"gelu\":\n        return nn.GELU()\n    if name.lower() == \"silu\":\n        return nn.SiLU()\n    msg = f\"Unsupported activation {name}\"\n    raise ValueError(msg)\n\n\nclass TitanMemory(AssocMemory):\n    \"\"\"Simplified TITAN-style associative memory.\"\"\"\n\n    def __init__(self, config: TitanMemoryConfig):\n        super().__init__()\n        self.config = config\n        hidden = config.dim * config.hidden_multiplier\n        blocks = []\n        activation = _activation(config.activation)\n        for layer_idx in range(config.layers - 1):\n            blocks.extend([nn.Linear(config.dim if layer_idx == 0 else hidden, hidden), activation])\n        blocks.append(nn.Linear(hidden if config.layers > 1 else config.dim, config.dim))\n        self.net = nn.Sequential(*blocks)\n        self.norm = nn.LayerNorm(config.dim)\n        self.grad_clip = 1.0\n\n    def forward(self, query: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        attn = self.net(query)\n        if self.training and self.grad_clip > 0:\n            with torch.no_grad():\n                norm = attn.norm(dim=-1, keepdim=True)\n                scale = torch.clamp(norm / self.grad_clip, min=1.0)\n            attn = attn / scale\n        return self.norm(attn)\n\n    def surprise(self, residual: torch.Tensor) -> torch.Tensor:\n        return residual.norm(dim=-1, keepdim=True)\n\n    @torch.no_grad()\n    def update(\n        self,\n        *,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        error_signal: torch.Tensor | None = None,\n        lr: float = 1e-3,\n    ) -> None:\n        with torch.enable_grad():\n            key_detached = key.detach().requires_grad_(True)\n            prediction = self.forward(key_detached)\n            target = value.detach()\n            if error_signal is None:\n                loss = torch.mean((prediction - target) ** 2)\n            else:\n                loss = torch.mean(error_signal * prediction)\n        grads = torch.autograd.grad(loss, list(self.net.parameters()), retain_graph=False)\n        for param, grad in zip(self.net.parameters(), grads, strict=False):\n            if grad is None:\n                continue\n            param.add_(grad, alpha=-lr)\n\n    @torch.no_grad()\n    def apply_deltas(self, deltas: Dict[str, torch.Tensor], scale: float = 1.0) -> None:\n        for name, tensor in deltas.items():\n            target = dict(self.named_parameters()).get(name)\n            if target is None:\n                continue\n            target.add_(tensor, alpha=scale)\n"
  },
  {
    "path": "src/nested_learning/titan/model.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Dict, cast\n\nimport torch\nimport torch.nn as nn\n\nfrom ..backbones import AttentionConfig, SelfAttention\nfrom ..fast_state import (\n    AttentionKVCache,\n    BlockFastState,\n    ModelAttentionCache,\n    ModelFastState,\n    build_block_fast_state,\n)\nfrom ..functional import (\n    call_with_deltas,\n    call_with_params,\n    grads_to_dict,\n    params_with_deltas,\n    require_grad_params,\n)\nfrom ..hope.self_mod import SelfModifier\nfrom ..levels import LevelSpec\nfrom ..optim.manager import LevelConfig, LevelOptimizerManager\nfrom ..titan.memory import TitanMemory, TitanMemoryConfig\n\n\n@dataclass\nclass TitanOnlyModelConfig:\n    vocab_size: int\n    dim: int\n    num_layers: int\n    heads: int\n    titan_level: LevelSpec\n    optimizers: Dict[str, dict] | None = None\n    teach_scale: float = 1.0\n    teach_clip: float = 0.0\n    teach_schedule: Dict[str, float] | None = None\n    qk_l2_norm: bool = False\n    local_conv_window: int | None = None\n    titan_hidden_multiplier: int = 4\n    activation: str = \"gelu\"\n    self_mod_hidden: int = 4\n    self_mod_lr: float = 1e-3\n    surprise_threshold: float | None = None\n    surprise_metric: str = \"l2\"\n    freeze_backbone: bool = False\n\n\nclass TitanOnlyBlock(nn.Module):\n    def __init__(self, config: TitanOnlyModelConfig):\n        super().__init__()\n        self.config = config\n        self.surprise_threshold: float | None = None\n        self.surprise_metric: str = \"l2\"\n        self.enabled: bool = True\n        self.attn = SelfAttention(\n            AttentionConfig(\n                dim=config.dim,\n                heads=config.heads,\n                qk_l2_norm=config.qk_l2_norm,\n                local_conv_window=config.local_conv_window,\n            )\n        )\n        titan_config = TitanMemoryConfig(\n            dim=config.dim,\n            hidden_multiplier=config.titan_hidden_multiplier,\n            activation=config.activation,\n        )\n        self.titan_memory = TitanMemory(titan_config)\n        self.self_modifier = SelfModifier(config.dim, hidden_multiplier=config.self_mod_hidden)\n        self.dropout = nn.Dropout(0.0)\n        self.norm = nn.LayerNorm(config.dim)\n        level_config = LevelConfig(\n            specs=[config.titan_level],\n            optimizer_configs=config.optimizers or {},\n            default_lr=config.self_mod_lr,\n        )\n        self.level_manager = LevelOptimizerManager(level_config)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        surprise_value: float | None = None,\n        fast_state: BlockFastState | None = None,\n        attention_cache: AttentionKVCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:\n        _ = differentiable_updates\n        next_attn_cache: AttentionKVCache | None = None\n        if return_attention_cache:\n            attn_out, next_attn_cache = self.attn(\n                x,\n                kv_cache=attention_cache,\n                return_kv_cache=True,\n            )\n        else:\n            attn_out = self.attn(x, kv_cache=attention_cache)\n        if fast_state is None:\n            mem_out = self.titan_memory(attn_out)\n        else:\n            if fast_state.titan_params is None:\n                raise ValueError(\n                    \"fast_state.titan_params is required for TitanOnlyBlock fast-state forward\"\n                )\n            mem_out = call_with_deltas(self.titan_memory, fast_state.titan_params, attn_out)\n        combined = attn_out + mem_out\n        if teach_signal is not None:\n            if fast_state is None:\n                self._update_titan(attn_out, mem_out, teach_signal, surprise_value)\n            else:\n                self._update_titan_fast(fast_state, attn_out, mem_out, teach_signal, surprise_value)\n        if fast_state is None:\n            self.level_manager.tick()\n        else:\n            fast_state.level_manager.tick()\n        out = self.norm(combined)\n        if return_attention_cache:\n            assert next_attn_cache is not None\n            return out, next_attn_cache\n        return out\n\n    def set_surprise_threshold(self, threshold: float | None) -> None:\n        self.surprise_threshold = threshold\n\n    def set_surprise_metric(self, metric: str) -> None:\n        self.surprise_metric = str(metric).strip().lower()\n\n    def set_enabled(self, enabled: bool) -> None:\n        self.enabled = enabled\n\n    def _passes_surprise(self, surprise_value: float | None) -> bool:\n        if self.surprise_threshold is None:\n            return True\n        if surprise_value is None:\n            return False\n        return surprise_value >= self.surprise_threshold\n\n    def _update_titan(\n        self,\n        attn_out: torch.Tensor,\n        mem_out: torch.Tensor,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        level_name = self.config.titan_level.name\n        if not self.enabled:\n            return\n        if not self.level_manager.should_update(level_name):\n            return\n        if not self._passes_surprise(surprise_value):\n            return\n        # Use full sequence for granular updates (Critique P1)\n        # Note: We intentionally do not pool over dim=1 (sequence) here.\n        modifier = self.self_modifier(\n            key=attn_out.detach(),\n            value=mem_out.detach(),\n            error_signal=teach_signal.detach(),\n        )\n        context_vec = attn_out.detach().mean(dim=(0, 1))\n        with torch.enable_grad():\n            query = attn_out.detach()\n            target = (teach_signal.detach() + modifier).detach()\n            base_params = {name: param for name, param in self.titan_memory.named_parameters()}\n            params_req = require_grad_params(base_params)\n            prediction = call_with_params(self.titan_memory, params_req, query)\n            loss_terms = nn.functional.mse_loss(prediction, target, reduction=\"none\")\n            active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0\n            mask = active.float()\n            if self.surprise_threshold is not None and self.surprise_metric == \"l2\":\n                norms = teach_signal.norm(dim=-1, keepdim=True)\n                mask = mask * (norms >= self.surprise_threshold).float()\n            loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)\n\n        grads = torch.autograd.grad(\n            loss,\n            tuple(params_req.values()),\n            retain_graph=False,\n            allow_unused=True,\n        )\n        grads_dict = grads_to_dict(params_req, grads)\n        self.level_manager.apply_module_grads(\n            level_name,\n            self.titan_memory,\n            grads_dict,\n            context=context_vec,\n            force=True,\n        )\n        # Pop metrics to avoid stale entries even if we do not log them yet.\n        self.level_manager.pop_last_metrics(level_name)\n\n    def _update_titan_fast(\n        self,\n        fast_state: BlockFastState,\n        attn_out: torch.Tensor,\n        mem_out: torch.Tensor,\n        teach_signal: torch.Tensor,\n        surprise_value: float | None,\n    ) -> None:\n        level_name = self.config.titan_level.name\n        if not self.enabled:\n            return\n        if not fast_state.level_manager.should_update(level_name):\n            return\n        if not self._passes_surprise(surprise_value):\n            return\n        if fast_state.titan_params is None:\n            return\n        modifier = self.self_modifier(\n            key=attn_out.detach(),\n            value=mem_out.detach(),\n            error_signal=teach_signal.detach(),\n        )\n        context_vec = attn_out.detach().mean(dim=(0, 1))\n        base_params = fast_state.titan_params\n        forward_params = params_with_deltas(self.titan_memory, base_params)\n        params_req = require_grad_params(forward_params)\n        with torch.enable_grad():\n            query = attn_out.detach()\n            target = (teach_signal.detach() + modifier).detach()\n            prediction = call_with_params(self.titan_memory, params_req, query)\n            loss_terms = nn.functional.mse_loss(prediction, target, reduction=\"none\")\n            active = teach_signal.detach().abs().sum(dim=-1, keepdim=True) > 0\n            mask = active.float()\n            if self.surprise_threshold is not None and self.surprise_metric == \"l2\":\n                norms = teach_signal.norm(dim=-1, keepdim=True)\n                mask = mask * (norms >= self.surprise_threshold).float()\n            loss = (loss_terms * mask).sum() / mask.sum().clamp(min=1.0)\n        grads = torch.autograd.grad(\n            loss,\n            tuple(params_req.values()),\n            retain_graph=False,\n            allow_unused=True,\n        )\n        grads_dict = grads_to_dict(params_req, grads)\n        updated, _magnitude = fast_state.level_manager.apply_grads(\n            level_name,\n            base_params,\n            grads_dict,\n            context=context_vec,\n            force=False,\n        )\n        fast_state.titan_params = updated\n        fast_state.level_manager.pop_last_metrics(level_name)\n\n\nclass TitanOnlyModel(nn.Module):\n    def __init__(self, config: TitanOnlyModelConfig):\n        super().__init__()\n        self.config = config\n        self.embed = nn.Embedding(config.vocab_size, config.dim)\n        self.blocks = nn.ModuleList([TitanOnlyBlock(config) for _ in range(config.num_layers)])\n        self.norm = nn.LayerNorm(config.dim)\n        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)\n        self.lm_head.weight = self.embed.weight\n        self._runtime_teach_scale = config.teach_scale\n        self._runtime_teach_clip = config.teach_clip\n        self._surprise_threshold: float | None = None\n        self._surprise_metric = \"l2\"\n        self._updates_enabled: bool = True\n        self.set_surprise_metric(config.surprise_metric)\n        self.set_surprise_threshold(config.surprise_threshold)\n        if config.freeze_backbone:\n            self.freeze_backbone()\n\n    def set_teach_runtime(self, *, scale: float | None = None, clip: float | None = None) -> None:\n        if scale is not None:\n            self._runtime_teach_scale = scale\n        if clip is not None:\n            self._runtime_teach_clip = clip\n\n    def set_surprise_threshold(self, threshold: float | None) -> None:\n        self._surprise_threshold = threshold\n        for block in self.blocks:\n            cast(TitanOnlyBlock, block).set_surprise_threshold(threshold)\n\n    def get_surprise_threshold(self) -> float | None:\n        return self._surprise_threshold\n\n    def set_surprise_metric(self, metric: str) -> None:\n        normalized = str(metric).strip().lower()\n        allowed = {\"l2\", \"loss\", \"logit_entropy\"}\n        if normalized not in allowed:\n            raise ValueError(\n                f\"Unsupported surprise_metric={metric!r}; expected one of {sorted(allowed)}\"\n            )\n        self._surprise_metric = normalized\n        for block in self.blocks:\n            cast(TitanOnlyBlock, block).set_surprise_metric(normalized)\n\n    def get_surprise_metric(self) -> str:\n        return self._surprise_metric\n\n    def set_allowed_update_levels(self, levels: set[str] | None) -> None:\n        enabled = True\n        if levels is not None and \"titan\" not in levels and len(levels) > 0:\n            enabled = False\n        self._updates_enabled = enabled\n        for block in self.blocks:\n            cast(TitanOnlyBlock, block).set_enabled(enabled)\n\n    def get_allowed_update_levels(self) -> set[str] | None:\n        if self._updates_enabled:\n            return {\"titan\"}\n        return set()\n\n    def forward(\n        self,\n        tokens: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        fast_state: ModelFastState | None = None,\n        surprise_value: float | None = None,\n        attention_cache: ModelAttentionCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, ModelAttentionCache]:\n        require_external = self._surprise_metric in {\"loss\", \"logit_entropy\"}\n        if require_external and self._surprise_threshold is not None:\n            if teach_signal is not None and surprise_value is None:\n                raise ValueError(\n                    f\"surprise_metric={self._surprise_metric} requires passing surprise_value \"\n                    \"when model.surprise_threshold is set.\"\n                )\n        x = self.embed(tokens)\n        if fast_state is not None and len(fast_state.blocks) != len(self.blocks):\n            raise ValueError(\"fast_state.blocks length does not match model.blocks\")\n        if attention_cache is not None and len(attention_cache.blocks) != len(self.blocks):\n            raise ValueError(\"attention_cache.blocks length does not match model.blocks\")\n        base_surprise = surprise_value\n        next_caches: list[AttentionKVCache | None] = []\n        for idx, block in enumerate(self.blocks):\n            scaled_signal = None\n            if teach_signal is not None:\n                scaled_signal = teach_signal * self._runtime_teach_scale\n                if self._runtime_teach_clip > 0:\n                    with torch.no_grad():\n                        norm = scaled_signal.norm(dim=-1, keepdim=True)\n                        scale = torch.clamp(norm / self._runtime_teach_clip, min=1.0)\n                    scaled_signal = scaled_signal / scale\n            block_surprise = base_surprise\n            if (\n                scaled_signal is not None\n                and base_surprise is None\n                and self._surprise_metric == \"l2\"\n            ):\n                block_surprise = float(scaled_signal.norm(dim=-1).mean().item())\n            block_state = None if fast_state is None else fast_state.blocks[idx]\n            block_cache = None if attention_cache is None else attention_cache.blocks[idx]\n            if return_attention_cache:\n                x, next_cache = block(  # type: ignore[arg-type]\n                    x,\n                    teach_signal=scaled_signal,\n                    surprise_value=block_surprise,\n                    fast_state=block_state,\n                    attention_cache=block_cache,\n                    return_attention_cache=True,\n                    differentiable_updates=differentiable_updates,\n                )\n                next_caches.append(next_cache)\n            else:\n                x = block(  # type: ignore[arg-type]\n                    x,\n                    teach_signal=scaled_signal,\n                    surprise_value=block_surprise,\n                    fast_state=block_state,\n                    attention_cache=block_cache,\n                    differentiable_updates=differentiable_updates,\n                )\n        x = self.norm(x)\n        logits = self.lm_head(x)\n        if return_attention_cache:\n            return logits, ModelAttentionCache(blocks=next_caches)\n        return logits\n\n    def freeze_backbone(self) -> None:\n        \"\"\"\n        Freeze shared transformer components; leave TITAN memory/trainable paths active.\n        \"\"\"\n        for p in self.embed.parameters():\n            p.requires_grad = False\n        for p in self.norm.parameters():\n            p.requires_grad = False\n        for p in self.lm_head.parameters():\n            p.requires_grad = False\n        for block in self.blocks:\n            typed_block = cast(TitanOnlyBlock, block)\n            for p in typed_block.attn.parameters():\n                p.requires_grad = False\n\n    def init_fast_state(self) -> ModelFastState:\n        states = []\n        for block in self.blocks:\n            typed_block = cast(TitanOnlyBlock, block)\n            specs = [typed_block.config.titan_level]\n            state = build_block_fast_state(\n                titan_module=typed_block.titan_memory,\n                cms_blocks={},\n                specs=specs,\n                optimizer_configs=typed_block.config.optimizers or {},\n                default_lr=typed_block.config.self_mod_lr,\n            )\n            states.append(state)\n        return ModelFastState(blocks=states)\n\n    def init_attention_cache(self) -> ModelAttentionCache:\n        return ModelAttentionCache(blocks=[None for _ in self.blocks])\n"
  },
  {
    "path": "src/nested_learning/titan/self_modifying.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Callable\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.func import grad, vmap\n\n\n@dataclass(frozen=True)\nclass SelfModifyingTitansConfig:\n    dim: int\n    eta_scale: float = 1e-3\n    chunk_size_other: int = 1\n    chunk_size_memory: int | None = None\n    objective: str = \"l2\"\n    stopgrad_vhat: bool = True\n    use_rank1_precond: bool = True\n    use_alpha: bool = True\n    momentum: float = 0.0\n    qk_l2_norm: bool = True\n    adaptive_q: bool = False\n    use_skip: bool = True\n    local_conv_window: int | None = 4\n    eps: float = 1e-6\n\n    def __post_init__(self) -> None:\n        if self.dim <= 0:\n            raise ValueError(\"dim must be positive\")\n        if self.eta_scale <= 0:\n            raise ValueError(\"eta_scale must be positive\")\n        if self.chunk_size_other <= 0:\n            raise ValueError(\"chunk_size_other must be positive\")\n        if self.chunk_size_memory is not None and self.chunk_size_memory <= 0:\n            raise ValueError(\"chunk_size_memory must be positive\")\n        if self.objective not in {\"l2\", \"dot\"}:\n            raise ValueError(\"objective must be one of {'l2', 'dot'}\")\n        if not (0.0 <= self.momentum < 1.0):\n            raise ValueError(\"momentum must be in [0, 1)\")\n        if self.local_conv_window is not None and int(self.local_conv_window) <= 0:\n            raise ValueError(\"local_conv_window must be positive\")\n        if self.chunk_size_memory is None:\n            object.__setattr__(self, \"chunk_size_memory\", int(self.chunk_size_other))\n\n\n@dataclass\nclass ResidualMLPMemoryState:\n    w1: torch.Tensor\n    w2: torch.Tensor\n    w_skip: torch.Tensor | None = None\n    m_w1: torch.Tensor | None = None\n    m_w2: torch.Tensor | None = None\n    m_w_skip: torch.Tensor | None = None\n\n    def clone(self) -> \"ResidualMLPMemoryState\":\n        return ResidualMLPMemoryState(\n            w1=self.w1.detach().clone(),\n            w2=self.w2.detach().clone(),\n            w_skip=None if self.w_skip is None else self.w_skip.detach().clone(),\n            m_w1=None if self.m_w1 is None else self.m_w1.detach().clone(),\n            m_w2=None if self.m_w2 is None else self.m_w2.detach().clone(),\n            m_w_skip=None if self.m_w_skip is None else self.m_w_skip.detach().clone(),\n        )\n\n\n@dataclass\nclass SelfModifyingTitansState:\n    \"\"\"\n    Fast state for self-modifying Titans.\n\n    Each memory M_□ is a residual MLP (Eq. 91) whose initial parameters are meta-learned\n    (stored in the module) and cloned into this fast state per context.\n    \"\"\"\n\n    k: ResidualMLPMemoryState\n    v: ResidualMLPMemoryState\n    q: ResidualMLPMemoryState\n    eta: ResidualMLPMemoryState\n    alpha: ResidualMLPMemoryState\n    memory: ResidualMLPMemoryState\n\n    def clone(self) -> \"SelfModifyingTitansState\":\n        return SelfModifyingTitansState(\n            k=self.k.clone(),\n            v=self.v.clone(),\n            q=self.q.clone(),\n            eta=self.eta.clone(),\n            alpha=self.alpha.clone(),\n            memory=self.memory.clone(),\n        )\n\n\nclass ResidualMLPMemory(nn.Module):\n    def __init__(\n        self,\n        *,\n        in_dim: int,\n        out_dim: int,\n        hidden_dim: int,\n        activation: Callable[[torch.Tensor], torch.Tensor],\n        use_skip: bool = True,\n    ) -> None:\n        super().__init__()\n        if in_dim <= 0 or out_dim <= 0 or hidden_dim <= 0:\n            raise ValueError(\"in_dim/out_dim/hidden_dim must be positive\")\n        self.in_dim = int(in_dim)\n        self.out_dim = int(out_dim)\n        self.hidden_dim = int(hidden_dim)\n        self.activation = activation\n        self.use_skip = bool(use_skip)\n        self.w2 = nn.Linear(self.in_dim, self.hidden_dim, bias=False)\n        self.w1 = nn.Linear(self.hidden_dim, self.out_dim, bias=False)\n        self.w_skip: nn.Linear | None = None\n        if self.use_skip and self.in_dim != self.out_dim:\n            self.w_skip = nn.Linear(self.in_dim, self.out_dim, bias=False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        hidden = self.activation(self.w2(x))\n        out = self.w1(hidden)\n        if self.w_skip is not None:\n            return self.w_skip(x) + out\n        if out.shape[-1] == x.shape[-1]:\n            return x + out\n        return out\n\n\nclass SelfModifyingTitans(nn.Module):\n    \"\"\"\n    Self-modifying Titans (Nested Learning paper, Eqs. 83–93), correctness-first.\n\n    - Multiple memories: M_k, M_v, M_q, M_eta, M_alpha, M_memory.\n    - Each memory is a 2-layer residual MLP (Eq. 91).\n    - Updates are performed on fast state using chunked DGD-like rule (Eq. 90/93).\n\n    Note: This implementation prioritizes semantic fidelity and testability over speed.\n    \"\"\"\n\n    def __init__(self, config: SelfModifyingTitansConfig):\n        super().__init__()\n        self.config = config\n        dim = config.dim\n        hidden = dim\n        act = F.gelu\n        self.local_conv: nn.Conv1d | None = None\n        if config.local_conv_window is not None:\n            window = int(config.local_conv_window)\n            self.local_conv = nn.Conv1d(\n                dim,\n                dim,\n                kernel_size=window,\n                groups=dim,\n                padding=0,\n                bias=False,\n            )\n        self.w_q = nn.Linear(dim, dim, bias=False)\n        self.m_k = ResidualMLPMemory(\n            in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip\n        )\n        self.m_v = ResidualMLPMemory(\n            in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip\n        )\n        self.m_q = ResidualMLPMemory(\n            in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip\n        )\n        self.m_eta = ResidualMLPMemory(\n            in_dim=dim, out_dim=1, hidden_dim=hidden, activation=act, use_skip=config.use_skip\n        )\n        self.m_alpha = ResidualMLPMemory(\n            in_dim=dim, out_dim=1, hidden_dim=hidden, activation=act, use_skip=config.use_skip\n        )\n        self.m_memory = ResidualMLPMemory(\n            in_dim=dim, out_dim=dim, hidden_dim=hidden, activation=act, use_skip=config.use_skip\n        )\n\n    def init_fast_state(self) -> SelfModifyingTitansState:\n        return SelfModifyingTitansState(\n            k=self._init_memory_state(self.m_k),\n            v=self._init_memory_state(self.m_v),\n            q=self._init_memory_state(self.m_q),\n            eta=self._init_memory_state(self.m_eta),\n            alpha=self._init_memory_state(self.m_alpha),\n            memory=self._init_memory_state(self.m_memory),\n        )\n\n    def apply_updates_inplace(\n        self,\n        x: torch.Tensor,\n        *,\n        chunk_size_other: int | None = None,\n        chunk_size_memory: int | None = None,\n    ) -> None:\n        \"\"\"\n        Apply the self-modifying update rule to the *module parameters* in-place.\n\n        This is intended to be called in an explicit \"update pass\" under `torch.no_grad()`\n        (e.g., after an outer backward), so we avoid mixing differentiable reads with\n        in-place writes during the same autograd graph.\n        \"\"\"\n        state = self.init_fast_state()\n        _out, updated = self.forward_with_updates(\n            x,\n            state,\n            chunk_size_other=chunk_size_other,\n            chunk_size_memory=chunk_size_memory,\n        )\n        self._load_state_mean_(updated)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        x = self._apply_local_conv(x)\n        q = self.m_q(x) if self.config.adaptive_q else self.w_q(x)\n        if self.config.qk_l2_norm:\n            q = F.normalize(q, dim=-1, eps=self.config.eps)\n        return self.m_memory(q)\n\n    def forward_with_state(\n        self,\n        x: torch.Tensor,\n        state: SelfModifyingTitansState,\n    ) -> torch.Tensor:\n        if x.ndim != 3:\n            raise ValueError(\"Expected x to have shape (B, T, D)\")\n        batch, _seq_len, dim = x.shape\n        if dim != self.config.dim:\n            raise ValueError(f\"Expected dim={self.config.dim}, got {dim}\")\n        state = self._ensure_batched_state(state, batch)\n        x = self._apply_local_conv(x)\n        q = (\n            self._memory_forward(x, state.q, meta=self.m_q)\n            if self.config.adaptive_q\n            else self.w_q(x)\n        )\n        if self.config.qk_l2_norm:\n            q = F.normalize(q, dim=-1, eps=self.config.eps)\n        return self._memory_forward(q, state.memory, meta=self.m_memory)\n\n    def forward_with_updates(\n        self,\n        x: torch.Tensor,\n        state: SelfModifyingTitansState,\n        *,\n        chunk_size_other: int | None = None,\n        chunk_size_memory: int | None = None,\n    ) -> tuple[torch.Tensor, SelfModifyingTitansState]:\n        if x.ndim != 3:\n            raise ValueError(\"Expected x to have shape (B, T, D)\")\n        batch, seq_len, dim = x.shape\n        if dim != self.config.dim:\n            raise ValueError(f\"Expected dim={self.config.dim}, got {dim}\")\n        state = self._ensure_batched_state(state, batch)\n        x = self._apply_local_conv(x)\n        other_chunk = int(\n            self.config.chunk_size_other if chunk_size_other is None else chunk_size_other\n        )\n        memory_chunk_cfg = self.config.chunk_size_memory\n        if memory_chunk_cfg is None:\n            memory_chunk_cfg = self.config.chunk_size_other\n        memory_chunk = int(memory_chunk_cfg if chunk_size_memory is None else chunk_size_memory)\n        if other_chunk <= 0 or memory_chunk <= 0:\n            raise ValueError(\"chunk sizes must be positive\")\n\n        outputs: list[torch.Tensor] = []\n        other_k: list[torch.Tensor] = []\n        other_v: list[torch.Tensor] = []\n        other_eta: list[torch.Tensor] = []\n        other_alpha: list[torch.Tensor] = []\n        memory_k: list[torch.Tensor] = []\n        memory_v: list[torch.Tensor] = []\n        memory_eta: list[torch.Tensor] = []\n        memory_alpha: list[torch.Tensor] = []\n\n        def _next_boundary(idx: int, *, chunk_size: int) -> int:\n            if chunk_size <= 0:\n                raise ValueError(\"chunk_size must be positive\")\n            return min(((idx // chunk_size) + 1) * chunk_size, seq_len)\n\n        with torch.no_grad():\n            idx = 0\n            while idx < seq_len:\n                next_other = _next_boundary(idx, chunk_size=other_chunk)\n                next_memory = _next_boundary(idx, chunk_size=memory_chunk)\n                end = min(next_other, next_memory, seq_len)\n                x_chunk = x[:, idx:end, :]\n\n                k_chunk = self._memory_forward(x_chunk, state.k)\n                v_chunk = self._memory_forward(x_chunk, state.v)\n                q_chunk = (\n                    self._memory_forward(x_chunk, state.q)\n                    if self.config.adaptive_q\n                    else self.w_q(x_chunk)\n                )\n                if self.config.qk_l2_norm:\n                    k_chunk = F.normalize(k_chunk, dim=-1, eps=self.config.eps)\n                    q_chunk = F.normalize(q_chunk, dim=-1, eps=self.config.eps)\n                eta_chunk = self._memory_forward(x_chunk, state.eta).squeeze(-1)\n                eta_chunk = F.softplus(eta_chunk) * self.config.eta_scale\n                if self.config.use_alpha:\n                    alpha_chunk = self._memory_forward(x_chunk, state.alpha).squeeze(-1)\n                    alpha_chunk = torch.sigmoid(alpha_chunk)\n                else:\n                    alpha_chunk = torch.ones_like(eta_chunk)\n                o_chunk = self._memory_forward(q_chunk, state.memory)\n                outputs.append(o_chunk)\n\n                other_k.append(k_chunk)\n                other_v.append(v_chunk)\n                other_eta.append(eta_chunk)\n                other_alpha.append(alpha_chunk)\n                memory_k.append(k_chunk)\n                memory_v.append(v_chunk)\n                memory_eta.append(eta_chunk)\n                memory_alpha.append(alpha_chunk)\n\n                idx = end\n\n                if idx == next_other and other_k:\n                    other_memories: tuple[str, ...] = (\"k\", \"v\", \"eta\")\n                    if self.config.adaptive_q:\n                        other_memories = (*other_memories, \"q\")\n                    if self.config.use_alpha:\n                        other_memories = (*other_memories, \"alpha\")\n                    self._apply_chunk_update_seq(\n                        state,\n                        k_seq=torch.cat(other_k, dim=1),\n                        v_seq=torch.cat(other_v, dim=1),\n                        eta_seq=torch.cat(other_eta, dim=1),\n                        alpha_seq=torch.cat(other_alpha, dim=1),\n                        memories=other_memories,\n                    )\n                    other_k.clear()\n                    other_v.clear()\n                    other_eta.clear()\n                    other_alpha.clear()\n\n                if idx == next_memory and memory_k:\n                    self._apply_chunk_update_seq(\n                        state,\n                        k_seq=torch.cat(memory_k, dim=1),\n                        v_seq=torch.cat(memory_v, dim=1),\n                        eta_seq=torch.cat(memory_eta, dim=1),\n                        alpha_seq=torch.cat(memory_alpha, dim=1),\n                        memories=(\"memory\",),\n                    )\n                    memory_k.clear()\n                    memory_v.clear()\n                    memory_eta.clear()\n                    memory_alpha.clear()\n\n            if other_k:\n                other_memories = (\"k\", \"v\", \"eta\")\n                if self.config.adaptive_q:\n                    other_memories = (*other_memories, \"q\")\n                if self.config.use_alpha:\n                    other_memories = (*other_memories, \"alpha\")\n                self._apply_chunk_update_seq(\n                    state,\n                    k_seq=torch.cat(other_k, dim=1),\n                    v_seq=torch.cat(other_v, dim=1),\n                    eta_seq=torch.cat(other_eta, dim=1),\n                    alpha_seq=torch.cat(other_alpha, dim=1),\n                    memories=other_memories,\n                )\n            if memory_k:\n                self._apply_chunk_update_seq(\n                    state,\n                    k_seq=torch.cat(memory_k, dim=1),\n                    v_seq=torch.cat(memory_v, dim=1),\n                    eta_seq=torch.cat(memory_eta, dim=1),\n                    alpha_seq=torch.cat(memory_alpha, dim=1),\n                    memories=(\"memory\",),\n                )\n\n        return torch.cat(outputs, dim=1), state\n\n    def _apply_local_conv(self, x: torch.Tensor) -> torch.Tensor:\n        if self.local_conv is None:\n            return x\n        if x.ndim != 3:\n            raise ValueError(\"Expected x to have shape (B, T, D)\")\n        kernel = int(self.local_conv.kernel_size[0])\n        # Causal depthwise conv: only attends to past tokens.\n        x_t = x.transpose(1, 2)\n        x_t = F.pad(x_t, (kernel - 1, 0))\n        x_t = self.local_conv(x_t)\n        return x_t.transpose(1, 2)\n\n    def _load_state_mean_(self, state: SelfModifyingTitansState) -> None:\n        def _mean_weight(weight: torch.Tensor) -> torch.Tensor:\n            return weight.mean(dim=0) if weight.ndim == 3 else weight\n\n        def _copy(module: ResidualMLPMemory, mem: ResidualMLPMemoryState) -> None:\n            module.w1.weight.copy_(_mean_weight(mem.w1))\n            module.w2.weight.copy_(_mean_weight(mem.w2))\n            if module.w_skip is None:\n                return\n            if mem.w_skip is None:\n                raise RuntimeError(\"Expected w_skip state for projected residual memory\")\n            module.w_skip.weight.copy_(_mean_weight(mem.w_skip))\n\n        with torch.no_grad():\n            _copy(self.m_k, state.k)\n            _copy(self.m_v, state.v)\n            _copy(self.m_eta, state.eta)\n            if self.config.use_alpha:\n                _copy(self.m_alpha, state.alpha)\n            _copy(self.m_memory, state.memory)\n            if self.config.adaptive_q:\n                _copy(self.m_q, state.q)\n\n    def _apply_chunk_update(\n        self,\n        state: SelfModifyingTitansState,\n        buffer: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],\n        *,\n        memories: tuple[str, ...],\n    ) -> None:\n        if not buffer:\n            return\n        k_seq = torch.stack([item[0] for item in buffer], dim=1)\n        v_seq = torch.stack([item[1] for item in buffer], dim=1)\n        eta_seq = torch.stack([item[2] for item in buffer], dim=1)\n        alpha_seq = torch.stack([item[3] for item in buffer], dim=1)\n        self._apply_chunk_update_seq(\n            state,\n            k_seq=k_seq,\n            v_seq=v_seq,\n            eta_seq=eta_seq,\n            alpha_seq=alpha_seq,\n            memories=memories,\n        )\n\n    def _apply_chunk_update_seq(\n        self,\n        state: SelfModifyingTitansState,\n        *,\n        k_seq: torch.Tensor,\n        v_seq: torch.Tensor,\n        eta_seq: torch.Tensor,\n        alpha_seq: torch.Tensor,\n        memories: tuple[str, ...],\n    ) -> None:\n        steps = k_seq.size(1)\n        dim = self.config.dim\n        eye = (\n            torch.eye(dim, device=k_seq.device, dtype=k_seq.dtype)\n            .unsqueeze(0)\n            .expand(k_seq.size(0), -1, -1)\n        )\n\n        boundary: dict[str, ResidualMLPMemoryState] = {\n            name: getattr(state, name).clone() for name in memories\n        }\n        grads = {name: self._memory_grads_chunk(boundary[name], k_seq, v_seq) for name in memories}\n\n        for t in range(steps):\n            k_t = k_seq[:, t, :]\n            eta_t = eta_seq[:, t]\n            alpha_t = alpha_seq[:, t]\n            kk = torch.einsum(\"bi,bj->bij\", k_t, k_t)\n            precond = alpha_t[:, None, None] * eye - eta_t[:, None, None] * kk\n            for name in memories:\n                fast = getattr(state, name)\n                g1, g2, gskip = grads[name]\n                self._apply_param_update(\n                    fast,\n                    (\n                        g1[:, t, ...],\n                        g2[:, t, ...],\n                        None if gskip is None else gskip[:, t, ...],\n                    ),\n                    eta_t,\n                    alpha_t,\n                    precond,\n                )\n\n    def _memory_grads(\n        self,\n        frozen: ResidualMLPMemoryState,\n        k_t: torch.Tensor,\n        v_t: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:\n        with torch.enable_grad():\n            w1 = frozen.w1.detach().requires_grad_(True)\n            w2 = frozen.w2.detach().requires_grad_(True)\n            w_skip = None\n            if frozen.w_skip is not None:\n                w_skip = frozen.w_skip.detach().requires_grad_(True)\n\n            pred = self._memory_forward(k_t, ResidualMLPMemoryState(w1=w1, w2=w2, w_skip=w_skip))\n            vhat = self._memory_forward(v_t, ResidualMLPMemoryState(w1=w1, w2=w2, w_skip=w_skip))\n            if self.config.stopgrad_vhat:\n                vhat = vhat.detach()\n\n            if self.config.objective == \"dot\":\n                loss = -(pred * vhat).sum(dim=-1)\n            else:\n                loss = (pred - vhat).pow(2).sum(dim=-1)\n            loss_scalar = loss.sum()\n\n            grads = torch.autograd.grad(\n                loss_scalar,\n                (w1, w2, w_skip) if w_skip is not None else (w1, w2),\n                retain_graph=False,\n                create_graph=False,\n                allow_unused=False,\n            )\n        if w_skip is None:\n            g1, g2 = grads\n            return g1, g2, None\n        g1, g2, gskip = grads\n        return g1, g2, gskip\n\n    def _memory_grads_chunk(\n        self,\n        frozen: ResidualMLPMemoryState,\n        k_seq: torch.Tensor,\n        v_seq: torch.Tensor,\n    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:\n        \"\"\"\n        Compute per-token gradients for an entire chunk in parallel (paper §8.2).\n\n        Returns gradients with leading shape (B, T, ...).\n        \"\"\"\n        w1 = frozen.w1.detach()\n        w2 = frozen.w2.detach()\n        w_skip = None if frozen.w_skip is None else frozen.w_skip.detach()\n\n        k_tokens = k_seq.transpose(0, 1)\n        v_tokens = v_seq.transpose(0, 1)\n\n        if w_skip is None:\n\n            def loss_fn_noskip(\n                w1_t: torch.Tensor,\n                w2_t: torch.Tensor,\n                k_t: torch.Tensor,\n                v_t: torch.Tensor,\n            ) -> torch.Tensor:\n                mem = ResidualMLPMemoryState(w1=w1_t, w2=w2_t)\n                pred = self._memory_forward(k_t, mem)\n                vhat = self._memory_forward(v_t, mem)\n                if self.config.stopgrad_vhat:\n                    vhat = vhat.detach()\n                if self.config.objective == \"dot\":\n                    loss = -(pred * vhat).sum(dim=-1)\n                else:\n                    loss = (pred - vhat).pow(2).sum(dim=-1)\n                return loss.sum()\n\n            grad_fn = grad(loss_fn_noskip, argnums=(0, 1))\n            g1_tokens, g2_tokens = vmap(grad_fn, in_dims=(None, None, 0, 0))(\n                w1,\n                w2,\n                k_tokens,\n                v_tokens,\n            )\n            return g1_tokens.transpose(0, 1), g2_tokens.transpose(0, 1), None\n\n        def loss_fn_skip(\n            w1_t: torch.Tensor,\n            w2_t: torch.Tensor,\n            w_skip_t: torch.Tensor,\n            k_t: torch.Tensor,\n            v_t: torch.Tensor,\n        ) -> torch.Tensor:\n            mem = ResidualMLPMemoryState(w1=w1_t, w2=w2_t, w_skip=w_skip_t)\n            pred = self._memory_forward(k_t, mem)\n            vhat = self._memory_forward(v_t, mem)\n            if self.config.stopgrad_vhat:\n                vhat = vhat.detach()\n            if self.config.objective == \"dot\":\n                loss = -(pred * vhat).sum(dim=-1)\n            else:\n                loss = (pred - vhat).pow(2).sum(dim=-1)\n            return loss.sum()\n\n        grad_fn = grad(loss_fn_skip, argnums=(0, 1, 2))\n        g1_tokens, g2_tokens, gskip_tokens = vmap(\n            grad_fn,\n            in_dims=(None, None, None, 0, 0),\n        )(w1, w2, w_skip, k_tokens, v_tokens)\n        return (\n            g1_tokens.transpose(0, 1),\n            g2_tokens.transpose(0, 1),\n            gskip_tokens.transpose(0, 1),\n        )\n\n    def _apply_param_update(\n        self,\n        fast: ResidualMLPMemoryState,\n        grads: tuple[torch.Tensor, torch.Tensor, torch.Tensor | None],\n        eta_t: torch.Tensor,\n        alpha_t: torch.Tensor,\n        precond: torch.Tensor,\n    ) -> None:\n        g1, g2, gskip = grads\n        g1 = self._apply_momentum(fast, \"m_w1\", g1)\n        g2 = self._apply_momentum(fast, \"m_w2\", g2)\n        if self.config.use_rank1_precond:\n            fast.w2 = torch.matmul(fast.w2, precond) - eta_t[:, None, None] * g2\n        else:\n            fast.w2 = alpha_t[:, None, None] * fast.w2 - eta_t[:, None, None] * g2\n        fast.w1 = alpha_t[:, None, None] * fast.w1 - eta_t[:, None, None] * g1\n\n        if fast.w_skip is None:\n            return\n        if gskip is None:\n            raise RuntimeError(\"Expected w_skip grad to be present\")\n        gskip = self._apply_momentum(fast, \"m_w_skip\", gskip)\n        if self.config.use_rank1_precond:\n            fast.w_skip = torch.matmul(fast.w_skip, precond) - eta_t[:, None, None] * gskip\n        else:\n            fast.w_skip = alpha_t[:, None, None] * fast.w_skip - eta_t[:, None, None] * gskip\n\n    def _apply_momentum(\n        self,\n        fast: ResidualMLPMemoryState,\n        attr_name: str,\n        grad: torch.Tensor,\n    ) -> torch.Tensor:\n        beta = float(self.config.momentum)\n        if beta <= 0.0:\n            return grad\n        buf = getattr(fast, attr_name)\n        if buf is None:\n            buf = torch.zeros_like(grad)\n        buf = beta * buf + grad\n        setattr(fast, attr_name, buf)\n        return buf\n\n    def _init_memory_state(self, module: ResidualMLPMemory) -> ResidualMLPMemoryState:\n        skip = None if module.w_skip is None else module.w_skip.weight.detach().clone()\n        return ResidualMLPMemoryState(\n            w1=module.w1.weight.detach().clone(),\n            w2=module.w2.weight.detach().clone(),\n            w_skip=skip,\n        )\n\n    def _ensure_batched_state(\n        self, state: SelfModifyingTitansState, batch: int\n    ) -> SelfModifyingTitansState:\n        if state.k.w1.ndim == 2:\n            return SelfModifyingTitansState(\n                k=self._expand_memory_state(state.k, batch),\n                v=self._expand_memory_state(state.v, batch),\n                q=self._expand_memory_state(state.q, batch),\n                eta=self._expand_memory_state(state.eta, batch),\n                alpha=self._expand_memory_state(state.alpha, batch),\n                memory=self._expand_memory_state(state.memory, batch),\n            )\n        if state.k.w1.ndim != 3:\n            raise ValueError(\"SelfModifyingTitansState weights must be 2D or 3D tensors\")\n        if state.k.w1.size(0) != batch:\n            raise ValueError(\n                f\"State batch mismatch: expected batch={batch}, got {state.k.w1.size(0)}\"\n            )\n        return state\n\n    def _expand_memory_state(\n        self, mem: ResidualMLPMemoryState, batch: int\n    ) -> ResidualMLPMemoryState:\n        def _expand(t: torch.Tensor) -> torch.Tensor:\n            return t.detach().clone().unsqueeze(0).repeat(batch, 1, 1)\n\n        def _expand_opt(t: torch.Tensor | None) -> torch.Tensor | None:\n            return None if t is None else _expand(t)\n\n        return ResidualMLPMemoryState(\n            w1=_expand(mem.w1),\n            w2=_expand(mem.w2),\n            w_skip=_expand_opt(mem.w_skip),\n            m_w1=_expand_opt(mem.m_w1),\n            m_w2=_expand_opt(mem.m_w2),\n            m_w_skip=_expand_opt(mem.m_w_skip),\n        )\n\n    def _memory_forward(\n        self,\n        x: torch.Tensor,\n        mem: ResidualMLPMemoryState,\n        *,\n        meta: ResidualMLPMemory | None = None,\n    ) -> torch.Tensor:\n        if meta is None:\n            w2 = mem.w2\n            w1 = mem.w1\n            w_skip = mem.w_skip\n        else:\n            w2 = self._straight_through_meta(mem.w2, meta.w2.weight)\n            w1 = self._straight_through_meta(mem.w1, meta.w1.weight)\n            w_skip = None\n            if mem.w_skip is not None:\n                if meta.w_skip is None:\n                    raise RuntimeError(\"Expected meta w_skip for projected residual memory\")\n                w_skip = self._straight_through_meta(mem.w_skip, meta.w_skip.weight)\n        if x.ndim == 2:\n            x_seq = x.unsqueeze(1)\n            squeeze = True\n        else:\n            x_seq = x\n            squeeze = False\n        w2_t = w2.transpose(-1, -2)\n        hidden = torch.matmul(x_seq, w2_t)\n        hidden = F.gelu(hidden)\n        w1_t = w1.transpose(-1, -2)\n        out = torch.matmul(hidden, w1_t)\n        if w_skip is not None:\n            w_skip_t = w_skip.transpose(-1, -2)\n            out = out + torch.matmul(x_seq, w_skip_t)\n        elif out.size(-1) == x_seq.size(-1):\n            out = out + x_seq\n        if squeeze:\n            return out.squeeze(1)\n        return out\n\n    @staticmethod\n    def _straight_through_meta(fast: torch.Tensor, meta: torch.Tensor) -> torch.Tensor:\n        if meta.ndim > fast.ndim:\n            raise ValueError(\"meta tensor must have <= fast tensor rank\")\n        expanded = meta\n        while expanded.ndim < fast.ndim:\n            expanded = expanded.unsqueeze(0)\n        return fast + (expanded - expanded.detach())\n"
  },
  {
    "path": "src/nested_learning/tokenizer.py",
    "content": "from __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Sequence\n\nimport sentencepiece as spm\nimport torch\n\n\nclass SentencePieceTokenizer:\n    def __init__(self, model_path: str | Path):\n        self.processor = spm.SentencePieceProcessor(model_file=str(model_path))\n\n    @property\n    def vocab_size(self) -> int:\n        return self.processor.vocab_size()\n\n    def encode(self, text: str, add_bos: bool = False, add_eos: bool = True) -> torch.Tensor:\n        tokens: list[int] = []\n        if add_bos:\n            tokens.append(self.processor.bos_id())\n        tokens.extend(self.processor.encode(text))\n        if add_eos:\n            tokens.append(self.processor.eos_id())\n        return torch.tensor(tokens, dtype=torch.long)\n\n    def batch_encode(self, texts: Sequence[str]) -> list[torch.Tensor]:\n        return [self.encode(text) for text in texts]\n"
  },
  {
    "path": "src/nested_learning/tokenizer_coverage.py",
    "content": "from __future__ import annotations\n\nfrom collections import Counter\nfrom pathlib import Path\nfrom typing import Dict\n\nfrom .tokenizer import SentencePieceTokenizer\n\n\ndef compute_tokenizer_coverage_stats(\n    tokenizer_path: Path,\n    sample_file: Path,\n    max_lines: int = 10_000,\n) -> Dict[str, object]:\n    \"\"\"\n    Compute tokenizer coverage statistics on a representative text sample.\n\n    Returns a JSON-serialisable dictionary; shared by both the coverage CLI and\n    the regression guard so they cannot drift apart silently.\n    \"\"\"\n\n    tokenizer = SentencePieceTokenizer(tokenizer_path)\n    total_words = 0\n    total_tokens = 0\n    total_chars = 0\n    processed_lines = 0\n    word_token_lengths: list[int] = []\n    piece_lengths: Counter[int] = Counter()\n\n    with sample_file.open(\"r\", encoding=\"utf-8\") as handle:\n        for idx, line in enumerate(handle):\n            if idx >= max_lines:\n                break\n            stripped = line.strip()\n            if not stripped:\n                continue\n            processed_lines += 1\n            total_chars += len(stripped)\n            words = stripped.split()\n            if not words:\n                continue\n            total_words += len(words)\n            encoded = tokenizer.encode(stripped, add_bos=False, add_eos=False)\n            ids = encoded.tolist()\n            total_tokens += len(ids)\n            for word in words:\n                word_tokens = tokenizer.encode(word, add_bos=False, add_eos=False).tolist()\n                if not word_tokens:\n                    continue\n                word_token_lengths.append(len(word_tokens))\n            for token_id in ids:\n                piece = tokenizer.processor.id_to_piece(token_id)\n                piece_lengths[len(piece)] += 1\n\n    if total_words == 0 or not word_token_lengths:\n        raise ValueError(\"Sample produced no words; double-check the sample_file path.\")\n\n    avg_tokens_per_word = total_tokens / total_words if total_words else 0.0\n    pct_single_token = sum(1 for length in word_token_lengths if length == 1) / len(\n        word_token_lengths\n    )\n    pct_two_or_less = sum(1 for length in word_token_lengths if length <= 2) / len(\n        word_token_lengths\n    )\n\n    return {\n        \"tokenizer\": str(tokenizer_path),\n        \"sample_file\": str(sample_file),\n        \"lines_processed\": processed_lines,\n        \"total_words\": total_words,\n        \"total_tokens\": total_tokens,\n        \"avg_tokens_per_word\": avg_tokens_per_word,\n        \"pct_single_token_words\": pct_single_token,\n        \"pct_two_or_less_tokens_words\": pct_two_or_less,\n        \"avg_chars_per_word\": total_chars / total_words,\n        \"piece_length_histogram\": dict(piece_lengths.most_common(20)),\n    }\n"
  },
  {
    "path": "src/nested_learning/training.py",
    "content": "from __future__ import annotations\n\nimport base64\nimport json\nimport os\nimport pickle\nimport random\nfrom contextlib import nullcontext\nfrom dataclasses import dataclass\nfrom hashlib import sha256\nfrom pathlib import Path\nfrom typing import Dict, Iterator, Protocol, Tuple, cast\n\nimport numpy as np\nimport torch\nfrom omegaconf import DictConfig, OmegaConf\nfrom torch.utils.data import DataLoader, DistributedSampler, IterableDataset\n\nfrom .data import (\n    MixtureShardDataset,\n    ShardSourceConfig,\n    SyntheticTextConfig,\n    SyntheticTextDataset,\n    TokenShardDataset,\n    collate_batch,\n)\nfrom .levels import LevelSpec\nfrom .logging_utils import BaseLogger, NullLogger, init_logger\nfrom .model import HOPEModel, ModelConfig\nfrom .optim.m3 import M3\nfrom .titan.model import TitanOnlyModel, TitanOnlyModelConfig\n\n\n@dataclass\nclass DistributedContext:\n    rank: int\n    world_size: int\n    device: torch.device\n\n\ndef unwrap_config(cfg: DictConfig) -> DictConfig:\n    \"\"\"Hydra can wrap grouped configs (e.g., hope/pilot) under the group name.\"\"\"\n    if \"model\" in cfg:\n        return cfg\n    if \"hope\" in cfg:\n        return cast(DictConfig, cfg.hope)\n    if \"ablations\" in cfg:\n        return cast(DictConfig, cfg.ablations)\n    return cfg\n\n\ndef build_model_from_cfg(model_cfg: DictConfig) -> torch.nn.Module:\n    model_type = model_cfg.get(\"type\", \"hope\")\n    optimizer_cfg: Dict[str, dict] = {}\n    if \"optimizers\" in model_cfg:\n        optimizer_cfg = cast(\n            Dict[str, dict],\n            OmegaConf.to_container(model_cfg.optimizers, resolve=True),\n        )\n    teach_scale = model_cfg.get(\"teach_scale\", 1.0)\n    teach_clip = model_cfg.get(\"teach_clip\", 0.0)\n    teach_schedule: Dict[str, float] = {}\n    if \"teach_schedule\" in model_cfg:\n        teach_schedule = cast(\n            Dict[str, float],\n            OmegaConf.to_container(model_cfg.teach_schedule, resolve=True),\n        )\n    qk_l2_norm = bool(model_cfg.get(\"qk_l2_norm\", False))\n    local_conv_window_raw = model_cfg.get(\"local_conv_window\")\n    local_conv_window = None if local_conv_window_raw is None else int(local_conv_window_raw)\n    surprise_threshold_raw = model_cfg.get(\"surprise_threshold\")\n    surprise_threshold = (\n        None if surprise_threshold_raw is None else float(surprise_threshold_raw)\n    )\n    surprise_metric = str(model_cfg.get(\"surprise_metric\", \"l2\"))\n    cms_use_layernorm = bool(model_cfg.get(\"cms_use_layernorm\", True))\n    if model_type == \"titan\":\n        titan_spec = LevelSpec(**model_cfg.titan_level)\n        titan_cfg = TitanOnlyModelConfig(\n            vocab_size=model_cfg.vocab_size,\n            dim=model_cfg.dim,\n            num_layers=model_cfg.num_layers,\n            heads=model_cfg.heads,\n            titan_level=titan_spec,\n            optimizers=optimizer_cfg,\n            teach_scale=teach_scale,\n            teach_clip=teach_clip,\n            teach_schedule=teach_schedule,\n            qk_l2_norm=qk_l2_norm,\n            local_conv_window=local_conv_window,\n            surprise_threshold=surprise_threshold,\n            surprise_metric=surprise_metric,\n            freeze_backbone=model_cfg.get(\"freeze_backbone\", False),\n            self_mod_lr=float(model_cfg.get(\"self_mod_lr\", 1e-3)),\n            self_mod_hidden=int(model_cfg.get(\"self_mod_hidden\", 4)),\n        )\n        return TitanOnlyModel(titan_cfg)\n    titan_spec = LevelSpec(**model_cfg.titan_level)\n    cms_specs = [LevelSpec(**entry) for entry in model_cfg.cms_levels]\n    self_mod_chunk_size_memory_raw = model_cfg.get(\"self_mod_chunk_size_memory\")\n    self_mod_chunk_size_memory = (\n        None if self_mod_chunk_size_memory_raw is None else int(self_mod_chunk_size_memory_raw)\n    )\n    self_mod_local_conv_window_raw = model_cfg.get(\"self_mod_local_conv_window\", 4)\n    self_mod_local_conv_window = (\n        None if self_mod_local_conv_window_raw is None else int(self_mod_local_conv_window_raw)\n    )\n    hope_cfg = ModelConfig(\n        vocab_size=model_cfg.vocab_size,\n        dim=model_cfg.dim,\n        num_layers=model_cfg.num_layers,\n        heads=model_cfg.heads,\n        titan_level=titan_spec,\n        cms_levels=cms_specs,\n        cms_flush_partial_at_end=bool(model_cfg.get(\"cms_flush_partial_at_end\", False)),\n        cms_use_layernorm=cms_use_layernorm,\n        optimizers=optimizer_cfg,\n        teach_scale=teach_scale,\n        teach_clip=teach_clip,\n        teach_schedule=teach_schedule,\n        gradient_checkpointing=model_cfg.get(\"gradient_checkpointing\", False),\n        surprise_threshold=surprise_threshold,\n        surprise_metric=surprise_metric,\n        freeze_backbone=model_cfg.get(\"freeze_backbone\", False),\n        qk_l2_norm=qk_l2_norm,\n        local_conv_window=local_conv_window,\n        self_mod_lr=float(model_cfg.get(\"self_mod_lr\", 1e-3)),\n        self_mod_hidden=int(model_cfg.get(\"self_mod_hidden\", 4)),\n        self_mod_chunk_size=int(model_cfg.get(\"self_mod_chunk_size\", 1)),\n        self_mod_chunk_size_memory=self_mod_chunk_size_memory,\n        self_mod_objective=str(model_cfg.get(\"self_mod_objective\", \"l2\")),\n        self_mod_stopgrad_vhat=bool(model_cfg.get(\"self_mod_stopgrad_vhat\", True)),\n        self_mod_use_rank1_precond=bool(model_cfg.get(\"self_mod_use_rank1_precond\", True)),\n        self_mod_use_alpha=bool(model_cfg.get(\"self_mod_use_alpha\", True)),\n        self_mod_use_skip=bool(model_cfg.get(\"self_mod_use_skip\", True)),\n        self_mod_momentum=float(model_cfg.get(\"self_mod_momentum\", 0.0)),\n        self_mod_adaptive_q=bool(model_cfg.get(\"self_mod_adaptive_q\", False)),\n        self_mod_local_conv_window=self_mod_local_conv_window,\n        transformer_mlp_hidden_multiplier=int(\n            model_cfg.get(\"transformer_mlp_hidden_multiplier\", 4)\n        ),\n        transformer_activation=str(model_cfg.get(\"transformer_activation\", \"gelu\")),\n        block_variant=str(model_cfg.get(\"block_variant\", \"hope_hybrid\")),\n    )\n    return HOPEModel(hope_cfg)\n\n\ndef build_dataloader(\n    data_cfg: DictConfig,\n    *,\n    distributed: bool,\n    dist_ctx: DistributedContext | None,\n    seed: int | None = None,\n) -> Tuple[DataLoader, DistributedSampler | None]:\n    dataset = _build_dataset(data_cfg)\n    use_sampler = distributed and not isinstance(dataset, IterableDataset)\n    if use_sampler:\n        assert dist_ctx is not None\n        sampler: DistributedSampler | None = DistributedSampler(\n            dataset,\n            num_replicas=dist_ctx.world_size,\n            rank=dist_ctx.rank,\n            shuffle=True,\n            drop_last=False,\n        )\n        shuffle = False\n    else:\n        sampler = None\n        shuffle = True\n    if isinstance(dataset, IterableDataset):\n        shuffle = False\n    generator = None\n    worker_init_fn = None\n    if seed is not None:\n        generator = torch.Generator()\n        generator.manual_seed(seed)\n        worker_init_fn = _make_worker_init_fn(seed)\n    dataloader = DataLoader(\n        dataset,\n        batch_size=data_cfg.batch_size,\n        shuffle=shuffle,\n        sampler=sampler,\n        collate_fn=collate_batch,\n        num_workers=data_cfg.get(\"num_workers\", 0),\n        pin_memory=True,\n        worker_init_fn=worker_init_fn,\n        generator=generator,\n    )\n    return dataloader, sampler\n\n\ndef _build_dataset(data_cfg: DictConfig):\n    source = data_cfg.source\n    if source == \"synthetic\":\n        synth_cfg = SyntheticTextConfig(\n            vocab_size=data_cfg.vocab_size,\n            seq_len=data_cfg.seq_len,\n            dataset_size=data_cfg.dataset_size,\n        )\n        return SyntheticTextDataset(synth_cfg)\n    if source == \"shards\":\n        shard_dir = data_cfg.shards_dir\n        return TokenShardDataset(shard_dir)\n    if source == \"mixture\":\n        mixture_cfg = data_cfg.mixture\n        sources = [\n            ShardSourceConfig(\n                name=entry.name,\n                shards_dir=entry.shards_dir,\n                weight=entry.weight,\n            )\n            for entry in mixture_cfg.sources\n        ]\n        samples_per_epoch = mixture_cfg.samples_per_epoch\n        seed = mixture_cfg.get(\"seed\", 0)\n        return MixtureShardDataset(\n            sources,\n            samples_per_epoch=samples_per_epoch,\n            seed=seed,\n        )\n    msg = f\"Unsupported data source {source}\"\n    raise ValueError(msg)\n\n\ndef compute_teach_signal(\n    model: \"_HasLMHead\",\n    logits: torch.Tensor,\n    tokens: torch.Tensor,\n    *,\n    next_tokens: torch.Tensor | None = None,\n    ignore_index: int | None = None,\n) -> torch.Tensor:\n    \"\"\"\n    Approximate dL/dh where h is the hidden state before the LM head.\n\n    This matches the gradient of mean next-token CE.\n\n    By default this corresponds to CE(logits[:, :-1], tokens[:, 1:]).\n    If `next_tokens` is provided, the final logit position is also supervised\n    against that boundary target (used for chunked streaming boundaries).\n\n    If ignore_index is provided, targets equal to ignore_index are masked out and\n    the mean reduction denominator becomes the number of active targets (matching\n    PyTorch CE semantics).\n    \"\"\"\n    logits_detached = logits.detach()\n    probs = torch.softmax(logits_detached, dim=-1)\n    residual = probs.clone()\n    batch_size, seq_len, _ = residual.shape\n\n    targets = torch.zeros(\n        batch_size,\n        seq_len,\n        device=tokens.device,\n        dtype=tokens.dtype,\n    )\n    active = torch.zeros(\n        batch_size,\n        seq_len,\n        device=tokens.device,\n        dtype=torch.bool,\n    )\n    if seq_len > 1:\n        targets[:, :-1] = tokens[:, 1:]\n        active[:, :-1] = True\n    if next_tokens is not None:\n        if next_tokens.ndim == 2 and next_tokens.size(1) == 1:\n            next_targets = next_tokens[:, 0]\n        elif next_tokens.ndim == 1:\n            next_targets = next_tokens\n        else:\n            raise ValueError(\"next_tokens must have shape [B] or [B, 1]\")\n        if next_targets.size(0) != batch_size:\n            raise ValueError(\"next_tokens batch dimension must match tokens batch dimension\")\n        targets[:, -1] = next_targets.to(device=tokens.device, dtype=tokens.dtype)\n        active[:, -1] = True\n    if ignore_index is not None:\n        active = active & (targets != ignore_index)\n\n    active_f = active.to(dtype=residual.dtype)\n    residual.mul_(active_f.unsqueeze(-1))\n    safe_targets = torch.where(active, targets, torch.zeros_like(targets))\n    src = -active_f.unsqueeze(-1)\n    residual.scatter_add_(-1, safe_targets.unsqueeze(-1), src)\n    denom: torch.Tensor = active_f.sum().clamp(min=1.0)\n    residual = residual / denom\n\n    head_weight = model.lm_head.weight.detach()\n    if head_weight.dtype != residual.dtype:\n        head_weight = head_weight.to(dtype=residual.dtype)\n    grad = residual @ head_weight\n    return grad\n\n\ndef _compute_layer_teach_signals(\n    loss: torch.Tensor,\n    block_outputs: list[torch.Tensor],\n    *,\n    detach: bool = True,\n    create_graph: bool = False,\n) -> list[torch.Tensor]:\n    grads = torch.autograd.grad(\n        loss,\n        block_outputs,\n        retain_graph=True,\n        create_graph=create_graph,\n        allow_unused=False,\n    )\n    if detach:\n        return [g.detach() for g in grads]\n    return list(grads)\n\n\ndef _compute_surprise_override(\n    metric: str,\n    *,\n    logits: torch.Tensor,\n    tokens: torch.Tensor,\n    loss: torch.Tensor,\n    next_tokens: torch.Tensor | None = None,\n) -> float | None:\n    normalized = str(metric).strip().lower()\n    if normalized == \"loss\":\n        return float(loss.detach().item())\n    if normalized == \"logit_entropy\":\n        supervised_steps = int(tokens.size(1) - 1 + (0 if next_tokens is None else 1))\n        if supervised_steps <= 0:\n            return None\n        logits_detached = logits[:, :supervised_steps].detach().float()\n        probs = torch.softmax(logits_detached, dim=-1)\n        entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()\n        return float(entropy.item())\n    return None\n\n\ndef _infer_online_chunk_size(model: HOPEModel) -> int | None:\n    min_period: int | None = None\n    blocks = getattr(model, \"blocks\", [])\n    for block in blocks:\n        cfg = getattr(block, \"config\", None)\n        levels = getattr(cfg, \"cms_levels\", None)\n        if not levels:\n            continue\n        for spec in levels:\n            period = int(spec.update_period)\n            if period <= 0:\n                continue\n            min_period = period if min_period is None else min(min_period, period)\n    return min_period\n\n\ndef _iter_online_token_chunks(\n    tokens: torch.Tensor, *, chunk_size: int\n) -> Iterator[tuple[torch.Tensor, bool]]:\n    if chunk_size < 1:\n        raise ValueError(\"chunk_size must be >= 1\")\n    seq_len = tokens.size(1)\n    for core_start in range(0, seq_len, chunk_size):\n        core_end = min(core_start + chunk_size, seq_len)\n        if core_end <= core_start:\n            continue\n        # Carry one-token overlap so chunk boundaries still include next-token supervision.\n        chunk_start = core_start - 1 if core_start > 0 else core_start\n        chunk_tokens = tokens[:, chunk_start:core_end]\n        finalize_updates = core_end >= seq_len\n        yield chunk_tokens, finalize_updates\n\n\ndef _iter_online_boundary_chunks(\n    tokens: torch.Tensor, *, chunk_size: int\n) -> Iterator[tuple[torch.Tensor, torch.Tensor | None, bool]]:\n    \"\"\"\n    Yield non-overlapping chunks plus the boundary target token for chunk end.\n\n    This enables exact boundary supervision without one-token overlap.\n    \"\"\"\n    if chunk_size < 1:\n        raise ValueError(\"chunk_size must be >= 1\")\n    seq_len = tokens.size(1)\n    for start in range(0, seq_len, chunk_size):\n        end = min(start + chunk_size, seq_len)\n        if end <= start:\n            continue\n        next_tokens = None\n        if end < seq_len:\n            next_tokens = tokens[:, end]\n        finalize_updates = end >= seq_len\n        yield tokens[:, start:end], next_tokens, finalize_updates\n\n\nclass _HasLMHead(Protocol):\n    lm_head: torch.nn.Linear\n\n\ndef _checksum_path(path: str | None) -> str | None:\n    if not path:\n        return None\n    candidate = Path(path)\n    if not candidate.exists() or not candidate.is_file():\n        return None\n    digest = sha256()\n    with candidate.open(\"rb\") as handle:\n        for chunk in iter(lambda: handle.read(1 << 20), b\"\"):\n            digest.update(chunk)\n    return digest.hexdigest()\n\n\ndef maybe_save_checkpoint(\n    cfg: DictConfig,\n    model: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    *,\n    step: int,\n    total_steps: int,\n    distributed: bool,\n    dist_ctx: DistributedContext | None,\n    step_offset: int = 0,\n) -> None:\n    ckpt_cfg = cfg.train.get(\"checkpoint\")\n    if not ckpt_cfg or not ckpt_cfg.get(\"enable\", False):\n        return\n    if distributed and dist_ctx is not None and dist_ctx.rank != 0:\n        return\n    save_interval = ckpt_cfg.get(\"save_interval\", total_steps)\n    save_last = ckpt_cfg.get(\"save_last\", True)\n    is_last_step = (step + 1) >= total_steps\n    should_save = ((step + 1) % max(1, save_interval) == 0) or (save_last and is_last_step)\n    if not should_save:\n        return\n    ckpt_dir = Path(ckpt_cfg.get(\"dir\", \"checkpoints/default\"))\n    ckpt_dir.mkdir(parents=True, exist_ok=True)\n    global_step = step + 1 + int(step_offset)\n    ckpt_path = ckpt_dir / f\"step_{global_step:06d}.pt\"\n    tmp_path = ckpt_path.with_suffix(\".tmp\")\n    resolved_cfg = OmegaConf.to_container(cfg, resolve=True)\n    state = {\n        \"model\": model.state_dict(),\n        \"optimizer\": optimizer.state_dict(),\n        \"step\": step + 1,\n        \"config\": resolved_cfg,\n    }\n    torch.save(state, tmp_path)\n    os.replace(tmp_path, ckpt_path)\n    write_checkpoint_metadata(cfg, ckpt_path, global_step)\n    prefix = \"[checkpoint]\"\n    if distributed and dist_ctx is not None:\n        prefix = f\"[checkpoint rank={dist_ctx.rank}]\"\n    print(f\"{prefix} saved {ckpt_path} (global_step={global_step})\")\n\n\ndef _validate_distributed_config(cfg: DictConfig, distributed: bool) -> None:\n    if not distributed:\n        return\n    strict = bool(cfg.train.get(\"strict_streaming_contract\", False))\n    fail_if_faithful_disabled = bool(cfg.train.get(\"fail_if_paper_faithful_disabled\", False))\n    fail_hard = strict or fail_if_faithful_disabled\n    if not fail_hard:\n        return\n    if bool(cfg.train.get(\"per_layer_teach_signal\", False)):\n        raise RuntimeError(\n            \"train.per_layer_teach_signal=true is not supported under DDP in this repo. \"\n            \"Set train.strict_streaming_contract=false and \"\n            \"train.fail_if_paper_faithful_disabled=false to allow the fallback, \"\n            \"or run single-process training.\"\n        )\n    if bool(cfg.train.get(\"online_updates\", False)):\n        raise RuntimeError(\n            \"train.online_updates=true is not supported under DDP in this repo. \"\n            \"Set train.strict_streaming_contract=false and \"\n            \"train.fail_if_paper_faithful_disabled=false to allow the fallback, \"\n            \"or run single-process training.\"\n        )\n    if bool(cfg.train.get(\"online_boundary_targets\", False)):\n        raise RuntimeError(\n            \"train.online_boundary_targets=true is not supported under DDP in this repo. \"\n            \"Set train.strict_streaming_contract=false and \"\n            \"train.fail_if_paper_faithful_disabled=false to allow the fallback, \"\n            \"or run single-process training.\"\n        )\n    if bool(cfg.train.get(\"online_carry_attention_cache\", False)):\n        raise RuntimeError(\n            \"train.online_carry_attention_cache=true is not supported under DDP in this repo. \"\n            \"Set train.strict_streaming_contract=false and \"\n            \"train.fail_if_paper_faithful_disabled=false to allow the fallback, \"\n            \"or run single-process training.\"\n        )\n\n\ndef _emit_streaming_warning(\n    *,\n    code: str,\n    message: str,\n    details: dict[str, object] | None = None,\n) -> None:\n    payload: dict[str, object] = {\"warning_code\": code, \"message\": message}\n    if details:\n        payload[\"details\"] = details\n    print(f\"[train.warn] {json.dumps(payload, sort_keys=True)}\")\n\n\ndef _validate_paper_auditing_variant(cfg: DictConfig) -> None:\n    strict = bool(cfg.train.get(\"strict_streaming_contract\", False))\n    block_variant = str(cfg.model.get(\"block_variant\", \"\")).strip().lower()\n    if not block_variant:\n        return\n    allowed = {\"hope_attention\", \"hope_selfmod\"}\n    if block_variant in allowed:\n        return\n    msg = (\n        \"strict streaming contract expects a paper-defined HOPE variant \"\n        f\"({sorted(allowed)}), got model.block_variant={block_variant!r}\"\n    )\n    if strict:\n        raise RuntimeError(msg)\n    _emit_streaming_warning(\n        code=\"non_paper_variant\",\n        message=msg,\n        details={\"block_variant\": block_variant},\n    )\n\n\ndef _validate_tied_lm_head_for_paper_auditing(\n    cfg: DictConfig,\n    model: torch.nn.Module,\n) -> None:\n    strict = bool(cfg.train.get(\"strict_streaming_contract\", False))\n    fail_if_faithful_disabled = bool(cfg.train.get(\"fail_if_paper_faithful_disabled\", False))\n    if not (strict or fail_if_faithful_disabled):\n        return\n    lm_head = getattr(model, \"lm_head\", None)\n    embed = getattr(model, \"embed\", None)\n    if lm_head is None or embed is None:\n        return\n    lm_weight = getattr(lm_head, \"weight\", None)\n    emb_weight = getattr(embed, \"weight\", None)\n    if lm_weight is None or emb_weight is None:\n        return\n    if lm_weight.data_ptr() == emb_weight.data_ptr():\n        return\n    raise RuntimeError(\n        \"paper-auditing mode requires tied LM head and embedding weights \"\n        \"(lm_head.weight must alias embed.weight).\"\n    )\n\n\ndef _validate_fast_state_batch_semantics(cfg: DictConfig) -> None:\n    if not bool(cfg.train.get(\"use_fast_state\", False)):\n        return\n    data_cfg = cfg.get(\"data\")\n    if data_cfg is None:\n        return\n    batch_size_raw = data_cfg.get(\"batch_size\", 1)\n    try:\n        batch_size = int(batch_size_raw)\n    except (TypeError, ValueError):\n        return\n    if batch_size <= 1:\n        return\n    msg = (\n        \"train.use_fast_state=true currently shares CMS/TITAN fast state across the batch. \"\n        \"For strict per-context semantics, set data.batch_size=1.\"\n    )\n    strict = bool(cfg.train.get(\"strict_streaming_contract\", False))\n    fail_if_faithful_disabled = bool(cfg.train.get(\"fail_if_paper_faithful_disabled\", False))\n    if strict or fail_if_faithful_disabled:\n        raise RuntimeError(msg)\n    _emit_streaming_warning(\n        code=\"shared_fast_state_batch\",\n        message=msg,\n        details={\"batch_size\": batch_size},\n    )\n\n\ndef _validate_online_update_fast_state_semantics(cfg: DictConfig) -> None:\n    train_cfg = cfg.get(\"train\")\n    if train_cfg is None:\n        return\n    online_updates = bool(train_cfg.get(\"online_updates\", False))\n    use_fast_state = bool(train_cfg.get(\"use_fast_state\", False))\n    if not online_updates or use_fast_state:\n        return\n    msg = (\n        \"train.online_updates=true with train.use_fast_state=false applies online writes \"\n        \"directly to base parameters within each step. This can make gradients across chunks \"\n        \"harder to interpret. Use train.use_fast_state=true for paper-faithful runs.\"\n    )\n    strict = bool(train_cfg.get(\"strict_streaming_contract\", False))\n    fail_if_faithful_disabled = bool(train_cfg.get(\"fail_if_paper_faithful_disabled\", False))\n    if strict or fail_if_faithful_disabled:\n        raise RuntimeError(msg)\n    _emit_streaming_warning(\n        code=\"online_updates_without_fast_state\",\n        message=msg,\n        details={\"online_updates\": True, \"use_fast_state\": False},\n    )\n\n\ndef _resolve_algorithm_mode(cfg: DictConfig) -> str:\n    mode = str(cfg.train.get(\"algorithm_mode\", \"two_pass_stopgrad_updates\")).strip()\n    allowed = {\"two_pass_stopgrad_updates\", \"boundary_state_grad_through_write\"}\n    if mode not in allowed:\n        raise RuntimeError(f\"Unsupported train.algorithm_mode={mode!r}; allowed={sorted(allowed)}\")\n    return mode\n\n\ndef _validate_algorithm_mode_constraints(\n    cfg: DictConfig,\n    *,\n    algorithm_mode: str,\n    distributed: bool,\n) -> None:\n    if algorithm_mode != \"boundary_state_grad_through_write\":\n        return\n    if distributed:\n        raise RuntimeError(\n            \"train.algorithm_mode='boundary_state_grad_through_write' is not supported in DDP.\"\n        )\n    if not bool(cfg.train.get(\"online_updates\", False)):\n        raise RuntimeError(\n            \"train.algorithm_mode='boundary_state_grad_through_write' requires \"\n            \"train.online_updates=true.\"\n        )\n    if not bool(cfg.train.get(\"per_layer_teach_signal\", False)):\n        raise RuntimeError(\n            \"train.algorithm_mode='boundary_state_grad_through_write' requires \"\n            \"train.per_layer_teach_signal=true.\"\n        )\n    if not bool(cfg.train.get(\"use_fast_state\", False)):\n        raise RuntimeError(\n            \"train.algorithm_mode='boundary_state_grad_through_write' requires \"\n            \"train.use_fast_state=true.\"\n        )\n    if bool(cfg.train.get(\"online_carry_attention_cache\", False)) and not bool(\n        cfg.train.get(\"online_boundary_targets\", False)\n    ):\n        raise RuntimeError(\n            \"online_carry_attention_cache=true requires train.online_boundary_targets=true \"\n            \"(non-overlap chunking).\"\n        )\n    _emit_streaming_warning(\n        code=\"experimental_boundary_state_mode\",\n        message=(\n            \"train.algorithm_mode='boundary_state_grad_through_write' is an experimental \"\n            \"single-process path for mechanism probing and may use more memory.\"\n        ),\n        details={\"algorithm_mode\": algorithm_mode},\n    )\n\n\ndef _validate_online_chunking_constraints(cfg: DictConfig) -> None:\n    online_updates = bool(cfg.train.get(\"online_updates\", False))\n    online_boundary_targets = bool(cfg.train.get(\"online_boundary_targets\", False))\n    online_carry_attention_cache = bool(cfg.train.get(\"online_carry_attention_cache\", False))\n    if online_carry_attention_cache and not online_updates:\n        raise RuntimeError(\"online_carry_attention_cache=true requires train.online_updates=true\")\n    if online_carry_attention_cache and not online_boundary_targets:\n        raise RuntimeError(\n            \"online_carry_attention_cache=true requires train.online_boundary_targets=true \"\n            \"(non-overlap chunking).\"\n        )\n\n\ndef _check_online_supervised_pairs(\n    *,\n    strict: bool,\n    observed_pairs: int,\n    seq_len: int,\n) -> None:\n    expected_pairs = max(int(seq_len) - 1, 0)\n    if observed_pairs == expected_pairs:\n        return\n    msg = (\n        \"online chunk supervision mismatch: observed pair coverage does not match sequence length \"\n        f\"(observed_pairs={observed_pairs}, expected_pairs={expected_pairs})\"\n    )\n    if strict:\n        raise RuntimeError(msg)\n    _emit_streaming_warning(\n        code=\"online_supervision_mismatch\",\n        message=msg,\n        details={\"observed_pairs\": observed_pairs, \"expected_pairs\": expected_pairs},\n    )\n\n\ndef run_training_loop(\n    cfg: DictConfig,\n    *,\n    device: torch.device,\n    distributed: bool = False,\n    dist_ctx: DistributedContext | None = None,\n) -> Dict[str, float]:\n    algorithm_mode = _resolve_algorithm_mode(cfg)\n    _validate_algorithm_mode_constraints(\n        cfg,\n        algorithm_mode=algorithm_mode,\n        distributed=distributed,\n    )\n    _validate_online_chunking_constraints(cfg)\n    _validate_distributed_config(cfg, distributed)\n    _validate_paper_auditing_variant(cfg)\n    _validate_fast_state_batch_semantics(cfg)\n    _validate_online_update_fast_state_semantics(cfg)\n    model = build_model_from_cfg(cfg.model).to(device)\n    train_seed = cfg.train.get(\"seed\")\n    deterministic = cfg.train.get(\"deterministic\", False)\n    if train_seed is not None:\n        _seed_everything(int(train_seed), deterministic=bool(deterministic))\n    model = _maybe_compile_model(model, cfg.train.get(\"compile\"))\n    if distributed:\n        assert dist_ctx is not None\n        if device.type == \"cuda\":\n            idx = device.index if device.index is not None else 0\n            model = torch.nn.parallel.DistributedDataParallel(\n                model,\n                device_ids=[idx],\n                output_device=idx,\n                find_unused_parameters=True,\n            )\n        else:\n            model = torch.nn.parallel.DistributedDataParallel(\n                model,\n                find_unused_parameters=True,\n            )\n        base_model = model.module\n    else:\n        base_model = model\n\n    _validate_tied_lm_head_for_paper_auditing(cfg, base_model)\n\n    seed_offset = 0\n    if train_seed is not None and dist_ctx is not None:\n        seed_offset = dist_ctx.rank\n    dataloader_seed = None if train_seed is None else int(train_seed) + seed_offset\n    dataloader, sampler = build_dataloader(\n        cfg.data,\n        distributed=distributed,\n        dist_ctx=dist_ctx,\n        seed=dataloader_seed,\n    )\n    optimizer = _build_optimizer(base_model, cfg, device=device)\n    autocast_factory = _make_autocast_factory(device, cfg.train.get(\"mixed_precision\"))\n    logger = init_logger(getattr(cfg, \"logging\", None), cfg)\n    if distributed and dist_ctx is not None and dist_ctx.rank != 0:\n        logger = NullLogger()\n    _log_run_features(logger, base_model, cfg, optimizer, device)\n    steps = cfg.train.steps\n    log_interval = cfg.train.get(\"log_interval\", 1)\n    per_layer_teach = bool(cfg.train.get(\"per_layer_teach_signal\", False))\n    online_updates = bool(cfg.train.get(\"online_updates\", False))\n    online_chunk_size = int(cfg.train.get(\"online_chunk_size\", 0) or 0)\n    online_boundary_targets = bool(cfg.train.get(\"online_boundary_targets\", False))\n    online_carry_attention_cache = bool(cfg.train.get(\"online_carry_attention_cache\", False))\n    use_fast_state = bool(cfg.train.get(\"use_fast_state\", False))\n    fail_if_faithful_disabled = bool(cfg.train.get(\"fail_if_paper_faithful_disabled\", False))\n    strict_streaming = bool(cfg.train.get(\"strict_streaming_contract\", False))\n    if distributed and per_layer_teach:\n        msg = \"per_layer_teach_signal disabled under DDP (uses base model methods)\"\n        if fail_if_faithful_disabled or strict_streaming:\n            raise RuntimeError(\n                f\"{msg}. Set train.strict_streaming_contract=false and \"\n                \"train.fail_if_paper_faithful_disabled=false to allow the fallback, \"\n                \"or run single-process training.\"\n            )\n        _emit_streaming_warning(\n            code=\"ddp_disables_per_layer_teach\",\n            message=msg,\n            details={\"distributed\": True},\n        )\n        per_layer_teach = False\n    if distributed and online_updates:\n        msg = \"online_updates disabled under DDP (uses base model methods)\"\n        if fail_if_faithful_disabled or strict_streaming:\n            raise RuntimeError(\n                f\"{msg}. Set train.strict_streaming_contract=false and \"\n                \"train.fail_if_paper_faithful_disabled=false to allow the fallback, \"\n                \"or run single-process training.\"\n            )\n        _emit_streaming_warning(\n            code=\"ddp_disables_online_updates\",\n            message=msg,\n            details={\"distributed\": True},\n        )\n        online_updates = False\n    if online_boundary_targets and not online_updates:\n        msg = \"online_boundary_targets=true requires train.online_updates=true\"\n        if fail_if_faithful_disabled or strict_streaming:\n            raise RuntimeError(msg)\n        _emit_streaming_warning(\n            code=\"boundary_targets_without_online_updates\",\n            message=msg,\n        )\n        online_boundary_targets = False\n    if online_carry_attention_cache and not online_updates:\n        raise RuntimeError(\"online_carry_attention_cache=true requires train.online_updates=true\")\n    if online_carry_attention_cache and not online_boundary_targets:\n        raise RuntimeError(\n            \"online_carry_attention_cache=true requires train.online_boundary_targets=true \"\n            \"(non-overlap chunking).\"\n        )\n    step_iter = iter(dataloader)\n    epoch = 0\n    metrics: Dict[str, float] = {}\n    surprise_metric_getter = getattr(base_model, \"get_surprise_metric\", None)\n    surprise_metric = (\n        str(surprise_metric_getter()).strip().lower()\n        if callable(surprise_metric_getter)\n        else str(cfg.model.get(\"surprise_metric\", \"l2\")).strip().lower()\n    )\n    for step in range(steps):\n        if sampler is not None and step % len(dataloader) == 0:\n            sampler.set_epoch(epoch)\n            epoch += 1\n        try:\n            batch = next(step_iter)\n        except StopIteration:\n            step_iter = iter(dataloader)\n            batch = next(step_iter)\n        tokens = batch.to(device)\n        fast_state = None\n        if use_fast_state:\n            init_fn = getattr(base_model, \"init_fast_state\", None)\n            if not callable(init_fn):\n                raise ValueError(\"train.use_fast_state=true requires model.init_fast_state()\")\n            fast_state = init_fn()\n        _apply_teach_schedule(base_model, cfg, step)\n        update_metrics: Dict[str, float] = {}\n        if online_updates and hasattr(base_model, \"forward_with_block_outputs\"):\n            total_loss = 0.0\n            total_tokens = 0\n            teach_signal_norm = 0.0\n            optimizer.zero_grad()\n            chunk_size = online_chunk_size\n            if chunk_size <= 0:\n                inferred = _infer_online_chunk_size(base_model)\n                chunk_size = inferred if inferred is not None else tokens.size(1)\n            if chunk_size < 1:\n                print(f\"[train] online_chunk_size={chunk_size} is too small; clamping to 1\")\n                chunk_size = 1\n            attention_cache = None\n            if online_carry_attention_cache:\n                init_attention_cache = getattr(base_model, \"init_attention_cache\", None)\n                if not callable(init_attention_cache):\n                    raise RuntimeError(\n                        \"online_carry_attention_cache=true requires model.init_attention_cache()\"\n                    )\n                attention_cache = init_attention_cache()\n\n            chunk_iter: Iterator[tuple[torch.Tensor, torch.Tensor | None, bool]]\n            if online_boundary_targets:\n                chunk_iter = _iter_online_boundary_chunks(tokens, chunk_size=chunk_size)\n            else:\n                chunk_iter = (\n                    (chunk, None, finalize_updates)\n                    for chunk, finalize_updates in _iter_online_token_chunks(\n                        tokens, chunk_size=chunk_size\n                    )\n                )\n            for chunk_tokens, next_tokens, finalize_updates in chunk_iter:\n                target_count = chunk_tokens.size(1) - 1 + (0 if next_tokens is None else 1)\n                if target_count <= 0:\n                    continue\n                chunk_attention_cache = attention_cache\n                with autocast_factory():\n                    if attention_cache is not None:\n                        logits, _pre, block_outputs, attention_cache = (\n                            base_model.forward_with_block_outputs(\n                                chunk_tokens,\n                                fast_state=fast_state,\n                                attention_cache=chunk_attention_cache,\n                                return_attention_cache=True,\n                            )\n                        )\n                    else:\n                        logits, _pre, block_outputs = (\n                            base_model.forward_with_block_outputs(\n                                chunk_tokens,\n                                fast_state=fast_state,\n                            )\n                            if fast_state is not None\n                            else base_model.forward_with_block_outputs(chunk_tokens)\n                        )\n                    if next_tokens is None:\n                        loss = torch.nn.functional.cross_entropy(\n                            logits[:, :-1].reshape(-1, logits.size(-1)),\n                            chunk_tokens[:, 1:].reshape(-1),\n                        )\n                    else:\n                        boundary_targets = torch.cat(\n                            [chunk_tokens[:, 1:], next_tokens.unsqueeze(1)],\n                            dim=1,\n                        )\n                        loss = torch.nn.functional.cross_entropy(\n                            logits[:, : boundary_targets.size(1), :].reshape(-1, logits.size(-1)),\n                            boundary_targets.reshape(-1),\n                        )\n                surprise_override = _compute_surprise_override(\n                    surprise_metric,\n                    logits=logits,\n                    tokens=chunk_tokens,\n                    loss=loss,\n                    next_tokens=next_tokens,\n                )\n                if per_layer_teach:\n                    differentiable_updates = algorithm_mode == \"boundary_state_grad_through_write\"\n                    teach_signals = _compute_layer_teach_signals(\n                        loss,\n                        block_outputs,\n                        detach=not differentiable_updates,\n                        create_graph=differentiable_updates,\n                    )\n                    mean_teach_norm = torch.stack(\n                        [sig.detach().norm(dim=-1).mean() for sig in teach_signals]\n                    ).mean()\n                    teach_signal_norm += float(\n                        mean_teach_norm\n                    ) * target_count\n                else:\n                    teach_signal = compute_teach_signal(\n                        base_model,\n                        logits,\n                        chunk_tokens,\n                        next_tokens=next_tokens,\n                    )\n                    teach_signal_norm += teach_signal.norm(dim=-1).mean().item() * target_count\n                differentiable_updates = algorithm_mode == \"boundary_state_grad_through_write\"\n                # Boundary-state mode keeps a cross-chunk differentiable write path.\n                # Retain the graph so later chunks can backprop through earlier writes.\n                loss.backward(retain_graph=differentiable_updates)\n                if differentiable_updates:\n                    if per_layer_teach:\n                        base_model(\n                            chunk_tokens,\n                            teach_signals=teach_signals,\n                            surprise_value=surprise_override,\n                            fast_state=fast_state,\n                            finalize_updates=finalize_updates,\n                            attention_cache=chunk_attention_cache,\n                            differentiable_updates=True,\n                        )\n                    else:\n                        base_model(\n                            chunk_tokens,\n                            teach_signal=teach_signal,\n                            surprise_value=surprise_override,\n                            fast_state=fast_state,\n                            finalize_updates=finalize_updates,\n                            attention_cache=chunk_attention_cache,\n                            differentiable_updates=True,\n                        )\n                    if hasattr(base_model, \"pop_update_metrics\"):\n                        update_metrics = base_model.pop_update_metrics()\n                else:\n                    with torch.no_grad():\n                        if per_layer_teach:\n                            base_model(\n                                chunk_tokens,\n                                teach_signals=teach_signals,\n                                surprise_value=surprise_override,\n                                fast_state=fast_state,\n                                finalize_updates=finalize_updates,\n                                attention_cache=chunk_attention_cache,\n                                differentiable_updates=False,\n                            )\n                        else:\n                            base_model(\n                                chunk_tokens,\n                                teach_signal=teach_signal,\n                                surprise_value=surprise_override,\n                                fast_state=fast_state,\n                                finalize_updates=finalize_updates,\n                                attention_cache=chunk_attention_cache,\n                                differentiable_updates=False,\n                            )\n                        if hasattr(base_model, \"pop_update_metrics\"):\n                            update_metrics = base_model.pop_update_metrics()\n                total_loss += loss.item() * target_count\n                total_tokens += target_count\n            _check_online_supervised_pairs(\n                strict=strict_streaming,\n                observed_pairs=total_tokens,\n                seq_len=int(tokens.size(1)),\n            )\n            torch.nn.utils.clip_grad_norm_(base_model.parameters(), max_norm=1.0)\n            optimizer.step()\n            loss = torch.tensor(total_loss / max(total_tokens, 1), device=device)\n            teach_signal_norm = teach_signal_norm / max(total_tokens, 1)\n        else:\n            with autocast_factory():\n                if per_layer_teach and hasattr(base_model, \"forward_with_block_outputs\"):\n                    logits, _pre, block_outputs = (\n                        base_model.forward_with_block_outputs(tokens, fast_state=fast_state)\n                        if fast_state is not None\n                        else base_model.forward_with_block_outputs(tokens)\n                    )\n                    loss = torch.nn.functional.cross_entropy(\n                        logits[:, :-1].reshape(-1, logits.size(-1)),\n                        tokens[:, 1:].reshape(-1),\n                    )\n                else:\n                    if fast_state is not None:\n                        logits = model(tokens, fast_state=fast_state)\n                    else:\n                        logits = model(tokens)\n                    loss = torch.nn.functional.cross_entropy(\n                        logits[:, :-1].reshape(-1, logits.size(-1)),\n                        tokens[:, 1:].reshape(-1),\n                    )\n            surprise_override = _compute_surprise_override(\n                surprise_metric,\n                logits=logits,\n                tokens=tokens,\n                loss=loss,\n                next_tokens=None,\n            )\n            optimizer.zero_grad()\n            if per_layer_teach and hasattr(base_model, \"forward_with_block_outputs\"):\n                teach_signals = _compute_layer_teach_signals(loss, block_outputs)\n            loss.backward()\n            torch.nn.utils.clip_grad_norm_(base_model.parameters(), max_norm=1.0)\n            optimizer.step()\n            with torch.no_grad():\n                if per_layer_teach and hasattr(base_model, \"forward_with_block_outputs\"):\n                    teach_signal_norm = float(\n                        torch.stack([sig.norm(dim=-1).mean() for sig in teach_signals]).mean()\n                    )\n                    base_model(\n                        tokens,\n                        teach_signals=teach_signals,\n                        surprise_value=surprise_override,\n                        fast_state=fast_state,\n                    )\n                else:\n                    teach_signal = compute_teach_signal(base_model, logits, tokens)\n                    teach_signal_norm = teach_signal.norm(dim=-1).mean().item()\n                    base_model(\n                        tokens,\n                        teach_signal=teach_signal,\n                        surprise_value=surprise_override,\n                        fast_state=fast_state,\n                    )\n                if hasattr(base_model, \"pop_update_metrics\"):\n                    update_metrics = base_model.pop_update_metrics()\n        if step % log_interval == 0:\n            ppl = torch.exp(loss.detach()).item()\n            metrics_payload = {\n                \"loss\": loss.item(),\n                \"ppl\": ppl,\n                \"teach_signal_norm\": teach_signal_norm,\n            }\n            metrics_payload.update(update_metrics)\n            logger.log(metrics_payload, step=step)\n            if (not distributed) or (dist_ctx and dist_ctx.rank == 0):\n                print(\n                    f\"[train] step={step} loss={loss.item():.4f} \"\n                    f\"ppl={ppl:.2f} teach_norm={teach_signal_norm:.4f}\"\n                )\n            metrics = metrics_payload\n        maybe_save_checkpoint(\n            cfg,\n            base_model,\n            optimizer,\n            step=step,\n            total_steps=steps,\n            distributed=distributed,\n            dist_ctx=dist_ctx,\n            step_offset=int(cfg.train.get(\"step_offset\", 0) or 0),\n        )\n    logger.finish()\n    return metrics\n\n\ndef _apply_teach_schedule(model: HOPEModel, cfg: DictConfig, step: int) -> None:\n    schedule = cfg.model.get(\"teach_schedule\")\n    base_scale = cfg.model.get(\"teach_scale\", 1.0)\n    scale = base_scale\n    if schedule:\n        warmup = schedule.get(\"warmup_steps\", 0)\n        if warmup and warmup > 0:\n            scale *= min(1.0, (step + 1) / warmup)\n        decay_start = schedule.get(\"decay_start\")\n        decay_duration = schedule.get(\"decay_duration\")\n        if (\n            decay_start is not None\n            and decay_duration\n            and decay_duration > 0\n            and (step + 1) > decay_start\n        ):\n            progress = min(1.0, (step + 1 - decay_start) / decay_duration)\n            scale *= max(0.0, 1.0 - progress)\n    model.set_teach_runtime(scale=scale)\n\n\ndef _maybe_compile_model(model: torch.nn.Module, compile_cfg: dict | None) -> torch.nn.Module:\n    if not compile_cfg or not compile_cfg.get(\"enable\", False):\n        return model\n    kwargs = {}\n    if \"mode\" in compile_cfg:\n        kwargs[\"mode\"] = compile_cfg[\"mode\"]\n    if \"backend\" in compile_cfg:\n        kwargs[\"backend\"] = compile_cfg[\"backend\"]\n    try:\n        return cast(torch.nn.Module, torch.compile(model, **kwargs))  # type: ignore[attr-defined]\n    except Exception as err:  # pragma: no cover - compile is optional\n        if compile_cfg.get(\"strict\", False):\n            raise\n        print(f\"[compile] fallback to eager due to: {err}\")\n        return model\n\n\ndef _make_autocast_factory(device: torch.device, mp_cfg: dict | None):\n    if not mp_cfg or not mp_cfg.get(\"enabled\", False):\n        return lambda: nullcontext()\n    dtype = _resolve_autocast_dtype(mp_cfg.get(\"dtype\", \"bf16\"))\n    device_type = device.type\n    if device_type not in {\"cuda\", \"cpu\", \"mps\"}:\n        device_type = \"cpu\"\n\n    def factory():\n        try:\n            return torch.autocast(device_type=device_type, dtype=dtype)\n        except Exception as err:  # pragma: no cover - device/dtype support varies by backend\n            print(f\"[autocast] disabled for device_type={device_type} dtype={dtype}: {err}\")\n            return nullcontext()\n\n    return factory\n\n\ndef _resolve_autocast_dtype(name: str) -> torch.dtype:\n    normalized = str(name).lower()\n    if normalized in {\"bf16\", \"bfloat16\"}:\n        return torch.bfloat16\n    if normalized in {\"fp16\", \"float16\", \"half\"}:\n        return torch.float16\n    msg = f\"Unsupported autocast dtype {name}\"\n    raise ValueError(msg)\n\n\ndef _build_optimizer(\n    model: torch.nn.Module, cfg: DictConfig, *, device: torch.device\n) -> torch.optim.Optimizer:\n    optimizer_cfg_raw = cfg.get(\"optim\")\n    if isinstance(optimizer_cfg_raw, DictConfig):\n        optimizer_cfg = optimizer_cfg_raw\n    else:\n        optimizer_cfg = cast(DictConfig, OmegaConf.create(optimizer_cfg_raw or {}))\n    param_policy_raw = optimizer_cfg.get(\"param_policy\")\n    if param_policy_raw is None:\n        outer_updates_memory_modules = optimizer_cfg.get(\"outer_updates_memory_modules\")\n        if outer_updates_memory_modules is None:\n            param_policy = \"all\"\n        else:\n            param_policy = \"all\" if bool(outer_updates_memory_modules) else \"exclude_memory\"\n    else:\n        param_policy = str(param_policy_raw).strip().lower()\n    named_params = _select_outer_named_parameters(model, param_policy)\n    if not named_params:\n        raise ValueError(\n            f\"No trainable parameters selected for optim.param_policy={param_policy!r}. \"\n            \"Check freeze_backbone, requires_grad flags, or adjust the policy.\"\n        )\n    optim_type = str(optimizer_cfg.get(\"type\", \"adamw\")).lower()\n    if optim_type == \"muon\":\n        return _build_muon_optimizer(\n            model,\n            optimizer_cfg,\n            device=device,\n            named_params=named_params,\n            param_policy=param_policy,\n        )\n    if optim_type == \"m3\":\n        return _build_m3_optimizer(\n            model,\n            optimizer_cfg,\n            device=device,\n            named_params=named_params,\n            param_policy=param_policy,\n        )\n    lr = optimizer_cfg.get(\"lr\", 1e-3)\n    betas = optimizer_cfg.get(\"betas\", (0.9, 0.999))\n    weight_decay = optimizer_cfg.get(\"weight_decay\", 0.0)\n    fused_cfg = optimizer_cfg.get(\"fused\", \"auto\")\n    fused = False\n    if fused_cfg == \"auto\":\n        fused = device.type == \"cuda\" and torch.cuda.is_available()\n    else:\n        fused = bool(fused_cfg)\n    kwargs = {\"lr\": lr, \"betas\": betas, \"weight_decay\": weight_decay}\n    if fused:\n        kwargs[\"fused\"] = True\n    params = [param for _, param in named_params]\n    return torch.optim.AdamW(params, **kwargs)\n\n\ndef _build_muon_optimizer(\n    model: torch.nn.Module,\n    optimizer_cfg: DictConfig,\n    *,\n    device: torch.device,\n    named_params: list[tuple[str, torch.nn.Parameter]] | None = None,\n    param_policy: str | None = None,\n):\n    if not hasattr(torch.optim, \"Muon\"):\n        raise RuntimeError(\"torch.optim.Muon is not available in this PyTorch build\")\n    lr = optimizer_cfg.get(\"lr\", 1e-3)\n    weight_decay = optimizer_cfg.get(\"weight_decay\", 0.01)\n    momentum = optimizer_cfg.get(\"momentum\", 0.95)\n    ns_coefficients = optimizer_cfg.get(\"ns_coefficients\")\n    ns_steps = optimizer_cfg.get(\"ns_steps\")\n    eps = optimizer_cfg.get(\"eps\", 1e-7)\n    fused_cfg = optimizer_cfg.get(\"fused\", \"auto\")\n    fused = False\n    if fused_cfg == \"auto\":\n        fused = device.type == \"cuda\" and torch.cuda.is_available()\n    else:\n        fused = bool(fused_cfg)\n    muon_params: list[torch.nn.Parameter] = []\n    adamw_params: list[torch.nn.Parameter] = []\n    source = named_params if named_params is not None else model.named_parameters()\n    for name, param in source:\n        if not param.requires_grad:\n            continue\n        if _is_muon_candidate(name, param):\n            muon_params.append(param)\n        else:\n            adamw_params.append(param)\n    muon_kwargs = {\n        \"lr\": lr,\n        \"weight_decay\": weight_decay,\n        \"momentum\": momentum,\n        \"eps\": eps,\n    }\n    if ns_coefficients is not None:\n        muon_kwargs[\"ns_coefficients\"] = tuple(ns_coefficients)\n    if ns_steps is not None:\n        muon_kwargs[\"ns_steps\"] = int(ns_steps)\n    muon_opt = torch.optim.Muon(muon_params, **muon_kwargs) if muon_params else None  # type: ignore[attr-defined]\n    adamw_kwargs = {\n        \"lr\": lr,\n        \"betas\": optimizer_cfg.get(\"betas\", (0.9, 0.999)),\n        \"weight_decay\": weight_decay,\n    }\n    if fused:\n        adamw_kwargs[\"fused\"] = True\n    adamw_opt = torch.optim.AdamW(adamw_params, **adamw_kwargs) if adamw_params else None\n    muon_elems = int(sum(p.numel() for p in muon_params))\n    adamw_elems = int(sum(p.numel() for p in adamw_params))\n    return _HybridOptimizer(\n        muon_opt,\n        adamw_opt,\n        muon_elems,\n        adamw_elems,\n        primary_name=\"muon\",\n        param_policy=param_policy,\n    )\n\n\ndef _build_m3_optimizer(\n    model: torch.nn.Module,\n    optimizer_cfg: DictConfig,\n    *,\n    device: torch.device,\n    named_params: list[tuple[str, torch.nn.Parameter]] | None = None,\n    param_policy: str | None = None,\n):\n    lr = optimizer_cfg.get(\"lr\", 1e-3)\n    weight_decay = optimizer_cfg.get(\"weight_decay\", 0.01)\n    beta1 = optimizer_cfg.get(\"beta1\", 0.9)\n    beta2 = optimizer_cfg.get(\"beta2\", 0.999)\n    beta3 = optimizer_cfg.get(\"beta3\", 0.9)\n    alpha = optimizer_cfg.get(\"alpha\", 1.0)\n    ns_steps = int(optimizer_cfg.get(\"ns_steps\", 3))\n    slow_chunk = int(optimizer_cfg.get(\"slow_chunk\", 100))\n    eps = optimizer_cfg.get(\"eps\", 1e-8)\n    fused_cfg = optimizer_cfg.get(\"fused\", \"auto\")\n    fused = False\n    if fused_cfg == \"auto\":\n        fused = device.type == \"cuda\" and torch.cuda.is_available()\n    else:\n        fused = bool(fused_cfg)\n\n    m3_params: list[torch.nn.Parameter] = []\n    adamw_params: list[torch.nn.Parameter] = []\n    source = named_params if named_params is not None else model.named_parameters()\n    for name, param in source:\n        if not param.requires_grad:\n            continue\n        if _is_muon_candidate(name, param):\n            m3_params.append(param)\n        else:\n            adamw_params.append(param)\n    m3_opt = (\n        M3(\n            m3_params,\n            lr=lr,\n            beta1=beta1,\n            beta2=beta2,\n            beta3=beta3,\n            alpha=alpha,\n            eps=eps,\n            ns_steps=ns_steps,\n            slow_chunk=slow_chunk,\n            weight_decay=weight_decay,\n        )\n        if m3_params\n        else None\n    )\n    adamw_kwargs = {\n        \"lr\": lr,\n        \"betas\": optimizer_cfg.get(\"betas\", (0.9, 0.999)),\n        \"weight_decay\": weight_decay,\n    }\n    if fused:\n        adamw_kwargs[\"fused\"] = True\n    adamw_opt = torch.optim.AdamW(adamw_params, **adamw_kwargs) if adamw_params else None\n    m3_elems = int(sum(p.numel() for p in m3_params))\n    adamw_elems = int(sum(p.numel() for p in adamw_params))\n    return _HybridOptimizer(\n        m3_opt,\n        adamw_opt,\n        m3_elems,\n        adamw_elems,\n        primary_name=\"m3\",\n        param_policy=param_policy,\n    )\n\n\ndef _select_outer_named_parameters(\n    model: torch.nn.Module, param_policy: str\n) -> list[tuple[str, torch.nn.Parameter]]:\n    policy = str(param_policy).strip().lower()\n    trainable: list[tuple[str, torch.nn.Parameter]] = [\n        (name, param) for name, param in model.named_parameters() if param.requires_grad\n    ]\n    if policy in {\"all\", \"full\"}:\n        return trainable\n    if policy in {\"exclude_memory\", \"no_memory\"}:\n        return [(name, param) for name, param in trainable if not _is_memory_param_name(name)]\n    if policy in {\"only_memory\", \"memory_only\"}:\n        return [(name, param) for name, param in trainable if _is_memory_param_name(name)]\n    raise ValueError(\n        f\"Unsupported optim.param_policy={param_policy!r}. \"\n        \"Expected one of ['all', 'exclude_memory', 'only_memory'].\"\n    )\n\n\ndef _is_memory_param_name(name: str) -> bool:\n    lowered = name.lower()\n    return any(token in lowered for token in (\".cms.\", \".titan_memory.\", \".selfmod.\"))\n\n\ndef _is_muon_candidate(name: str, param: torch.nn.Parameter) -> bool:\n    if param.ndim < 2:\n        return False\n    lowered = name.lower()\n    if \"norm\" in lowered or \"embed\" in lowered:\n        return False\n    return True\n\n\nclass _HybridOptimizer:\n    def __init__(\n        self,\n        primary_opt: torch.optim.Optimizer | None,\n        secondary_opt: torch.optim.Optimizer | None,\n        primary_param_elems: int,\n        secondary_param_elems: int,\n        *,\n        primary_name: str = \"muon\",\n        param_policy: str | None = None,\n    ):\n        self.primary_opt = primary_opt\n        self.secondary_opt = secondary_opt\n        self.primary_param_elems = primary_param_elems\n        self.secondary_param_elems = secondary_param_elems\n        self.primary_name = primary_name\n        self.param_policy = param_policy\n\n    def zero_grad(self) -> None:\n        if self.primary_opt:\n            self.primary_opt.zero_grad()\n        if self.secondary_opt:\n            self.secondary_opt.zero_grad()\n\n    def step(self) -> None:\n        if self.primary_opt:\n            self.primary_opt.step()\n        if self.secondary_opt:\n            self.secondary_opt.step()\n\n    def state_dict(self) -> dict:\n        return {\n            self.primary_name: self.primary_opt.state_dict() if self.primary_opt else None,\n            \"adamw\": self.secondary_opt.state_dict() if self.secondary_opt else None,\n        }\n\n    def load_state_dict(self, state: dict) -> None:\n        if self.primary_opt and state.get(self.primary_name) is not None:\n            self.primary_opt.load_state_dict(state[self.primary_name])\n        if self.secondary_opt and state.get(\"adamw\") is not None:\n            self.secondary_opt.load_state_dict(state[\"adamw\"])\n\n    @property\n    def param_groups(self):\n        groups = []\n        if self.primary_opt:\n            groups.extend(self.primary_opt.param_groups)\n        if self.secondary_opt:\n            groups.extend(self.secondary_opt.param_groups)\n        return groups\n\n    def get_param_split(self) -> dict[str, int]:\n        return {\n            self.primary_name: self.primary_param_elems,\n            \"adamw\": self.secondary_param_elems,\n        }\n\n\ndef _log_run_features(\n    logger: BaseLogger,\n    model: torch.nn.Module,\n    cfg: DictConfig,\n    optimizer: torch.optim.Optimizer,\n    device: torch.device,\n) -> None:\n    mp_cfg = cfg.train.get(\"mixed_precision\", {})\n    compile_cfg = cfg.train.get(\"compile\", {})\n    algorithm_mode = str(cfg.train.get(\"algorithm_mode\", \"two_pass_stopgrad_updates\"))\n    features: dict[str, object] = {\n        \"train.mixed_precision_enabled\": bool(mp_cfg.get(\"enabled\", False)),\n        \"train.mixed_precision_dtype\": str(mp_cfg.get(\"dtype\", \"bf16\")),\n        \"train.compile_enabled\": bool(compile_cfg.get(\"enable\", False)),\n        \"train.compile_mode\": str(compile_cfg.get(\"mode\", \"default\")) if compile_cfg else \"default\",\n        \"train.strict_streaming_contract\": bool(cfg.train.get(\"strict_streaming_contract\", False)),\n        \"train.online_updates\": bool(cfg.train.get(\"online_updates\", False)),\n        \"train.online_boundary_targets\": bool(cfg.train.get(\"online_boundary_targets\", False)),\n        \"train.online_carry_attention_cache\": bool(\n            cfg.train.get(\"online_carry_attention_cache\", False)\n        ),\n        \"train.use_fast_state\": bool(cfg.train.get(\"use_fast_state\", False)),\n        \"train.algorithm_mode\": algorithm_mode,\n        \"train.backprop_through_online_writes\": algorithm_mode\n        == \"boundary_state_grad_through_write\",\n        \"attention.flash_enabled\": _detect_flash_attention(model),\n        \"device\": device.type,\n    }\n    optimizer_cfg_raw = cfg.get(\"optim\")\n    if isinstance(optimizer_cfg_raw, DictConfig):\n        optimizer_cfg = optimizer_cfg_raw\n    else:\n        optimizer_cfg = cast(DictConfig, OmegaConf.create(optimizer_cfg_raw or {}))\n    param_policy_raw = optimizer_cfg.get(\"param_policy\")\n    if param_policy_raw is None:\n        outer_updates_memory_modules = optimizer_cfg.get(\"outer_updates_memory_modules\")\n        if outer_updates_memory_modules is None:\n            param_policy = \"all\"\n        else:\n            param_policy = \"all\" if bool(outer_updates_memory_modules) else \"exclude_memory\"\n    else:\n        param_policy = str(param_policy_raw).strip().lower()\n    try:\n        selected = _select_outer_named_parameters(model, param_policy)\n        total_elems = int(sum(param.numel() for _, param in selected))\n        memory_elems = int(\n            sum(param.numel() for name, param in selected if _is_memory_param_name(name))\n        )\n        features[\"optim.param_policy\"] = param_policy\n        features[\"optim.param_policy_param_elems\"] = total_elems\n        features[\"optim.param_policy_memory_param_elems\"] = memory_elems\n        features[\"optim.param_policy_non_memory_param_elems\"] = total_elems - memory_elems\n    except Exception as err:  # pragma: no cover - purely diagnostic\n        features[\"optim.param_policy\"] = param_policy\n        features[\"optim.param_policy_error\"] = str(err)\n    split_fn = getattr(optimizer, \"get_param_split\", None)\n    if callable(split_fn):\n        split = split_fn()\n        for key, value in split.items():\n            features[f\"optim.{key}_param_elems\"] = int(value)\n    logger.log(features, step=-1)\n    print(f\"[train] run_features {features}\")\n\n\ndef _detect_flash_attention(model: torch.nn.Module) -> bool:\n    blocks = getattr(model, \"blocks\", [])\n    for block in blocks:\n        attn = getattr(block, \"attn\", None)\n        config = getattr(attn, \"config\", None)\n        if config is not None and hasattr(config, \"use_flash\"):\n            return bool(config.use_flash)\n    return False\n\n\ndef write_checkpoint_metadata(cfg: DictConfig, ckpt_path: Path, step: int) -> None:\n    config_yaml = OmegaConf.to_yaml(cfg)\n    config_path = ckpt_path.with_suffix(\".yaml\")\n    config_path.write_text(config_yaml)\n    config_hash = sha256(config_yaml.encode(\"utf-8\")).hexdigest()\n    ckpt_hash = _checksum_path(str(ckpt_path))\n    sha_path = ckpt_path.with_suffix(\".sha256\")\n    if ckpt_hash:\n        sha_path.write_text(f\"{ckpt_hash}  {ckpt_path.name}\\n\")\n    tokenizer_path = cfg.data.get(\"tokenizer_path\") if hasattr(cfg, \"data\") else None\n    metadata = {\n        \"step\": step,\n        \"checkpoint_sha256\": ckpt_hash,\n        \"config_sha256\": config_hash,\n        \"tokenizer_hash\": _checksum_path(tokenizer_path) if tokenizer_path else None,\n        \"config_path\": str(config_path),\n        \"algorithm_mode\": str(cfg.train.get(\"algorithm_mode\", \"two_pass_stopgrad_updates\")),\n        \"online_updates\": bool(cfg.train.get(\"online_updates\", False)),\n        \"online_boundary_targets\": bool(cfg.train.get(\"online_boundary_targets\", False)),\n        \"online_carry_attention_cache\": bool(\n            cfg.train.get(\"online_carry_attention_cache\", False)\n        ),\n        \"use_fast_state\": bool(cfg.train.get(\"use_fast_state\", False)),\n        \"rng_states\": _capture_rng_states(),\n    }\n    ckpt_path.with_suffix(\".meta.json\").write_text(json.dumps(metadata, indent=2))\n\n\ndef verify_checkpoint_integrity(ckpt_path: Path) -> Dict[str, object]:\n    if not ckpt_path.exists():\n        raise FileNotFoundError(f\"Checkpoint {ckpt_path} not found\")\n    meta_path = ckpt_path.with_suffix(\".meta.json\")\n    if not meta_path.exists():\n        raise FileNotFoundError(f\"Metadata file {meta_path} missing\")\n    metadata = json.loads(meta_path.read_text())\n    computed_sha = _checksum_path(str(ckpt_path))\n    recorded_sha = metadata.get(\"checkpoint_sha256\")\n    if recorded_sha and computed_sha and recorded_sha != computed_sha:\n        raise ValueError(\n            f\"Checkpoint SHA mismatch: recorded {recorded_sha} vs computed {computed_sha}\"\n        )\n    sha_file = ckpt_path.with_suffix(\".sha256\")\n    if sha_file.exists() and computed_sha:\n        recorded_line = sha_file.read_text().strip().split()\n        if recorded_line:\n            recorded = recorded_line[0]\n            if recorded != computed_sha:\n                raise ValueError(f\".sha256 mismatch: {recorded} vs {computed_sha}\")\n    config_path = ckpt_path.with_suffix(\".yaml\")\n    if not config_path.exists():\n        raise FileNotFoundError(f\"Config file {config_path} missing\")\n    config_hash = sha256(config_path.read_text().encode(\"utf-8\")).hexdigest()\n    recorded_cfg_hash = metadata.get(\"config_sha256\")\n    if recorded_cfg_hash and recorded_cfg_hash != config_hash:\n        raise ValueError(\n            f\"Config SHA mismatch: recorded {recorded_cfg_hash} vs computed {config_hash}\"\n        )\n    if \"rng_states\" not in metadata:\n        raise ValueError(\"Metadata missing rng_states\")\n    return metadata\n\n\ndef _capture_rng_states() -> Dict[str, object]:\n    payload: Dict[str, object] = {\n        \"python\": _encode_pickle(random.getstate()),\n        \"numpy\": _encode_pickle(np.random.get_state()),\n        \"torch\": _tensor_state_to_hex(torch.random.get_rng_state()),\n    }\n    if torch.cuda.is_available():\n        payload[\"torch_cuda\"] = [\n            _tensor_state_to_hex(state) for state in torch.cuda.get_rng_state_all()\n        ]  # type: ignore[attr-defined]\n    return payload\n\n\ndef _encode_pickle(obj: object) -> str:\n    return base64.b64encode(pickle.dumps(obj)).decode(\"ascii\")\n\n\ndef _tensor_state_to_hex(state: torch.Tensor) -> str:\n    return state.cpu().numpy().tobytes().hex()\n\n\ndef _seed_everything(seed: int, *, deterministic: bool = False) -> None:\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    if deterministic:\n        torch.use_deterministic_algorithms(True, warn_only=True)\n        if hasattr(torch.backends, \"cudnn\"):\n            torch.backends.cudnn.benchmark = False  # type: ignore[attr-defined]\n            torch.backends.cudnn.deterministic = True  # type: ignore[attr-defined]\n    else:\n        torch.use_deterministic_algorithms(False)\n        if hasattr(torch.backends, \"cudnn\"):\n            torch.backends.cudnn.benchmark = True  # type: ignore[attr-defined]\n            torch.backends.cudnn.deterministic = False  # type: ignore[attr-defined]\n\n\ndef _make_worker_init_fn(base_seed: int):\n    def _init_fn(worker_id: int) -> None:\n        worker_seed = base_seed + worker_id\n        np.random.seed(worker_seed)\n        random.seed(worker_seed)\n        torch.manual_seed(worker_seed)\n\n    return _init_fn\n"
  },
  {
    "path": "src/nested_learning/transformer.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\n\nimport torch\nfrom torch import nn\n\nfrom .backbones import AttentionConfig, SelfAttention\nfrom .fast_state import AttentionKVCache\n\n\n@dataclass\nclass TransformerBlockConfig:\n    dim: int\n    heads: int\n    mlp_hidden_multiplier: int = 4\n    activation: str = \"gelu\"\n    qk_l2_norm: bool = False\n    local_conv_window: int | None = None\n\n\nclass FeedForward(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        *,\n        hidden_multiplier: int = 4,\n        activation: str = \"gelu\",\n    ) -> None:\n        super().__init__()\n        hidden = dim * hidden_multiplier\n        if activation == \"relu\":\n            act: nn.Module = nn.ReLU()\n        elif activation == \"silu\":\n            act = nn.SiLU()\n        else:\n            act = nn.GELU()\n        self.norm = nn.LayerNorm(dim)\n        self.net = nn.Sequential(\n            nn.Linear(dim, hidden, bias=False),\n            act,\n            nn.Linear(hidden, dim, bias=False),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]\n        residual = x\n        x = self.norm(x)\n        return residual + self.net(x)\n\n\nclass TransformerBlock(nn.Module):\n    \"\"\"\n    Baseline Transformer block: Attention -> MLP (no TITAN/CMS learning updates).\n\n    This is used for Phase 2 comparisons (HOPE-Attention vs standard Transformer).\n    \"\"\"\n\n    def __init__(self, config: TransformerBlockConfig) -> None:\n        super().__init__()\n        self.config = config\n        self.attn = SelfAttention(\n            AttentionConfig(\n                dim=config.dim,\n                heads=config.heads,\n                qk_l2_norm=config.qk_l2_norm,\n                local_conv_window=config.local_conv_window,\n            )\n        )\n        self.mlp = FeedForward(\n            config.dim,\n            hidden_multiplier=config.mlp_hidden_multiplier,\n            activation=config.activation,\n        )\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        *,\n        teach_signal: torch.Tensor | None = None,\n        surprise_value: float | None = None,\n        fast_state=None,\n        finalize_updates: bool = True,\n        attention_cache: AttentionKVCache | None = None,\n        return_attention_cache: bool = False,\n        differentiable_updates: bool = False,\n    ) -> torch.Tensor | tuple[torch.Tensor, AttentionKVCache]:\n        _ = (teach_signal, surprise_value, fast_state, finalize_updates, differentiable_updates)\n        if return_attention_cache:\n            attn_out, next_cache = self.attn(\n                x,\n                kv_cache=attention_cache,\n                return_kv_cache=True,\n            )\n            return self.mlp(attn_out), next_cache\n        return self.mlp(self.attn(x, kv_cache=attention_cache))\n\n    def set_surprise_threshold(self, threshold: float | None) -> None:\n        _ = threshold\n\n    def set_surprise_metric(self, metric: str) -> None:\n        _ = metric\n\n    def set_allowed_levels(self, allowed) -> None:\n        _ = allowed\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "import sys\nfrom pathlib import Path\n\nROOT = Path(__file__).resolve().parents[1]\nSRC = ROOT / \"src\"\nif str(SRC) not in sys.path:\n    sys.path.insert(0, str(SRC))\n"
  },
  {
    "path": "tests/data/passkey_corpus.txt",
    "content": "this is a tiny corpus for tokenizer smoke tests.\nanother sentence for the tokenizer.\npasskey models should tokenize prompts about nested learning hope titan cms.\n"
  },
  {
    "path": "tests/data/tiny_tokenizer.vocab",
    "content": "<unk>\t0\n<s>\t0\n</s>\t0\n▁t\t-0\nen\t-1\nok\t-2\ner\t-3\niz\t-4\nor\t-5\n▁a\t-6\n▁s\t-7\neniz\t-8\n▁tok\t-9\n▁tokeniz\t-10\nes\t-11\nho\t-12\nin\t-13\nis\t-14\nts\t-15\n▁c\t-16\n▁f\t-17\n▁p\t-18\n▁th\t-19\n▁for\t-20\n▁tokenizer\t-21\nan\t-22\nar\t-23\nas\t-24\nbo\t-25\nce\t-26\nde\t-27\ned\t-28\ney\t-29\nit\t-30\nld\t-31\nle\t-32\nls\t-33\nmo\t-34\nmp\t-35\nms\t-36\nno\t-37\npe\t-38\npu\t-39\nro\t-40\nsk\t-41\nth\t-42\nut\t-43\n▁n\t-44\narn\t-45\nent\t-46\nest\t-47\nhou\t-48\ning\t-49\niny\t-50\nmok\t-51\npus\t-52\n▁ho\t-53\n▁is\t-54\n▁le\t-55\n▁mo\t-56\nassk\t-57\nbout\t-58\ndels\t-59\nence\t-60\nests\t-61\nitan\t-62\nmoke\t-63\nmpts\t-64\nnoth\t-65\n▁cms\t-66\n▁cor\t-67\n▁pro\t-68\n▁the\t-69\nested\t-70\nhould\t-71\n▁hope\t-72\n▁sent\t-73\n▁this\t-74\n▁tiny\t-75\narning\t-76\nasskey\t-77\nnother\t-78\n▁about\t-79\n▁smoke\t-80\n▁tests\t-81\n▁titan\t-82\n▁corpus\t-83\n▁models\t-84\n▁nested\t-85\n▁should\t-86\n▁another\t-87\n▁passkey\t-88\n▁prompts\t-89\n▁learning\t-90\n▁sentence\t-91\n▁tokenize\t-92\nab\t-93\ncm\t-94\nco\t-95\nea\t-96\nel\t-97\nfo\t-98\nhe\t-99\nhi\t-100\nke\t-101\n▁\t-102\ne\t-103\nt\t-104\ns\t-105\no\t-106\nn\t-107\ni\t-108\nr\t-109\na\t-110\nh\t-111\nk\t-112\np\t-113\nm\t-114\n.\t-115\nc\t-116\nd\t-117\nl\t-118\nu\t-119\nz\t-120\nf\t-121\ny\t-122\nb\t-123\ng\t-124\n"
  },
  {
    "path": "tests/test_algorithm_mode_grad.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.optim.manager import LevelConfig, LevelOptimizerManager\n\n\ndef _manager() -> LevelOptimizerManager:\n    spec = LevelSpec(name=\"cms_fast\", update_period=1, optimizer_key=\"default\")\n    cfg = LevelConfig(\n        specs=(spec,),\n        optimizer_configs={\"default\": {\"type\": \"deep_momentum\", \"params\": {\"variant\": \"basic\"}}},\n        default_lr=0.1,\n    )\n    return LevelOptimizerManager(cfg)\n\n\ndef test_apply_grads_differentiable_preserves_gradient_path() -> None:\n    mgr = _manager()\n    base = torch.randn(4, requires_grad=True)\n    loss = (base**2).sum()\n    (grad,) = torch.autograd.grad(loss, (base,), create_graph=True)\n    updated, _ = mgr.apply_grads(\n        \"cms_fast\",\n        {\"w\": base},\n        {\"w\": grad},\n        force=True,\n        differentiable=True,\n    )\n    downstream = (updated[\"w\"] ** 2).sum()\n    downstream.backward()\n    assert base.grad is not None\n    assert float(base.grad.abs().sum().item()) > 0.0\n\n\ndef test_apply_grads_nondifferentiable_breaks_gradient_path() -> None:\n    mgr = _manager()\n    base = torch.randn(4, requires_grad=True)\n    loss = (base**2).sum()\n    (grad,) = torch.autograd.grad(loss, (base,), create_graph=True)\n    updated, _ = mgr.apply_grads(\n        \"cms_fast\",\n        {\"w\": base},\n        {\"w\": grad},\n        force=True,\n        differentiable=False,\n    )\n    assert updated[\"w\"].requires_grad is False\n"
  },
  {
    "path": "tests/test_attention_cache.py",
    "content": "import pytest\nimport torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.transformer import TransformerBlock, TransformerBlockConfig\n\n\ndef _build_transformer_model() -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=2,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=2),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"transformer\",\n        qk_l2_norm=False,\n    )\n    model = HOPEModel(cfg).eval()\n    return model\n\n\ndef test_attention_cache_chunked_logits_match_full_logits() -> None:\n    torch.manual_seed(0)\n    model = _build_transformer_model()\n    tokens = torch.randint(0, 64, (1, 11))\n    with torch.no_grad():\n        full = model(tokens)\n        cache = model.init_attention_cache()\n        pieces = []\n        for start, end in ((0, 3), (3, 7), (7, 11)):\n            chunk = tokens[:, start:end]\n            logits_chunk, cache = model(\n                chunk,\n                attention_cache=cache,\n                return_attention_cache=True,\n            )\n            pieces.append(logits_chunk)\n        stitched = torch.cat(pieces, dim=1)\n    assert torch.allclose(full, stitched, atol=1e-5, rtol=1e-5)\n\n\ndef test_attention_cache_reset_changes_continuation_state() -> None:\n    torch.manual_seed(1)\n    model = _build_transformer_model()\n    tokens = torch.randint(0, 64, (1, 8))\n    prefix = tokens[:, :4]\n    suffix = tokens[:, 4:]\n    with torch.no_grad():\n        cache = model.init_attention_cache()\n        _, cache = model(prefix, attention_cache=cache, return_attention_cache=True)\n        carried, _ = model(suffix, attention_cache=cache, return_attention_cache=True)\n        fresh = model(suffix)\n    # Carrying cache should generally differ from a fresh-only suffix pass.\n    assert not torch.allclose(carried, fresh)\n\n\ndef test_transformer_block_rejects_kv_cache_with_local_conv() -> None:\n    block = TransformerBlock(\n        TransformerBlockConfig(\n            dim=16,\n            heads=4,\n            local_conv_window=4,\n        )\n    )\n    x = torch.randn(1, 3, 16)\n    # Build a minimal cache tensor matching [B, H, T, D].\n    k = torch.randn(1, 4, 2, 4)\n    v = torch.randn(1, 4, 2, 4)\n    from nested_learning.fast_state import AttentionKVCache\n\n    with pytest.raises(RuntimeError, match=\"local_conv_window\"):\n        block(x, attention_cache=AttentionKVCache(key=k, value=v))\n"
  },
  {
    "path": "tests/test_attention_features.py",
    "content": "import torch\n\nfrom nested_learning.backbones import AttentionConfig, SelfAttention\n\n\ndef test_self_attention_qk_l2_norm_unit_vectors() -> None:\n    attn = SelfAttention(AttentionConfig(dim=16, heads=4, qk_l2_norm=True, use_flash=False))\n    x = torch.randn(2, 5, 16)\n    q, k, _v = attn._compute_qkv(x)\n    q_norm = q.norm(dim=-1)\n    k_norm = k.norm(dim=-1)\n    assert torch.allclose(q_norm, torch.ones_like(q_norm), atol=1e-4, rtol=1e-4)\n    assert torch.allclose(k_norm, torch.ones_like(k_norm), atol=1e-4, rtol=1e-4)\n\n\ndef test_self_attention_local_conv_window_preserves_shape() -> None:\n    attn = SelfAttention(AttentionConfig(dim=16, heads=4, local_conv_window=4, use_flash=False))\n    assert attn.local_conv is not None\n    assert attn.local_conv.kernel_size == (4,)\n    x = torch.randn(2, 8, 16)\n    out = attn(x)\n    assert out.shape == x.shape\n\n\ndef test_self_attention_local_conv_is_causal() -> None:\n    torch.manual_seed(0)\n    dim = 4\n    attn = SelfAttention(\n        AttentionConfig(dim=dim, heads=2, local_conv_window=4, use_flash=False, dropout=0.0)\n    ).eval()\n    assert attn.local_conv is not None\n    with torch.no_grad():\n        attn.local_conv.weight.fill_(1.0)\n        eye = torch.eye(dim)\n        attn.qkv.weight.zero_()\n        attn.qkv.weight[:dim].copy_(eye)\n        attn.qkv.weight[dim : 2 * dim].copy_(eye)\n        attn.qkv.weight[2 * dim :].copy_(eye)\n        attn.out_proj.weight.copy_(eye)\n    x1 = torch.randn(1, 8, dim)\n    x2 = x1.clone()\n    x2[:, 4:, :] = torch.randn_like(x2[:, 4:, :])\n    out1 = attn(x1)\n    out2 = attn(x2)\n    assert torch.allclose(out1[:, :4, :], out2[:, :4, :], atol=1e-5, rtol=1e-5)\n"
  },
  {
    "path": "tests/test_boundary_state_mode.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import _compute_layer_teach_signals\n\n\ndef _build_attention_model() -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=1),),\n        block_variant=\"hope_attention\",\n        cms_flush_partial_at_end=True,\n    )\n    return HOPEModel(cfg).train()\n\n\ndef _two_chunk_grad_norm(*, differentiable_updates: bool) -> float:\n    torch.manual_seed(0)\n    model = _build_attention_model()\n    state = model.init_fast_state()\n    tokens = torch.randint(0, 64, (1, 6))\n    chunk1 = tokens[:, :3]\n    chunk2 = tokens[:, 3:]\n\n    logits1, _pre, block_outputs = model.forward_with_block_outputs(\n        chunk1,\n        fast_state=state,\n    )\n    targets1 = torch.cat([chunk1[:, 1:], chunk2[:, :1]], dim=1)\n    loss1 = F.cross_entropy(\n        logits1[:, : targets1.size(1), :].reshape(-1, logits1.size(-1)),\n        targets1.reshape(-1),\n    )\n    teach_signals = _compute_layer_teach_signals(\n        loss1,\n        block_outputs,\n        detach=not differentiable_updates,\n        create_graph=differentiable_updates,\n    )\n\n    _ = model(\n        chunk1,\n        teach_signals=teach_signals,\n        fast_state=state,\n        finalize_updates=False,\n        differentiable_updates=differentiable_updates,\n    )\n\n    logits2 = model(chunk2, fast_state=state)\n    loss2 = F.cross_entropy(\n        logits2[:, :-1].reshape(-1, logits2.size(-1)),\n        chunk2[:, 1:].reshape(-1),\n    )\n    grad = torch.autograd.grad(loss2, block_outputs[0], allow_unused=True)[0]\n    if grad is None:\n        return 0.0\n    return float(grad.detach().norm().item())\n\n\ndef test_boundary_state_grad_mode_propagates_across_write_path() -> None:\n    assert _two_chunk_grad_norm(differentiable_updates=True) > 0.0\n\n\ndef test_stopgrad_mode_blocks_boundary_state_grad_path() -> None:\n    assert _two_chunk_grad_norm(differentiable_updates=False) == 0.0\n"
  },
  {
    "path": "tests/test_boundary_state_training_loop.py",
    "content": "from __future__ import annotations\n\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import run_training_loop\n\n\ndef _tiny_boundary_state_cfg():\n    return OmegaConf.create(\n        {\n            \"model\": {\n                \"type\": \"hope\",\n                \"vocab_size\": 64,\n                \"dim\": 16,\n                \"num_layers\": 1,\n                \"heads\": 2,\n                \"block_variant\": \"hope_attention\",\n                \"titan_level\": {\"name\": \"titan\", \"update_period\": 1},\n                \"cms_levels\": [{\"name\": \"cms_fast\", \"update_period\": 1}],\n                \"surprise_threshold\": None,\n            },\n            \"data\": {\n                \"source\": \"synthetic\",\n                \"vocab_size\": 64,\n                \"seq_len\": 8,\n                \"dataset_size\": 8,\n                \"batch_size\": 1,\n                \"num_workers\": 0,\n            },\n            \"optim\": {\n                \"type\": \"adamw\",\n                \"lr\": 1e-3,\n                \"weight_decay\": 0.0,\n                \"fused\": False,\n                \"param_policy\": \"all\",\n            },\n            \"train\": {\n                \"steps\": 2,\n                \"log_interval\": 1,\n                \"seed\": 0,\n                \"deterministic\": True,\n                \"algorithm_mode\": \"boundary_state_grad_through_write\",\n                \"per_layer_teach_signal\": True,\n                \"online_updates\": True,\n                \"online_chunk_size\": 2,\n                \"online_boundary_targets\": True,\n                \"online_carry_attention_cache\": True,\n                \"use_fast_state\": True,\n                \"strict_streaming_contract\": False,\n                \"fail_if_paper_faithful_disabled\": False,\n                \"mixed_precision\": {\"enabled\": False, \"dtype\": \"bf16\"},\n                \"compile\": {\"enable\": False},\n                \"checkpoint\": {\"enable\": False},\n            },\n            \"logging\": {\"enabled\": False},\n        }\n    )\n\n\ndef test_boundary_state_mode_runs_in_training_loop() -> None:\n    cfg = _tiny_boundary_state_cfg()\n    metrics = run_training_loop(cfg, device=torch.device(\"cpu\"), distributed=False)\n    assert \"loss\" in metrics\n    assert \"teach_signal_norm\" in metrics\n    assert metrics[\"loss\"] == metrics[\"loss\"]  # NaN check\n"
  },
  {
    "path": "tests/test_build_model_from_cfg_selfmod.py",
    "content": "from omegaconf import OmegaConf\n\nfrom nested_learning.hope.block import HOPESelfModBlock\nfrom nested_learning.training import build_model_from_cfg\n\n\ndef test_build_model_from_cfg_plumbs_selfmod_fields() -> None:\n    model_cfg = OmegaConf.create(\n        {\n            \"type\": \"hope\",\n            \"vocab_size\": 32,\n            \"dim\": 16,\n            \"num_layers\": 1,\n            \"heads\": 4,\n            \"titan_level\": {\"name\": \"titan\", \"update_period\": 1, \"optimizer_key\": \"titan_opt\"},\n            \"cms_levels\": [{\"name\": \"cms_fast\", \"update_period\": 1, \"optimizer_key\": \"cms_opt\"}],\n            \"block_variant\": \"hope_selfmod\",\n            \"self_mod_chunk_size\": 3,\n            \"self_mod_chunk_size_memory\": 7,\n            \"self_mod_objective\": \"dot\",\n            \"self_mod_stopgrad_vhat\": False,\n            \"self_mod_use_rank1_precond\": False,\n            \"self_mod_use_alpha\": False,\n            \"self_mod_momentum\": 0.5,\n        }\n    )\n    model = build_model_from_cfg(model_cfg)\n    assert model.config.self_mod_chunk_size == 3\n    assert model.config.self_mod_chunk_size_memory == 7\n    assert model.config.self_mod_objective == \"dot\"\n    assert model.config.self_mod_stopgrad_vhat is False\n    assert model.config.self_mod_use_rank1_precond is False\n    assert model.config.self_mod_use_alpha is False\n    assert abs(model.config.self_mod_momentum - 0.5) < 1e-9\n\n    block = model.blocks[0]\n    assert isinstance(block, HOPESelfModBlock)\n    assert block.selfmod.config.chunk_size_other == 3\n    assert block.selfmod.config.chunk_size_memory == 7\n    assert block.selfmod.config.objective == \"dot\"\n    assert block.selfmod.config.stopgrad_vhat is False\n    assert block.selfmod.config.use_rank1_precond is False\n    assert block.selfmod.config.use_alpha is False\n    assert abs(block.selfmod.config.momentum - 0.5) < 1e-9\n"
  },
  {
    "path": "tests/test_checkpoint_metadata_and_eval_loaders.py",
    "content": "from __future__ import annotations\n\nimport importlib.util\nimport json\nfrom pathlib import Path\n\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import (\n    build_model_from_cfg,\n    verify_checkpoint_integrity,\n    write_checkpoint_metadata,\n)\n\n\ndef _tiny_cfg():\n    return OmegaConf.create(\n        {\n            \"model\": {\n                \"type\": \"hope\",\n                \"vocab_size\": 64,\n                \"dim\": 16,\n                \"num_layers\": 1,\n                \"heads\": 2,\n                \"block_variant\": \"hope_attention\",\n                \"titan_level\": {\"name\": \"titan\", \"update_period\": 1},\n                \"cms_levels\": [{\"name\": \"cms_fast\", \"update_period\": 1}],\n            },\n            \"train\": {\n                \"algorithm_mode\": \"boundary_state_grad_through_write\",\n                \"online_updates\": True,\n                \"online_boundary_targets\": True,\n                \"online_carry_attention_cache\": True,\n                \"use_fast_state\": True,\n            },\n            \"data\": {\n                \"tokenizer_path\": None,\n            },\n        }\n    )\n\n\ndef _load_script_module(script_path: Path):\n    spec = importlib.util.spec_from_file_location(script_path.stem, script_path)\n    assert spec is not None and spec.loader is not None\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)\n    return module\n\n\ndef test_checkpoint_metadata_includes_algorithm_and_online_flags(tmp_path: Path) -> None:\n    cfg = _tiny_cfg()\n    ckpt_path = tmp_path / \"step_000001.pt\"\n    torch.save({\"model\": {}}, ckpt_path)\n    write_checkpoint_metadata(cfg, ckpt_path, step=1)\n\n    metadata_path = ckpt_path.with_suffix(\".meta.json\")\n    metadata = json.loads(metadata_path.read_text())\n    assert metadata[\"algorithm_mode\"] == \"boundary_state_grad_through_write\"\n    assert metadata[\"online_updates\"] is True\n    assert metadata[\"online_boundary_targets\"] is True\n    assert metadata[\"online_carry_attention_cache\"] is True\n    assert metadata[\"use_fast_state\"] is True\n    verified = verify_checkpoint_integrity(ckpt_path)\n    assert verified[\"algorithm_mode\"] == \"boundary_state_grad_through_write\"\n\n\ndef test_eval_loaders_accept_boundary_state_checkpoint(tmp_path: Path) -> None:\n    cfg = _tiny_cfg()\n    model = build_model_from_cfg(cfg.model)\n    ckpt_path = tmp_path / \"model.pt\"\n    torch.save({\"model\": model.state_dict()}, ckpt_path)\n    config_path = tmp_path / \"config.yaml\"\n    config_path.write_text(OmegaConf.to_yaml(cfg))\n\n    root = Path(__file__).resolve().parents[1]\n    script_paths = (\n        root / \"scripts/eval/zeroshot.py\",\n        root / \"scripts/eval/niah.py\",\n        root / \"scripts/eval/passkey.py\",\n        root / \"scripts/eval/pg19_perplexity.py\",\n    )\n    for script_path in script_paths:\n        module = _load_script_module(script_path)\n        loaded = module.load_model(config_path, ckpt_path, torch.device(\"cpu\"))\n        tokens = torch.randint(0, 64, (1, 6))\n        logits = loaded(tokens)\n        assert logits.shape == (1, 6, 64)\n"
  },
  {
    "path": "tests/test_cli_tooling.py",
    "content": "from __future__ import annotations\n\nimport json\nimport os\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nfrom typer.testing import CliRunner\n\nfrom nested_learning.cli import app\nfrom nested_learning.config_utils import compose_config\n\n\ndef test_doctor_json_output() -> None:\n    runner = CliRunner()\n    result = runner.invoke(app, [\"doctor\", \"--json\"])\n    assert result.exit_code == 0, result.stdout\n    payload = json.loads(result.stdout)\n    assert \"python_version\" in payload\n    assert \"torch_version\" in payload\n    assert \"default_device\" in payload\n\n\ndef test_smoke_cpu_passes() -> None:\n    runner = CliRunner()\n    result = runner.invoke(\n        app,\n        [\n            \"smoke\",\n            \"--config-name\",\n            \"pilot_smoke\",\n            \"--device\",\n            \"cpu\",\n            \"--batch-size\",\n            \"1\",\n            \"--seq-len\",\n            \"8\",\n        ],\n    )\n    assert result.exit_code == 0, result.stdout\n    payload = json.loads(result.stdout)\n    assert payload[\"status\"] == \"ok\"\n    assert payload[\"device\"] == \"cpu\"\n    assert payload[\"logits_shape\"][0] == 1\n    assert payload[\"logits_shape\"][1] == 8\n\n\ndef test_smoke_auto_passes() -> None:\n    runner = CliRunner()\n    result = runner.invoke(\n        app,\n        [\n            \"smoke\",\n            \"--config-name\",\n            \"pilot_smoke\",\n            \"--device\",\n            \"auto\",\n            \"--batch-size\",\n            \"1\",\n            \"--seq-len\",\n            \"8\",\n        ],\n    )\n    assert result.exit_code == 0, result.stdout\n    payload = json.loads(result.stdout)\n    assert payload[\"status\"] == \"ok\"\n\n\ndef test_audit_reports_tied_weights() -> None:\n    runner = CliRunner()\n    result = runner.invoke(app, [\"audit\", \"--config-name\", \"pilot_paper_faithful\"])\n    assert result.exit_code == 0, result.stdout\n    payload = json.loads(result.stdout)\n    assert payload[\"status\"] == \"ok\"\n    assert payload[\"has_embed\"] is True\n    assert payload[\"has_lm_head\"] is True\n    assert payload[\"lm_tied_to_embedding\"] is True\n\n\ndef test_train_dry_run_prints_config() -> None:\n    runner = CliRunner()\n    result = runner.invoke(\n        app,\n        [\n            \"train\",\n            \"--config-name\",\n            \"pilot_smoke\",\n            \"--dry-run\",\n            \"--device\",\n            \"cpu\",\n        ],\n    )\n    assert result.exit_code == 0, result.stdout\n    assert \"model:\" in result.stdout\n    assert \"train:\" in result.stdout\n\n\ndef test_compose_config_with_explicit_config_dir(tmp_path: Path) -> None:\n    cfg_dir = tmp_path / \"configs\"\n    cfg_dir.mkdir(parents=True)\n    (cfg_dir / \"mini.yaml\").write_text(\n        \"model:\\n\"\n        \"  type: hope\\n\"\n        \"  vocab_size: 128\\n\"\n        \"  dim: 32\\n\"\n        \"  num_layers: 1\\n\"\n        \"  heads: 2\\n\"\n        \"  titan_level: {name: titan, update_period: 8, optimizer_key: titan_opt}\\n\"\n        \"  cms_levels: []\\n\"\n        \"  optimizers: {}\\n\"\n        \"data: {source: synthetic, vocab_size: 128, seq_len: 8, dataset_size: 8, batch_size: 1}\\n\"\n        \"train: {device: cpu, steps: 1, log_interval: 1}\\n\",\n        encoding=\"utf-8\",\n    )\n    cfg = compose_config(\"mini\", config_dir=cfg_dir)\n    assert cfg.model.vocab_size == 128\n    assert cfg.train.device == \"cpu\"\n\n\ndef test_python_module_entrypoint_help() -> None:\n    repo_root = Path(__file__).resolve().parents[1]\n    env = dict(os.environ)\n    existing = env.get(\"PYTHONPATH\", \"\")\n    env[\"PYTHONPATH\"] = str(repo_root / \"src\") + (os.pathsep + existing if existing else \"\")\n    result = subprocess.run(\n        [sys.executable, \"-m\", \"nested_learning\", \"--help\"],\n        cwd=repo_root,\n        env=env,\n        check=False,\n        capture_output=True,\n        text=True,\n    )\n    assert result.returncode == 0, result.stderr\n    assert \"Nested Learning CLI\" in result.stdout\n"
  },
  {
    "path": "tests/test_cms.py",
    "content": "import torch\n\nfrom nested_learning.cms import CMS\nfrom nested_learning.hope.block import HOPEAttentionBlock, HOPEAttentionBlockConfig\nfrom nested_learning.levels import LevelSpec\n\n\ndef test_cms_forward_preserves_shape() -> None:\n    cms = CMS(\n        dim=16,\n        levels=[LevelSpec(name=\"fast\", update_period=2), LevelSpec(name=\"slow\", update_period=4)],\n    )\n    x = torch.randn(2, 9, 16)\n    out, inputs, outputs = cms(x, return_intermediates=True)\n    assert out.shape == x.shape\n    assert set(inputs.keys()) == {\"fast\", \"slow\"}\n    assert set(outputs.keys()) == {\"fast\", \"slow\"}\n\n\ndef test_cms_can_disable_layernorm() -> None:\n    cms = CMS(\n        dim=16,\n        levels=[LevelSpec(name=\"fast\", update_period=2)],\n        use_layernorm=False,\n    )\n    assert not any(\"net.0\" in name for name, _ in cms.named_parameters())\n    x = torch.randn(2, 9, 16)\n    out = cms(x)\n    assert isinstance(out, torch.Tensor)\n    assert out.shape == x.shape\n\n\ndef test_cms_updates_respect_update_period_tokens() -> None:\n    cfg = HOPEAttentionBlockConfig(\n        dim=16,\n        heads=4,\n        cms_levels=[\n            LevelSpec(name=\"fast\", update_period=2),\n            LevelSpec(name=\"slow\", update_period=4),\n        ],\n        optimizer_configs={},\n    )\n    block = HOPEAttentionBlock(cfg)\n    x = torch.randn(1, 9, 16)\n    teach = torch.randn_like(x)\n    _ = block(x, teach_signal=teach)\n    stats = block.pop_update_stats()\n    assert stats[\"cms.fast\"][\"gate_hit\"] == 4.0\n    assert stats[\"cms.fast\"][\"chunk_tokens\"] == 8.0\n    assert stats[\"cms.slow\"][\"gate_hit\"] == 2.0\n    assert stats[\"cms.slow\"][\"chunk_tokens\"] == 8.0\n\n\ndef test_cms_updates_skip_when_no_signal() -> None:\n    cfg = HOPEAttentionBlockConfig(\n        dim=16,\n        heads=4,\n        cms_levels=[LevelSpec(name=\"fast\", update_period=2)],\n        optimizer_configs={},\n    )\n    block = HOPEAttentionBlock(cfg)\n    x = torch.randn(1, 8, 16)\n    teach = torch.zeros_like(x)\n    _ = block(x, teach_signal=teach)\n    stats = block.pop_update_stats()\n    assert stats == {}\n\n\ndef test_cms_online_updates_affect_later_tokens() -> None:\n    torch.manual_seed(0)\n    cfg_online = HOPEAttentionBlockConfig(\n        dim=16,\n        heads=4,\n        cms_levels=[LevelSpec(name=\"fast\", update_period=2)],\n        optimizer_configs={},\n        cms_online_updates=True,\n    )\n    cfg_offline = HOPEAttentionBlockConfig(\n        dim=16,\n        heads=4,\n        cms_levels=[LevelSpec(name=\"fast\", update_period=2)],\n        optimizer_configs={},\n        cms_online_updates=False,\n    )\n    block_online = HOPEAttentionBlock(cfg_online)\n    block_offline = HOPEAttentionBlock(cfg_offline)\n    x = torch.randn(1, 6, 16)\n    teach = torch.randn_like(x)\n    out_online = block_online(x, teach_signal=teach)\n    out_offline = block_offline(x, teach_signal=teach)\n    assert not torch.allclose(out_online[:, 2:], out_offline[:, 2:])\n"
  },
  {
    "path": "tests/test_cms_cross_call.py",
    "content": "import pytest\nimport torch\n\nfrom nested_learning.fast_state import build_block_fast_state\nfrom nested_learning.hope.block import (\n    HOPEAttentionBlock,\n    HOPEAttentionBlockConfig,\n    HOPEBlock,\n    HOPEBlockConfig,\n    HOPESelfModBlock,\n    HOPESelfModBlockConfig,\n)\nfrom nested_learning.levels import LevelSpec\n\n\ndef _build_variant(variant: str, *, flush_partial: bool):\n    cms_levels = (LevelSpec(name=\"fast\", update_period=4),)\n    if variant == \"attention\":\n        cfg = HOPEAttentionBlockConfig(\n            dim=8,\n            heads=1,\n            cms_levels=cms_levels,\n            cms_online_updates=True,\n            cms_flush_partial_at_end=flush_partial,\n            cms_chunk_reduction=\"sum\",\n        )\n        block = HOPEAttentionBlock(cfg)\n        state = build_block_fast_state(\n            titan_module=None,\n            cms_blocks=dict(block.cms.blocks.items()),\n            specs=cfg.cms_levels,\n            optimizer_configs=cfg.optimizer_configs,\n            default_lr=cfg.self_mod_lr,\n        )\n        return block, state\n    if variant == \"selfmod\":\n        cfg = HOPESelfModBlockConfig(\n            dim=8,\n            cms_levels=cms_levels,\n            cms_online_updates=True,\n            cms_flush_partial_at_end=flush_partial,\n            cms_chunk_reduction=\"sum\",\n            selfmod_chunk_size=1,\n            selfmod_chunk_size_memory=4,\n        )\n        block = HOPESelfModBlock(cfg)\n        state = build_block_fast_state(\n            titan_module=None,\n            cms_blocks=dict(block.cms.blocks.items()),\n            selfmod_module=block.selfmod,\n            specs=cfg.cms_levels,\n            optimizer_configs=cfg.optimizer_configs,\n            default_lr=cfg.self_mod_lr,\n        )\n        return block, state\n    if variant == \"hybrid\":\n        cfg = HOPEBlockConfig(\n            dim=8,\n            heads=1,\n            titan_level=LevelSpec(name=\"titan\", update_period=1),\n            cms_levels=cms_levels,\n            cms_online_updates=True,\n            cms_flush_partial_at_end=flush_partial,\n            cms_chunk_reduction=\"sum\",\n        )\n        block = HOPEBlock(cfg)\n        state = build_block_fast_state(\n            titan_module=block.titan_memory,\n            cms_blocks=dict(block.cms.blocks.items()),\n            specs=(cfg.titan_level, *cfg.cms_levels),\n            optimizer_configs=cfg.optimizer_configs,\n            default_lr=cfg.self_mod_lr,\n        )\n        return block, state\n    raise ValueError(f\"unknown variant {variant}\")\n\n\n@pytest.mark.parametrize(\"variant\", [\"attention\", \"selfmod\", \"hybrid\"])\ndef test_cms_fast_state_buffers_persist_across_calls(variant: str) -> None:\n    torch.manual_seed(0)\n    block, state = _build_variant(variant, flush_partial=False)\n    x = torch.randn(1, 2, 8)\n    teach = torch.ones(1, 2, 8)\n\n    _ = block(x, teach_signal=teach, fast_state=state, finalize_updates=False)\n    first_stats = block.pop_update_stats()\n    assert first_stats[\"cms.fast\"][\"updates_applied\"] == 0.0\n    assert first_stats[\"cms.fast\"][\"pending_tokens\"] == 2.0\n\n    _ = block(x, teach_signal=teach, fast_state=state, finalize_updates=False)\n    payload = block.pop_update_stats()[\"cms.fast\"]\n    assert payload[\"gate_hit\"] == 1.0\n    assert payload[\"chunk_tokens\"] == 4.0\n\n\n@pytest.mark.parametrize(\"variant\", [\"attention\", \"selfmod\", \"hybrid\"])\ndef test_cms_fast_state_flushes_only_on_finalize(variant: str) -> None:\n    torch.manual_seed(0)\n    block, state = _build_variant(variant, flush_partial=True)\n    x3 = torch.randn(1, 3, 8)\n    x1 = torch.randn(1, 1, 8)\n    teach3 = torch.ones(1, 3, 8)\n    teach1 = torch.ones(1, 1, 8)\n\n    _ = block(x3, teach_signal=teach3, fast_state=state, finalize_updates=False)\n    first_stats = block.pop_update_stats()\n    assert first_stats[\"cms.fast\"][\"updates_applied\"] == 0.0\n    assert first_stats[\"cms.fast\"][\"pending_tokens\"] == 3.0\n\n    _ = block(x3, teach_signal=teach3, fast_state=state, finalize_updates=False)\n    payload_mid = block.pop_update_stats()[\"cms.fast\"]\n    assert payload_mid[\"gate_hit\"] == 1.0\n    assert payload_mid[\"chunk_tokens\"] == 4.0\n\n    _ = block(x1, teach_signal=teach1, fast_state=state, finalize_updates=True)\n    payload_final = block.pop_update_stats()[\"cms.fast\"]\n    assert payload_final[\"gate_hit\"] == 1.0\n    assert payload_final[\"chunk_tokens\"] == 3.0\n"
  },
  {
    "path": "tests/test_cms_delta_rule.py",
    "content": "import torch\n\nfrom nested_learning.hope.block import _chunk_loss\n\n\ndef test_cms_target_shift_loss_grad_is_proportional_to_delta() -> None:\n    torch.manual_seed(0)\n    prediction = torch.randn(2, 5, 7, requires_grad=True)\n    delta = torch.randn(2, 5, 7)\n    active = torch.tensor(\n        [\n            [1, 1, 0, 1, 1],\n            [0, 1, 1, 1, 0],\n        ],\n        dtype=torch.float32,\n    )\n    mask_f = active.unsqueeze(-1)\n    loss = _chunk_loss(prediction, delta, mask_f, reduction=\"sum\")\n    loss.backward()\n    assert prediction.grad is not None\n    expected = 2.0 * delta * mask_f\n    assert torch.allclose(prediction.grad, expected, atol=1e-6, rtol=1e-6)\n\n\ndef test_cms_chunk_loss_sum_scales_relative_to_mean() -> None:\n    torch.manual_seed(0)\n    prediction = torch.randn(1, 4, 5, requires_grad=True)\n    delta = torch.randn(1, 4, 5)\n    mask_f = torch.ones(1, 4, 1)\n\n    loss_sum = _chunk_loss(prediction, delta, mask_f, reduction=\"sum\")\n    loss_sum.backward()\n    assert prediction.grad is not None\n    grad_sum = prediction.grad.detach().clone()\n\n    prediction.grad.zero_()\n    loss_mean = _chunk_loss(prediction, delta, mask_f, reduction=\"mean\")\n    loss_mean.backward()\n    assert prediction.grad is not None\n    grad_mean = prediction.grad.detach().clone()\n\n    scale = float(mask_f.sum().item())\n    assert torch.allclose(grad_sum, grad_mean * scale, atol=1e-6, rtol=1e-6)\n"
  },
  {
    "path": "tests/test_cms_flush_partial.py",
    "content": "import torch\n\nfrom nested_learning.fast_state import build_block_fast_state\nfrom nested_learning.hope.block import HOPEAttentionBlock, HOPEAttentionBlockConfig\nfrom nested_learning.levels import LevelSpec\n\n\ndef _run_block(*, flush_partial: bool, use_fast_state: bool) -> dict[str, float]:\n    torch.manual_seed(0)\n    cfg = HOPEAttentionBlockConfig(\n        dim=8,\n        heads=1,\n        cms_levels=(LevelSpec(name=\"fast\", update_period=4),),\n        cms_flush_partial_at_end=flush_partial,\n        cms_online_updates=True,\n        cms_chunk_reduction=\"sum\",\n    )\n    block = HOPEAttentionBlock(cfg)\n    x = torch.randn(1, 6, 8)\n    teach = torch.ones(1, 6, 8)\n    fast_state = None\n    if use_fast_state:\n        fast_state = build_block_fast_state(\n            titan_module=None,\n            cms_blocks=dict(block.cms.blocks.items()),\n            specs=cfg.cms_levels,\n            optimizer_configs=cfg.optimizer_configs,\n            default_lr=cfg.self_mod_lr,\n        )\n    _out = block(x, teach_signal=teach, fast_state=fast_state)\n    stats = block.pop_update_stats()\n    return stats[\"cms.fast\"]\n\n\ndef test_cms_flush_partial_disabled_leaves_remainder_unupdated() -> None:\n    for use_fast_state in (False, True):\n        payload = _run_block(flush_partial=False, use_fast_state=use_fast_state)\n        assert payload[\"gate_hit\"] == 1.0\n        assert payload[\"chunk_tokens\"] == 4.0\n\n\ndef test_cms_flush_partial_enabled_updates_final_remainder() -> None:\n    for use_fast_state in (False, True):\n        payload = _run_block(flush_partial=True, use_fast_state=use_fast_state)\n        assert payload[\"gate_hit\"] == 2.0\n        assert payload[\"chunk_tokens\"] == 6.0\n\n"
  },
  {
    "path": "tests/test_compare_variants_cli.py",
    "content": "import json\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nimport sentencepiece as spm\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import build_model_from_cfg\n\n\ndef _train_tiny_sentencepiece(tmp_path: Path, *, vocab_size: int) -> Path:\n    corpus_path = tmp_path / \"corpus.txt\"\n    corpus_path.write_text(\n        \"\\n\".join(\n            [\n                \"This is a tiny corpus for sentencepiece.\",\n                \"Remember that the secret key is KEY-1234.\",\n                \"Question: What is the passkey? Answer: PASSKEY-1234.\",\n            ]\n        )\n    )\n    model_prefix = tmp_path / \"spm_test\"\n    spm.SentencePieceTrainer.Train(\n        input=str(corpus_path),\n        model_prefix=str(model_prefix),\n        vocab_size=vocab_size,\n        hard_vocab_limit=False,\n        model_type=\"unigram\",\n        bos_id=1,\n        eos_id=2,\n        pad_id=0,\n        unk_id=3,\n        character_coverage=1.0,\n    )\n    return model_prefix.with_suffix(\".model\")\n\n\ndef _write_minimal_model_config(path: Path, *, vocab_size: int, block_variant: str) -> None:\n    payload = {\n        \"model\": {\n            \"vocab_size\": vocab_size,\n            \"dim\": 16,\n            \"num_layers\": 1,\n            \"heads\": 4,\n            \"block_variant\": block_variant,\n            \"titan_level\": {\"name\": \"titan\", \"update_period\": 1},\n            \"cms_levels\": [{\"name\": \"cms_fast\", \"update_period\": 1}],\n        }\n    }\n    path.write_text(OmegaConf.to_yaml(OmegaConf.create(payload)))\n\n\ndef _write_checkpoint(path: Path, config_path: Path) -> None:\n    cfg = OmegaConf.load(config_path)\n    model = build_model_from_cfg(cfg.model)\n    torch.save({\"model\": model.state_dict()}, path)\n\n\ndef test_compare_variants_cli_smoke(tmp_path: Path) -> None:\n    vocab_size = 64\n    spm_model = _train_tiny_sentencepiece(tmp_path, vocab_size=vocab_size)\n\n    config_a = tmp_path / \"a.yaml\"\n    config_b = tmp_path / \"b.yaml\"\n    _write_minimal_model_config(config_a, vocab_size=vocab_size, block_variant=\"transformer\")\n    _write_minimal_model_config(config_b, vocab_size=vocab_size, block_variant=\"transformer\")\n\n    ckpt_a = tmp_path / \"a.pt\"\n    ckpt_b = tmp_path / \"b.pt\"\n    _write_checkpoint(ckpt_a, config_a)\n    _write_checkpoint(ckpt_b, config_b)\n\n    out_path = tmp_path / \"out.json\"\n    cmd = [\n        sys.executable,\n        \"scripts/eval/compare_variants.py\",\n        \"--a-config\",\n        str(config_a),\n        \"--a-checkpoint\",\n        str(ckpt_a),\n        \"--b-config\",\n        str(config_b),\n        \"--b-checkpoint\",\n        str(ckpt_b),\n        \"--tokenizer-path\",\n        str(spm_model),\n        \"--device\",\n        \"cpu\",\n        \"--smoke\",\n        \"--output\",\n        str(out_path),\n    ]\n    completed = subprocess.run(cmd, check=True, capture_output=True, text=True)\n    assert completed.returncode == 0\n    data = json.loads(out_path.read_text())\n    assert \"a\" in data and \"b\" in data\n    assert \"passkey\" in data[\"a\"] and \"niah\" in data[\"a\"]\n    assert \"accuracy_base\" in data[\"a\"][\"passkey\"]\n"
  },
  {
    "path": "tests/test_compile_toggle.py",
    "content": "from __future__ import annotations\n\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import run_training_loop\n\n\ndef _tiny_compile_cfg():\n    return OmegaConf.create(\n        {\n            \"model\": {\n                \"type\": \"hope\",\n                \"vocab_size\": 64,\n                \"dim\": 16,\n                \"num_layers\": 1,\n                \"heads\": 2,\n                \"block_variant\": \"hope_attention\",\n                \"titan_level\": {\"name\": \"titan\", \"update_period\": 1},\n                \"cms_levels\": [{\"name\": \"cms_fast\", \"update_period\": 1}],\n            },\n            \"data\": {\n                \"source\": \"synthetic\",\n                \"vocab_size\": 64,\n                \"seq_len\": 8,\n                \"dataset_size\": 8,\n                \"batch_size\": 1,\n                \"num_workers\": 0,\n            },\n            \"optim\": {\n                \"type\": \"adamw\",\n                \"lr\": 1e-3,\n                \"weight_decay\": 0.0,\n                \"fused\": False,\n                \"param_policy\": \"all\",\n            },\n            \"train\": {\n                \"steps\": 1,\n                \"log_interval\": 1,\n                \"seed\": 0,\n                \"deterministic\": True,\n                \"algorithm_mode\": \"two_pass_stopgrad_updates\",\n                \"per_layer_teach_signal\": True,\n                \"online_updates\": True,\n                \"online_chunk_size\": 2,\n                \"online_boundary_targets\": True,\n                \"online_carry_attention_cache\": True,\n                \"use_fast_state\": True,\n                \"strict_streaming_contract\": False,\n                \"fail_if_paper_faithful_disabled\": False,\n                \"mixed_precision\": {\"enabled\": False, \"dtype\": \"bf16\"},\n                \"compile\": {\"enable\": True, \"strict\": False},\n                \"checkpoint\": {\"enable\": False},\n            },\n            \"logging\": {\"enabled\": False},\n        }\n    )\n\n\ndef test_compile_toggle_smoke_does_not_crash() -> None:\n    cfg = _tiny_compile_cfg()\n    metrics = run_training_loop(cfg, device=torch.device(\"cpu\"), distributed=False)\n    assert \"loss\" in metrics\n"
  },
  {
    "path": "tests/test_compliance_report.py",
    "content": "from __future__ import annotations\n\nimport json\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nfrom omegaconf import OmegaConf\n\n\ndef _run_report(config_path: Path, output_path: Path, repo_root: Path) -> dict:\n    subprocess.run(\n        [\n            sys.executable,\n            str(repo_root / \"scripts/checks/compliance_report.py\"),\n            \"--config\",\n            str(config_path),\n            \"--output\",\n            str(output_path),\n        ],\n        cwd=repo_root,\n        check=True,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        text=True,\n    )\n    return json.loads(output_path.read_text())\n\n\ndef test_compliance_report_includes_algorithm_mode_checks(tmp_path: Path) -> None:\n    repo_root = Path(__file__).resolve().parents[1]\n    report = _run_report(\n        config_path=repo_root / \"configs/pilot.yaml\",\n        output_path=tmp_path / \"report.json\",\n        repo_root=repo_root,\n    )\n    checks = {item[\"name\"] for item in report[\"checks\"]}\n    assert \"algorithm_mode_supported\" in checks\n\n\ndef test_compliance_report_validates_boundary_mode_constraints(tmp_path: Path) -> None:\n    repo_root = Path(__file__).resolve().parents[1]\n    cfg = OmegaConf.load(repo_root / \"configs/pilot.yaml\")\n    cfg.train.algorithm_mode = \"boundary_state_grad_through_write\"\n    cfg.train.per_layer_teach_signal = True\n    cfg.train.online_updates = True\n    cfg.train.online_boundary_targets = True\n    cfg.train.online_carry_attention_cache = True\n    cfg.train.use_fast_state = True\n    cfg.data.batch_size = 1\n    tmp_cfg = tmp_path / \"boundary_config.yaml\"\n    tmp_cfg.write_text(OmegaConf.to_yaml(cfg), encoding=\"utf-8\")\n    report = _run_report(\n        config_path=tmp_cfg,\n        output_path=tmp_path / \"boundary_report.json\",\n        repo_root=repo_root,\n    )\n    by_name = {item[\"name\"]: item for item in report[\"checks\"]}\n    assert by_name[\"algorithm_mode_supported\"][\"ok\"] is True\n    assert by_name[\"boundary_algorithm_mode_constraints\"][\"ok\"] is True\n"
  },
  {
    "path": "tests/test_continual_classification.py",
    "content": "from pathlib import Path\n\nimport sentencepiece as spm\nimport torch\n\nfrom nested_learning.continual_classification import ClassificationExample\nfrom nested_learning.continual_streaming import (\n    ContinualEvalConfig,\n    build_streaming_tasks,\n    evaluate_continual_classification,\n)\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.memorize import MemorizeConfig\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.tokenizer import SentencePieceTokenizer\n\n\ndef _train_tiny_sentencepiece(tmp_path: Path, *, vocab_size: int) -> Path:\n    corpus_path = tmp_path / \"corpus.txt\"\n    corpus_path.write_text(\n        \"\\n\".join(\n            [\n                \"Text: hello world Label: A\",\n                \"Text: goodbye world Label: B\",\n                \"Text: foo bar Label: C\",\n                \"Text: baz qux Label: D\",\n            ]\n        )\n    )\n    model_prefix = tmp_path / \"spm_continual\"\n    spm.SentencePieceTrainer.Train(\n        input=str(corpus_path),\n        model_prefix=str(model_prefix),\n        vocab_size=vocab_size,\n        hard_vocab_limit=False,\n        model_type=\"unigram\",\n        bos_id=1,\n        eos_id=2,\n        pad_id=0,\n        unk_id=3,\n        character_coverage=1.0,\n    )\n    return model_prefix.with_suffix(\".model\")\n\n\ndef _tiny_transformer_model(vocab_size: int) -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=vocab_size,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=[LevelSpec(name=\"cms_fast\", update_period=1)],\n        block_variant=\"transformer\",\n    )\n    return HOPEModel(cfg).eval()\n\n\ndef _toy_examples() -> list[ClassificationExample]:\n    examples = []\n    for label in [\"A\", \"B\", \"C\", \"D\"]:\n        for idx in range(3):\n            examples.append(ClassificationExample(text=f\"example {idx} for {label}\", label=label))\n    return examples\n\n\ndef test_build_streaming_tasks_balanced_split() -> None:\n    cfg = ContinualEvalConfig(task_size=2, seed=0, train_per_label=2, eval_per_label=1)\n    tasks = build_streaming_tasks(_toy_examples(), cfg=cfg)\n    assert len(tasks) == 2\n    for task in tasks:\n        assert len(task.labels) == 2\n        assert len(task.train) == 4\n        assert len(task.eval) == 2\n\n\ndef test_evaluate_continual_classification_runs(tmp_path: Path) -> None:\n    vocab_size = 64\n    spm_model = _train_tiny_sentencepiece(tmp_path, vocab_size=vocab_size)\n    tokenizer = SentencePieceTokenizer(spm_model)\n    model = _tiny_transformer_model(vocab_size)\n\n    eval_cfg = ContinualEvalConfig(task_size=2, seed=0, train_per_label=2, eval_per_label=1)\n    tasks = build_streaming_tasks(_toy_examples(), cfg=eval_cfg)\n\n    memorize_cfg = MemorizeConfig(enabled=False)\n    result, meta = evaluate_continual_classification(\n        model,\n        tokenizer,\n        tasks,\n        torch.device(\"cpu\"),\n        cfg=eval_cfg,\n        memorize_cfg=memorize_cfg,\n    )\n    assert len(result.task_accuracy_matrix) == len(tasks)\n    assert len(result.task_accuracy_matrix[0]) == len(tasks)\n    assert 0.0 <= result.avg_accuracy_final <= 1.0\n    assert \"task_size\" in meta\n\n\ndef test_evaluate_continual_classification_with_memorize_fast_state(tmp_path: Path) -> None:\n    vocab_size = 64\n    spm_model = _train_tiny_sentencepiece(tmp_path, vocab_size=vocab_size)\n    tokenizer = SentencePieceTokenizer(spm_model)\n    model = _tiny_transformer_model(vocab_size)\n\n    eval_cfg = ContinualEvalConfig(task_size=2, seed=0, train_per_label=2, eval_per_label=1)\n    tasks = build_streaming_tasks(_toy_examples(), cfg=eval_cfg)\n\n    memorize_cfg = MemorizeConfig(enabled=True, steps=1, reset=False, use_fast_state=True)\n    result, _meta = evaluate_continual_classification(\n        model,\n        tokenizer,\n        tasks,\n        torch.device(\"cpu\"),\n        cfg=eval_cfg,\n        memorize_cfg=memorize_cfg,\n    )\n    assert len(result.per_task_forgetting) == len(tasks)\n"
  },
  {
    "path": "tests/test_continual_eval_state_mode.py",
    "content": "import importlib.util\nfrom pathlib import Path\n\nimport torch\nfrom torch.utils.data import DataLoader, Dataset\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.memorize import MemorizeConfig\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef _load_evaluate_segment():\n    root = Path(__file__).resolve().parents[1]\n    script_path = root / \"scripts\" / \"eval\" / \"continual.py\"\n    spec = importlib.util.spec_from_file_location(\"tests.continual_eval_script\", script_path)\n    if spec is None or spec.loader is None:\n        raise RuntimeError(f\"Failed to load {script_path}\")\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)\n    return module.evaluate_segment\n\n\nevaluate_segment = _load_evaluate_segment()\n\n\nclass _TokenDataset(Dataset):\n    def __init__(self) -> None:\n        self.samples = [torch.randint(0, 32, (12,)) for _ in range(6)]\n\n    def __len__(self) -> int:\n        return len(self.samples)\n\n    def __getitem__(self, idx: int) -> torch.Tensor:\n        return self.samples[idx]\n\n\ndef _build_model() -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=8,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=2),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"transformer\",\n    )\n    return HOPEModel(cfg)\n\n\ndef test_continual_eval_state_modes_run_without_errors() -> None:\n    torch.manual_seed(0)\n    model = _build_model()\n    dataset = _TokenDataset()\n    loader = DataLoader(dataset, batch_size=2, shuffle=False)\n    mem_cfg = MemorizeConfig(enabled=False)\n\n    base_reset, mem_reset, _stats_reset = evaluate_segment(\n        model,\n        loader,\n        torch.device(\"cpu\"),\n        max_batches=2,\n        memorize_cfg=mem_cfg,\n        eval_state_mode=\"reset_per_sample\",\n        eval_use_fast_state=False,\n        eval_use_attention_cache=True,\n    )\n    base_carry, mem_carry, _stats_carry = evaluate_segment(\n        model,\n        loader,\n        torch.device(\"cpu\"),\n        max_batches=2,\n        memorize_cfg=mem_cfg,\n        eval_state_mode=\"carry_across_samples\",\n        eval_use_fast_state=False,\n        eval_use_attention_cache=True,\n    )\n\n    assert torch.isfinite(torch.tensor(base_reset))\n    assert torch.isfinite(torch.tensor(mem_reset))\n    assert torch.isfinite(torch.tensor(base_carry))\n    assert torch.isfinite(torch.tensor(mem_carry))\n"
  },
  {
    "path": "tests/test_data_scripts_help.py",
    "content": "from __future__ import annotations\n\nimport subprocess\nfrom pathlib import Path\n\n\ndef test_data_scripts_help_smoke() -> None:\n    repo_root = Path(__file__).resolve().parents[1]\n    subprocess.run(\n        [\"bash\", \"scripts/checks/check_data_script_help.sh\"],\n        cwd=repo_root,\n        check=True,\n    )\n"
  },
  {
    "path": "tests/test_data_split_fallbacks.py",
    "content": "from __future__ import annotations\n\nimport io\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nsys.path.append(str(Path(__file__).resolve().parents[1]))\n\nfrom scripts.data import filter_corpus, shard_corpus, train_tokenizer\n\n\ndef test_train_tokenizer_manifest_supports_text_data_files(tmp_path: Path) -> None:\n    corpus = tmp_path / \"corpus.txt\"\n    corpus.write_text(\"hello world\\nthis is a test\\nanother line\\n\", encoding=\"utf-8\")\n    manifest = tmp_path / \"manifest.yaml\"\n    manifest.write_text(\n        \"\\n\".join(\n            [\n                \"datasets:\",\n                \"  - name: local\",\n                \"    dataset: text\",\n                \"    split: train\",\n                \"    text_column: text\",\n                f\"    data_files: {corpus}\",\n                \"    sample_limit: 10\",\n                \"\",\n            ]\n        ),\n        encoding=\"utf-8\",\n    )\n    specs = train_tokenizer._load_specs_from_manifest(manifest)  # noqa: SLF001\n    assert len(specs) == 1\n    assert specs[0].dataset == \"text\"\n    assert specs[0].split == \"train\"\n    assert specs[0].data_files == str(corpus)\n    buf = io.StringIO()\n    count = train_tokenizer._write_samples(specs[0], buf)  # noqa: SLF001\n    assert count == 3\n\n\ndef test_shard_corpus_accepts_text_data_files_with_train_split(tmp_path: Path) -> None:\n    corpus = tmp_path / \"corpus.txt\"\n    corpus.write_text((\"hello world \" * 100).strip() + \"\\n\", encoding=\"utf-8\")\n    out_dir = tmp_path / \"shards\"\n    cfg = shard_corpus.ShardConfig(\n        name=\"local\",\n        dataset=\"text\",\n        split=\"train\",\n        subset=None,\n        text_column=\"text\",\n        tokenizer_path=Path(\"tests/data/tiny_tokenizer.model\"),\n        seq_len=4,\n        sequences_per_shard=2,\n        output_dir=out_dir,\n        eos_id=-1,\n        max_records=10,\n        data_files=str(corpus),\n    )\n    stats = shard_corpus.shard_dataset(cfg)\n    assert stats[\"records\"] > 0\n    assert stats[\"sequences\"] > 0\n    assert stats[\"shards\"] > 0\n    assert list(out_dir.glob(\"shard_*.npy\"))\n\n\ndef test_train_tokenizer_allows_small_corpus_with_no_hard_vocab_limit(tmp_path: Path) -> None:\n    corpus = tmp_path / \"corpus.txt\"\n    corpus.write_text((\"hello world\\n\" * 20).strip() + \"\\n\", encoding=\"utf-8\")\n    manifest = tmp_path / \"manifest.yaml\"\n    manifest.write_text(\n        \"\\n\".join(\n            [\n                \"datasets:\",\n                \"  - name: local\",\n                \"    dataset: text\",\n                \"    split: train\",\n                \"    text_column: text\",\n                f\"    data_files: {corpus}\",\n                \"    sample_limit: 50\",\n                \"\",\n            ]\n        ),\n        encoding=\"utf-8\",\n    )\n    out_dir = tmp_path / \"tokenizer\"\n    log_file = tmp_path / \"tokenizer_log.json\"\n    repo_root = Path(__file__).resolve().parents[1]\n    subprocess.run(\n        [\n            sys.executable,\n            str(repo_root / \"scripts/data/train_tokenizer.py\"),\n            \"--manifest\",\n            str(manifest),\n            \"--vocab-size\",\n            \"1000\",\n            \"--model-type\",\n            \"unigram\",\n            \"--output-dir\",\n            str(out_dir),\n            \"--log-file\",\n            str(log_file),\n            \"--no-hard-vocab-limit\",\n        ],\n        check=True,\n        cwd=repo_root,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        text=True,\n    )\n    assert (out_dir / \"spm_1000_unigram.model\").exists()\n    assert log_file.exists()\n\n\ndef test_split_fallback_prefers_validation_then_test() -> None:\n    available = [\"test\", \"validation\"]\n    assert train_tokenizer._select_fallback_split(available) == \"validation\"  # noqa: SLF001\n    assert shard_corpus._select_fallback_split(available) == \"validation\"  # noqa: SLF001\n    assert filter_corpus._select_fallback_split(available) == \"validation\"  # noqa: SLF001\n\n\ndef test_split_fallback_uses_first_when_no_standard_split() -> None:\n    available = [\"dev\", \"holdout\"]\n    assert train_tokenizer._select_fallback_split(available) == \"dev\"  # noqa: SLF001\n    assert shard_corpus._select_fallback_split(available) == \"dev\"  # noqa: SLF001\n    assert filter_corpus._select_fallback_split(available) == \"dev\"  # noqa: SLF001\n"
  },
  {
    "path": "tests/test_determinism_seed.py",
    "content": "import random\n\nimport numpy as np\nimport torch\n\nfrom nested_learning.training import _seed_everything\n\n\ndef test_seed_everything_reproducible_python_numpy_torch() -> None:\n    _seed_everything(1234, deterministic=False)\n    a = (\n        random.random(),\n        float(np.random.rand()),\n        float(torch.rand(1).item()),\n    )\n    _seed_everything(1234, deterministic=False)\n    b = (\n        random.random(),\n        float(np.random.rand()),\n        float(torch.rand(1).item()),\n    )\n    assert a == b\n\n\ndef test_seed_everything_toggles_deterministic_algorithms() -> None:\n    prev_flag = torch.are_deterministic_algorithms_enabled()\n    has_cudnn = hasattr(torch.backends, \"cudnn\")\n    prev_benchmark = None\n    prev_deterministic = None\n    if has_cudnn:\n        prev_benchmark = bool(torch.backends.cudnn.benchmark)  # type: ignore[attr-defined]\n        prev_deterministic = bool(torch.backends.cudnn.deterministic)  # type: ignore[attr-defined]\n    try:\n        _seed_everything(1, deterministic=True)\n        assert torch.are_deterministic_algorithms_enabled()\n        _seed_everything(1, deterministic=False)\n        assert not torch.are_deterministic_algorithms_enabled()\n    finally:\n        torch.use_deterministic_algorithms(prev_flag)\n        if has_cudnn and prev_benchmark is not None and prev_deterministic is not None:\n            torch.backends.cudnn.benchmark = prev_benchmark  # type: ignore[attr-defined]\n            torch.backends.cudnn.deterministic = prev_deterministic  # type: ignore[attr-defined]\n"
  },
  {
    "path": "tests/test_device_resolution.py",
    "content": "import torch\n\nfrom nested_learning.device import resolve_device\n\n\ndef test_resolve_device_mps_falls_back_when_unavailable() -> None:\n    device = resolve_device(\"mps\")\n    mps_available = hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available()\n    assert device.type == (\"mps\" if mps_available else \"cpu\")\n\n"
  },
  {
    "path": "tests/test_distributed_fail_fast.py",
    "content": "import pytest\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import _validate_distributed_config\n\n\ndef test_fail_if_paper_faithful_disabled_blocks_ddp_per_layer_teach() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"fail_if_paper_faithful_disabled\": True,\n                \"per_layer_teach_signal\": True,\n                \"online_updates\": False,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"per_layer_teach_signal\"):\n        _validate_distributed_config(cfg, distributed=True)\n\n\ndef test_fail_if_paper_faithful_disabled_blocks_ddp_online_updates() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"fail_if_paper_faithful_disabled\": True,\n                \"per_layer_teach_signal\": False,\n                \"online_updates\": True,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"online_updates\"):\n        _validate_distributed_config(cfg, distributed=True)\n\n\ndef test_fail_if_paper_faithful_disabled_allows_single_process() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"fail_if_paper_faithful_disabled\": True,\n                \"per_layer_teach_signal\": True,\n                \"online_updates\": True,\n            }\n        }\n    )\n    _validate_distributed_config(cfg, distributed=False)\n\n\ndef test_strict_streaming_contract_blocks_ddp_online_features() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"strict_streaming_contract\": True,\n                \"fail_if_paper_faithful_disabled\": False,\n                \"per_layer_teach_signal\": True,\n                \"online_updates\": False,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"strict_streaming_contract\"):\n        _validate_distributed_config(cfg, distributed=True)\n\n\ndef test_fail_if_paper_faithful_disabled_blocks_ddp_boundary_targets() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"fail_if_paper_faithful_disabled\": True,\n                \"per_layer_teach_signal\": False,\n                \"online_updates\": False,\n                \"online_boundary_targets\": True,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"online_boundary_targets\"):\n        _validate_distributed_config(cfg, distributed=True)\n\n\ndef test_fail_if_paper_faithful_disabled_blocks_ddp_attention_cache_carry() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"fail_if_paper_faithful_disabled\": True,\n                \"per_layer_teach_signal\": False,\n                \"online_updates\": False,\n                \"online_carry_attention_cache\": True,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"online_carry_attention_cache\"):\n        _validate_distributed_config(cfg, distributed=True)\n"
  },
  {
    "path": "tests/test_eval_builders.py",
    "content": "import sys\nfrom pathlib import Path\n\nsys.path.append(str(Path(__file__).resolve().parents[1]))\n\nfrom scripts.eval import zeroshot\n\n\ndef test_commonsenseqa_builder() -> None:\n    sample = {\n        \"question\": \"Where would you most likely find a revolving door?\",\n        \"choices\": {\n            \"label\": [\"A\", \"B\", \"C\"],\n            \"text\": [\"bank\", \"library\", \"garden\"],\n        },\n        \"answerKey\": \"B\",\n    }\n    _, texts, target = zeroshot.build_commonsenseqa_texts(sample)\n    assert len(texts) == 3\n    assert target == 1\n    assert \"library\" in texts[target]\n\n\ndef test_openbookqa_builder() -> None:\n    sample = {\n        \"question_stem\": \"Plants need what to make food?\",\n        \"choices\": {\n            \"label\": [\"A\", \"B\", \"C\", \"D\"],\n            \"text\": [\"sunlight\", \"soil\", \"wind\", \"music\"],\n        },\n        \"answerKey\": \"A\",\n    }\n    _, texts, target = zeroshot.build_openbookqa_texts(sample)\n    assert len(texts) == 4\n    assert target == 0\n    assert \"sunlight\" in texts[target]\n"
  },
  {
    "path": "tests/test_eval_state.py",
    "content": "import torch\n\nfrom nested_learning.eval_state import (\n    forward_with_eval_state,\n    init_eval_streaming_state,\n    parse_eval_state_mode,\n)\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef _transformer_model() -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=8,\n        num_layers=2,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=2),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"transformer\",\n    )\n    return HOPEModel(cfg).eval()\n\n\ndef test_parse_eval_state_mode_variants() -> None:\n    assert parse_eval_state_mode(\"reset_per_sample\") is False\n    assert parse_eval_state_mode(\"isolated\") is False\n    assert parse_eval_state_mode(\"carry_across_samples\") is True\n    assert parse_eval_state_mode(\"carry\") is True\n\n\ndef test_forward_with_eval_state_attention_cache_continuity() -> None:\n    torch.manual_seed(0)\n    model = _transformer_model()\n    tokens = torch.randint(0, 32, (1, 7))\n    with torch.no_grad():\n        full = model(tokens)\n        state = init_eval_streaming_state(\n            model,\n            use_fast_state=False,\n            use_attention_cache=True,\n        )\n        logits_a, state = forward_with_eval_state(model, tokens[:, :3], state=state)\n        logits_b, state = forward_with_eval_state(model, tokens[:, 3:], state=state)\n        stitched = torch.cat([logits_a, logits_b], dim=1)\n    assert state is not None\n    assert state.attention_cache is not None\n    assert torch.allclose(full, stitched, atol=1e-5, rtol=1e-5)\n\n\ndef test_forward_with_eval_state_none_state_passthrough() -> None:\n    model = _transformer_model()\n    tokens = torch.randint(0, 32, (1, 4))\n    with torch.no_grad():\n        logits, state = forward_with_eval_state(model, tokens, state=None)\n        expected = model(tokens)\n    assert state is None\n    assert torch.allclose(logits, expected)\n"
  },
  {
    "path": "tests/test_eval_state_cli.py",
    "content": "import importlib.util\nfrom pathlib import Path\n\nfrom typer.testing import CliRunner\n\n\ndef _load_eval_script(name: str):\n    root = Path(__file__).resolve().parents[1]\n    script_path = root / \"scripts\" / \"eval\" / f\"{name}.py\"\n    spec = importlib.util.spec_from_file_location(f\"tests.{name}\", script_path)\n    if spec is None or spec.loader is None:\n        raise RuntimeError(f\"Failed to load {script_path}\")\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)\n    return module\n\n\nniah = _load_eval_script(\"niah\")\nzeroshot = _load_eval_script(\"zeroshot\")\n\n\ndef test_zeroshot_rejects_carry_eval_state_mode() -> None:\n    runner = CliRunner()\n    result = runner.invoke(\n        zeroshot.app,\n        [\n            \"--config\",\n            \"dummy.yaml\",\n            \"--checkpoint\",\n            \"dummy.pt\",\n            \"--tokenizer-path\",\n            \"dummy.model\",\n            \"--list-tasks\",\n            \"--eval-state-mode\",\n            \"carry_across_samples\",\n        ],\n    )\n    assert result.exit_code != 0\n    assert \"reset_per_sample\" in result.stdout + result.stderr\n\n\ndef test_zeroshot_allows_reset_eval_state_mode_for_task_listing() -> None:\n    runner = CliRunner()\n    result = runner.invoke(\n        zeroshot.app,\n        [\n            \"--config\",\n            \"dummy.yaml\",\n            \"--checkpoint\",\n            \"dummy.pt\",\n            \"--tokenizer-path\",\n            \"dummy.model\",\n            \"--list-tasks\",\n            \"--eval-state-mode\",\n            \"reset_per_sample\",\n        ],\n    )\n    assert result.exit_code == 0\n    assert \"Available tasks:\" in result.stdout\n\n\ndef test_niah_rejects_carry_eval_state_mode_before_loading_inputs(tmp_path: Path) -> None:\n    runner = CliRunner()\n    config = tmp_path / \"cfg.yaml\"\n    ckpt = tmp_path / \"ckpt.pt\"\n    tok = tmp_path / \"tok.model\"\n    config.write_text(\"model: {}\\ntrain: {}\\ndata: {}\\n\")\n    ckpt.write_bytes(b\"\")\n    tok.write_text(\"\")\n    result = runner.invoke(\n        niah.app,\n        [\n            \"--config\",\n            str(config),\n            \"--checkpoint\",\n            str(ckpt),\n            \"--tokenizer-path\",\n            str(tok),\n            \"--context-lengths\",\n            \"32\",\n            \"--samples-per-length\",\n            \"1\",\n            \"--device\",\n            \"cpu\",\n            \"--eval-state-mode\",\n            \"carry_across_samples\",\n        ],\n    )\n    assert result.exit_code != 0\n    assert \"reset_per_sample\" in result.stdout + result.stderr\n"
  },
  {
    "path": "tests/test_faithfulness_harness.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import compute_teach_signal\n\n\ndef _cms_delta_l1(state, level_name: str) -> float:\n    params = state.blocks[0].cms_params[level_name]\n    return float(sum(delta.abs().sum().item() for delta in params.values()))\n\n\ndef test_e2e_update_paths_and_surprise_gate() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"hope_selfmod\",\n    )\n    model = HOPEModel(cfg).eval()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n\n    # Baseline: both selfmod and CMS should update in fast-state mode.\n    state = model.init_fast_state()\n    assert state.blocks[0].selfmod_state is not None\n    cms_before = _cms_delta_l1(state, \"cms_fast\")\n    assert cms_before == 0.0\n    selfmod_before = state.blocks[0].selfmod_state.memory.w2.detach().clone()\n    with torch.no_grad():\n        logits_before = model(tokens, fast_state=state)\n        teach = compute_teach_signal(model, logits_before, tokens)\n        _ = model(tokens, teach_signal=teach, fast_state=state)\n        logits_after = model(tokens, fast_state=state)\n    cms_after = _cms_delta_l1(state, \"cms_fast\")\n    selfmod_after = state.blocks[0].selfmod_state.memory.w2.detach().clone()\n    assert cms_after > 0.0\n    assert not torch.allclose(selfmod_before, selfmod_after)\n    assert not torch.allclose(logits_before, logits_after)\n\n    # Surprise gate: CMS updates should be blocked when threshold exceeds the computed surprise.\n    gated_state = model.init_fast_state()\n    with torch.no_grad():\n        gated_logits = model(tokens, fast_state=gated_state)\n        gated_teach = compute_teach_signal(model, gated_logits, tokens)\n    surprise = float(gated_teach.norm(dim=-1).mean().item())\n    model.set_surprise_threshold(surprise + 1.0)\n    try:\n        with torch.no_grad():\n            _ = model(tokens, teach_signal=gated_teach, fast_state=gated_state)\n        assert _cms_delta_l1(gated_state, \"cms_fast\") == 0.0\n    finally:\n        model.set_surprise_threshold(None)\n\n"
  },
  {
    "path": "tests/test_fast_state_batch_semantics.py",
    "content": "import pytest\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import _validate_fast_state_batch_semantics\n\n\ndef test_fast_state_batch_semantics_raises_when_strict() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\"use_fast_state\": True, \"fail_if_paper_faithful_disabled\": True},\n            \"data\": {\"batch_size\": 2},\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"fast state\"):\n        _validate_fast_state_batch_semantics(cfg)\n\n\ndef test_fast_state_batch_semantics_allows_batch1() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\"use_fast_state\": True, \"fail_if_paper_faithful_disabled\": True},\n            \"data\": {\"batch_size\": 1},\n        }\n    )\n    _validate_fast_state_batch_semantics(cfg)\n\n\ndef test_fast_state_batch_semantics_warns_with_structured_payload_when_not_strict(\n    capsys: pytest.CaptureFixture[str],\n) -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\"use_fast_state\": True, \"strict_streaming_contract\": False},\n            \"data\": {\"batch_size\": 3},\n        }\n    )\n    _validate_fast_state_batch_semantics(cfg)\n    captured = capsys.readouterr()\n    assert \"warning_code\" in captured.out\n    assert \"shared_fast_state_batch\" in captured.out\n"
  },
  {
    "path": "tests/test_fast_state_forward_equivalence.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef test_fast_state_zero_deltas_matches_meta_forward() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=1),),\n        block_variant=\"hope_hybrid\",\n    )\n    model = HOPEModel(cfg).eval()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    fast_state = model.init_fast_state()\n    with torch.no_grad():\n        logits_meta = model(tokens)\n        logits_fast = model(tokens, fast_state=fast_state)\n    assert torch.allclose(logits_meta, logits_fast, atol=1e-6)\n\n"
  },
  {
    "path": "tests/test_fast_state_meta_grads.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import _is_memory_param_name\n\n\ndef test_fast_state_preserves_outer_grads_for_memory_meta_params() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=1),),\n        block_variant=\"hope_hybrid\",\n    )\n    model = HOPEModel(cfg)\n    fast_state = model.init_fast_state()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    logits = model(tokens, fast_state=fast_state)\n    loss = torch.nn.functional.cross_entropy(\n        logits[:, :-1].reshape(-1, logits.size(-1)),\n        tokens[:, 1:].reshape(-1),\n    )\n    loss.backward()\n\n    memory_param_names = [\n        name\n        for name, param in model.named_parameters()\n        if param.requires_grad and _is_memory_param_name(name)\n    ]\n    assert memory_param_names, \"Test expected at least one memory parameter\"\n    assert any(\n        model.get_parameter(name).grad is not None for name in memory_param_names\n    ), \"Expected at least one memory parameter grad in fast_state mode\"\n"
  },
  {
    "path": "tests/test_fast_state_selfmod_meta_grads.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef test_hope_selfmod_fast_state_preserves_meta_forward_at_init() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(),\n        block_variant=\"hope_selfmod\",\n    )\n    model = HOPEModel(cfg).eval()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    fast_state = model.init_fast_state()\n    with torch.no_grad():\n        logits_meta = model(tokens)\n        logits_fast = model(tokens, fast_state=fast_state)\n    assert torch.allclose(logits_meta, logits_fast, atol=1e-6)\n\n\ndef test_hope_selfmod_fast_state_preserves_outer_grads_for_meta_memory_init() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(),\n        block_variant=\"hope_selfmod\",\n    )\n    model = HOPEModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    fast_state = model.init_fast_state()\n    logits = model(tokens, fast_state=fast_state)\n    loss = F.cross_entropy(\n        logits[:, :-1].reshape(-1, logits.size(-1)),\n        tokens[:, 1:].reshape(-1),\n    )\n    loss.backward()\n    block = model.blocks[0]\n    selfmod = getattr(block, \"selfmod\", None)\n    assert selfmod is not None\n    grad = selfmod.m_memory.w1.weight.grad\n    assert grad is not None\n    assert grad.abs().sum().item() > 0.0\n\n"
  },
  {
    "path": "tests/test_git_tracked_sizes_check.py",
    "content": "from __future__ import annotations\n\nimport subprocess\nfrom pathlib import Path\n\n\ndef test_git_tracked_sizes_check_passes_repo_defaults() -> None:\n    repo_root = Path(__file__).resolve().parents[1]\n    subprocess.run(\n        [\"bash\", \"scripts/checks/check_git_tracked_sizes.sh\"],\n        cwd=repo_root,\n        check=True,\n    )\n"
  },
  {
    "path": "tests/test_hope_block.py",
    "content": "import torch\n\nfrom nested_learning.hope.block import HOPEBlock, HOPEBlockConfig\nfrom nested_learning.levels import LevelSpec\n\n\ndef make_block() -> HOPEBlock:\n    config = HOPEBlockConfig(\n        dim=32,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=2),\n        cms_levels=[LevelSpec(name=\"fast\", update_period=1)],\n    )\n    return HOPEBlock(config)\n\n\ndef test_hope_block_forward() -> None:\n    block = make_block()\n    tokens = torch.randn(2, 8, 32)\n    out = block(tokens)\n    assert out.shape == tokens.shape\n\n\ndef test_hope_block_self_mod() -> None:\n    block = make_block()\n    tokens = torch.randn(2, 8, 32)\n    teach = torch.randn_like(tokens)\n    out = block(tokens, teach_signal=teach)\n    assert out.shape == tokens.shape\n"
  },
  {
    "path": "tests/test_hope_selfmod_fast_state_meta_unchanged.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.memorize import snapshot_state_dict\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import compute_teach_signal\n\n\ndef test_hope_selfmod_fast_state_updates_do_not_mutate_meta_params() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(),\n        block_variant=\"hope_selfmod\",\n        self_mod_lr=1.0,\n    )\n    model = HOPEModel(cfg)\n    baseline = snapshot_state_dict(model)\n    fast_state = model.init_fast_state()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    with torch.no_grad():\n        logits = model(tokens, fast_state=fast_state)\n        teach = compute_teach_signal(model, logits, tokens)\n        _ = model(tokens, teach_signal=teach, fast_state=fast_state)\n    for name, value in model.state_dict().items():\n        assert torch.allclose(baseline[name], value.cpu(), atol=1e-6)\n\n"
  },
  {
    "path": "tests/test_hope_selfmod_integration.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import compute_teach_signal\n\n\ndef test_hope_selfmod_variant_updates_selfmod_state_in_fast_mode() -> None:\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=[LevelSpec(name=\"cms_fast\", update_period=2)],\n        block_variant=\"hope_selfmod\",\n    )\n    model = HOPEModel(cfg)\n    state = model.init_fast_state()\n    assert state.blocks[0].selfmod_state is not None\n    before = state.blocks[0].selfmod_state.memory.w2.detach().clone()\n\n    tokens = torch.randint(0, cfg.vocab_size, (1, 6))\n    with torch.no_grad():\n        logits = model(tokens, fast_state=state)\n        teach = compute_teach_signal(model, logits, tokens)\n        _ = model(tokens, teach_signal=teach, fast_state=state)\n\n    after = state.blocks[0].selfmod_state.memory.w2.detach().clone()\n    assert not torch.allclose(before.unsqueeze(0), after)\n"
  },
  {
    "path": "tests/test_hope_selfmod_update_pass.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import compute_teach_signal\n\n\ndef test_hope_selfmod_updates_module_params_only_in_update_pass() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(),\n        block_variant=\"hope_selfmod\",\n        self_mod_lr=1.0,\n    )\n    model = HOPEModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    before = model.blocks[0].selfmod.m_memory.w2.weight.detach().clone()\n\n    _ = model(tokens)\n    after_forward = model.blocks[0].selfmod.m_memory.w2.weight.detach().clone()\n    assert torch.allclose(before, after_forward, atol=1e-6, rtol=1e-6)\n\n    with torch.no_grad():\n        logits = model(tokens)\n        teach = compute_teach_signal(model, logits, tokens)\n        _ = model(tokens, teach_signal=teach)\n    after_update = model.blocks[0].selfmod.m_memory.w2.weight.detach().clone()\n    assert not torch.allclose(after_forward, after_update)\n\n"
  },
  {
    "path": "tests/test_levels.py",
    "content": "from nested_learning.levels import LevelClock, LevelSpec\n\n\ndef test_level_clock_updates_on_schedule() -> None:\n    specs = [LevelSpec(name=\"fast\", update_period=1), LevelSpec(name=\"slow\", update_period=3)]\n    clock = LevelClock(specs)\n    updates = []\n    for step in range(5):\n        for spec in specs:\n            if clock.should_update(spec.name):\n                updates.append((step, spec.name))\n                clock.record_update(spec.name)\n        clock.tick()\n    assert updates[0] == (0, \"fast\")\n    assert any(level == \"slow\" for _, level in updates)\n"
  },
  {
    "path": "tests/test_m3.py",
    "content": "import torch\n\nfrom nested_learning.optim.m3 import M3\n\n\ndef test_m3_updates_and_slow_momentum() -> None:\n    torch.manual_seed(0)\n    param = torch.nn.Parameter(torch.ones(2, 2))\n    opt = M3(\n        [param],\n        lr=0.1,\n        beta1=0.9,\n        beta2=0.9,\n        beta3=0.5,\n        alpha=1.0,\n        ns_steps=1,\n        slow_chunk=2,\n        eps=1e-6,\n    )\n    param.grad = torch.ones_like(param)\n    opt.step()\n    first = param.detach().clone()\n    param.grad = torch.ones_like(param)\n    opt.step()\n    state = opt.state[param]\n    assert not torch.allclose(first, param)\n    assert torch.any(state[\"o2\"] != 0)\n\n\ndef test_m3_step_matches_reference_denominator_for_first_update() -> None:\n    # With ns_steps=0 and 1D params, orthogonalization is identity.\n    # This pins the exact first-step denominator/scaling behavior.\n    param = torch.nn.Parameter(torch.tensor([2.0]))\n    grad = torch.tensor([3.0])\n    opt = M3(\n        [param],\n        lr=0.1,\n        beta1=0.5,\n        beta2=0.25,\n        beta3=0.0,\n        alpha=0.0,\n        ns_steps=0,\n        slow_chunk=100,\n        eps=1e-6,\n        weight_decay=0.0,\n    )\n    param.grad = grad.clone()\n    opt.step()\n\n    m1 = 0.5 * grad\n    v = 0.25 * grad * grad\n    expected_update = m1 / (torch.sqrt(v) + 1e-6)\n    expected_param = torch.tensor([2.0]) - 0.1 * expected_update\n    assert torch.allclose(param.detach(), expected_param, atol=1e-6, rtol=1e-6)\n\n\ndef test_m3_two_steps_match_closed_form_without_slow_momentum() -> None:\n    # 1D + ns_steps=0 makes orthogonalization identity, so we can pin exact numerics.\n    param = torch.nn.Parameter(torch.tensor([1.5]))\n    grad = torch.tensor([2.0])\n    lr = 0.05\n    beta1 = 0.2\n    beta2 = 0.3\n    eps = 1e-6\n    opt = M3(\n        [param],\n        lr=lr,\n        beta1=beta1,\n        beta2=beta2,\n        beta3=0.0,\n        alpha=0.0,\n        ns_steps=0,\n        slow_chunk=1000,\n        eps=eps,\n        weight_decay=0.0,\n    )\n    param.grad = grad.clone()\n    opt.step()\n    param.grad = grad.clone()\n    opt.step()\n\n    g = grad\n    m1_1 = beta1 * g\n    v_1 = beta2 * g * g\n    p1 = torch.tensor([1.5]) - lr * (m1_1 / (torch.sqrt(v_1) + eps))\n\n    m1_2 = m1_1 + beta1 * g\n    v_2 = v_1 + beta2 * g * g\n    p2 = p1 - lr * (m1_2 / (torch.sqrt(v_2) + eps))\n    assert torch.allclose(param.detach(), p2, atol=1e-6, rtol=1e-6)\n\n\ndef test_m3_weight_decay_is_included_in_reference_step() -> None:\n    param = torch.nn.Parameter(torch.tensor([2.0]))\n    grad = torch.tensor([3.0])\n    lr = 0.1\n    wd = 0.4\n    eps = 1e-6\n    opt = M3(\n        [param],\n        lr=lr,\n        beta1=0.5,\n        beta2=0.25,\n        beta3=0.0,\n        alpha=0.0,\n        ns_steps=0,\n        slow_chunk=100,\n        eps=eps,\n        weight_decay=wd,\n    )\n    param.grad = grad.clone()\n    opt.step()\n\n    g_eff = grad + wd * torch.tensor([2.0])\n    m1 = 0.5 * g_eff\n    v = 0.25 * g_eff * g_eff\n    expected = torch.tensor([2.0]) - lr * (m1 / (torch.sqrt(v) + eps))\n    assert torch.allclose(param.detach(), expected, atol=1e-6, rtol=1e-6)\n\n\ndef test_m3_slow_buffer_resets_and_o2_updates_on_chunk_boundary() -> None:\n    param = torch.nn.Parameter(torch.ones(2))\n    opt = M3(\n        [param],\n        lr=0.01,\n        beta1=0.0,\n        beta2=0.0,\n        beta3=0.5,\n        alpha=1.0,\n        ns_steps=0,\n        slow_chunk=2,\n        eps=1e-6,\n    )\n    param.grad = torch.ones_like(param)\n    opt.step()\n    state = opt.state[param]\n    assert torch.allclose(state[\"slow_buffer\"], torch.ones_like(param))\n    assert torch.all(state[\"o2\"] == 0)\n\n    param.grad = torch.ones_like(param)\n    opt.step()\n    state = opt.state[param]\n    # Boundary step consumes accumulated slow buffer into m2/o2 and clears it.\n    assert torch.allclose(state[\"slow_buffer\"], torch.zeros_like(param))\n    assert torch.any(state[\"o2\"] != 0)\n"
  },
  {
    "path": "tests/test_m3_slow_timing.py",
    "content": "import torch\n\nfrom nested_learning.optim.m3 import M3\n\n\ndef test_m3_slow_momentum_applies_next_chunk_not_boundary_step() -> None:\n    param = torch.nn.Parameter(torch.tensor([0.0]))\n    opt = M3(\n        [param],\n        lr=1.0,\n        beta1=1.0,\n        beta2=0.0,\n        beta3=1.0,\n        alpha=1.0,\n        eps=1.0,\n        ns_steps=0,\n        slow_chunk=2,\n        weight_decay=0.0,\n    )\n    param.grad = torch.tensor([1.0])\n    opt.step()\n    param.grad = torch.tensor([1.0])\n    opt.step()\n    # With correct timing, the slow momentum (o2) is updated after step 2 and therefore\n    # does not affect the step-2 update itself.\n    assert torch.allclose(param.detach(), torch.tensor([-3.0]))\n\n"
  },
  {
    "path": "tests/test_memorization.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.memorize import MemorizeConfig, memorize_tokens, snapshot_state_dict\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef _tiny_model() -> HOPEModel:\n    titan = LevelSpec(name=\"titan\", update_period=2, optimizer_key=\"titan_opt\")\n    cms = [\n        LevelSpec(name=\"cms_fast\", update_period=1, optimizer_key=\"cms_opt\"),\n        LevelSpec(name=\"cms_mid\", update_period=2, optimizer_key=\"cms_opt\"),\n    ]\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=titan,\n        cms_levels=cms,\n        optimizers=None,\n        teach_scale=0.1,\n    )\n    return HOPEModel(cfg)\n\n\ndef _tiny_model_update_every_call() -> HOPEModel:\n    titan = LevelSpec(name=\"titan\", update_period=1, optimizer_key=\"titan_opt\")\n    cms = [\n        LevelSpec(name=\"cms_fast\", update_period=1, optimizer_key=\"cms_opt\"),\n        LevelSpec(name=\"cms_mid\", update_period=1, optimizer_key=\"cms_opt\"),\n    ]\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=titan,\n        cms_levels=cms,\n        optimizers=None,\n        teach_scale=0.1,\n    )\n    return HOPEModel(cfg)\n\n\ndef _tiny_model_with_self_mod_lr(lr: float) -> HOPEModel:\n    titan = LevelSpec(name=\"titan\", update_period=1, optimizer_key=\"titan_opt\")\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=titan,\n        cms_levels=(),\n        optimizers=None,\n        teach_scale=0.1,\n        self_mod_lr=lr,\n    )\n    return HOPEModel(cfg)\n\n\ndef _fast_titan_delta_norm(fast_state, before: dict[str, torch.Tensor]) -> float:\n    block_state = fast_state.blocks[0]\n    if block_state.titan_params is None:\n        return 0.0\n    total = 0.0\n    for name, value in block_state.titan_params.items():\n        total += (value.cpu() - before[name]).norm().item()\n    return total\n\n\ndef test_memorize_fast_state_does_not_mutate_meta_params() -> None:\n    torch.manual_seed(0)\n    model = _tiny_model()\n    tokens = torch.randint(0, model.config.vocab_size, (1, 8))\n    baseline = snapshot_state_dict(model)\n    fast_state = model.init_fast_state()\n    cfg = MemorizeConfig(enabled=True, steps=2, use_fast_state=True)\n    memorize_tokens(model, tokens, cfg, fast_state=fast_state)\n    for name, param in model.state_dict().items():\n        assert torch.allclose(baseline[name], param.cpu(), atol=1e-6)\n\n\ndef test_memorize_fast_state_changes_outputs_and_resets() -> None:\n    torch.manual_seed(0)\n    model = _tiny_model()\n    tokens = torch.randint(0, model.config.vocab_size, (1, 8))\n    fast_state = model.init_fast_state()\n    with torch.no_grad():\n        logits_before = model(tokens, fast_state=fast_state).detach().clone()\n\n    cfg = MemorizeConfig(enabled=True, steps=1, use_fast_state=True)\n    memorize_tokens(model, tokens, cfg, fast_state=fast_state)\n    with torch.no_grad():\n        logits_after = model(tokens, fast_state=fast_state).detach().clone()\n\n    assert not torch.allclose(logits_before, logits_after)\n\n    reset_state = model.init_fast_state()\n    with torch.no_grad():\n        logits_reset = model(tokens, fast_state=reset_state).detach().clone()\n    assert torch.allclose(logits_before, logits_reset, atol=1e-6)\n\n\ndef test_memorize_respects_surprise_threshold() -> None:\n    model = _tiny_model()\n    tokens = torch.randint(0, model.config.vocab_size, (1, 8))\n    fast_state = model.init_fast_state()\n    block_state = fast_state.blocks[0]\n    titan_before = {k: v.cpu().clone() for k, v in block_state.titan_params.items()}  # type: ignore[union-attr]\n    cfg = MemorizeConfig(enabled=True, steps=1, surprise_threshold=1e6, use_fast_state=True)\n    memorize_tokens(model, tokens, cfg, fast_state=fast_state)\n    assert _fast_titan_delta_norm(fast_state, titan_before) == 0.0\n\n\ndef test_memorize_paths_filter_blocks_updates() -> None:\n    model = _tiny_model()\n    tokens = torch.randint(0, model.config.vocab_size, (1, 8))\n    fast_state = model.init_fast_state()\n    block_state = fast_state.blocks[0]\n    titan_before = {k: v.cpu().clone() for k, v in block_state.titan_params.items()}  # type: ignore[union-attr]\n    cfg = MemorizeConfig(enabled=True, steps=1, paths=(), use_fast_state=True)\n    memorize_tokens(model, tokens, cfg, fast_state=fast_state)\n    assert _fast_titan_delta_norm(fast_state, titan_before) == 0.0\n\n\ndef test_memorize_online_chunking_updates_once_per_target() -> None:\n    model = _tiny_model_update_every_call()\n    tokens = torch.randint(0, model.config.vocab_size, (1, 8))\n    fast_state = model.init_fast_state()\n    cfg = MemorizeConfig(enabled=True, online_chunk_size=1, use_fast_state=True)\n    stats = memorize_tokens(model, tokens, cfg, fast_state=fast_state)\n    assert stats[\"titan_update_events\"] == float(tokens.size(1) - 1)\n\n\ndef test_teach_mask_restricts_memorization_updates() -> None:\n    torch.manual_seed(0)\n    model = _tiny_model_update_every_call()\n    tokens = torch.randint(0, model.config.vocab_size, (1, 8))\n    cfg = MemorizeConfig(enabled=True, steps=1, use_fast_state=True, paths=(\"cms_fast\",))\n\n    fast_state_masked = model.init_fast_state()\n    zero_mask = torch.zeros((tokens.size(0), tokens.size(1)))\n    stats_masked = memorize_tokens(\n        model, tokens, cfg, fast_state=fast_state_masked, teach_mask=zero_mask\n    )\n    assert stats_masked[\"cms_fast_update_events\"] == 0.0\n\n    fast_state_full = model.init_fast_state()\n    one_mask = torch.ones((tokens.size(0), tokens.size(1)))\n    stats_full = memorize_tokens(\n        model,\n        tokens,\n        cfg,\n        fast_state=fast_state_full,\n        teach_mask=one_mask,\n    )\n    assert stats_full[\"cms_fast_update_events\"] > 0.0\n\n\ndef test_self_mod_lr_scales_fast_state_update_magnitude() -> None:\n    torch.manual_seed(0)\n    model_hi = _tiny_model_with_self_mod_lr(1e-3)\n    torch.manual_seed(0)\n    model_lo = _tiny_model_with_self_mod_lr(1e-4)\n    tokens = torch.randint(0, model_hi.config.vocab_size, (1, 8))\n\n    state_hi = model_hi.init_fast_state()\n    state_lo = model_lo.init_fast_state()\n    titan_hi_before = {k: v.cpu().clone() for k, v in state_hi.blocks[0].titan_params.items()}  # type: ignore[union-attr]\n    titan_lo_before = {k: v.cpu().clone() for k, v in state_lo.blocks[0].titan_params.items()}  # type: ignore[union-attr]\n\n    cfg = MemorizeConfig(enabled=True, steps=1, paths=(\"titan\",), use_fast_state=True)\n    memorize_tokens(model_hi, tokens, cfg, fast_state=state_hi)\n    memorize_tokens(model_lo, tokens, cfg, fast_state=state_lo)\n\n    hi = _fast_titan_delta_norm(state_hi, titan_hi_before)\n    lo = _fast_titan_delta_norm(state_lo, titan_lo_before)\n    assert hi > lo * 5.0\n"
  },
  {
    "path": "tests/test_model.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef test_model_forward() -> None:\n    config = ModelConfig(\n        vocab_size=100,\n        dim=32,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=2),\n        cms_levels=[LevelSpec(name=\"fast\", update_period=1)],\n    )\n    model = HOPEModel(config)\n    tokens = torch.randint(0, 100, (2, 10))\n    logits = model(tokens)\n    assert logits.shape == (2, 10, 100)\n"
  },
  {
    "path": "tests/test_model_streaming_cadence.py",
    "content": "import pytest\nimport torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef _metric(metrics: dict[str, float], key: str) -> float:\n    return float(metrics.get(key, 0.0))\n\n\ndef _build_attention_model(*, flush_partial: bool) -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=8,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"fast\", update_period=4),),\n        block_variant=\"hope_attention\",\n        cms_flush_partial_at_end=flush_partial,\n    )\n    return HOPEModel(cfg).eval()\n\n\ndef _build_attention_model_with_period(*, update_period: int) -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=8,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"fast\", update_period=update_period),),\n        block_variant=\"hope_attention\",\n        cms_flush_partial_at_end=False,\n    )\n    return HOPEModel(cfg).eval()\n\n\ndef test_model_streaming_cadence_matches_single_call_counts() -> None:\n    torch.manual_seed(0)\n    model = _build_attention_model(flush_partial=False)\n    key_prefix = \"layer0.cms.fast\"\n\n    full_state = model.init_fast_state()\n    with torch.no_grad():\n        _ = model(\n            torch.randint(0, 32, (1, 4)),\n            teach_signal=torch.ones(1, 4, 8),\n            fast_state=full_state,\n            finalize_updates=False,\n        )\n    full_metrics = model.pop_update_metrics()\n\n    split_state = model.init_fast_state()\n    with torch.no_grad():\n        _ = model(\n            torch.randint(0, 32, (1, 2)),\n            teach_signal=torch.ones(1, 2, 8),\n            fast_state=split_state,\n            finalize_updates=False,\n        )\n    _ = model.pop_update_metrics()\n    with torch.no_grad():\n        _ = model(\n            torch.randint(0, 32, (1, 2)),\n            teach_signal=torch.ones(1, 2, 8),\n            fast_state=split_state,\n            finalize_updates=False,\n        )\n    split_metrics = model.pop_update_metrics()\n\n    assert _metric(full_metrics, f\"{key_prefix}.updates_applied\") == 1.0\n    assert _metric(split_metrics, f\"{key_prefix}.updates_applied\") == 1.0\n    assert _metric(full_metrics, f\"{key_prefix}.chunk_tokens\") == 4.0\n    assert _metric(split_metrics, f\"{key_prefix}.chunk_tokens\") == 4.0\n\n\n@pytest.mark.parametrize(\"update_period\", [2, 4, 8])\ndef test_model_streaming_cadence_matches_for_multiple_periods(update_period: int) -> None:\n    torch.manual_seed(0)\n    model = _build_attention_model_with_period(update_period=update_period)\n    key_prefix = \"layer0.cms.fast\"\n    total_tokens = update_period * 2\n    full_tokens = torch.randint(0, 32, (1, total_tokens))\n    full_teach = torch.ones(1, total_tokens, 8)\n\n    full_state = model.init_fast_state()\n    with torch.no_grad():\n        _ = model(\n            full_tokens,\n            teach_signal=full_teach,\n            fast_state=full_state,\n            finalize_updates=False,\n        )\n    full_metrics = model.pop_update_metrics()\n\n    split_state = model.init_fast_state()\n    for _ in range(2):\n        with torch.no_grad():\n            _ = model(\n                torch.randint(0, 32, (1, update_period)),\n                teach_signal=torch.ones(1, update_period, 8),\n                fast_state=split_state,\n                finalize_updates=False,\n            )\n    split_metrics = model.pop_update_metrics()\n\n    assert _metric(full_metrics, f\"{key_prefix}.updates_applied\") == 2.0\n    assert _metric(split_metrics, f\"{key_prefix}.updates_applied\") == 1.0\n    assert _metric(full_metrics, f\"{key_prefix}.chunk_tokens\") == float(total_tokens)\n    assert _metric(split_metrics, f\"{key_prefix}.chunk_tokens\") == float(update_period)\n\n\ndef test_model_finalize_flushes_partial_once() -> None:\n    torch.manual_seed(0)\n    model = _build_attention_model(flush_partial=True)\n    state = model.init_fast_state()\n    key_prefix = \"layer0.cms.fast\"\n\n    with torch.no_grad():\n        _ = model(\n            torch.randint(0, 32, (1, 3)),\n            teach_signal=torch.ones(1, 3, 8),\n            fast_state=state,\n            finalize_updates=False,\n        )\n    first = model.pop_update_metrics()\n    assert _metric(first, f\"{key_prefix}.updates_applied\") == 0.0\n    assert _metric(first, f\"{key_prefix}.pending_tokens\") == 3.0\n\n    with torch.no_grad():\n        _ = model(\n            torch.randint(0, 32, (1, 3)),\n            teach_signal=torch.ones(1, 3, 8),\n            fast_state=state,\n            finalize_updates=False,\n        )\n    second = model.pop_update_metrics()\n    assert _metric(second, f\"{key_prefix}.updates_applied\") == 1.0\n    assert _metric(second, f\"{key_prefix}.chunk_tokens\") == 4.0\n    assert _metric(second, f\"{key_prefix}.pending_tokens\") == 2.0\n\n    with torch.no_grad():\n        _ = model(\n            torch.randint(0, 32, (1, 1)),\n            teach_signal=torch.ones(1, 1, 8),\n            fast_state=state,\n            finalize_updates=True,\n        )\n    final = model.pop_update_metrics()\n    assert _metric(final, f\"{key_prefix}.updates_applied\") == 1.0\n    assert _metric(final, f\"{key_prefix}.tokens_flushed\") == 3.0\n    assert _metric(final, f\"{key_prefix}.pending_tokens\") == 0.0\n\n\ndef test_slow_cms_level_does_not_starve_under_segmented_calls() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=8,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(\n            LevelSpec(name=\"fast\", update_period=2),\n            LevelSpec(name=\"slow\", update_period=8),\n        ),\n        block_variant=\"hope_attention\",\n        cms_flush_partial_at_end=False,\n    )\n    model = HOPEModel(cfg).eval()\n    state = model.init_fast_state()\n    saw_slow_update = False\n    for _ in range(4):\n        with torch.no_grad():\n            _ = model(\n                torch.randint(0, 32, (1, 2)),\n                teach_signal=torch.ones(1, 2, 8),\n                fast_state=state,\n                finalize_updates=False,\n            )\n        metrics = model.pop_update_metrics()\n        if _metric(metrics, \"layer0.cms.slow.updates_applied\") > 0:\n            saw_slow_update = True\n    assert saw_slow_update\n"
  },
  {
    "path": "tests/test_online_chunking.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import (\n    _compute_layer_teach_signals,\n    _iter_online_boundary_chunks,\n    _iter_online_token_chunks,\n)\n\n\ndef test_online_chunking_carries_boundary_overlap_and_token_pairs() -> None:\n    tokens = torch.arange(10).view(1, 10)\n    chunks = list(_iter_online_token_chunks(tokens, chunk_size=4))\n\n    lengths = [chunk.size(1) for chunk, _ in chunks]\n    finals = [is_final for _, is_final in chunks]\n    pairs = sum(chunk.size(1) - 1 for chunk, _ in chunks)\n\n    assert lengths == [4, 5, 3]\n    assert finals == [False, False, True]\n    assert pairs == tokens.size(1) - 1\n\n\ndef test_online_chunking_supports_chunk_size_one() -> None:\n    tokens = torch.arange(5).view(1, 5)\n    chunks = list(_iter_online_token_chunks(tokens, chunk_size=1))\n    # First chunk has length 1 (no CE pairs), remaining chunks have overlap length 2.\n    lengths = [chunk.size(1) for chunk, _ in chunks]\n    assert lengths[0] == 1\n    assert all(length == 2 for length in lengths[1:])\n    assert sum(chunk.size(1) - 1 for chunk, _ in chunks) == tokens.size(1) - 1\n\n\ndef test_online_chunking_chunk_size_one_preserves_train_loop_token_accounting() -> None:\n    tokens = torch.arange(9).view(1, 9)\n    total_pairs = 0\n    for chunk, _finalize in _iter_online_token_chunks(tokens, chunk_size=1):\n        if chunk.size(1) <= 1:\n            continue\n        total_pairs += chunk.size(1) - 1\n    assert total_pairs == tokens.size(1) - 1\n\n\ndef test_online_boundary_chunks_emit_next_tokens_and_exact_target_count() -> None:\n    tokens = torch.arange(10).view(1, 10)\n    chunks = list(_iter_online_boundary_chunks(tokens, chunk_size=4))\n    lengths = [chunk.size(1) for chunk, _next, _final in chunks]\n    next_tokens = [None if nxt is None else int(nxt[0].item()) for _chunk, nxt, _final in chunks]\n    finals = [is_final for _chunk, _next, is_final in chunks]\n    target_count = sum(chunk.size(1) - 1 + (0 if nxt is None else 1) for chunk, nxt, _ in chunks)\n    assert lengths == [4, 4, 2]\n    assert next_tokens == [4, 8, None]\n    assert finals == [False, False, True]\n    assert target_count == tokens.size(1) - 1\n\n\ndef _supervised_targets_overlap(tokens: torch.Tensor, chunk_size: int) -> list[int]:\n    targets: list[int] = []\n    for chunk, _ in _iter_online_token_chunks(tokens, chunk_size=chunk_size):\n        if chunk.size(1) <= 1:\n            continue\n        targets.extend(chunk[:, 1:].reshape(-1).tolist())\n    return targets\n\n\ndef _supervised_targets_boundary(tokens: torch.Tensor, chunk_size: int) -> list[int]:\n    targets: list[int] = []\n    for chunk, next_tokens, _ in _iter_online_boundary_chunks(tokens, chunk_size=chunk_size):\n        if chunk.size(1) > 1:\n            targets.extend(chunk[:, 1:].reshape(-1).tolist())\n        if next_tokens is not None:\n            targets.extend(next_tokens.reshape(-1).tolist())\n    return targets\n\n\ndef _build_transformer_model() -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=2,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=2),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"transformer\",\n    )\n    return HOPEModel(cfg).eval()\n\n\ndef test_online_boundary_chunked_loss_matches_monolithic_with_attention_cache() -> None:\n    torch.manual_seed(0)\n    model = _build_transformer_model()\n    tokens = torch.randint(0, 64, (1, 11))\n    with torch.no_grad():\n        full_logits = model(tokens)\n        full_loss = F.cross_entropy(\n            full_logits[:, :-1].reshape(-1, full_logits.size(-1)),\n            tokens[:, 1:].reshape(-1),\n        )\n\n        cache = model.init_attention_cache()\n        total_pairs = 0\n        total_loss = 0.0\n        for chunk, next_tokens, _ in _iter_online_boundary_chunks(tokens, chunk_size=4):\n            logits, cache = model(\n                chunk,\n                attention_cache=cache,\n                return_attention_cache=True,\n            )\n            if next_tokens is None:\n                targets = chunk[:, 1:]\n            else:\n                targets = torch.cat([chunk[:, 1:], next_tokens.unsqueeze(1)], dim=1)\n            if targets.numel() == 0:\n                continue\n            loss = F.cross_entropy(\n                logits[:, : targets.size(1), :].reshape(-1, logits.size(-1)),\n                targets.reshape(-1),\n            )\n            total_pairs += targets.size(1)\n            total_loss += float(loss.item()) * targets.size(1)\n        chunked_loss = total_loss / max(total_pairs, 1)\n    assert total_pairs == tokens.size(1) - 1\n    assert torch.isclose(torch.tensor(chunked_loss), full_loss, atol=1e-6, rtol=1e-6)\n\n\ndef test_online_target_coverage_property_randomized() -> None:\n    torch.manual_seed(0)\n    for seq_len in range(2, 25):\n        tokens = torch.arange(seq_len).view(1, seq_len)\n        expected = list(range(1, seq_len))\n        for chunk_size in range(1, seq_len + 1):\n            overlap_targets = _supervised_targets_overlap(tokens, chunk_size)\n            boundary_targets = _supervised_targets_boundary(tokens, chunk_size)\n            assert sorted(overlap_targets) == expected\n            assert sorted(boundary_targets) == expected\n\n\ndef test_chunk_schedule_permutations_preserve_supervision_set() -> None:\n    tokens = torch.arange(17).view(1, 17)\n    baseline_overlap = sorted(_supervised_targets_overlap(tokens, chunk_size=3))\n    baseline_boundary = sorted(_supervised_targets_boundary(tokens, chunk_size=3))\n    for chunk_size in (1, 2, 4, 5, 8, 16):\n        assert (\n            sorted(_supervised_targets_overlap(tokens, chunk_size=chunk_size))\n            == baseline_overlap\n        )\n        assert (\n            sorted(_supervised_targets_boundary(tokens, chunk_size=chunk_size))\n            == baseline_boundary\n        )\n\n\ndef test_per_layer_teach_with_boundary_chunks_runs_update_path() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"hope_attention\",\n    )\n    model = HOPEModel(cfg).eval()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    state = model.init_fast_state()\n    cache = model.init_attention_cache()\n\n    chunk, next_tokens, finalize_updates = next(_iter_online_boundary_chunks(tokens, chunk_size=4))\n    logits, _pre, block_outputs, cache = model.forward_with_block_outputs(\n        chunk,\n        fast_state=state,\n        attention_cache=cache,\n        return_attention_cache=True,\n    )\n    assert next_tokens is not None\n    targets = torch.cat([chunk[:, 1:], next_tokens.unsqueeze(1)], dim=1)\n    loss = F.cross_entropy(\n        logits[:, : targets.size(1), :].reshape(-1, logits.size(-1)),\n        targets.reshape(-1),\n    )\n    teach_signals = _compute_layer_teach_signals(loss, block_outputs)\n    loss.backward()\n    with torch.no_grad():\n        _ = model(\n            chunk,\n            teach_signals=teach_signals,\n            fast_state=state,\n            attention_cache=cache,\n            finalize_updates=finalize_updates,\n        )\n    metrics = model.pop_update_metrics()\n    assert \"layer0.cms.cms_fast.gate_hit\" in metrics\n"
  },
  {
    "path": "tests/test_optim.py",
    "content": "import torch\n\nfrom nested_learning.optim.deep import DeepMomentum\n\n\ndef test_deep_momentum_nl_preconditioner_projects_grad() -> None:\n    grad = torch.randn(4, 6)\n    context = torch.randn(6)\n    optimizer = DeepMomentum(beta=0.0, beta2=0.0, variant=\"nl_l2_precond\")\n    update = optimizer(grad, context=context)\n    unit = context / context.norm()\n    expected = grad - (grad * unit).sum(dim=-1, keepdim=True) * unit\n    assert torch.allclose(update, expected, atol=1e-5, rtol=1e-4)\n    assert optimizer.last_metrics[\"ctx_norm\"] > 0\n    assert optimizer.last_metrics[\"proj_norm\"] >= 0\n\n\ndef test_deep_momentum_nl_preconditioner_reduces_simple_objective() -> None:\n    torch.manual_seed(0)\n    context = torch.randn(6)\n    weights = torch.randn(6)\n    grad = torch.dot(weights, context) * context\n    optimizer = DeepMomentum(beta=0.0, beta2=0.0, variant=\"nl_l2_precond\")\n    update = optimizer(grad, context=context)\n    with torch.no_grad():\n        old_obj = 0.5 * torch.dot(weights, context) ** 2\n        new_weights = weights - 0.1 * update\n        new_obj = 0.5 * torch.dot(new_weights, context) ** 2\n    assert new_obj < old_obj\n\n\ndef test_deep_momentum_keeps_state_per_param_key() -> None:\n    optimizer = DeepMomentum(beta=0.5, beta2=0.0, variant=\"preconditioned\")\n    grad_a = torch.ones(2, 3)\n    grad_b = torch.ones(5)\n    out_a1 = optimizer(grad_a, param_key=\"a\").detach().clone()\n    _ = optimizer(grad_b, param_key=\"b\")\n    out_a2 = optimizer(grad_a, param_key=\"a\").detach().clone()\n    assert out_a2.shape == out_a1.shape\n    assert torch.all(out_a2 > out_a1)\n    assert set(optimizer.state.keys()) == {\"a\", \"b\"}\n\n\ndef test_deep_momentum_nl_preconditioner_skips_mismatched_shapes() -> None:\n    optimizer = DeepMomentum(beta=0.0, beta2=0.0, variant=\"nl_l2_precond\")\n    context = torch.randn(512)\n    grad_bias = torch.randn(2048)\n    out = optimizer(grad_bias, context=context)\n    assert torch.allclose(out, grad_bias)\n    assert optimizer.last_metrics[\"proj_skipped\"] == 1.0\n\n\ndef test_deep_momentum_nl_preconditioner_outputs_orthogonal_update() -> None:\n    torch.manual_seed(0)\n    grad = torch.randn(3, 5)\n    context = torch.randn(5)\n    optimizer = DeepMomentum(beta=0.0, beta2=0.0, variant=\"nl_l2_precond\")\n    update = optimizer(grad, context=context)\n    unit = context / context.norm()\n    # Each row update should be orthogonal to the context direction.\n    proj = (update * unit).sum(dim=-1).abs()\n    assert torch.all(proj < 1e-5)\n"
  },
  {
    "path": "tests/test_optimizer_param_policy.py",
    "content": "import torch\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import _build_optimizer, _is_memory_param_name\n\n\ndef _make_small_hope_model() -> HOPEModel:\n    return HOPEModel(\n        ModelConfig(\n            vocab_size=128,\n            dim=16,\n            num_layers=2,\n            heads=2,\n            titan_level=LevelSpec(name=\"titan\", update_period=2),\n            cms_levels=(LevelSpec(name=\"cms_fast\", update_period=1),),\n            block_variant=\"hope_hybrid\",\n        )\n    )\n\n\ndef _optimizer_param_set(optimizer: torch.optim.Optimizer) -> set[torch.nn.Parameter]:\n    params: set[torch.nn.Parameter] = set()\n    for group in optimizer.param_groups:\n        for param in group[\"params\"]:\n            params.add(param)\n    return params\n\n\ndef test_param_policy_all_includes_all_trainable_params() -> None:\n    model = _make_small_hope_model()\n    cfg = OmegaConf.create({\"optim\": {\"type\": \"adamw\", \"lr\": 1e-3, \"param_policy\": \"all\"}})\n    optimizer = _build_optimizer(model, cfg, device=torch.device(\"cpu\"))\n    opt_params = _optimizer_param_set(optimizer)\n    expected = {p for _n, p in model.named_parameters() if p.requires_grad}\n    assert opt_params == expected\n    has_memory = any(\n        _is_memory_param_name(name) for name, p in model.named_parameters() if p.requires_grad\n    )\n    assert has_memory\n\n\ndef test_param_policy_exclude_memory_drops_memory_params() -> None:\n    model = _make_small_hope_model()\n    cfg = OmegaConf.create(\n        {\"optim\": {\"type\": \"adamw\", \"lr\": 1e-3, \"param_policy\": \"exclude_memory\"}}\n    )\n    optimizer = _build_optimizer(model, cfg, device=torch.device(\"cpu\"))\n    opt_params = _optimizer_param_set(optimizer)\n    expected = {\n        p\n        for name, p in model.named_parameters()\n        if p.requires_grad and not _is_memory_param_name(name)\n    }\n    assert opt_params == expected\n    contains_memory = any(\n        _is_memory_param_name(name) for name, p in model.named_parameters() if p in opt_params\n    )\n    assert not contains_memory\n\n\ndef test_param_policy_only_memory_keeps_only_memory_params() -> None:\n    model = _make_small_hope_model()\n    cfg = OmegaConf.create(\n        {\"optim\": {\"type\": \"adamw\", \"lr\": 1e-3, \"param_policy\": \"only_memory\"}}\n    )\n    optimizer = _build_optimizer(model, cfg, device=torch.device(\"cpu\"))\n    opt_params = _optimizer_param_set(optimizer)\n    expected = {\n        p\n        for name, p in model.named_parameters()\n        if p.requires_grad and _is_memory_param_name(name)\n    }\n    assert expected\n    assert opt_params == expected\n"
  },
  {
    "path": "tests/test_package_release_script.py",
    "content": "from __future__ import annotations\n\nimport json\nimport subprocess\nfrom pathlib import Path\n\n\ndef test_package_script_includes_train_flags_and_excludes_raw_data(tmp_path: Path) -> None:\n    repo_root = Path(__file__).resolve().parents[1]\n    script_src = repo_root / \"scripts/package_pilot_release.sh\"\n    script_dst = tmp_path / \"scripts/package_pilot_release.sh\"\n    script_dst.parent.mkdir(parents=True)\n    script_dst.write_text(script_src.read_text(), encoding=\"utf-8\")\n\n    # Minimal repo layout expected by package script.\n    (tmp_path / \"artifacts/checkpoints/pilot\").mkdir(parents=True)\n    (tmp_path / \"configs\").mkdir(parents=True)\n    (tmp_path / \"data/raw\").mkdir(parents=True)\n    (tmp_path / \"data/raw/secret.txt\").write_text(\"do-not-copy\", encoding=\"utf-8\")\n    (tmp_path / \"configs/pilot.yaml\").write_text(\"model: {}\\ntrain: {}\\n\", encoding=\"utf-8\")\n\n    ckpt = tmp_path / \"artifacts/checkpoints/pilot/step_000001.pt\"\n    ckpt.write_bytes(b\"dummy-checkpoint\")\n    (tmp_path / \"artifacts/checkpoints/pilot/step_000001.meta.json\").write_text(\n        json.dumps(\n            {\n                \"algorithm_mode\": \"two_pass_stopgrad_updates\",\n                \"online_updates\": True,\n                \"online_boundary_targets\": False,\n                \"online_carry_attention_cache\": False,\n                \"use_fast_state\": False,\n            }\n        ),\n        encoding=\"utf-8\",\n    )\n    (tmp_path / \"artifacts/checkpoints/pilot/step_000001.yaml\").write_text(\n        \"model: {}\\ntrain: {}\\n\",\n        encoding=\"utf-8\",\n    )\n    (tmp_path / \"artifacts/checkpoints/pilot/step_000001.sha256\").write_text(\n        \"abc  step_000001.pt\\n\",\n        encoding=\"utf-8\",\n    )\n\n    subprocess.run([\"bash\", str(script_dst)], cwd=tmp_path, check=True)\n\n    manifest = (tmp_path / \"artifacts/pilot_release/MANIFEST.txt\").read_text(encoding=\"utf-8\")\n    assert \"HOPE Train Flags:\" in manifest\n    assert \"algorithm_mode='two_pass_stopgrad_updates'\" in manifest\n    assert not (tmp_path / \"artifacts/pilot_release/secret.txt\").exists()\n"
  },
  {
    "path": "tests/test_paper_faithful_configs.py",
    "content": "from pathlib import Path\n\nfrom hydra import compose, initialize_config_dir\nfrom hydra.core.global_hydra import GlobalHydra\n\nfrom nested_learning.training import build_model_from_cfg\n\n\ndef _compose_config(name: str, overrides: list[str] | None = None):\n    config_dir = Path(__file__).resolve().parents[1] / \"configs\"\n    GlobalHydra.instance().clear()\n    with initialize_config_dir(version_base=None, config_dir=str(config_dir)):\n        return compose(config_name=name, overrides=overrides or [])\n\n\ndef test_pilot_paper_faithful_config_composes() -> None:\n    cfg = _compose_config(\"pilot_paper_faithful\")\n    assert cfg.model.block_variant == \"hope_attention\"\n    assert cfg.model.cms_flush_partial_at_end is True\n    assert cfg.model.surprise_threshold is None\n    assert cfg.data.batch_size == 1\n    assert cfg.train.use_fast_state is True\n    assert cfg.train.strict_streaming_contract is True\n    assert cfg.train.online_updates is True\n    assert cfg.train.online_boundary_targets is True\n    assert cfg.train.online_carry_attention_cache is True\n    assert cfg.train.fail_if_paper_faithful_disabled is True\n    assert cfg.train.algorithm_mode == \"two_pass_stopgrad_updates\"\n    assert cfg.optim.param_policy == \"all\"\n    build_model_from_cfg(cfg.model)\n\n\ndef test_pilot_selfmod_paper_faithful_config_composes() -> None:\n    cfg = _compose_config(\"pilot_selfmod_paper_faithful\")\n    assert cfg.model.block_variant == \"hope_selfmod\"\n    assert cfg.model.cms_flush_partial_at_end is True\n    assert cfg.model.surprise_threshold is None\n    assert cfg.model.self_mod_use_skip is False\n    assert cfg.data.batch_size == 1\n    assert cfg.train.use_fast_state is True\n    assert cfg.train.strict_streaming_contract is True\n    assert cfg.train.online_updates is True\n    assert cfg.train.online_boundary_targets is True\n    assert cfg.train.online_carry_attention_cache is True\n    assert cfg.train.fail_if_paper_faithful_disabled is True\n    assert cfg.optim.param_policy == \"all\"\n    build_model_from_cfg(cfg.model)\n\n\ndef test_paper_faithful_variants_are_explicitly_paper_defined() -> None:\n    attention_cfg = _compose_config(\"pilot_paper_faithful\")\n    selfmod_cfg = _compose_config(\"pilot_selfmod_paper_faithful\")\n    allowed = {\"hope_attention\", \"hope_selfmod\"}\n    assert attention_cfg.model.block_variant in allowed\n    assert selfmod_cfg.model.block_variant in allowed\n\n\ndef test_pilot_paper_faithful_override_to_boundary_state_mode_applies() -> None:\n    cfg = _compose_config(\n        \"pilot_paper_faithful\",\n        overrides=[\"train.algorithm_mode=boundary_state_grad_through_write\"],\n    )\n    assert cfg.train.algorithm_mode == \"boundary_state_grad_through_write\"\n\n\ndef test_pilot_paper_faithful_never_implicitly_falls_back_to_stopgrad() -> None:\n    cfg = _compose_config(\n        \"pilot_paper_faithful\",\n        overrides=[\"train.algorithm_mode=boundary_state_grad_through_write\"],\n    )\n    assert cfg.train.algorithm_mode != \"two_pass_stopgrad_updates\"\n"
  },
  {
    "path": "tests/test_phase2_memorization_delta.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.memorize import MemorizeConfig, memorize_tokens\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef _tiny_variant(variant: str) -> HOPEModel:\n    titan = LevelSpec(name=\"titan\", update_period=1, optimizer_key=\"titan_opt\")\n    cms = (LevelSpec(name=\"cms_fast\", update_period=1, optimizer_key=\"cms_opt\"),)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=titan,\n        cms_levels=cms,\n        optimizers=None,\n        teach_scale=0.1,\n        block_variant=variant,\n    )\n    return HOPEModel(cfg).eval()\n\n\ndef test_hope_attention_adapts_transformer_does_not() -> None:\n    tokens = torch.randint(0, 32, (1, 16), generator=torch.Generator().manual_seed(1337))\n    cfg = MemorizeConfig(enabled=True, steps=1, use_fast_state=True, paths=(\"cms_fast\",))\n\n    torch.manual_seed(0)\n    hope = _tiny_variant(\"hope_attention\")\n    state = hope.init_fast_state()\n    with torch.no_grad():\n        before = hope(tokens, fast_state=state).detach().clone()\n    stats = memorize_tokens(hope, tokens, cfg, fast_state=state)\n    with torch.no_grad():\n        after = hope(tokens, fast_state=state).detach().clone()\n    assert not torch.allclose(before, after)\n    assert stats[\"cms_fast_update_events\"] > 0.0\n\n    torch.manual_seed(0)\n    transformer = _tiny_variant(\"transformer\")\n    state = transformer.init_fast_state()\n    with torch.no_grad():\n        before = transformer(tokens, fast_state=state).detach().clone()\n    stats = memorize_tokens(transformer, tokens, cfg, fast_state=state)\n    with torch.no_grad():\n        after = transformer(tokens, fast_state=state).detach().clone()\n    assert torch.allclose(before, after, atol=0.0, rtol=0.0)\n    assert stats[\"cms_fast_update_events\"] == 0.0\n\n"
  },
  {
    "path": "tests/test_residual_mlp_memory.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom nested_learning.titan.self_modifying import ResidualMLPMemory\n\n\ndef test_residual_mlp_memory_matches_eq91_when_dims_match() -> None:\n    torch.manual_seed(0)\n    mem = ResidualMLPMemory(in_dim=8, out_dim=8, hidden_dim=8, activation=F.gelu, use_skip=False)\n    assert mem.w_skip is None\n    x = torch.randn(2, 5, 8)\n    with torch.no_grad():\n        expected = x + mem.w1(mem.activation(mem.w2(x)))\n        actual = mem(x)\n    assert torch.allclose(actual, expected, atol=1e-6, rtol=1e-6)\n\n\ndef test_residual_mlp_memory_uses_projection_skip_when_dims_differ() -> None:\n    mem = ResidualMLPMemory(in_dim=8, out_dim=1, hidden_dim=8, activation=F.gelu, use_skip=True)\n    assert mem.w_skip is not None\n\n\ndef test_residual_mlp_memory_disables_projection_skip_in_faithful_mode() -> None:\n    torch.manual_seed(0)\n    mem = ResidualMLPMemory(in_dim=8, out_dim=1, hidden_dim=8, activation=F.gelu, use_skip=False)\n    assert mem.w_skip is None\n    x = torch.randn(2, 5, 8)\n    with torch.no_grad():\n        expected = mem.w1(mem.activation(mem.w2(x)))\n        actual = mem(x)\n    assert torch.allclose(actual, expected, atol=1e-6, rtol=1e-6)\n"
  },
  {
    "path": "tests/test_run_features.py",
    "content": "from __future__ import annotations\n\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import _log_run_features\n\n\nclass _CaptureLogger:\n    def __init__(self) -> None:\n        self.entries: list[tuple[dict[str, object], int]] = []\n\n    def log(self, data: dict[str, object], step: int) -> None:\n        self.entries.append((data, step))\n\n\ndef _tiny_model() -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=64,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=1),),\n        block_variant=\"hope_attention\",\n    )\n    return HOPEModel(cfg)\n\n\ndef _tiny_cfg(algorithm_mode: str) -> object:\n    return OmegaConf.create(\n        {\n            \"train\": {\n                \"mixed_precision\": {\"enabled\": False, \"dtype\": \"bf16\"},\n                \"compile\": {\"enable\": False, \"mode\": \"default\"},\n                \"strict_streaming_contract\": True,\n                \"online_updates\": True,\n                \"online_boundary_targets\": True,\n                \"online_carry_attention_cache\": True,\n                \"use_fast_state\": True,\n                \"algorithm_mode\": algorithm_mode,\n            },\n            \"optim\": {\"param_policy\": \"all\"},\n        }\n    )\n\n\ndef test_run_features_reports_stopgrad_mode_flag() -> None:\n    model = _tiny_model()\n    cfg = _tiny_cfg(\"two_pass_stopgrad_updates\")\n    logger = _CaptureLogger()\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n    _log_run_features(logger, model, cfg, optimizer, torch.device(\"cpu\"))\n    payload, step = logger.entries[-1]\n    assert step == -1\n    assert payload[\"train.algorithm_mode\"] == \"two_pass_stopgrad_updates\"\n    assert payload[\"train.backprop_through_online_writes\"] is False\n\n\ndef test_run_features_reports_boundary_state_mode_flag() -> None:\n    model = _tiny_model()\n    cfg = _tiny_cfg(\"boundary_state_grad_through_write\")\n    logger = _CaptureLogger()\n    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)\n    _log_run_features(logger, model, cfg, optimizer, torch.device(\"cpu\"))\n    payload, _ = logger.entries[-1]\n    assert payload[\"train.algorithm_mode\"] == \"boundary_state_grad_through_write\"\n    assert payload[\"train.backprop_through_online_writes\"] is True\n"
  },
  {
    "path": "tests/test_self_modifying_titans.py",
    "content": "import torch\n\nfrom nested_learning.titan.self_modifying import SelfModifyingTitans, SelfModifyingTitansConfig\n\n\ndef test_self_modifying_titans_forward_shape() -> None:\n    model = SelfModifyingTitans(SelfModifyingTitansConfig(dim=8))\n    x = torch.randn(2, 5, 8)\n    out = model(x)\n    assert out.shape == x.shape\n\n\ndef test_self_modifying_titans_updates_fast_state() -> None:\n    torch.manual_seed(0)\n    model = SelfModifyingTitans(SelfModifyingTitansConfig(dim=8, eta_scale=1.0))\n    x = torch.randn(1, 6, 8)\n    state = model.init_fast_state()\n    before = state.memory.w2.detach().clone()\n    out, updated = model.forward_with_updates(x, state)\n    assert out.shape == (1, 6, 8)\n    assert not torch.allclose(before.unsqueeze(0), updated.memory.w2)\n\n\ndef test_self_modifying_titans_supports_batch_fast_state_updates() -> None:\n    torch.manual_seed(0)\n    model = SelfModifyingTitans(SelfModifyingTitansConfig(dim=8, eta_scale=1.0))\n    x = torch.randn(2, 6, 8)\n    state = model.init_fast_state()\n    out, updated = model.forward_with_updates(x, state)\n    assert out.shape == (2, 6, 8)\n    assert updated.memory.w2.shape == (2, 8, 8)\n    assert not torch.allclose(updated.memory.w2[0], updated.memory.w2[1])\n\n\ndef test_self_modifying_titans_chunked_outputs_match_no_update_with_single_chunk() -> None:\n    torch.manual_seed(0)\n    seq_len = 6\n    model = SelfModifyingTitans(\n        SelfModifyingTitansConfig(\n            dim=8,\n            eta_scale=1.0,\n            chunk_size_other=seq_len,\n            chunk_size_memory=seq_len,\n        )\n    )\n    x = torch.randn(1, seq_len, 8)\n    state = model.init_fast_state()\n    before = state.memory.w2.detach().clone()\n\n    out_no_update = model.forward_with_state(x, state)\n    out_chunked, updated = model.forward_with_updates(x, state)\n\n    assert torch.allclose(out_chunked, out_no_update, atol=1e-6)\n    assert not torch.allclose(before.unsqueeze(0), updated.memory.w2)\n\n\ndef test_self_modifying_titans_flushes_partial_chunks_for_memory_updates() -> None:\n    torch.manual_seed(0)\n    model = SelfModifyingTitans(\n        SelfModifyingTitansConfig(\n            dim=8,\n            eta_scale=1.0,\n            chunk_size_other=1,\n            chunk_size_memory=4,\n        )\n    )\n    state = model.init_fast_state()\n    x = torch.randn(1, 3, 8)\n    before_other = state.k.w2.detach().clone()\n    before_memory = state.memory.w2.detach().clone()\n\n    _out, updated = model.forward_with_updates(x, state)\n\n    assert not torch.allclose(before_other.unsqueeze(0), updated.k.w2)\n    assert not torch.allclose(before_memory.unsqueeze(0), updated.memory.w2)\n"
  },
  {
    "path": "tests/test_selfmod_adaptive_q.py",
    "content": "import torch\n\nfrom nested_learning.titan.self_modifying import SelfModifyingTitans, SelfModifyingTitansConfig\n\n\ndef test_selfmod_fixed_q_does_not_update_q_memory() -> None:\n    torch.manual_seed(0)\n    model = SelfModifyingTitans(SelfModifyingTitansConfig(dim=8, eta_scale=1.0, adaptive_q=False))\n    x = torch.randn(1, 6, 8)\n    state = model.init_fast_state()\n    before = state.q.w2.detach().clone()\n    _out, updated = model.forward_with_updates(x, state)\n    assert torch.allclose(before.unsqueeze(0), updated.q.w2, atol=1e-6, rtol=1e-6)\n\n\ndef test_selfmod_adaptive_q_updates_q_memory() -> None:\n    torch.manual_seed(0)\n    model = SelfModifyingTitans(SelfModifyingTitansConfig(dim=8, eta_scale=1.0, adaptive_q=True))\n    x = torch.randn(1, 6, 8)\n    state = model.init_fast_state()\n    before = state.q.w2.detach().clone()\n    _out, updated = model.forward_with_updates(x, state)\n    assert not torch.allclose(before.unsqueeze(0), updated.q.w2)\n\n"
  },
  {
    "path": "tests/test_selfmod_dgd_linear.py",
    "content": "import torch\n\nfrom nested_learning.titan.self_modifying import (\n    ResidualMLPMemoryState,\n    SelfModifyingTitans,\n    SelfModifyingTitansConfig,\n)\n\n\ndef test_selfmod_linear_memory_l2_grad_matches_analytic() -> None:\n    torch.manual_seed(0)\n    model = SelfModifyingTitans(\n        SelfModifyingTitansConfig(\n            dim=4,\n            eta_scale=1.0,\n            objective=\"l2\",\n            stopgrad_vhat=True,\n            use_rank1_precond=False,\n            use_alpha=False,\n            local_conv_window=None,\n        )\n    )\n    w_skip = torch.randn(4, 4)\n    frozen = ResidualMLPMemoryState(\n        w1=torch.zeros_like(model.m_memory.w1.weight),\n        w2=torch.zeros_like(model.m_memory.w2.weight),\n        w_skip=w_skip.clone(),\n    )\n    k = torch.randn(3, 4)\n    v = torch.randn(3, 4)\n    g1, g2, gskip = model._memory_grads(frozen, k, v)\n    assert gskip is not None\n\n    with torch.no_grad():\n        pred = k @ w_skip.t()\n        vhat = v @ w_skip.t()\n        diff = pred - vhat\n        expected = 2.0 * torch.einsum(\"bi,bj->bij\", diff, k).sum(dim=0)\n\n    assert torch.allclose(gskip, expected, atol=1e-6, rtol=1e-6)\n    assert torch.allclose(g1, torch.zeros_like(g1), atol=1e-6, rtol=1e-6)\n    assert torch.allclose(g2, torch.zeros_like(g2), atol=1e-6, rtol=1e-6)\n\n"
  },
  {
    "path": "tests/test_selfmod_grad_flow.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\n\n\ndef test_hope_selfmod_forward_allows_outer_gradients() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(),\n        block_variant=\"hope_selfmod\",\n    )\n    model = HOPEModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (2, 6))\n    logits = model(tokens)\n    loss = F.cross_entropy(\n        logits[:, :-1].reshape(-1, cfg.vocab_size),\n        tokens[:, 1:].reshape(-1),\n    )\n    loss.backward()\n    block = model.blocks[0]\n    selfmod = getattr(block, \"selfmod\", None)\n    assert selfmod is not None\n    grad = selfmod.m_memory.w1.weight.grad\n    assert grad is not None\n    assert grad.abs().sum().item() > 0.0\n\n"
  },
  {
    "path": "tests/test_selfmod_local_conv.py",
    "content": "import torch\n\nfrom nested_learning.titan.self_modifying import SelfModifyingTitans, SelfModifyingTitansConfig\n\n\ndef test_selfmod_local_conv_is_causal() -> None:\n    torch.manual_seed(0)\n    model = SelfModifyingTitans(SelfModifyingTitansConfig(dim=4, local_conv_window=4))\n    assert model.local_conv is not None\n    with torch.no_grad():\n        model.local_conv.weight.fill_(1.0)\n    x = torch.zeros(1, 6, 4)\n    x[0, 4, 0] = 1.0\n    y = model._apply_local_conv(x)\n    assert torch.allclose(y[0, :4, 0], torch.zeros(4))\n    assert y[0, 4, 0].item() != 0.0\n\n"
  },
  {
    "path": "tests/test_selfmod_online.py",
    "content": "import torch\n\nfrom nested_learning.fast_state import build_block_fast_state\nfrom nested_learning.hope.block import HOPESelfModBlock, HOPESelfModBlockConfig\nfrom nested_learning.levels import LevelSpec\n\n\ndef test_selfmod_updates_on_update_pass_even_with_zero_teach_signal() -> None:\n    cfg = HOPESelfModBlockConfig(\n        dim=8,\n        cms_levels=[LevelSpec(name=\"fast\", update_period=1)],\n        optimizer_configs={},\n        selfmod_online_updates=True,\n    )\n    block = HOPESelfModBlock(cfg)\n    fast_state = build_block_fast_state(\n        titan_module=None,\n        cms_blocks=block.cms.blocks,\n        selfmod_module=block.selfmod,\n        specs=cfg.cms_levels,\n        optimizer_configs={},\n        default_lr=cfg.self_mod_lr,\n    )\n    assert fast_state.selfmod_state is not None\n    x = torch.randn(1, 4, 8)\n    teach = torch.zeros_like(x)\n    before = fast_state.selfmod_state.memory.w1.clone()\n    _ = block(x, teach_signal=teach, fast_state=fast_state)\n    after = fast_state.selfmod_state.memory.w1\n    assert not torch.allclose(before, after)\n"
  },
  {
    "path": "tests/test_strict_streaming_contract.py",
    "content": "import pytest\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.training import (\n    _check_online_supervised_pairs,\n    _resolve_algorithm_mode,\n    _validate_algorithm_mode_constraints,\n    _validate_online_chunking_constraints,\n    _validate_online_update_fast_state_semantics,\n    _validate_paper_auditing_variant,\n)\n\n\ndef test_strict_streaming_contract_rejects_non_paper_variant() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\"strict_streaming_contract\": True},\n            \"model\": {\"block_variant\": \"hope_hybrid\"},\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"paper-defined HOPE variant\"):\n        _validate_paper_auditing_variant(cfg)\n\n\ndef test_non_strict_streaming_contract_warns_for_non_paper_variant(\n    capsys: pytest.CaptureFixture[str],\n) -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\"strict_streaming_contract\": False},\n            \"model\": {\"block_variant\": \"hope_hybrid\"},\n        }\n    )\n    _validate_paper_auditing_variant(cfg)\n    out = capsys.readouterr().out\n    assert \"warning_code\" in out\n    assert \"non_paper_variant\" in out\n\n\ndef test_strict_streaming_contract_allows_paper_defined_variants() -> None:\n    for variant in (\"hope_attention\", \"hope_selfmod\"):\n        cfg = OmegaConf.create(\n            {\n                \"train\": {\"strict_streaming_contract\": True},\n                \"model\": {\"block_variant\": variant},\n            }\n        )\n        _validate_paper_auditing_variant(cfg)\n\n\ndef test_online_updates_without_fast_state_warns_when_not_strict(\n    capsys: pytest.CaptureFixture[str],\n) -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"online_updates\": True,\n                \"use_fast_state\": False,\n                \"strict_streaming_contract\": False,\n                \"fail_if_paper_faithful_disabled\": False,\n            }\n        }\n    )\n    _validate_online_update_fast_state_semantics(cfg)\n    out = capsys.readouterr().out\n    assert \"warning_code\" in out\n    assert \"online_updates_without_fast_state\" in out\n\n\ndef test_online_updates_without_fast_state_fails_in_strict_mode() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"online_updates\": True,\n                \"use_fast_state\": False,\n                \"strict_streaming_contract\": True,\n                \"fail_if_paper_faithful_disabled\": False,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"train.online_updates=true\"):\n        _validate_online_update_fast_state_semantics(cfg)\n\n\ndef test_online_supervised_pairs_mismatch_warns_when_not_strict(\n    capsys: pytest.CaptureFixture[str],\n) -> None:\n    _check_online_supervised_pairs(strict=False, observed_pairs=3, seq_len=8)\n    out = capsys.readouterr().out\n    assert \"online_supervision_mismatch\" in out\n\n\ndef test_online_supervised_pairs_mismatch_fails_in_strict_mode() -> None:\n    with pytest.raises(RuntimeError, match=\"online chunk supervision mismatch\"):\n        _check_online_supervised_pairs(strict=True, observed_pairs=3, seq_len=8)\n\n\ndef test_algorithm_mode_defaults_to_two_pass_stopgrad_updates() -> None:\n    cfg = OmegaConf.create({\"train\": {}})\n    assert _resolve_algorithm_mode(cfg) == \"two_pass_stopgrad_updates\"\n\n\ndef test_algorithm_mode_rejects_unknown_values() -> None:\n    cfg = OmegaConf.create({\"train\": {\"algorithm_mode\": \"unknown\"}})\n    with pytest.raises(RuntimeError, match=\"Unsupported train.algorithm_mode\"):\n        _resolve_algorithm_mode(cfg)\n\n\ndef test_algorithm_mode_accepts_boundary_state_mode_name() -> None:\n    cfg = OmegaConf.create({\"train\": {\"algorithm_mode\": \"boundary_state_grad_through_write\"}})\n    assert _resolve_algorithm_mode(cfg) == \"boundary_state_grad_through_write\"\n\n\ndef test_boundary_state_mode_requires_online_per_layer_and_fast_state() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"algorithm_mode\": \"boundary_state_grad_through_write\",\n                \"online_updates\": False,\n                \"per_layer_teach_signal\": False,\n                \"use_fast_state\": False,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"online_updates=true\"):\n        _validate_algorithm_mode_constraints(\n            cfg,\n            algorithm_mode=\"boundary_state_grad_through_write\",\n            distributed=False,\n        )\n\n\ndef test_boundary_state_mode_rejects_distributed() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"algorithm_mode\": \"boundary_state_grad_through_write\",\n                \"online_updates\": True,\n                \"per_layer_teach_signal\": True,\n                \"use_fast_state\": True,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"not supported in DDP\"):\n        _validate_algorithm_mode_constraints(\n            cfg,\n            algorithm_mode=\"boundary_state_grad_through_write\",\n            distributed=True,\n        )\n\n\ndef test_boundary_state_mode_emits_experimental_warning(\n    capsys: pytest.CaptureFixture[str],\n) -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"algorithm_mode\": \"boundary_state_grad_through_write\",\n                \"online_updates\": True,\n                \"per_layer_teach_signal\": True,\n                \"use_fast_state\": True,\n                \"online_boundary_targets\": True,\n                \"online_carry_attention_cache\": True,\n            }\n        }\n    )\n    _validate_algorithm_mode_constraints(\n        cfg,\n        algorithm_mode=\"boundary_state_grad_through_write\",\n        distributed=False,\n    )\n    out = capsys.readouterr().out\n    assert \"experimental_boundary_state_mode\" in out\n\n\ndef test_online_cache_requires_boundary_targets() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"online_updates\": True,\n                \"online_boundary_targets\": False,\n                \"online_carry_attention_cache\": True,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"online_boundary_targets=true\"):\n        _validate_online_chunking_constraints(cfg)\n\n\ndef test_online_cache_requires_online_updates() -> None:\n    cfg = OmegaConf.create(\n        {\n            \"train\": {\n                \"online_updates\": False,\n                \"online_boundary_targets\": True,\n                \"online_carry_attention_cache\": True,\n            }\n        }\n    )\n    with pytest.raises(RuntimeError, match=\"online_updates=true\"):\n        _validate_online_chunking_constraints(cfg)\n"
  },
  {
    "path": "tests/test_surprise_metric.py",
    "content": "import torch\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import compute_teach_signal\n\n\ndef _cms_delta_l1(state, level_name: str) -> float:\n    params = state.blocks[0].cms_params[level_name]\n    return float(sum(delta.abs().sum().item() for delta in params.values()))\n\n\ndef _logit_entropy(logits: torch.Tensor) -> float:\n    logits_detached = logits[:, :-1].detach().float()\n    probs = torch.softmax(logits_detached, dim=-1)\n    entropy = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()\n    return float(entropy.item())\n\n\ndef _next_token_loss(logits: torch.Tensor, tokens: torch.Tensor) -> float:\n    loss = torch.nn.functional.cross_entropy(\n        logits[:, :-1].reshape(-1, logits.size(-1)),\n        tokens[:, 1:].reshape(-1),\n    )\n    return float(loss.detach().item())\n\n\ndef test_surprise_metric_loss_gates_updates_when_threshold_set() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"hope_attention\",\n        surprise_metric=\"loss\",\n        surprise_threshold=0.0,\n    )\n    model = HOPEModel(cfg).eval()\n    state = model.init_fast_state()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n\n    with torch.no_grad():\n        logits = model(tokens, fast_state=state)\n        teach = compute_teach_signal(model, logits, tokens)\n        loss_value = _next_token_loss(logits, tokens)\n        model.set_surprise_threshold(loss_value + 1.0)\n        _ = model(\n            tokens,\n            teach_signal=teach,\n            surprise_value=loss_value,\n            fast_state=state,\n        )\n    assert _cms_delta_l1(state, \"cms_fast\") == 0.0\n\n    with torch.no_grad():\n        model.set_surprise_threshold(loss_value - 1.0)\n        _ = model(\n            tokens,\n            teach_signal=teach,\n            surprise_value=loss_value,\n            fast_state=state,\n        )\n    assert _cms_delta_l1(state, \"cms_fast\") > 0.0\n\n\ndef test_surprise_metric_entropy_gates_updates_when_threshold_set() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"hope_attention\",\n        surprise_metric=\"logit_entropy\",\n        surprise_threshold=0.0,\n    )\n    model = HOPEModel(cfg).eval()\n    state = model.init_fast_state()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n\n    with torch.no_grad():\n        logits = model(tokens, fast_state=state)\n        teach = compute_teach_signal(model, logits, tokens)\n        entropy_value = _logit_entropy(logits)\n        model.set_surprise_threshold(entropy_value + 1.0)\n        _ = model(\n            tokens,\n            teach_signal=teach,\n            surprise_value=entropy_value,\n            fast_state=state,\n        )\n    assert _cms_delta_l1(state, \"cms_fast\") == 0.0\n\n    with torch.no_grad():\n        model.set_surprise_threshold(entropy_value - 1.0)\n        _ = model(\n            tokens,\n            teach_signal=teach,\n            surprise_value=entropy_value,\n            fast_state=state,\n        )\n    assert _cms_delta_l1(state, \"cms_fast\") > 0.0\n\n\ndef test_surprise_metric_requires_external_value_when_threshold_set() -> None:\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=2),),\n        block_variant=\"hope_attention\",\n        surprise_metric=\"loss\",\n        surprise_threshold=0.1,\n    )\n    model = HOPEModel(cfg).eval()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n    state = model.init_fast_state()\n    with torch.no_grad():\n        logits = model(tokens, fast_state=state)\n        teach = compute_teach_signal(model, logits, tokens)\n        try:\n            _ = model(tokens, teach_signal=teach, fast_state=state)\n        except ValueError as err:\n            assert \"requires passing surprise_value\" in str(err)\n        else:\n            raise AssertionError(\"Expected ValueError when surprise_value is omitted\")\n\n\ndef test_surprise_metric_l2_uses_chunk_gate_then_token_mask() -> None:\n    torch.manual_seed(0)\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=1),),\n        block_variant=\"hope_attention\",\n        surprise_metric=\"l2\",\n        surprise_threshold=4.0,\n    )\n    model = HOPEModel(cfg).eval()\n    tokens = torch.randint(0, cfg.vocab_size, (1, 8))\n\n    with torch.no_grad():\n        state_a = model.init_fast_state()\n        logits = model(tokens, fast_state=state_a)\n        teach = compute_teach_signal(model, logits, tokens)\n        sparse_teach = torch.zeros_like(teach)\n        sparse_teach[:, 0, :] = 6.0\n        # Mean L2 surprise stays below threshold, so chunk-level gate suppresses updates.\n        _ = model(tokens, teach_signal=sparse_teach, fast_state=state_a)\n        assert _cms_delta_l1(state_a, \"cms_fast\") == 0.0\n\n        state_b = model.init_fast_state()\n        # Force chunk-level gate open; token-level mask still keeps only high-norm positions.\n        _ = model(\n            tokens,\n            teach_signal=sparse_teach,\n            surprise_value=10.0,\n            fast_state=state_b,\n        )\n        assert _cms_delta_l1(state_b, \"cms_fast\") > 0.0\n"
  },
  {
    "path": "tests/test_surprise_override.py",
    "content": "import torch\n\nfrom nested_learning.training import _compute_surprise_override\n\n\ndef _entropy(logits: torch.Tensor) -> float:\n    probs = torch.softmax(logits.float(), dim=-1)\n    ent = -(probs * torch.log(probs.clamp(min=1e-9))).sum(dim=-1).mean()\n    return float(ent.item())\n\n\ndef test_logit_entropy_surprise_uses_boundary_target_step_when_present() -> None:\n    torch.manual_seed(0)\n    logits = torch.randn(1, 4, 13)\n    tokens = torch.randint(0, 13, (1, 4))\n    value = _compute_surprise_override(\n        \"logit_entropy\",\n        logits=logits,\n        tokens=tokens,\n        loss=torch.tensor(1.0),\n        next_tokens=torch.randint(0, 13, (1,)),\n    )\n    assert value is not None\n    assert abs(value - _entropy(logits[:, :4])) < 1e-8\n\n\ndef test_logit_entropy_surprise_default_excludes_last_unsupervised_step() -> None:\n    torch.manual_seed(1)\n    logits = torch.randn(1, 5, 11)\n    tokens = torch.randint(0, 11, (1, 5))\n    value = _compute_surprise_override(\n        \"logit_entropy\",\n        logits=logits,\n        tokens=tokens,\n        loss=torch.tensor(1.0),\n    )\n    assert value is not None\n    assert abs(value - _entropy(logits[:, :4])) < 1e-8\n\n\ndef test_logit_entropy_surprise_returns_none_when_no_supervised_steps() -> None:\n    logits = torch.randn(1, 1, 7)\n    tokens = torch.randint(0, 7, (1, 1))\n    value = _compute_surprise_override(\n        \"logit_entropy\",\n        logits=logits,\n        tokens=tokens,\n        loss=torch.tensor(1.0),\n    )\n    assert value is None\n"
  },
  {
    "path": "tests/test_teach_signal.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.titan.model import TitanOnlyModel, TitanOnlyModelConfig\nfrom nested_learning.training import _compute_layer_teach_signals, compute_teach_signal\n\n\ndef _tiny_config() -> ModelConfig:\n    titan = LevelSpec(name=\"titan\", update_period=2, optimizer_key=\"titan_opt\")\n    cms = [\n        LevelSpec(name=\"cms_fast\", update_period=1, optimizer_key=\"cms_opt\"),\n        LevelSpec(name=\"cms_mid\", update_period=4, optimizer_key=\"cms_opt\"),\n    ]\n    return ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=2,\n        heads=4,\n        titan_level=titan,\n        cms_levels=cms,\n        optimizers=None,\n        teach_scale=0.1,\n    )\n\n\ndef _tiny_titan_config() -> TitanOnlyModelConfig:\n    titan = LevelSpec(name=\"titan\", update_period=2, optimizer_key=\"titan_opt\")\n    return TitanOnlyModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=2,\n        heads=4,\n        titan_level=titan,\n        optimizers=None,\n        teach_scale=0.1,\n    )\n\n\ndef test_teach_signal_matches_gradient() -> None:\n    torch.manual_seed(0)\n    cfg = _tiny_config()\n    model = HOPEModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (2, 6))\n\n    hidden_cache: dict[str, torch.Tensor] = {}\n\n    def hook(_, __, output: torch.Tensor) -> None:\n        output.retain_grad()\n        hidden_cache[\"hidden\"] = output\n\n    handle = model.norm.register_forward_hook(hook)\n    logits = model(tokens)\n    teach_signal = compute_teach_signal(model, logits, tokens)\n    loss = F.cross_entropy(\n        logits[:, :-1].reshape(-1, cfg.vocab_size),\n        tokens[:, 1:].reshape(-1),\n    )\n    loss.backward()\n    handle.remove()\n\n    hidden = hidden_cache[\"hidden\"]\n    assert hidden.grad is not None\n    grad = hidden.grad\n    assert torch.allclose(teach_signal, grad, atol=1e-5, rtol=1e-4)\n\n\ndef test_teach_signal_matches_gradient_titan() -> None:\n    torch.manual_seed(0)\n    cfg = _tiny_titan_config()\n    model = TitanOnlyModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (2, 6))\n\n    hidden_cache: dict[str, torch.Tensor] = {}\n\n    def hook(_, __, output: torch.Tensor) -> None:\n        output.retain_grad()\n        hidden_cache[\"hidden\"] = output\n\n    handle = model.norm.register_forward_hook(hook)\n    logits = model(tokens)\n    teach_signal = compute_teach_signal(model, logits, tokens)\n    loss = F.cross_entropy(\n        logits[:, :-1].reshape(-1, cfg.vocab_size),\n        tokens[:, 1:].reshape(-1),\n    )\n    loss.backward()\n    handle.remove()\n\n    hidden = hidden_cache[\"hidden\"]\n    assert hidden.grad is not None\n    grad = hidden.grad\n    assert torch.allclose(teach_signal, grad, atol=1e-5, rtol=1e-4)\n\n\ndef test_per_layer_teach_signal_shapes() -> None:\n    torch.manual_seed(0)\n    cfg = _tiny_config()\n    model = HOPEModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (2, 6))\n    logits, _pre, block_outputs = model.forward_with_block_outputs(tokens)\n    loss = F.cross_entropy(\n        logits[:, :-1].reshape(-1, cfg.vocab_size),\n        tokens[:, 1:].reshape(-1),\n    )\n    teach_signals = _compute_layer_teach_signals(loss, block_outputs)\n    assert len(teach_signals) == cfg.num_layers\n    for signal, output in zip(teach_signals, block_outputs):\n        assert signal.shape == output.shape\n\n\ndef test_per_layer_teach_signal_matches_autograd_grads() -> None:\n    torch.manual_seed(0)\n    cfg = _tiny_config()\n    model = HOPEModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (2, 6))\n    logits, _pre, block_outputs = model.forward_with_block_outputs(tokens)\n    loss = F.cross_entropy(\n        logits[:, :-1].reshape(-1, cfg.vocab_size),\n        tokens[:, 1:].reshape(-1),\n    )\n    expected = torch.autograd.grad(loss, block_outputs, retain_graph=True, allow_unused=False)\n    teach_signals = _compute_layer_teach_signals(loss, block_outputs)\n    for exp, actual in zip(expected, teach_signals):\n        assert torch.allclose(actual, exp.detach(), atol=1e-6, rtol=1e-6)\n\n\ndef test_teach_signal_matches_gradient_with_ignore_index() -> None:\n    torch.manual_seed(0)\n    cfg = _tiny_config()\n    model = HOPEModel(cfg)\n    tokens = torch.randint(1, cfg.vocab_size, (2, 6))\n    tokens[0, 2] = 0\n    tokens[1, 4] = 0\n\n    hidden_cache: dict[str, torch.Tensor] = {}\n\n    def hook(_, __, output: torch.Tensor) -> None:\n        output.retain_grad()\n        hidden_cache[\"hidden\"] = output\n\n    handle = model.norm.register_forward_hook(hook)\n    logits = model(tokens)\n    teach_signal = compute_teach_signal(model, logits, tokens, ignore_index=0)\n    loss = F.cross_entropy(\n        logits[:, :-1].reshape(-1, cfg.vocab_size),\n        tokens[:, 1:].reshape(-1),\n        ignore_index=0,\n    )\n    loss.backward()\n    handle.remove()\n\n    hidden = hidden_cache[\"hidden\"]\n    assert hidden.grad is not None\n    grad = hidden.grad\n    assert torch.allclose(teach_signal, grad, atol=1e-5, rtol=1e-4)\n\n    ignored = tokens[:, 1:] == 0\n    masked = teach_signal[:, :-1][ignored]\n    assert torch.allclose(masked, torch.zeros_like(masked))\n\n\ndef test_teach_signal_matches_gradient_with_boundary_target() -> None:\n    torch.manual_seed(0)\n    cfg = _tiny_config()\n    model = HOPEModel(cfg)\n    tokens = torch.randint(0, cfg.vocab_size, (2, 5))\n    next_tokens = torch.randint(0, cfg.vocab_size, (2,))\n\n    hidden_cache: dict[str, torch.Tensor] = {}\n\n    def hook(_, __, output: torch.Tensor) -> None:\n        output.retain_grad()\n        hidden_cache[\"hidden\"] = output\n\n    handle = model.norm.register_forward_hook(hook)\n    logits = model(tokens)\n    teach_signal = compute_teach_signal(model, logits, tokens, next_tokens=next_tokens)\n    targets = torch.cat([tokens[:, 1:], next_tokens.unsqueeze(1)], dim=1)\n    loss = F.cross_entropy(\n        logits.reshape(-1, cfg.vocab_size),\n        targets.reshape(-1),\n    )\n    loss.backward()\n    handle.remove()\n\n    hidden = hidden_cache[\"hidden\"]\n    assert hidden.grad is not None\n    grad = hidden.grad\n    assert torch.allclose(teach_signal, grad, atol=1e-5, rtol=1e-4)\n\n\ndef test_teach_signal_matches_gradient_with_boundary_target_and_ignore_index() -> None:\n    torch.manual_seed(0)\n    cfg = _tiny_config()\n    model = HOPEModel(cfg)\n    tokens = torch.randint(1, cfg.vocab_size, (2, 5))\n    next_tokens = torch.randint(1, cfg.vocab_size, (2,))\n    # Mark one in-sequence target and one boundary target as ignored.\n    tokens[0, 2] = 0\n    next_tokens[1] = 0\n\n    hidden_cache: dict[str, torch.Tensor] = {}\n\n    def hook(_, __, output: torch.Tensor) -> None:\n        output.retain_grad()\n        hidden_cache[\"hidden\"] = output\n\n    handle = model.norm.register_forward_hook(hook)\n    logits = model(tokens)\n    teach_signal = compute_teach_signal(\n        model,\n        logits,\n        tokens,\n        next_tokens=next_tokens,\n        ignore_index=0,\n    )\n    targets = torch.cat([tokens[:, 1:], next_tokens.unsqueeze(1)], dim=1)\n    loss = F.cross_entropy(\n        logits.reshape(-1, cfg.vocab_size),\n        targets.reshape(-1),\n        ignore_index=0,\n    )\n    loss.backward()\n    handle.remove()\n\n    hidden = hidden_cache[\"hidden\"]\n    assert hidden.grad is not None\n    grad = hidden.grad\n    assert torch.allclose(teach_signal, grad, atol=1e-5, rtol=1e-4)\n"
  },
  {
    "path": "tests/test_tied_weight_guard.py",
    "content": "import pytest\nimport torch\nfrom omegaconf import OmegaConf\n\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.training import _validate_tied_lm_head_for_paper_auditing\n\n\ndef _tiny_hope_model() -> HOPEModel:\n    cfg = ModelConfig(\n        vocab_size=32,\n        dim=8,\n        num_layers=1,\n        heads=2,\n        titan_level=LevelSpec(name=\"titan\", update_period=1),\n        cms_levels=(LevelSpec(name=\"cms_fast\", update_period=1),),\n        block_variant=\"hope_attention\",\n    )\n    return HOPEModel(cfg)\n\n\ndef test_paper_auditing_guard_accepts_tied_weights() -> None:\n    model = _tiny_hope_model()\n    cfg = OmegaConf.create({\"train\": {\"strict_streaming_contract\": True}})\n    _validate_tied_lm_head_for_paper_auditing(cfg, model)\n\n\ndef test_paper_auditing_guard_rejects_untied_weights() -> None:\n    model = _tiny_hope_model()\n    model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.detach().clone())\n    cfg = OmegaConf.create({\"train\": {\"strict_streaming_contract\": True}})\n    with pytest.raises(RuntimeError, match=\"requires tied LM head\"):\n        _validate_tied_lm_head_for_paper_auditing(cfg, model)\n"
  },
  {
    "path": "tests/test_variants.py",
    "content": "import torch\n\nfrom nested_learning.hope.block import HOPEAttentionBlock, HOPEBlock, HOPESelfModBlock\nfrom nested_learning.levels import LevelSpec\nfrom nested_learning.model import HOPEModel, ModelConfig\nfrom nested_learning.titan.memory import TitanMemory\nfrom nested_learning.transformer import TransformerBlock\n\n\ndef _base_cfg(*, block_variant: str) -> ModelConfig:\n    titan = LevelSpec(name=\"titan\", update_period=1, optimizer_key=\"titan_opt\")\n    cms = [LevelSpec(name=\"cms_fast\", update_period=1, optimizer_key=\"cms_opt\")]\n    return ModelConfig(\n        vocab_size=32,\n        dim=16,\n        num_layers=1,\n        heads=4,\n        titan_level=titan,\n        cms_levels=cms,\n        block_variant=block_variant,\n        optimizers=None,\n    )\n\n\ndef test_hope_hybrid_variant_contains_titan_memory() -> None:\n    model = HOPEModel(_base_cfg(block_variant=\"hope_hybrid\"))\n    block = model.blocks[0]\n    assert isinstance(block, HOPEBlock)\n    assert isinstance(block.titan_memory, TitanMemory)\n\n\ndef test_hope_attention_variant_excludes_titan_memory() -> None:\n    model = HOPEModel(_base_cfg(block_variant=\"hope_attention\"))\n    block = model.blocks[0]\n    assert isinstance(block, HOPEAttentionBlock)\n    assert not hasattr(block, \"titan_memory\")\n\n    tokens = torch.randint(0, model.config.vocab_size, (2, 5))\n    logits = model(tokens)\n    assert logits.shape == (2, 5, model.config.vocab_size)\n\n\ndef test_hope_selfmod_variant_excludes_titan_memory() -> None:\n    model = HOPEModel(_base_cfg(block_variant=\"hope_selfmod\"))\n    block = model.blocks[0]\n    assert isinstance(block, HOPESelfModBlock)\n    assert not hasattr(block, \"titan_memory\")\n    assert hasattr(block, \"selfmod\")\n\n    fast_state = model.init_fast_state()\n    tokens = torch.randint(0, model.config.vocab_size, (1, 5))\n    logits = model(tokens, fast_state=fast_state)\n    assert logits.shape == (1, 5, model.config.vocab_size)\n\n\ndef test_transformer_variant_runs_with_and_without_fast_state() -> None:\n    model = HOPEModel(_base_cfg(block_variant=\"transformer\"))\n    block = model.blocks[0]\n    assert isinstance(block, TransformerBlock)\n\n    tokens = torch.randint(0, model.config.vocab_size, (2, 5))\n    logits = model(tokens)\n    assert logits.shape == (2, 5, model.config.vocab_size)\n\n    fast_state = model.init_fast_state()\n    logits_fast = model(tokens, fast_state=fast_state)\n    assert logits_fast.shape == (2, 5, model.config.vocab_size)\n"
  },
  {
    "path": "tests/test_verify_docs_refs.py",
    "content": "import importlib.util\nfrom pathlib import Path\n\n\ndef _load_verify_docs_refs():\n    script_path = Path(__file__).resolve().parents[1] / \"scripts\" / \"checks\" / \"verify_docs_refs.py\"\n    spec = importlib.util.spec_from_file_location(\"verify_docs_refs\", script_path)\n    if spec is None or spec.loader is None:\n        raise RuntimeError(\"failed to load verify_docs_refs script\")\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)  # type: ignore[union-attr]\n    return module\n\n\ndef test_parse_referenced_paths_extracts_repo_paths() -> None:\n    module = _load_verify_docs_refs()\n    text = \"\"\"\nUse `scripts/checks/run_fidelity_ci_subset.sh` and `uv run pytest`.\nAlso see [status](docs/IMPLEMENTATION_STATUS.md) and `README.md`.\nCode pointer `src/nested_learning/training.py:225`.\nIgnore `https://example.com` and `--flag`.\n\"\"\"\n    refs = module.parse_referenced_paths(text)\n    assert \"scripts/checks/run_fidelity_ci_subset.sh\" in refs\n    assert \"docs/IMPLEMENTATION_STATUS.md\" in refs\n    assert \"README.md\" in refs\n    assert \"src/nested_learning/training.py\" in refs\n    assert \"https://example.com\" not in refs\n\n\ndef test_verify_docs_refs_reports_missing_paths(tmp_path: Path) -> None:\n    module = _load_verify_docs_refs()\n    (tmp_path / \"scripts\" / \"checks\").mkdir(parents=True)\n    existing = tmp_path / \"scripts\" / \"checks\" / \"ok.sh\"\n    existing.write_text(\"#!/usr/bin/env bash\\n\", encoding=\"utf-8\")\n    doc = tmp_path / \"doc.md\"\n    doc.write_text(\n        \"`scripts/checks/ok.sh` `scripts/checks/missing.sh`\",\n        encoding=\"utf-8\",\n    )\n    missing, missing_anchors = module.verify_docs_refs(repo_root=tmp_path, docs=[doc])\n    assert str(doc) in missing\n    assert missing[str(doc)] == [\"scripts/checks/missing.sh\"]\n    assert missing_anchors == {}\n\n\ndef test_verify_docs_refs_validates_markdown_anchors(tmp_path: Path) -> None:\n    module = _load_verify_docs_refs()\n    target = tmp_path / \"docs\" / \"guide.md\"\n    target.parent.mkdir(parents=True)\n    target.write_text(\n        \"# Overview\\n\\n## Streaming Contract\\n\",\n        encoding=\"utf-8\",\n    )\n    doc = tmp_path / \"doc.md\"\n    doc.write_text(\n        \"[ok](docs/guide.md#overview) [bad](docs/guide.md#missing-anchor)\",\n        encoding=\"utf-8\",\n    )\n    missing, missing_anchors = module.verify_docs_refs(repo_root=tmp_path, docs=[doc])\n    assert missing == {}\n    assert str(doc) in missing_anchors\n    assert missing_anchors[str(doc)] == [\"docs/guide.md#missing-anchor\"]\n"
  },
  {
    "path": "tests/test_verify_update_cadence.py",
    "content": "import importlib.util\nimport json\nfrom pathlib import Path\n\n\ndef _load_verify_cadence():\n    script_path = (\n        Path(__file__).resolve().parents[1] / \"scripts\" / \"checks\" / \"verify_update_cadence.py\"\n    )\n    spec = importlib.util.spec_from_file_location(\"verify_update_cadence\", script_path)\n    if spec is None or spec.loader is None:\n        raise RuntimeError(\"failed to load verify_update_cadence script\")\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)  # type: ignore[union-attr]\n    return module.verify_cadence\n\n\ndef _write_log(path: Path, payload: dict[str, float]) -> None:\n    path.write_text(json.dumps([{\"step\": 0, **payload}], indent=2))\n\n\ndef test_verify_update_cadence_no_flush(tmp_path: Path) -> None:\n    verify_cadence = _load_verify_cadence()\n    log_path = tmp_path / \"metrics.json\"\n    _write_log(\n        log_path,\n        {\n            \"layer0.cms.cms_fast.updates_applied\": 2.0,\n            \"layer0.cms.cms_fast.chunk_tokens\": 8.0,\n            \"layer0.cms.cms_fast.tokens_flushed\": 0.0,\n            \"layer0.cms.cms_fast.pending_tokens\": 2.0,\n        },\n    )\n    report = verify_cadence(\n        log_path=log_path,\n        metric_prefix=\"layer0.cms.cms_fast\",\n        total_tokens=10,\n        update_period=4,\n        flush_partial=False,\n    )\n    assert report[\"ok\"] is True\n\n\ndef test_verify_update_cadence_with_flush(tmp_path: Path) -> None:\n    verify_cadence = _load_verify_cadence()\n    log_path = tmp_path / \"metrics.json\"\n    _write_log(\n        log_path,\n        {\n            \"layer0.cms.cms_fast.updates_applied\": 3.0,\n            \"layer0.cms.cms_fast.chunk_tokens\": 10.0,\n            \"layer0.cms.cms_fast.tokens_flushed\": 2.0,\n            \"layer0.cms.cms_fast.pending_tokens\": 0.0,\n        },\n    )\n    report = verify_cadence(\n        log_path=log_path,\n        metric_prefix=\"layer0.cms.cms_fast\",\n        total_tokens=10,\n        update_period=4,\n        flush_partial=True,\n    )\n    assert report[\"ok\"] is True\n\n\ndef test_verify_update_cadence_detects_mismatch(tmp_path: Path) -> None:\n    verify_cadence = _load_verify_cadence()\n    log_path = tmp_path / \"metrics.json\"\n    _write_log(\n        log_path,\n        {\n            \"layer0.cms.cms_fast.updates_applied\": 1.0,\n            \"layer0.cms.cms_fast.chunk_tokens\": 4.0,\n            \"layer0.cms.cms_fast.tokens_flushed\": 0.0,\n            \"layer0.cms.cms_fast.pending_tokens\": 0.0,\n        },\n    )\n    report = verify_cadence(\n        log_path=log_path,\n        metric_prefix=\"layer0.cms.cms_fast\",\n        total_tokens=10,\n        update_period=4,\n        flush_partial=False,\n    )\n    assert report[\"ok\"] is False\n\n\ndef test_verify_update_cadence_report_schema_is_non_empty(tmp_path: Path) -> None:\n    verify_cadence = _load_verify_cadence()\n    log_path = tmp_path / \"metrics.json\"\n    _write_log(\n        log_path,\n        {\n            \"layer0.cms.cms_fast.updates_applied\": 2.0,\n            \"layer0.cms.cms_fast.chunk_tokens\": 8.0,\n            \"layer0.cms.cms_fast.tokens_flushed\": 0.0,\n            \"layer0.cms.cms_fast.pending_tokens\": 2.0,\n        },\n    )\n    report = verify_cadence(\n        log_path=log_path,\n        metric_prefix=\"layer0.cms.cms_fast\",\n        total_tokens=10,\n        update_period=4,\n        flush_partial=False,\n    )\n    for key in (\n        \"ok\",\n        \"metric_prefix\",\n        \"log_path\",\n        \"expected\",\n        \"observed\",\n        \"checks\",\n    ):\n        assert key in report\n    assert report[\"metric_prefix\"] == \"layer0.cms.cms_fast\"\n    assert isinstance(report[\"checks\"], dict)\n    assert report[\"checks\"]\n"
  },
  {
    "path": "train.py",
    "content": "from __future__ import annotations\n\nimport hydra\nfrom omegaconf import DictConfig\n\nfrom nested_learning.device import resolve_device\nfrom nested_learning.training import run_training_loop, unwrap_config\n\n\n@hydra.main(config_path=\"configs\", config_name=\"pilot\", version_base=None)\ndef main(cfg: DictConfig) -> None:\n    cfg = unwrap_config(cfg)\n    device = resolve_device(cfg.train.device)\n    run_training_loop(cfg, device=device, distributed=False)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "train_deepspeed.py",
    "content": "from __future__ import annotations\n\nimport json\nimport os\nfrom pathlib import Path\n\nimport hydra\nimport torch\nimport torch.distributed as dist\nfrom omegaconf import DictConfig\n\nfrom nested_learning.logging_utils import NullLogger, init_logger\nfrom nested_learning.training import (\n    DistributedContext,\n    _seed_everything,\n    build_dataloader,\n    build_model_from_cfg,\n    compute_teach_signal,\n    unwrap_config,\n)\n\ntry:\n    import deepspeed\nexcept ImportError as exc:  # pragma: no cover - optional dependency\n    raise RuntimeError(\n        \"DeepSpeed is not installed. Install it in this environment to use train_deepspeed.py.\"\n    ) from exc\n\n\ndef setup_distributed() -> DistributedContext:\n    deepspeed.init_distributed()\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n    torch.cuda.set_device(local_rank)\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n    device = torch.device(f\"cuda:{local_rank}\")\n    return DistributedContext(rank=rank, world_size=world_size, device=device)\n\n\ndef load_ds_config(path: str | Path) -> dict:\n    with open(path, \"r\", encoding=\"utf-8\") as handle:\n        return json.load(handle)\n\n\n@hydra.main(config_path=\"configs\", config_name=\"hope/target\", version_base=None)\ndef main(cfg: DictConfig) -> None:\n    cfg = unwrap_config(cfg)\n    dist_ctx = setup_distributed()\n    train_seed = cfg.train.get(\"seed\")\n    deterministic = cfg.train.get(\"deterministic\", False)\n    if train_seed is not None:\n        _seed_everything(int(train_seed), deterministic=bool(deterministic))\n    model = build_model_from_cfg(cfg.model)\n    ds_config = load_ds_config(cfg.deepspeed.config)\n    engine, optimizer, _, _ = deepspeed.initialize(\n        model=model,\n        model_parameters=[p for p in model.parameters() if p.requires_grad],\n        config=ds_config,\n    )\n\n    train_seed = cfg.train.get(\"seed\")\n    loader_seed = None if train_seed is None else int(train_seed) + dist_ctx.rank\n    dataloader, sampler = build_dataloader(\n        cfg.data,\n        distributed=True,\n        dist_ctx=dist_ctx,\n        seed=loader_seed,\n    )\n    logger = (\n        init_logger(getattr(cfg, \"logging\", None), cfg) if engine.global_rank == 0 else NullLogger()\n    )\n    steps = cfg.train.steps\n    log_interval = cfg.train.get(\"log_interval\", 10)\n    checkpoint_cfg = cfg.train.get(\"checkpoint\", {})\n    ckpt_dir = Path(checkpoint_cfg.get(\"dir\", \"checkpoints/deepspeed\"))\n\n    if checkpoint_cfg.get(\"resume_tag\"):\n        tag = checkpoint_cfg[\"resume_tag\"]\n        engine.load_checkpoint(str(ckpt_dir), tag=tag)\n        if engine.global_rank == 0:\n            print(f\"[DeepSpeed] Resumed from {ckpt_dir} tag={tag}\")\n\n    step_iter = iter(dataloader)\n    epoch = 0\n    for step in range(steps):\n        if sampler is not None and step % len(dataloader) == 0:\n            sampler.set_epoch(epoch)\n            epoch += 1\n        try:\n            batch = next(step_iter)\n        except StopIteration:\n            step_iter = iter(dataloader)\n            batch = next(step_iter)\n        tokens = batch.to(dist_ctx.device)\n        logits = engine(tokens)\n        loss = torch.nn.functional.cross_entropy(\n            logits[:, :-1].reshape(-1, logits.size(-1)), tokens[:, 1:].reshape(-1)\n        )\n        engine.backward(loss)\n        engine.step()\n        with torch.no_grad():\n            teach_signal = compute_teach_signal(engine.module, logits, tokens)\n            engine.module(tokens, teach_signal=teach_signal)\n        if step % log_interval == 0 and engine.global_rank == 0:\n            ppl = torch.exp(loss.detach()).item()\n            logger.log({\"loss\": loss.item(), \"ppl\": ppl}, step=step)\n            print(f\"[DeepSpeed] step={step} loss={loss.item():.4f} ppl={ppl:.2f}\")\n        if (\n            checkpoint_cfg.get(\"enable\", False)\n            and step % checkpoint_cfg.get(\"save_interval\", 100) == 0\n            and engine.global_rank == 0\n        ):\n            ckpt_dir.mkdir(parents=True, exist_ok=True)\n            engine.save_checkpoint(str(ckpt_dir), tag=f\"step_{step:06d}\")\n\n    logger.finish()\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "train_dist.py",
    "content": "from __future__ import annotations\n\nimport os\n\nimport hydra\nimport torch\nimport torch.distributed as dist\nfrom omegaconf import DictConfig\n\nfrom nested_learning.training import DistributedContext, run_training_loop, unwrap_config\n\n\ndef setup_distributed(backend: str | None = None) -> DistributedContext:\n    if backend is None:\n        backend = \"nccl\" if torch.cuda.is_available() else \"gloo\"\n    dist.init_process_group(backend=backend)\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n    world_size = dist.get_world_size()\n    if backend == \"nccl\":\n        torch.cuda.set_device(local_rank)\n        device = torch.device(f\"cuda:{local_rank}\")\n    else:\n        device = torch.device(\"cpu\")\n    return DistributedContext(rank=dist.get_rank(), world_size=world_size, device=device)\n\n\n@hydra.main(config_path=\"configs\", config_name=\"hope/mid\", version_base=None)\ndef main(cfg: DictConfig) -> None:\n    cfg = unwrap_config(cfg)\n    dist_ctx = setup_distributed()\n    run_training_loop(cfg, device=dist_ctx.device, distributed=True, dist_ctx=dist_ctx)\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "train_fsdp.py",
    "content": "from __future__ import annotations\n\nimport os\nfrom functools import partial\nfrom pathlib import Path\n\nimport hydra\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nfrom omegaconf import DictConfig\nfrom torch.distributed.fsdp import (\n    CPUOffload,\n    FullStateDictConfig,\n    StateDictType,\n)\nfrom torch.distributed.fsdp import (\n    FullyShardedDataParallel as FSDP,\n)\nfrom torch.distributed.fsdp.wrap import size_based_auto_wrap_policy\n\nfrom nested_learning.logging_utils import NullLogger, init_logger\nfrom nested_learning.training import (\n    DistributedContext,\n    _build_optimizer,\n    _make_autocast_factory,\n    _maybe_compile_model,\n    _seed_everything,\n    build_dataloader,\n    build_model_from_cfg,\n    compute_teach_signal,\n    unwrap_config,\n    verify_checkpoint_integrity,\n    write_checkpoint_metadata,\n)\n\n\ndef setup_distributed() -> DistributedContext:\n    dist.init_process_group(backend=\"nccl\")\n    local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n    torch.cuda.set_device(local_rank)\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n    device = torch.device(f\"cuda:{local_rank}\")\n    return DistributedContext(rank=rank, world_size=world_size, device=device)\n\n\ndef build_fsdp_model(cfg: DictConfig, device: torch.device) -> tuple[FSDP, torch.nn.Module]:\n    base_model = build_model_from_cfg(cfg.model).to(device)\n    base_model = _maybe_compile_model(base_model, cfg.train.get(\"compile\"))\n    fsdp_cfg = cfg.train.get(\"fsdp\", {})\n    min_params = fsdp_cfg.get(\"auto_wrap_min_params\", 2_000_000)\n    # Avoid wrapping tied-weight modules (embed/lm_head) into separate FSDP instances.\n    # Parameter sharing across FSDP wrappers can produce shape mismatches at runtime.\n    exclude = {nn.Embedding, nn.Linear}\n    auto_wrap_policy = partial(\n        size_based_auto_wrap_policy,\n        min_num_params=min_params,\n        exclude_wrap_modules=exclude,\n    )\n    cpu_offload = CPUOffload(offload_params=fsdp_cfg.get(\"cpu_offload\", False))\n    model = FSDP(\n        base_model,\n        device_id=device.index,\n        auto_wrap_policy=auto_wrap_policy,\n        cpu_offload=cpu_offload,\n        use_orig_params=True,  # Required for custom inner optimizers / in-place updates\n    )\n    return model, base_model\n\n\ndef unwrap_model(module: torch.nn.Module) -> torch.nn.Module:\n    if hasattr(module, \"_fsdp_wrapped_module\"):\n        return module._fsdp_wrapped_module  # type: ignore[attr-defined]\n    if hasattr(module, \"module\"):\n        return module.module  # type: ignore[attr-defined]\n    return module\n\n\ndef save_checkpoint(\n    cfg: DictConfig,\n    model: FSDP,\n    optimizer: torch.optim.Optimizer,\n    step: int,\n    rank: int,\n    step_offset: int = 0,\n) -> None:\n    ckpt_cfg = cfg.train.get(\"checkpoint\")\n    if not ckpt_cfg or not ckpt_cfg.get(\"enable\", False):\n        return\n    save_interval = ckpt_cfg.get(\"save_interval\", 1000)\n    save_last = ckpt_cfg.get(\"save_last\", True)\n    total_steps = cfg.train.get(\"steps\", step + 1)\n    next_step = step + 1\n    should_save = (next_step % save_interval == 0) or (save_last and next_step >= total_steps)\n    if not should_save:\n        return\n    ckpt_dir = Path(ckpt_cfg.get(\"dir\", \"checkpoints/fsdp\"))\n    global_step = next_step + int(step_offset)\n    ckpt_path = ckpt_dir / f\"step_{global_step:06d}.pt\"\n    if rank == 0:\n        ckpt_dir.mkdir(parents=True, exist_ok=True)\n    dist.barrier()\n    full_state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)\n    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_cfg):\n        model_state = model.state_dict()\n    if rank != 0:\n        return\n    state = {\"model\": model_state, \"optimizer\": optimizer.state_dict(), \"step\": global_step}\n    torch.save(state, ckpt_path)\n    write_checkpoint_metadata(cfg, ckpt_path, global_step)\n\n\ndef maybe_resume(cfg: DictConfig, model: FSDP, optimizer: torch.optim.Optimizer, rank: int) -> int:\n    ckpt_cfg = cfg.train.get(\"checkpoint\")\n    if not ckpt_cfg:\n        return 0\n    resume_path = ckpt_cfg.get(\"resume_path\")\n    if not resume_path:\n        return 0\n    if not Path(resume_path).exists():\n        raise FileNotFoundError(f\"Resume checkpoint {resume_path} not found\")\n    map_location = \"cpu\"\n    verify_checkpoint_integrity(Path(resume_path))\n    state = torch.load(resume_path, map_location=map_location)\n    full_state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False)\n    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_cfg):\n        model.load_state_dict(state[\"model\"])\n    optimizer.load_state_dict(state[\"optimizer\"])\n    if rank == 0:\n        print(f\"[FSDP] Resumed from {resume_path} at step {state.get('step', 0)}\")\n    return state.get(\"step\", 0)\n\n\n@hydra.main(config_path=\"configs\", config_name=\"hope/mid\", version_base=None)\ndef main(cfg: DictConfig) -> None:\n    cfg = unwrap_config(cfg)\n    dist_ctx = setup_distributed()\n    train_seed = cfg.train.get(\"seed\")\n    deterministic = cfg.train.get(\"deterministic\", False)\n    if train_seed is not None:\n        _seed_everything(int(train_seed), deterministic=bool(deterministic))\n    model, _ = build_fsdp_model(cfg, dist_ctx.device)\n    train_seed = cfg.train.get(\"seed\")\n    loader_seed = None if train_seed is None else int(train_seed) + dist_ctx.rank\n    dataloader, sampler = build_dataloader(\n        cfg.data,\n        distributed=True,\n        dist_ctx=dist_ctx,\n        seed=loader_seed,\n    )\n    optimizer = _build_optimizer(model, cfg, device=dist_ctx.device)\n    start_step = maybe_resume(cfg, model, optimizer, dist_ctx.rank)\n    logger = init_logger(getattr(cfg, \"logging\", None), cfg) if dist_ctx.rank == 0 else NullLogger()\n    autocast_factory = _make_autocast_factory(dist_ctx.device, cfg.train.get(\"mixed_precision\"))\n\n    steps = cfg.train.steps\n    log_interval = cfg.train.get(\"log_interval\", 10)\n    step_iter = iter(dataloader)\n    epoch = 0\n    for step in range(start_step, steps):\n        if sampler is not None and step % len(dataloader) == 0:\n            sampler.set_epoch(epoch)\n            epoch += 1\n        try:\n            batch = next(step_iter)\n        except StopIteration:\n            step_iter = iter(dataloader)\n            batch = next(step_iter)\n        tokens = batch.to(dist_ctx.device)\n        with autocast_factory():\n            logits = model(tokens)\n            loss = torch.nn.functional.cross_entropy(\n                logits[:, :-1].reshape(-1, logits.size(-1)), tokens[:, 1:].reshape(-1)\n            )\n        optimizer.zero_grad()\n        loss.backward()\n        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n        optimizer.step()\n        with torch.no_grad():\n            # FSDP shards parameters at rest; compute_teach_signal needs access to the full\n            # (tied) LM head weight matrix. Summon the root parameters (embed/lm_head) to\n            # materialize the full 2D weight before computing the teach signal.\n            with FSDP.summon_full_params(model, recurse=False):\n                inner_full = unwrap_model(model)\n                teach_signal = compute_teach_signal(inner_full, logits, tokens)\n            _ = model(tokens, teach_signal=teach_signal)\n        if step % log_interval == 0 and dist_ctx.rank == 0:\n            ppl = torch.exp(loss.detach()).item()\n            logger.log({\"loss\": loss.item(), \"ppl\": ppl}, step=step)\n            print(f\"[fsdp] step={step} loss={loss.item():.4f} ppl={ppl:.2f}\")\n        save_checkpoint(cfg, model, optimizer, step, dist_ctx.rank)\n\n    logger.finish()\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]