[
  {
    "path": ".devcontainer/Dockerfile",
    "content": "# See here for image contents: https://github.com/devcontainers/images/blob/main/src/python/.devcontainer/Dockerfile\n\n# [Choice] Python version (use -bookworm or -bullseye variants on local arm64/Apple Silicon): 3, 3.12, 3.11, 3.10, 3.9, 3.8, 3-bookworm, 3.12-bookworm, 3.11-bookworm, 3.10-bookworm, 3.9-bookworm, 3.8-bookworm, 3-bullseye, 3.12-bullseye, 3.11-bullseye, 3.10-bullseye, 3.9-bullseye, 3.8-bullseye, 3-buster, 3.12-buster, 3.11-buster, 3.10-buster, 3.9-buster, 3.8-buster\nARG VARIANT=3-bookworm\nFROM mcr.microsoft.com/devcontainers/python:1-${VARIANT}\n\n# Temporary: Upgrade python packages due to https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-40897\n# They are installed by the base image (python) which does not have the patch.\nRUN python3 -m pip install --upgrade pip setuptools\n"
  },
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "// For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:\n// https://github.com/microsoft/vscode-dev-containers/tree/v0.194.0/containers/python-3\n{\n  \"name\": \"Python 3 (litgpt)\",\n  \"build\": {\n    \"dockerfile\": \"Dockerfile\",\n    \"context\": \"..\",\n    \"args\": {\n      \"VARIANT\": \"3.11-bookworm\"\n    }\n  },\n  \"runArgs\": [\n    // Enable GPU passthrough, requires WSL2 on Windows\n    //\"--gpus=all\",\n    // One of the following options is required for torch multiprocessing\n    //\"--ipc=host\",\n    //\"--shm-size=4gb\",\n  ],\n  // Features to add to the dev container. More info: https://containers.dev/features.\n  \"features\": {\n    \"ghcr.io/devcontainers/features/git:1\": {},\n    \"ghcr.io/devcontainers/features/git-lfs:1\": {},\n    //\"ghcr.io/devcontainers/features/nvidia-cuda:1\": {},\n    \"ghcr.io/devcontainers-extra/features/actionlint:1\": {},\n    \"ghcr.io/devcontainers-extra/features/pre-commit:2\": {},\n    \"ghcr.io/dhoeric/features/act:1\": {},\n    \"ghcr.io/devcontainers/features/docker-in-docker:2\": {\n      \"version\": \"latest\",\n      \"moby\": true\n    }\n  },\n  // Set *default* container specific settings.json values on container create.\n  \"customizations\": {\n    \"vscode\": {\n      \"settings\": {\n        \"editor.tabSize\": 4,\n        \"editor.renderWhitespace\": \"all\",\n        \"editor.formatOnSave\": true,\n        \"editor.rulers\": [120],\n        \"files.exclude\": {\n          \"**/__pycache__\": true\n        },\n        \"python.pythonPath\": \"/usr/local/bin/python\",\n        \"python.defaultInterpreterPath\": \"/usr/local/bin/python\",\n        \"python.languageServer\": \"Pylance\",\n        \"python.analysis.autoImportCompletions\": true,\n        \"python.analysis.completeFunctionParens\": true,\n        \"python.analysis.autoSearchPaths\": true,\n        \"python.testing.pytestArgs\": [\"tests\"],\n        \"python.testing.unittestEnabled\": false,\n        \"python.testing.pytestEnabled\": true,\n        \"code-eol.highlightNonDefault\": true,\n        \"code-eol.highlightExtraWhitespace\": true,\n        \"autoDocstring.docstringFormat\": \"google-notypes\",\n        \"autoDocstring.guessTypes\": true,\n        \"autoDocstring.generateDocstringOnEnter\": true,\n        \"autoDocstring.startOnNewLine\": true,\n        \"telemetry.telemetryLevel\": \"off\",\n        \"[python]\": {\n          \"editor.formatOnSave\": true,\n          \"editor.defaultFormatter\": \"charliermarsh.ruff\",\n          \"editor.codeActionsOnSave\": {\n            \"source.organizeImports\": \"always\",\n            \"source.fixAll\": \"always\"\n          }\n        }\n      },\n      // Add the IDs of extensions you want installed when the container is created.\n      \"extensions\": [\n        \"ms-python.python\",\n        \"ms-python.vscode-pylance\",\n        \"ms-toolsai.jupyter\",\n        \"GitHub.copilot\",\n        \"GitHub.copilot-chat\",\n        \"github.vscode-github-actions\",\n        \"SanjulaGanepola.github-local-actions\",\n        \"charliermarsh.ruff\",\n        \"esbenp.prettier-vscode\",\n        \"ms-vscode.test-adapter-converter\",\n        \"njqdev.vscode-python-typehint\",\n        \"KevinRose.vsc-python-indent\",\n        \"medo64.render-crlf\",\n        \"shardulm94.trailing-spaces\",\n        \"nhoizey.gremlins\",\n        \"wayou.vscode-todo-highlight\",\n        \"Gruntfuggly.todo-tree\",\n        \"njpwerner.autodocstring\",\n        \"rodolphebarbanneau.python-docstring-highlighter\",\n        \"mechatroner.rainbow-csv\",\n        \"uctakeoff.vscode-counter\",\n        \"bierner.github-markdown-preview\",\n        \"yahyabatulu.vscode-markdown-alert\",\n        \"ms-vscode-remote.vscode-remote-extensionpack\",\n        \"ms-azuretools.vscode-docker\",\n        \"redhat.vscode-yaml\"\n      ]\n    }\n  },\n  // Use 'forwardPorts' to make a list of ports inside the container available locally.\n  // \"forwardPorts\": [],\n  // Use 'postCreateCommand' to run commands after the container is created.\n  \"postCreateCommand\": \"pre-commit install && pip install '.[extra,compiler,test]' -U\",\n  // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.\n  \"remoteUser\": \"vscode\"\n}\n"
  },
  {
    "path": ".github/CODEOWNERS",
    "content": "* @lantiga @t-vi @lianakoleva @KaelanDt @k223kim @andyland\n/README.md                           @williamfalcon @lantiga @lianakoleva\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/ask-a-question.md",
    "content": "---\nname: Ask a Question\nabout: Ask and answer questions related to LitGPT\ntitle: ''\nlabels: question\n\n---\n\nPlease describe your question here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.yaml",
    "content": "name: Bug Report\ndescription: Report errors related to LitGPT\ntitle: \"Description\"\nlabels: bug\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Thank you for taking the time to report an issue. Please fill out the details below to help us resolve it.\n\n  - type: textarea\n    id: bug_description\n    attributes:\n      label: Bug description\n      description: A description of the issue.\n      placeholder: |\n        Please provide a description of what the bug or issue is.\n    validations:\n      required: true\n\n  - type: input\n    attributes:\n      label: Reproduced in studio\n      description: >\n        Create a new Lightning Studio with code that reproduces the issue and share the link.\n        Also include all the relevant files and data required to reproduce shared issue.\n        In case the code does not crash, please add assert statements to show what is the real and expected output.\n        A simple guide on how to create such a studio can be found [here](https://www.youtube.com/watch?v=YcW-2Zt_bFg&ab_channel=LightningAI).\n      placeholder: https://lightning.ai/...\n    validations:\n      required: false\n\n  - type: dropdown\n    id: operating_system\n    attributes:\n      label: What operating system are you using?\n      description: If applicable, please select the operating system where you experienced this issue.\n      options:\n        - \"Unknown\"\n        - \"macOS\"\n        - \"Linux\"\n        - \"Windows\"\n    validations:\n      required: true\n\n  - type: textarea\n    id: version\n    attributes:\n      label: LitGPT Version\n      description: |\n        Please provide details about your LitGPT version by running the following code in your terminal:\n        ```\n        pip show litgpt | grep Version:\n        ```\n    validations:\n      required: false\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.md",
    "content": "---\nname: Suggest a Feature\nabout: Propose a new feature or enhancement\ntitle: ''\nlabels: enhancement\n\n---\n\nPlease describe the feature or enhancement along with the intended usecase.\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# Basic dependabot.yml file with\n# minimum configuration for two package managers\n\nversion: 2\nupdates:\n  # Enable version updates for python\n  - package-ecosystem: \"pip\"\n    # Look for a `requirements` in the `root` directory\n    directory: \"/\"\n    # Check for updates once a week\n    schedule:\n      interval: \"monthly\"\n    # Labels on pull requests for version updates only\n    labels:\n      - \"dependencies\"\n    pull-request-branch-name:\n      # Separate sections of the branch name with a hyphen\n      # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`\n      separator: \"-\"\n    # Allow up to 5 open pull requests for pip dependencies\n    open-pull-requests-limit: 3\n\n  # Enable version updates for GitHub Actions\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    # Check for updates once a week\n    schedule:\n      interval: \"weekly\"\n    # Labels on pull requests for version updates only\n    labels:\n      - \"CI / actions\"\n    pull-request-branch-name:\n      # Separate sections of the branch name with a hyphen\n      # for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1`\n      separator: \"-\"\n    # Allow up to 5 open pull requests for GitHub Actions\n    open-pull-requests-limit: 1\n    groups:\n      GHA-updates:\n        patterns:\n          - \"*\"\n"
  },
  {
    "path": ".github/workflows/check-links.yml",
    "content": "name: Check hyperlinks\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v6\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.10\"\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install \"mistune<3.1\"  # a newer version is incompatible with nbconvert\n          pip install pytest pytest-check-links\n\n      - name: Check links\n        run: |\n          pytest --check-links README.md --check-links-ignore \"http*\"\n          pytest --check-links tutorials --check-links-ignore \"http*\"\n"
  },
  {
    "path": ".github/workflows/cpu-tests.yml",
    "content": "name: CPU tests\n\non:\n  push:\n    branches: [main]\n  pull_request_target:\n    branches: [main]\n    types: [opened, reopened, ready_for_review, labeled, synchronize]\n  pull_request: {} # todo\n  workflow_dispatch: {}\n\n# lock down all permissions by default\npermissions:\n  contents: read # needed to check out code\n  checks: write # needed for test results\n  pull-requests: read # needed for PR metadata\n  actions: read # needed to use actions\n  security-events: none\n  statuses: write # needed to update commit status\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}\n  cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}\n\ndefaults:\n  run:\n    shell: bash\n\nenv:\n  HF_HOME: .cache-HF # Define HF_HOME for caching\n  TRANSFORMERS_CACHE: .cache-HF/transformers\n  DATASETS_CACHE: .cache-HF/datasets\n  HF_DATASETS_CACHE: .cache-HF/datasets\n  TORCH_URL: \"https://download.pytorch.org/whl/cpu/\"\n\njobs:\n  testing-imports:\n    runs-on: ${{ matrix.os }}\n    if: github.event_name != 'pull_request_target'\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [\"ubuntu-22.04\", \"ubuntu-24.04\", \"macOS-14\", \"windows-2022\"]\n        python-version: [\"3.10\"]\n    timeout-minutes: 10\n    steps:\n      - name: Checkout generic\n        uses: actions/checkout@v6\n      - uses: actions/setup-python@v6\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Install minimal dependencies\n        run: |\n          pip install . -U --extra-index-url=\"${TORCH_URL}\"\n          pip list\n\n      - name: Testing package imports\n        # make sure all modules are still importable with only the minimal dependencies available\n        run: |\n          modules=$(\n            find litgpt -type f -name \"*.py\" | \\\n            sed 's/\\.py$//' | sed 's/\\//./g' | \\\n            sed 's/.__init__//g' | xargs -I {} echo \"import {};\"\n          )\n          echo \"$modules\"\n          python -c \"$modules\"\n\n  pytester:\n    # Route PRs based on contributor type to avoid duplicate runs:\n    # - Collaborators: use pull_request (tests workflow changes from PR)\n    # - External forks: use pull_request_target (uses trusted workflow from main)\n    # - Always run for push to main and workflow_dispatch\n    if: |\n      (github.event_name == 'pull_request' && contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) ||\n      (github.event_name == 'pull_request_target' && !contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) ||\n      (github.event_name != 'pull_request' && github.event_name != 'pull_request_target')\n    runs-on: ${{ matrix.os }}\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [\"ubuntu-22.04\"]\n        python-version: [\"3.10\", \"3.11\", \"3.12\", \"3.13\"]\n        requires: [\"latest\"]\n        include:\n          - { os: \"ubuntu-22.04\", python-version: \"3.10\", requires: \"oldest\" }\n          - { os: \"windows-2022\", python-version: \"3.10\", requires: \"latest\" }\n          - { os: \"macOS-14\", python-version: \"3.10\", requires: \"latest\" }\n    timeout-minutes: 35\n    steps:\n      - name: Checkout generic\n        uses: actions/checkout@v6\n        if: github.event_name != 'pull_request_target'\n      - name: Checkout for `pull_request_target`\n        uses: actions/checkout@v6\n        if: github.event_name == 'pull_request_target'\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n      - uses: actions/setup-python@v6\n        with:\n          python-version: ${{ matrix.python-version }}\n          cache-dependency-path: pyproject.toml\n          cache: \"pip\"\n\n      # Add caching for HF models and tokenizers\n      - name: HF cache\n        uses: actions/cache@v5\n        continue-on-error: true\n        with:\n          path: .cache-HF\n          key: hf-cache_${{ runner.os }}-py${{ matrix.python-version }}\n          restore-keys: |\n            hf-cache_${{ runner.os }}-py${{ matrix.python-version }}\n            hf-cache_${{ runner.os }}-\n            hf-cache_\n\n      - name: Set min. dependencies\n        if: matrix.requires == 'oldest'\n        run: |\n          pip install 'lightning-utilities[cli]>=0.15.1'\n          python -m lightning_utilities.cli requirements set-oldest --req_files=pyproject.toml\n      - name: Install dependencies\n        run: |\n          pip install '.[extra,compiler,test]' -U --upgrade-strategy eager --extra-index-url=\"${TORCH_URL}\"\n          pip list\n\n      - name: Run tests\n        env:\n          HF_TOKEN: ${{ secrets.HF_TOKEN }}\n        run: pytest -v litgpt/ tests/ --timeout=180 --durations=100\n\n      - name: Show cache\n        run: |\n          pip install -q py-tree\n          python -m py_tree -d 1 .cache-HF\n\n  testing-guardian:\n    runs-on: ubuntu-latest\n    needs: [pytester, testing-imports]\n    if: |\n      (github.event_name == 'pull_request_target' && !contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association)) ||\n      (github.event_name == 'pull_request' && contains('OWNER,MEMBER,COLLABORATOR', github.event.pull_request.author_association))\n    steps:\n      - run: echo \"${{ needs.pytester.result }}\"\n      - name: failing...\n        if: needs.pytester.result == 'failure'\n        run: exit 1\n      - name: cancelled or skipped...\n        if: contains(fromJSON('[\"cancelled\", \"skipped\"]'), needs.pytester.result)\n        timeout-minutes: 1\n        run: sleep 90\n"
  },
  {
    "path": ".github/workflows/mkdocs-deploy.yml",
    "content": "name: Deploy MkDocs\n\non:\n  push:\n    branches: [main]\n\npermissions:\n  contents: write\n\njobs:\n  deploy:\n    runs-on: ubuntu-24.04\n    steps:\n      # Step 1: Checkout the repository\n      - uses: actions/checkout@v6\n\n      # Step 2: Set up Python\n      - uses: actions/setup-python@v6\n        with:\n          python-version: \"3.x\"\n          cache: \"pip\"\n\n      # Step 3: Install MkDocs and dependencies\n      - run: pip install mkdocs mkdocs-material mkdocs-pagetree-plugin\n      # Step 4: Deploy to GitHub Pages\n      - run: |\n          mkdir -p gh-pages/docs\n          cp -r tutorials/* gh-pages/docs\n          cd gh-pages\n          mv docs/mkdocs.yml mkdocs.yml\n          echo \"{{ pagetree }}\" > docs/index.md\n          mkdocs gh-deploy --force\n"
  },
  {
    "path": ".github/workflows/publish-pkg.yml",
    "content": "# To create a release, create a tag and push it to GitHub:\n#git tag -a \"v0.0.1-beta\" -m \"beta version testing\"\n#git push --tags\n# https://dev.to/iamtekson/publish-package-to-pypi-and-release-new-version-using-github-actions-108k\nname: Publish LitGPT to PyPI\n\non:\n  push:\n    tags:\n      - \"v*\"\njobs:\n  build-n-publish:\n    name: Build and publish to PyPI\n    runs-on: ubuntu-latest\n    environment:\n      name: pypi\n      url: https://pypi.org/p/litgpt\n    permissions:\n      id-token: write\n\n    steps:\n      - name: Checkout source\n        uses: actions/checkout@v6\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.x\"\n          cache: \"pip\"\n\n      - name: Build source and wheel distributions\n        run: |\n          python -m pip install --upgrade build twine\n          pip install importlib_metadata==7.2.1\n          python -m build\n          twine check --strict dist/*\n      - name: Publish distribution to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          user: __token__\n          password: ${{ secrets.PYPI_API_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": ".ipynb_checkpoints/\n__pycache__\n.idea\n.DS_Store\n*.egg-info\nbuild\ndist\n.venv\n.venv/\n.vscode\nuv.lock\n\n# data\ndata\ndatasets\n!litgpt/data\n!tests/data\ncheckpoints\nout\nwandb\nevents.out.tfevents*\n\n# test artifacts from tests/test_readme.py\n**/custom_finetuning_dataset.json\nclient.py\n**/custom_texts/\n"
  },
  {
    "path": ".lightning/workflows/tests.yaml",
    "content": "trigger:\n  push:\n    branches: [\"main\"]\n  pull_request:\n    branches: [\"main\"]\n\nimage: \"pytorchlightning/lightning-thunder:ubuntu24.04-cuda12.8.1-cudnn-fe1.15.0-py3.12-pt_2.8.0-dev\"\nmachine: \"L4_X_2\"\ninterruptible: \"true\"\ntimeout: \"45\" # minutes\nparametrize:\n  matrix:\n    dependency: [\"\", \"compiler\"]\n  include: []\n  exclude: []\n\nenv:\n  SKIP_WITH_CI: \"1\" # skip single tests with CI\n  NCCL_DEBUG: \"INFO\"\n  CUBLAS_WORKSPACE_CONFIG: \":4096:8\"\n  NCCL_IGNORE_DISABLED_P2P: \"1\"\n  TORCH_VERSION: \"2.8.0\"\n  RUN_ONLY_CUDA_TESTS: \"1\" # run CUDA tests only\n\nrun: |\n  whereis nvidia\n  nvidia-smi\n  python --version\n  pip --version\n  pip list\n  set -ex\n\n  echo \"Install uv and create virtual environment\"\n  curl -LsSf https://astral.sh/uv/install.sh | sh\n  [ -f \"$HOME/.local/bin/env\" ] && . \"$HOME/.local/bin/env\"\n  export PATH=\"$HOME/.local/bin:$PATH\"\n  uv venv .venv --system-site-packages\n  . .venv/bin/activate\n  hash -r\n\n  uv pip install -q '.[extra,test]' \"torch==${TORCH_VERSION}\" cffi -U\n\n  if [ \"${dependency}\" == \"compiler\" ]; then\n    uv pip uninstall torchvision torchaudio\n    uv pip install -q '.[compiler,extra,test]' \"torch==${TORCH_VERSION}\"\n    python -c \"from thunder.executors import nvfuser_available ; assert nvfuser_available(), 'nvFuser is missing!'\"\n    python -c \"from thunder.executors.triton_utils import triton_version ; assert triton_version() is not None, 'triton is missing!'\"\n  fi\n\n  uv pip list\n  python -c \"import torch ; gpus = torch.cuda.device_count() ; assert gpus >= 2, f'GPU: {gpus}'\"\n  python -c \"from torch import __version__ as ver ; assert str(ver).split('+')[0] == '${TORCH_VERSION}', f'PyTorch: installed {ver} but expected ${TORCH_VERSION}'\"\n\n  pytest -v --durations=100\n\n  wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh\n  PL_RUN_STANDALONE_TESTS=1 bash run_standalone_tests.sh \"tests\"\n\n  if [ \"${dependency}\" == \"compiler\" ]; then\n    uv pip uninstall lightning-thunder transformers\n    # install thunder from source, so that, thunder.tests will be available\n    uv pip install -U \"lightning-thunder[test] @ git+https://github.com/Lightning-AI/lightning-thunder.git\" \"torch==${TORCH_VERSION}\"\n    # Pin transformers to match thunder's test_networks.py requirements\n    # See: https://github.com/Lightning-AI/lightning-thunder/blob/main/requirements/test.txt\n    # Get transformers version from thunder requirements\n    TRANSFORMERS_VERSION=$(curl -fsSL https://raw.githubusercontent.com/Lightning-AI/lightning-thunder/main/requirements/test.txt \\\n      | grep '^transformers==' \\\n      | cut -d'=' -f3 \\\n      | cut -d'#' -f1 \\\n      | xargs)\n    if [ -z \"${TRANSFORMERS_VERSION}\" ]; then\n      echo \"Error: Could not determine transformers version from lightning-thunder requirements\"\n      exit 1\n    fi\n    uv pip install transformers==${TRANSFORMERS_VERSION}\n    # without env var, it filters out all tests\n    RUN_ONLY_CUDA_TESTS=0 pytest tests/ext_thunder/test_thunder_networks.py -v\n  fi\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "# Copyright The Lightning team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\ndefault_language_version:\n  python: python3\n\nci:\n  autofix_prs: true\n  autoupdate_commit_msg: \"[pre-commit.ci] pre-commit suggestions\"\n  autoupdate_schedule: quarterly\n  # submodules: true\n\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v6.0.0\n    hooks:\n      - id: end-of-file-fixer\n      - id: trailing-whitespace\n        exclude: README.md\n      - id: check-yaml\n      - id: check-toml\n      #- id: check-docstring-first\n      #- id: check-executables-have-shebangs\n      - id: check-case-conflict\n      - id: check-added-large-files\n        args: [\"--maxkb=250\", \"--enforce-all\"]\n      - id: detect-private-key\n\n  - repo: https://github.com/codespell-project/codespell\n    rev: v2.4.1\n    hooks:\n      - id: codespell\n        additional_dependencies: [tomli]\n        args: [\"--write-changes\"]\n        exclude: pyproject.toml\n\n  #- repo: https://github.com/crate-ci/typos\n  #  rev: dictgen-v0.3.1\n  #  hooks:\n  #    - id: typos\n  #      args: [] # empty to do not write fixes\n  #      exclude: pyproject.toml\n\n  #- repo: https://github.com/executablebooks/mdformat\n  #  rev: 0.7.21\n  #  hooks:\n  #    - id: mdformat\n  #      args: [\"--number\"]\n  #      additional_dependencies:\n  #        - mdformat-gfm\n  #        - mdformat-black\n  #        - mdformat_frontmatter\n\n  - repo: https://github.com/pre-commit/mirrors-prettier\n    rev: v3.1.0\n    hooks:\n      - id: prettier\n        files: \\.(json|yml|yaml|toml)\n        # https://prettier.io/docs/en/options.html#print-width\n        args: [\"--print-width=140\"]\n\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.14.10\n    hooks:\n      - id: ruff\n        args: [\"--fix\"]\n      - id: ruff-format\n      - id: ruff\n\n  - repo: https://github.com/tox-dev/pyproject-fmt\n    rev: v2.11.1\n    hooks:\n      - id: pyproject-fmt\n        additional_dependencies: [tox]\n  - repo: https://github.com/abravalheri/validate-pyproject\n    rev: v0.24.1\n    hooks:\n      - id: validate-pyproject\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\nmessage: \"If you use this software, you can cite it as shown below.\"\ntitle: \"LitGPT\"\nabstract: \"20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.\"\ndate-released: 2023-03-22\nauthors:\n  - name: \"The Lightning AI team\"\nlicense: \"Apache-2.0\"\nurl: \"https://github.com/Lightning-AI/litgpt\"\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [2023] Lightning AI\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n\n# ⚡ LitGPT\n\n**20+ high-performance LLMs with recipes to pretrain, finetune, and deploy at scale.**\n\n<pre>\n✅ From scratch implementations      ✅ No abstractions         ✅ Beginner friendly\n   ✅ Flash attention                   ✅ FSDP                    ✅ LoRA, QLoRA, Adapter\n✅ Reduce GPU memory (fp4/8/16/32)   ✅ 1-1000+ GPUs/TPUs       ✅ 20+ LLMs         \n</pre>\n\n\n---\n\n\n![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pytorch-lightning)\n![cpu-tests](https://github.com/lightning-AI/lit-stablelm/actions/workflows/cpu-tests.yml/badge.svg) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lit-stablelm/blob/master/LICENSE) [![Discord](https://img.shields.io/discord/1077906959069626439)](https://discord.gg/VptPCZkGNa)\n\n<p align=\"center\">\n  <a href=\"#quick-start\">Quick start</a> •\n  <a href=\"#choose-from-20-llms\">Models</a> •\n  <a href=\"#finetune-an-llm\">Finetune</a> •\n  <a href=\"#deploy-an-llm\">Deploy</a> •\n  <a href=\"#all-workflows\">All workflows</a> •\n  <a href=\"#state-of-the-art-features\">Features</a> •\n  <a href=\"#training-recipes\">Recipes (YAML)</a> •\n  <a href=\"https://lightning.ai/\">Lightning AI</a> •\n    <a href=\"#tutorials\">Tutorials</a>\n</p>\n\n&nbsp;\n\n<a target=\"_blank\" href=\"https://lightning.ai/lightning-ai/studios/litgpt-quick-start\">\n  <img src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/get-started-badge.svg\" height=\"36px\" alt=\"Get started\"/>\n</a>\n\n&nbsp;\n\n</div>\n\n# Looking for GPUs?\nOver 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme) - purpose-built for PyTorch and PyTorch Lightning. \n- [GPUs](https://lightning.ai/pricing?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme) from $0.19.   \n- [Clusters](https://lightning.ai/clusters?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): frontier-grade training/inference clusters.   \n- [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): workspaces where AI helps you debug, tune and vibe train.\n- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): workspaces where AI helps you optimize, and deploy models.     \n- [Notebooks](https://lightning.ai/notebooks?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): Persistent GPU workspaces where AI helps you code and analyze.\n- [Inference](https://lightning.ai/deploy?utm_source=litgpt_readme&utm_medium=referral&utm_campaign=litgpt_readme): Deploy models as inference APIs.\n\n# Finetune, pretrain, and inference LLMs Lightning fast ⚡⚡\nEvery LLM is implemented from scratch with **no abstractions** and **full control**, making them blazing fast, minimal, and performant at enterprise scale.\n\n✅ **Enterprise ready -** Apache 2.0 for unlimited enterprise use.</br>\n✅ **Developer friendly -** Easy debugging with no abstraction layers and single file implementations.</br>\n✅ **Optimized performance -** Models designed to maximize performance, reduce costs, and speed up training.</br>\n✅ **Proven recipes -** Highly-optimized training/finetuning recipes tested at enterprise scale.</br>\n\n&nbsp;\n\n# Quick start\nInstall LitGPT\n```\npip install 'litgpt[extra]'\n```\n\nLoad and use any of the [20+ LLMs](#choose-from-20-llms):\n```python\nfrom litgpt import LLM\n\nllm = LLM.load(\"microsoft/phi-2\")\ntext = llm.generate(\"Fix the spelling: Every fall, the family goes to the mountains.\")\nprint(text)\n# Corrected Sentence: Every fall, the family goes to the mountains.\n```\n\n&nbsp;\n\n✅ Optimized for fast inference</br>\n✅ Quantization</br>\n✅ Runs on low-memory GPUs</br>\n✅ No layers of internal abstractions</br>\n✅ Optimized for production scale</br>\n\n<details>\n  <summary>Advanced install options</summary>\n\nInstall from source:\n\n```bash\ngit clone https://github.com/Lightning-AI/litgpt\ncd litgpt\n# if using uv\nuv sync --all-extras\n# if using pip\npip install -e \".[extra,compiler,test]\"\n```\n</details>\n\n[Explore the full Python API docs](tutorials/python-api.md).\n\n&nbsp;\n\n---\n# Choose from 20+ LLMs\nEvery model is written from scratch to maximize performance and remove layers of abstraction:\n\n| Model | Model size | Author | Reference |\n|----|----|----|----|\n| Llama 3, 3.1, 3.2, 3.3 | 1B, 3B, 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3)                                           |\n| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950)                                       |\n| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma)                                     |\n| Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf)  |\n| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905)                                                                            |\n| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/)                                               |\n| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186)                                          |\n| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf)                                                                                 |\n| ... | ... | ... | ...   |\n\n<details>\n  <summary>See full list of 20+ LLMs</summary>\n\n&nbsp;\n\n#### All models\n\n| Model | Model size | Author | Reference |\n|----|----|----|----|\n| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma)                                                                 |\n| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950)                                                                   |\n| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae)                                                                                              |\n| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3)                                                                                              |\n| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models)                 |\n| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2)                                  |\n| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf)                                       |\n| Gemma 2 | 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf)                                  |\n| Gemma 3 | 1B, 4B, 12B, 27B | Google | [Google Team, Google Deepmind](https://arxiv.org/pdf/2503.19786)                                  |\n| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288)                                                                           |\n| Llama 3.1 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3)                                                                                 |\n| Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/)                                           |\n| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)                                                                                 |\n| Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/)                                                                                  |\n| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)                                                                             |\n| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/)                                                                     |\n| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/)                                                                  |\n| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/)                                                                         |\n| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/)    |\n| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama)                                                         |\n| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research  | [Li et al. 2023](https://arxiv.org/abs/2309.05463)                                                                  |\n| Phi 3 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219)                                                                            |\n| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905)                                                                            |\n| Phi 4 Mini Instruct | 3.8B | Microsoft Research | [Microsoft 2025](https://arxiv.org/abs/2503.01743)                                           |\n| Phi 4 Mini Reasoning | 3.8B | Microsoft Research | [Xu, Peng et al. 2025](https://arxiv.org/abs/2504.21233)                                           |\n| Phi 4 Reasoning | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318)                                           |\n| Phi 4 Reasoning Plus | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318)                                           |\n| Platypus | 7B, 13B, 70B |  Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317)                                                               |\n| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373)                                            |\n| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/)                                               |\n| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186)                                          |\n| Qwen2.5 1M (Long Context) | 7B, 14B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwen2.5-1m/)                                          |\n| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122)                                          |\n| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/)                                                                         |\n| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/)                                                                         |\n| Qwen3 | 0.6B, 1.7B, 4B{Hybrid, Thinking-2507, Instruct-2507}, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/)                                                                         |\n| Qwen3 MoE | 30B{Hybrid, Thinking-2507, Instruct-2507}, 235B{Hybrid, Thinking-2507, Instruct-2507} | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/)                                                                         |\n| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf)                                                                                 |\n| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm)                                                               |\n| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra)                                                                         |\n| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding)                                                  |\n| StableLM  | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM)                                                                    |\n| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding)                                             |\n| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama)                                                                         |\n\n\n**Tip**: You can list all available models by running the `litgpt download list` command.\n\n\n</details>\n\n&nbsp;\n\n---\n\n# Workflows\n\n<p align=\"center\">\n  <a href=\"#finetune-an-llm\">Finetune</a> •\n  <a href=\"#pretrain-an-llm\">Pretrain</a> •\n  <a href=\"#continue-pretraining-an-llm\">Continued pretraining</a> •\n    <a href=\"#evaluate-an-llm\">Evaluate</a> •\n    <a href=\"#deploy-an-llm\">Deploy</a> •\n    <a href=\"#test-an-llm\">Test</a>\n</p>\n\n&nbsp;\n\nUse the command line interface to run advanced workflows such as pretraining or finetuning on your own data.\n\n\n## All workflows\nAfter installing LitGPT, select the model and workflow to run (finetune, pretrain, evaluate, deploy, etc...):\n\n```bash\n# litgpt [action] [model]\nlitgpt  serve     meta-llama/Llama-3.2-3B-Instruct\nlitgpt  finetune  meta-llama/Llama-3.2-3B-Instruct\nlitgpt  pretrain  meta-llama/Llama-3.2-3B-Instruct\nlitgpt  chat      meta-llama/Llama-3.2-3B-Instruct\nlitgpt  evaluate  meta-llama/Llama-3.2-3B-Instruct\n```\n\n&nbsp;\n\n----\n\n## Finetune an LLM\n\n<div align=\"center\">\n<a target=\"_blank\" href=\"https://lightning.ai/lightning-ai/studios/litgpt-finetune\">\n  <img src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/run-on-studio.svg\" height=\"36px\" alt=\"Run on Studios\"/>\n</a>\n</div>\n\n&nbsp;\n\nFinetuning is the process of taking a pretrained AI model and further training it on a smaller, specialized dataset tailored to a specific task or application.\n\n\n&nbsp;\n\n```bash\n# 0) setup your dataset\ncurl -L https://huggingface.co/datasets/ksaw008/finance_alpaca/resolve/main/finance_alpaca.json -o my_custom_dataset.json\n\n# 1) Finetune a model (auto downloads weights)\nlitgpt finetune microsoft/phi-2 \\\n  --data JSON \\\n  --data.json_path my_custom_dataset.json \\\n  --data.val_split_fraction 0.1 \\\n  --out_dir out/custom-model\n\n# 2) Test the model\nlitgpt chat out/custom-model/final\n\n# 3) Deploy the model\nlitgpt serve out/custom-model/final\n```\n\n[Read the full finetuning docs](tutorials/finetune.md)\n\n&nbsp;\n\n----\n\n## Deploy an LLM\n\n<div align=\"center\">\n<a target=\"_blank\" href=\"https://lightning.ai/lightning-ai/studios/litgpt-serve\">\n  <img src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/deploy-on-studios.svg\" height=\"36px\" alt=\"Deploy on Studios\"/>\n</a>\n</div>\n\n&nbsp;\n\nDeploy a pretrained or finetune LLM to use it in real-world applications. Deploy, automatically sets up a web server that can be accessed by a website or app.\n\n```bash\n# deploy an out-of-the-box LLM\nlitgpt serve microsoft/phi-2\n\n# deploy your own trained model\nlitgpt serve path/to/microsoft/phi-2/checkpoint\n```\n\n<details>\n  <summary>Show code to query server:</summary>\n\n&nbsp;\n\nTest the server in a separate terminal and integrate the model API into your AI product:\n```python\n# 3) Use the server (in a separate Python session)\nimport requests, json\nresponse = requests.post(\n    \"http://127.0.0.1:8000/predict\",\n    json={\"prompt\": \"Fix typos in the following sentence: Example input\"}\n)\nprint(response.json()[\"output\"])\n```\n</details>\n\n[Read the full deploy docs](tutorials/deploy.md).\n\n&nbsp;\n\n----\n\n## Evaluate an LLM\nEvaluate an LLM to test its performance on various tasks to see how well it understands and generates text. Simply put, we can evaluate things like how well would it do in college-level chemistry, coding, etc... (MMLU, Truthful QA, etc...)\n\n```bash\nlitgpt evaluate microsoft/phi-2 --tasks 'truthfulqa_mc2,mmlu'\n```\n\n[Read the full evaluation docs](tutorials/evaluation.md).\n\n&nbsp;\n\n----\n\n##  Test an LLM\n\n<div align=\"center\">\n<a target=\"_blank\" href=\"https://lightning.ai/lightning-ai/studios/litgpt-chat\">\n  <img src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/run-on-studio.svg\" height=\"36px\" alt=\"Run on Studios\"/>\n</a>\n</div>\n\n&nbsp;\n\nTest how well the model works via an interactive chat. Use the `chat` command to chat, extract embeddings, etc...\n\nHere's an example showing how to use the Phi-2 LLM:\n```bash\nlitgpt chat microsoft/phi-2\n\n>> Prompt: What do Llamas eat?\n```\n\n<details>\n  <summary>Full code:</summary>\n\n&nbsp;\n\n```bash\n# 1) List all supported LLMs\nlitgpt download list\n\n# 2) Use a model (auto downloads weights)\nlitgpt chat microsoft/phi-2\n\n>> Prompt: What do Llamas eat?\n```\n\nThe download of certain models requires an additional access token. You can read more about this in the [download](tutorials/download_model_weights.md#specific-models-and-access-tokens) documentation.\n\n</details>\n\n[Read the full chat docs](tutorials/inference.md).\n\n&nbsp;\n\n----\n\n## Pretrain an LLM\n\n<div align=\"center\">\n<a target=\"_blank\" href=\"https://lightning.ai/lightning-ai/studios/litgpt-pretrain\">\n  <img src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/run-on-studio.svg\" height=\"36px\" alt=\"Run on Studios\"/>\n</a>\n</div>\n\n&nbsp;\n\nPretraining is the process of teaching an AI model by exposing it to a large amount of data before it is fine-tuned for specific tasks.\n\n<details>\n  <summary>Show code:</summary>\n\n&nbsp;\n\n```bash\nmkdir -p custom_texts\ncurl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt\ncurl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt\n\n# 1) Download a tokenizer\nlitgpt download EleutherAI/pythia-160m \\\n  --tokenizer_only True\n\n# 2) Pretrain the model\nlitgpt pretrain EleutherAI/pythia-160m \\\n  --tokenizer_dir EleutherAI/pythia-160m \\\n  --data TextFiles \\\n  --data.train_data_path \"custom_texts/\" \\\n  --train.max_tokens 10_000_000 \\\n  --out_dir out/custom-model\n\n# 3) Test the model\nlitgpt chat out/custom-model/final\n```\n</details>\n\n[Read the full pretraining docs](tutorials/pretrain.md)\n\n&nbsp;\n\n----\n\n## Continue pretraining an LLM\n\n<div align=\"center\">\n<a target=\"_blank\" href=\"https://lightning.ai/lightning-ai/studios/litgpt-continue-pretraining\">\n  <img src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/run-on-studio.svg\" height=\"36px\" alt=\"Run on Studios\"/>\n</a>\n</div>\n\n&nbsp;\n\nContinued pretraining is another way of finetuning that specializes an already pretrained model by training on custom data:\n\n<details>\n  <summary>Show code:</summary>\n\n&nbsp;\n\n```bash\nmkdir -p custom_texts\ncurl https://www.gutenberg.org/cache/epub/24440/pg24440.txt --output custom_texts/book1.txt\ncurl https://www.gutenberg.org/cache/epub/26393/pg26393.txt --output custom_texts/book2.txt\n\n# 1) Continue pretraining a model (auto downloads weights)\nlitgpt pretrain EleutherAI/pythia-160m \\\n  --tokenizer_dir EleutherAI/pythia-160m \\\n  --initial_checkpoint_dir EleutherAI/pythia-160m \\\n  --data TextFiles \\\n  --data.train_data_path \"custom_texts/\" \\\n  --train.max_tokens 10_000_000 \\\n  --out_dir out/custom-model\n\n# 2) Test the model\nlitgpt chat out/custom-model/final\n```\n\n</details>\n\n[Read the full continued pretraining docs](tutorials/pretrain.md#continued-pretraining-on-custom-data)\n\n&nbsp;\n\n----\n\n# State-of-the-art features\n\n✅ State-of-the-art optimizations: Flash Attention v2, multi-GPU support via fully-sharded data parallelism, [optional CPU offloading](tutorials/oom.md#do-sharding-across-multiple-gpus), and [TPU and XLA support](extensions/xla).</br>\n✅ [Pretrain](tutorials/pretrain.md), [finetune](tutorials/finetune.md), and [deploy](tutorials/inference.md)</br>\n✅ Reduce compute requirements with low-precision settings: FP16, BF16, and FP16/FP32 mixed.</br>\n✅ Lower memory requirements with [quantization](tutorials/quantize.md): 4-bit floats, 8-bit integers, and double quantization.</br>\n✅ [Configuration files](config_hub) for great out-of-the-box performance.</br>\n✅ Parameter-efficient finetuning: [LoRA](tutorials/finetune_lora.md), [QLoRA](tutorials/finetune_lora.md), [Adapter](tutorials/finetune_adapter.md), and [Adapter v2](tutorials/finetune_adapter.md).</br>\n✅ [Exporting](tutorials/convert_lit_models.md) to other popular model weight formats.</br>\n✅ Many popular datasets for [pretraining](tutorials/pretrain.md) and [finetuning](tutorials/prepare_dataset.md), and [support for custom datasets](tutorials/prepare_dataset.md#preparing-custom-datasets-for-instruction-finetuning).</br>\n✅ Readable and easy-to-modify code to experiment with the latest research ideas.</br>\n\n&nbsp;\n\n---\n\n# Training recipes\n\nLitGPT comes with validated recipes (YAML configs) to train models under different conditions.  We've generated these recipes based on the parameters we found to perform the best for different training conditions.\n\nBrowse all training recipes [here](config_hub).\n\n### Example\n\n```bash\nlitgpt finetune \\\n  --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml\n```\n<details>\n  <summary>✅ Use configs to customize training</summary>\n\nConfigs let you customize training for all granular parameters like:\n\n```yaml\n# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-llama2-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n...\n```\n</details>\n\n<details>\n  <summary>✅ Example: LoRA finetuning config</summary>\n\n&nbsp;\n\n```yaml\n# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-llama2-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: float, default: 0.0003)\n  learning_rate: 0.0002\n\n  #   (type: float, default: 0.02)\n  weight_decay: 0.0\n\n  #   (type: float, default: 0.9)\n  beta1: 0.9\n\n  #   (type: float, default: 0.95)\n  beta2: 0.95\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n```\n</details>\n\n<details>\n  <summary>✅ Override any parameter in the CLI:</summary>\n\n```bash\nlitgpt finetune \\\n  --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml \\\n  --lora_r 4\n```\n</details>\n\n&nbsp;\n\n----\n\n# Project highlights\n\nLitGPT powers many great AI projects, initiatives, challenges and of course enterprises. Please submit a pull request to be considered for a feature.\n\n<details>\n  <summary>📊 SAMBA: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling</summary>\n\nThe [Samba](https://github.com/microsoft/Samba) project by researchers at Microsoft is built on top of the LitGPT code base and combines state space models with sliding window attention, which outperforms pure state space models.\n\n</details>\n\n<details>\n  <summary>🏆 NeurIPS 2023 Large Language Model Efficiency Challenge: 1 LLM + 1 GPU + 1 Day</summary>\n\nThe LitGPT repository was the official starter kit for the [NeurIPS 2023 LLM Efficiency Challenge](https://llm-efficiency-challenge.github.io), which is a competition focused on finetuning an existing non-instruction tuned LLM for 24 hours on a single GPU.\n\n</details>\n\n<details>\n  <summary>🦙 TinyLlama: An Open-Source Small Language Model</summary>\n\n\nLitGPT powered the [TinyLlama project](https://github.com/jzhang38/TinyLlama) and [TinyLlama: An Open-Source Small Language Model](https://arxiv.org/abs/2401.02385) research paper.\n\n</details>\n\n<details>\n  <summary>🍪 MicroLlama: MicroLlama-300M</summary>\n\n[MicroLlama](https://github.com/keeeeenw/MicroLlama) is a 300M Llama model pretrained on 50B tokens powered by TinyLlama and LitGPT.\n</details>\n\n<details>\n  <summary>🔬 Pre-training Small Base LMs with Fewer Tokens</summary>\n\nThe research paper [\"Pre-training Small Base LMs with Fewer Tokens\"](https://arxiv.org/abs/2404.08634), which utilizes LitGPT, develops smaller base language models by inheriting a few transformer blocks from larger models and training on a tiny fraction of the data used by the larger models. It demonstrates that these smaller models can perform comparably to larger models despite using significantly less training data and resources.\n\n</details>\n\n&nbsp;\n\n----\n\n# Community\n\nWe welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.\n\n- [Request a feature](https://github.com/Lightning-AI/litgpt/issues)\n- [Submit your first contribution](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/)\n- [Join our Discord](https://discord.gg/VptPCZkGNa)\n\n&nbsp;\n\n# Tutorials\n\n🚀 [Get started](tutorials/0_to_litgpt.md)</br>\n⚡️ [Finetuning, incl. LoRA, QLoRA, and Adapters](tutorials/finetune.md)</br>\n🤖 [Pretraining](tutorials/pretrain.md)</br>\n💬 [Model evaluation](tutorials/evaluation.md)</br>\n📘 [Supported and custom datasets](tutorials/prepare_dataset.md)</br>\n🧹 [Quantization](tutorials/quantize.md)</br>\n🤯 [Tips for dealing with out-of-memory (OOM) errors](tutorials/oom.md)</br>\n🧑🏽‍💻 [Using cloud TPUs](extensions/xla)</br>\n\n&nbsp;\n\n----\n\n### Acknowledgments\n\nThis implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/) ⚡**.\n\n- [@karpathy](https://github.com/karpathy) for [nanoGPT](https://github.com/karpathy/nanoGPT)\n- [@EleutherAI](https://github.com/EleutherAI) for [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) and the [Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)\n- [@TimDettmers](https://github.com/TimDettmers) for [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)\n- [@Microsoft](https://github.com/microsoft) for [LoRA](https://github.com/microsoft/LoRA)\n- [@tridao](https://github.com/tridao) for [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)\n\n### License\n\nLitGPT is released under the [Apache 2.0](https://github.com/Lightning-AI/litgpt/blob/main/LICENSE) license.\n\n### Citation\n\nIf you use LitGPT in your research, please cite the following work:\n\n```bibtex\n@misc{litgpt-2023,\n  author       = {Lightning AI},\n  title        = {LitGPT},\n  howpublished = {\\url{https://github.com/Lightning-AI/litgpt}},\n  year         = {2023},\n}\n```\n\n&nbsp;\n"
  },
  {
    "path": "config_hub/finetune/README.md",
    "content": "## Config files\n\nThe table below lists the performances you can expect from the provided config files. Note that you can achieve lower memory consumption by lowering the micro batch size as needed. In addition, you can lower the rank (`lora_r`) in the LoRA configuration files and disable LoRA for certain layers (for example, setting `lora_projection` and other LoRA layer-specific parameters to `false`).\nFor more information, see the [Dealing with out-of-memory (OOM) errors](../../tutorials/oom.md) on lowering the memory requirements.\nThe \"Cost\" column refers to the on-demand compute cost on [Lightning AI Studios where these benchmarks were executed](https://lightning.ai/lightning-ai/studios/automated-benchmarks-for-litgpt).\nAll experiments were conducted using bfloat-16 precision on the Alpaca2k dataset. The \"Multitask score\" refers to [MMLU](https://arxiv.org/abs/2009.03300).\n\n&nbsp;\n\n| Config                            | Model                  | Epochs | Max seq length | Micro batch size | Machine | Training runtime | Cost | Peak memory | Validation loss | Validation perplexity | Multitask score (MMLU) |\n| --------------------------------- | ---------------------- | ------ | -------------- | ---------------- | ------- | ---------------- | ---- | ----------- | --------------- | --------------------- | --------------- |\n| falcon-7b/lora.yaml               | falcon-7b              | 4      | 512            | 1                | 1xA10G  | 24.84 min        | $0.7 | 16.69 GB    | 0.945           | 2.573                 | 26.2%           |\n| falcon-7b/lora.yaml               | falcon-7b              | 4      | 512            | 1                | 4xA10G  | 24.94 min        | $2.0 | 16.69 GB    | 0.945           | 2.573                 | 26.4%           |\n| falcon-7b/qlora.yaml              | falcon-7b              | 4      | 512            | 1                | 1xA10G  | 50.85 min        | $1.5 | 9.44 GB     | 0.993           | 2.699                 | 26.3%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| gemma-2b/full.yaml                | gemma-2b               | 1      | 512            | 1                | 4xA10G  | 14.06 min        | $1.1 | 17.43 GB    | 1.021           | 2.777                 | 32.4%           |\n| gemma-2b/lora.yaml                | gemma-2b               | 2      | 512            | 2                | 1xA10G  | 9.41 min         | $0.3 | 12.62 GB    | 0.981           | 2.666                 | 34.4%           |\n| gemma-2b/lora.yaml                | gemma-2b               | 2      | 512            | 2                | 4xA10G  | 9.41 min         | $0.8 | 12.62 GB    | 0.981           | 2.667                 | 34.0%           |\n| gemma-2b/qlora.yaml               | gemma-2b               | 2      | 512            | 2                | 1xA10G  | 12.91 min        | $0.4 | 11.58 GB    | 1.085           | 2.959                 | 36.4%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| gemma-7b/lora.yaml                | gemma-7b               | 2      | 512            | 1                | 1xA10G  | OOM              | OOM  | OOM         | OOM             | OOM                   |                 |\n| gemma-7b/lora.yaml                | gemma-7b               | 2      | 512            | 1                | 4xA10G  | OOM              | OOM  | OOM         | OOM             | OOM                   |                 |\n| gemma-7b/qlora.yaml               | gemma-7b               | 2      | 512            | 1                | 1xA10G  | 43.58 min        | $1.3 | 17.18 GB    | 0.973           | 2.646                 | 62.45%          |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| gemma2-2b/lora.yaml               | gemma-2b               | 2      | 512            | 2                | 1xA10G  | 11.96 min        | $0.4 | 14.31 GB    | 0.951           | 2.589                 | 23.84%          |\n| gemma2b/qlora.yaml                | gemma-2b               | 2      | 512            | 2                | 1xA10G  | 16.06 min        | $0.5 | 13.52 GB    | 0.983           | 2.673                 | 24.12%          |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| gemma2-9b/lora.yaml               | gemma-2-9b             | 2      | 512            | 1                | 1xA10G  | OOM              | OOM  | OOM         | OOM             | OOM                   |                 |\n| gemma2-9b/lora.yaml               | gemma-2-9b             | 2      | 512            | 1                | 4xA10G  | OOM              | OOM  | OOM         | OOM             | OOM                   |                 |\n| gemma2-9b/qlora.yaml              | gemma-2-9b             | 2      | 512            | 1                | 1xA10G  | 50.01 min        | $4.0 | 20.92 GB    | 0.852           | 2.345                 | 24.2%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| llama-2-7b/full.yaml              | llama-2-7b             | 1      | 512            | 4                | 4xA10G  | OOM              | OOM  | OOM         | OOM             | OOM                   |                 |\n| llama-2-7b/lora.yaml              | llama-2-7b             | 4      | 512            | 2                | 1xA10G  | 32.82 min        | $1.0 | 19.77 GB    | 0.802           | 2.230                 | 40.3%           |\n| llama-2-7b/lora.yaml              | llama-2-7b             | 4      | 512            | 2                | 4xA10G  | 32.83 min        | $2.6 | 19.77 GB    | 0.802           | 2.229                 | 40.2%           |\n| llama-2-7b/qlora.yaml             | llama-2-7b             | 4      | 512            | 2                | 1xA10G  | 45.67 min        | $1.4 | 13.68 GB    | 0.814           | 2.258                 | 38.6%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| llama-3-8b/full.yaml              | llama-3-8b             | 1      | 512            | 4                | 4xA10G  | OOM              | OOM  | OOM         | OOM             | OOM                   |                 |\n| llama-3-8b/lora.yaml              | llama-3-8b             | 2      | 512            | 1                | 1xA10G  | 14.79 min        | $0.4 | 19.73 GB    | 0.888           | 2.431                 | 62.4%           |\n| llama-3-8b/lora.yaml              | llama-3-8b             | 2      | 512            | 1                | 4xA10G  | 14.88 min        | $1.2 | 19.73 GB    | 0.889           | 2.432                 | 62.5%           |\n| llama-3-8b/qlora.yaml             | llama-3-8b             | 2      | 512            | 2                | 1xA10G  | 22.24 min        | $0.7 | 17.41 GB    | 0.939           | 2.558                 | 62.2%           |\n|                                   |                        |        |                |                  |         |                  |      |            |                 |                        |                 |\n| llama-3.1-8b/full.yaml            | llama-3.1-8b           | 1      | 512            | 4                | 1xA10G  | OOM              | OOM  | OOM         | OOM             | OOM                   | OOM             |\n| llama-3.1-8b/lora.yaml            | llama-3.1-8b           | 2      | 512            | 1                | 1xA10G  | 13.36 min        | $1.1 | 19.73 GB    | 0.878           | 2.406                 | xx.xx           |\n| llama-3.1-8b/qlora.yaml           | llama-3.1-8b           | 2      | 512            | 2                | 1xA10G  | 21.81 min        | $0.7 | 17.41 GB    | 0.928           | 2.529                 | xx.xx           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| llama-3.2-1b/full.yaml            | llama-3.2-1b           | 1      | 512            | 4                | 1xA10G  |  2.01 min        | $0.1 |  8.70 GB    | 1.442           | 4.229                 | 38.21%          |\n| llama-3.2-1b/lora.yaml            | llama-3.2-1b           | 2      | 512            | 1                | 1xA10G  |  4.17 min        | $0.4 |  4.49 GB    | 1.114           | 3.046                 | 36.87%          |\n| llama-3.2-1b/qlora.yaml           | llama-3.2-1b           | 2      | 512            | 2                | 1xA10G  |  6.20 min        | $0.6 |  5.53 GB    | 1.201           | 3.322                 | 36.49%          |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| llama-3.2-3b/full.yaml            | llama-3.2-3b           | 1      | 512            | 4                | 1xA10G  |  4.71 min        | $0.4 | 16.51 GB    | 1.255           | 3.509                 | 54.69%          |\n| llama-3.2-3b/lora.yaml            | llama-3.2-3b           | 2      | 512            | 1                | 1xA10G  |  8.31 min        | $0.8 |  9.67 GB    | 0.973           | 2.647                 | 54.77%          |\n| llama-3.2-3b/qlora.yaml           | llama-3.2-3b           | 2      | 512            | 2                | 1xA10G  | 14.89 min        | $1.4 | 10.30 GB    | 1.031           | 2.804                 | 55.08%          |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| mistral-7b-v0.2/lora.yaml         | mistral-7b-v0.2        | 4      | 512            | 2                | 1xA10G  | 31.00 min        | $0.9 | 20.66 GB    | 0.801           | 2.228                 | 55.7%           |\n| mistral-7b-v0.2/lora.yaml         | mistral-7b-v0.2        | 4      | 512            | 2                | 4xA10G  | 31.00 min        | $2.5 | 20.66 GB    | 0.802           | 2.229                 | 55.5%           |\n| mistral-7b-v0.2/qlora.yaml        | mistral-7b-v0.2        | 4      | 512            | 2                | 1xA10G  | 44.75 min        | $1.3 | 14.29 GB    | 0.813           | 2.255                 | 56.5%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| mistral-7b/lora.yaml              | mistral-7b             | 4      | 512            | 2                | 1xA10G  | 31.01 min        | $0.9 | 20.66 GB    | 0.794           | 2.211                 | 57.9%           |\n| mistral-7b/lora.yaml              | mistral-7b             | 4      | 512            | 2                | 4xA10G  | 31.03 min        | $2.5 | 20.66 GB    | 0.796           | 2.218                 | 57.9%           |\n| mistral-7b/qlora.yaml             | mistral-7b             | 4      | 512            | 2                | 1xA10G  | 44.75 min        | $1.3 | 14.29 GB    | 0.803           | 2.231                 | 57.9%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| phi-2/full.yaml                   | phi-2                  | 1      | 512            | 4                | 4xA10G  | 11.87 min        | $1.0 | 14.44 GB    | 1.305           | 3.688                 | 38.4%           |\n| phi-2/lora.yaml                   | phi-2                  | 1      | 512            | 4                | 1xA10G  | 3.78 min         | $0.1 | 13.98 GB    | 0.819           | 2.269                 | 53.0%           |\n| phi-2/lora.yaml                   | phi-2                  | 1      | 512            | 4                | 4xA10G  | 3.78 min         | $0.3 | 13.98 GB    | 0.820           | 2.271                 | 52.4%           |\n| phi-2/qlora.yaml                  | phi-2                  | 1      | 512            | 4                | 1xA10G  | 4.51 min         | $0.1 | 14.27 GB    | 0.837           | 2.310                 | 52.3%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| phi-3/full.yaml                   | Phi-3-mini-4k-instruct | 1      | 512            | 4                | 1xA10G  | 6.93 min         | $0.2 | 17.01 GB    | 0.714           | 2.043                 | 69.81%          |\n| phi-3/lora.yaml                   | Phi-3-mini-4k-instruct | 1      | 512            | 4                | 1xA10G  | 6.46 min         | $0.2 | 19.75 GB    | 0.707           | 2.028                 | 69.70%          |\n| phi-3/qlora.yaml                  | Phi-3-mini-4k-instruct | 1      | 512            | 4                | 1xA10G  | 7.47 min         | $0.2 | 19.13 GB    | 0.729           | 2.074                 | 68.96%          |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| stablelm-base-alpha-3b/full.yaml  | stablelm-base-alpha-3b | 1      | 512            | 1                | 4xA10G  | 70.13 min        | $5.6 | 21.23 GB    | 1.513           | 4.540                 | 23.2%           |\n| stablelm-base-alpha-3b/lora.yaml  | stablelm-base-alpha-3b | 4      | 512            | 1                | 1xA10G  | 13.07 min        | $0.4 | 8.58 GB     | 1.361           | 3.900                 | 25.9%           |\n| stablelm-base-alpha-3b/lora.yaml  | stablelm-base-alpha-3b | 4      | 512            | 1                | 4xA10G  | 13.16 min        | $1.1 | 8.58 GB     | 1.362           | 3.906                 | 25.9%           |\n| stablelm-base-alpha-3b/qlora.yaml | stablelm-base-alpha-3b | 4      | 512            | 1                | 1xA10G  | 25.86 min        | $0.8 | 5.24 GB     | 1.388           | 4.009                 | 26.1%           |\n|                                   |                        |        |                |                  |         |                  |      |             |                 |                       |                 |\n| tiny-llama/full.yaml              | tiny-llama             | 1      | 512            | 4                | 1xA10G  | 2.58 min         | $0.1 | 14.10 GB    | 1.088           | 2.968                 | 24.6%           |\n| tiny-llama/full.yaml              | tiny-llama             | 1      | 512            | 4                | 4xA10G  | 2.57 min         | $0.2 | 14.10 GB    | 1.088           | 2.968                 | 24.5%           |\n| tiny-llama/lora.yaml              | tiny-llama             | 3      | 512            | 8                | 1xA10G  | 8.09 min         | $0.2 | 13.50 GB    | 1.039           | 2.826                 | 25.5%           |\n| tiny-llama/qlora.yaml             | tiny-llama             | 3      | 512            | 8                | 1xA10G  | 8.70 min         | $0.3 | 16.24 GB    | 1.056           | 2.874                 | 25.3%           |\n\n*OOM = Out of memory\n\n\n&nbsp;\n## Extending the context length\n\nIf you require a longer sequence length than the one used in a given config file, you can either edit the `max_seq_length` in the config file or pass an additional argument when running the finetuning command, for example, `--max_seq_length 4096` to override the sequence length provided in the config file.\n\n&nbsp;\n## Training on GPUs without bfloat16 support\n\nIf you are training on GPUs without bfloat-16 support, you need to change the `precision` option to `16-true` (16-bit floating point precision) or `16-mixed` (16/32-bit mixed precision) training:\n\n```bash\nlitgpt finetune lora \\\n  --config config_hub/finetune/phi-2/lora.yaml \\\n  --precision 16-true\n```\nor\n\n```bash\nlitgpt finetune lora \\\n  --config config_hub/finetune/phi-2/lora.yaml \\\n  --precision 16-mixed\n```\n\nNote that `16-true` is more compute and memory-efficient, but it can sometimes lead to training convergence issues. In this case, it's recommended to use `16-mixed`.\n\n&nbsp;\n## Multi-GPU experiments\n\nAll runs are single-GPU experiments, use `--devices 4` to utilize more than one GPU:\n\n\n```bash\nlitgpt finetune lora \\\n  --config config_hub/finetune/phi-2/lora.yaml \\\n  --devices 4\n```\n"
  },
  {
    "path": "config_hub/finetune/falcon-7b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/tiiuae/falcon-7b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-falcon-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/falcon-7b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/tiiuae/falcon-7b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-falcon-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma-2b/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-2b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/full-gemma-2b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 4\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 16\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 100\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps: 50\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma-2b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-2b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-gemma-2b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 8\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma-2b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-2b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-gemma-2b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 16\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma-7b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-7b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-gemma-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 16\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma-7b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-7b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-gemma-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 16\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma2-2b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-2-2b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-gemma-2-2b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 8\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma2-2b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-2-2b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-gemma-2-2b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 16\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma2-9b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-2-9b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-gemma-2-9b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 16\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/gemma2-9b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/google/gemma-2-9b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-gemma-2-9b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 16\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.1\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 6\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-2-7b/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/finetune/full)\nout_dir: out/finetune/full-llama2-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use (type: Union[int, str], default: 1)\ndevices: 4\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\nresume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)\n  global_batch_size: 64\n\n  # Number of samples per data-parallel rank (type: int, default: 1)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 25\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 600)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-2-7b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-llama2-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-2-7b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-llama2-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3-8b/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/finetune/full)\nout_dir: out/finetune/full-llama-3-8b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use (type: Union[int, str], default: 1)\ndevices: 4\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\nresume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)\n  global_batch_size: 64\n\n  # Number of samples per data-parallel rank (type: int, default: 1)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 25\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 600)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3-8b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-llama-3-8b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3-8b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Meta-Llama-3-8B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-llama3-8b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.1-8b/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/finetune/full)\nout_dir: out/finetune/full-llama-3.1-8b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use (type: Union[int, str], default: 1)\ndevices: 4\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\nresume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)\n  global_batch_size: 64\n\n  # Number of samples per data-parallel rank (type: int, default: 1)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 25\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 600)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.1-8b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-llama-3.1-8b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.1-8b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Meta-Llama-3.1-8B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-llama3.1-8b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.2-1B/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-3.2-1B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/finetune/full)\nout_dir: out/finetune/full-llama-3.2-1B\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\n# resume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)\n  global_batch_size: 64\n\n  # Number of samples per data-parallel rank (type: int, default: 1)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 25\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 600)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.2-1B/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-3.2-1B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-llama-3.2-1B\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.2-1B/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-3.2-1B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-llama3.2-1b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.2-3B/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-3.2-3B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/finetune/full)\nout_dir: out/finetune/full-llama-3.2-3B\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\n# resume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)\n  global_batch_size: 64\n\n  # Number of samples per data-parallel rank (type: int, default: 1)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 25\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 600)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.2-3B/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-3.2-3B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-llama-3.2-3B\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/llama-3.2-3B/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/meta-llama/Llama-3.2-3B\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-llama3.2-3b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 2\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/mistral-7b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-mistral-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/mistral-7b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/mistralai/Mistral-7B-v0.1\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-mistral-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/mistral-7b-v0.2/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-mistral-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/mistral-7b-v0.2/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-mistral-7b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 2\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/phi-2/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/microsoft/phi-2\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/finetune/full)\nout_dir: out/finetune/full-phi-2\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use (type: Union[int, str], default: 1)\ndevices: 2\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 1)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps: 100\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 600)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/phi-2/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/microsoft/phi-2\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-phi-2\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 8\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/phi-2/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/microsoft/phi-2\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-phi-2\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 8\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/phi-3/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/finetune/full)\nout_dir: out/finetune/full-phi-3\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use (type: Union[int, str], default: 1)\ndevices: 1\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 64)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 1)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 200\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 600)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/phi-3/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-phi-3\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 8\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/phi-3/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/microsoft/Phi-3-mini-4k-instruct\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-phi-3\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 8\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/stablelm-base-alpha-3b/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/full-stablelm-base-alpha-3b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 2\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 1000\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/stablelm-base-alpha-3b/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-stablelm-base-alpha-3b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/stablelm-base-alpha-3b/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/stabilityai/stablelm-base-alpha-3b\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-stablelm-base-alpha-3b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: false\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: false\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: false\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.05\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n    download_dir: data/alpaca2k\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 200\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 1\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 4\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/tiny-llama/full.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/full-tiny-llama-1.1b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 32\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 1000\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 1\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 25\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/tiny-llama/lora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/lora-tiny-llama-1.1b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize:\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 8\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 3\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/finetune/tiny-llama/qlora.yaml",
    "content": "# The path to the base model's checkpoint directory to load for finetuning. (type: <class 'Path'>, default: checkpoints/stabilityai/stablelm-base-alpha-3b)\ncheckpoint_dir: checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n\n# Directory in which to save checkpoints and logs. (type: <class 'Path'>, default: out/lora)\nout_dir: out/finetune/qlora-tiny-llama-1.1b\n\n# The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-true\n\n# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null)\nquantize: bnb.nf4\n\n# How many devices/GPUs to use. (type: Union[int, str], default: 1)\ndevices: 1\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# The LoRA rank. (type: int, default: 8)\nlora_r: 32\n\n# The LoRA alpha. (type: int, default: 16)\nlora_alpha: 16\n\n# The LoRA dropout value. (type: float, default: 0.05)\nlora_dropout: 0.05\n\n# Whether to apply LoRA to the query weights in attention. (type: bool, default: True)\nlora_query: true\n\n# Whether to apply LoRA to the key weights in attention. (type: bool, default: False)\nlora_key: true\n\n# Whether to apply LoRA to the value weights in attention. (type: bool, default: True)\nlora_value: true\n\n# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False)\nlora_projection: true\n\n# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False)\nlora_mlp: true\n\n# Whether to apply LoRA to output head in GPT. (type: bool, default: False)\nlora_head: true\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\ndata:\n  class_path: litgpt.data.Alpaca2k\n  init_args:\n    mask_prompt: false\n    val_split_fraction: 0.03847\n    prompt_style: alpaca\n    ignore_index: -100\n    seed: 42\n    num_workers: 4\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 800\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128)\n  global_batch_size: 8\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 8\n\n  # Number of iterations with learning rate warmup active (type: int, default: 100)\n  lr_warmup_steps: 10\n\n  # Number of epochs to train on (type: Optional[int], default: 5)\n  epochs: 3\n\n  # Total number of tokens to train on (type: Optional[int], default: null)\n  max_tokens:\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 512\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: null)\n  max_norm:\n\n  #   (type: float, default: 6e-05)\n  min_lr: 6.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 100)\n  interval: 100\n\n  # Number of tokens to generate (type: Optional[int], default: 100)\n  max_new_tokens: 100\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: true\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: csv)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 1337)\nseed: 1337\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0002\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.0\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n"
  },
  {
    "path": "config_hub/pretrain/debug.yaml",
    "content": "# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with\n# ``model_config``. (type: Optional[str], default: null)\nmodel_name: pythia-14m\n\n# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with\n# ``model_config``. (type: Optional[Config], default: null)\nmodel_config:\n\n# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)\nout_dir: out/pretrain/debug\n\n# The precision to use for pretraining. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-mixed\n\n# Optional path to a checkpoint directory to initialize the model from.\n# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)\ninitial_checkpoint_dir:\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\nresume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.\ndata: TinyStories\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 1000\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)\n  global_batch_size: 125\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 5\n\n  # Number of iterations with learning rate warmup active (type: int, default: 2000)\n  lr_warmup_steps: 100\n\n  # Number of epochs to train on (type: Optional[int], default: null)\n  epochs:\n\n  # Total number of tokens to train on (type: Optional[int], default: 3000000000000)\n  max_tokens: 100000000\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length:\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: 1.0)\n  max_norm: 1.0\n\n  #   (type: float, default: 4e-05)\n  min_lr: 6e-5\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 1000)\n  interval: 1000\n\n  # Number of tokens to generate (type: Optional[int], default: null)\n  max_new_tokens:\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: false\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 6e-4\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n\n# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)\ndevices: auto\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data\n# module require this. (type: Optional[Path], default: null)\ntokenizer_dir: checkpoints/EleutherAI/pythia-14m\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)\nlogger_name: tensorboard\n\n# The random seed to use for reproducibility. (type: int, default: 42)\nseed: 42\n"
  },
  {
    "path": "config_hub/pretrain/microllama.yaml",
    "content": "# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with\n# ``model_config``. (type: Optional[str], default: null)\nmodel_name: micro-llama-300M\n\n# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with\n# ``model_config``. (type: Optional[Config], default: null)\nmodel_config:\n\n# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)\nout_dir: out/pretrain/micro-llama\n\n# The precision to use for pretraining. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-mixed\n\n# Optional path to a checkpoint directory to initialize the model from.\n# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)\ninitial_checkpoint_dir:\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\nresume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.\ndata: MicroLlama\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 1000\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48)\n  # Scale this number according to the number of GPU and memory size per GPU\n  # For example, we used 48 for 4 x 24G 4090\n  global_batch_size: 48\n\n  # Number of samples per data-parallel rank (type: int, default: 12)\n  # Scale this number according to the memory size per GPU\n  # For example, we used 12 for 24G 4090\n  micro_batch_size: 12\n\n  # Number of iterations with learning rate warmup active (type: int, default: 2000)\n  lr_warmup_steps: 2000\n\n  # Number of epochs to train on (type: Optional[int], default: null)\n  epochs:\n\n  # Total number of tokens to train on (type: Optional[int], default: 3000000000000)\n  max_tokens: 3000000000000\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 2048\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: 1.0)\n  max_norm: 1.0\n\n  #   (type: float, default: 4e-05)\n  min_lr: 4.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 1000)\n  interval: 1000\n\n  # Number of tokens to generate (type: Optional[int], default: null)\n  max_new_tokens:\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 4e-4\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n\n# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)\ndevices: auto\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data\n# module require this. (type: Optional[Path], default: null)\ntokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)\nlogger_name: tensorboard\n\n# The random seed to use for reproducibility. (type: int, default: 42)\nseed: 42\n"
  },
  {
    "path": "config_hub/pretrain/tinyllama.yaml",
    "content": "# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with\n# ``model_config``. (type: Optional[str], default: null)\nmodel_name: tiny-llama-1.1b\n\n# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with\n# ``model_config``. (type: Optional[Config], default: null)\nmodel_config:\n\n# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)\nout_dir: out/pretrain/tiny-llama\n\n# The precision to use for pretraining. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-mixed\n\n# Optional path to a checkpoint directory to initialize the model from.\n# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)\ninitial_checkpoint_dir:\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\nresume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.\ndata: TinyLlama\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 1000\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)\n  global_batch_size: 512\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 4\n\n  # Number of iterations with learning rate warmup active (type: int, default: 2000)\n  lr_warmup_steps: 2000\n\n  # Number of epochs to train on (type: Optional[int], default: null)\n  epochs:\n\n  # Total number of tokens to train on (type: Optional[int], default: 3000000000000)\n  max_tokens: 3000000000000\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 2048\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)\n  tie_embeddings:\n\n  #   (type: Optional[float], default: 1.0)\n  max_norm: 1.0\n\n  #   (type: float, default: 4e-05)\n  min_lr: 4.0e-05\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 1000)\n  interval: 1000\n\n  # Number of tokens to generate (type: Optional[int], default: null)\n  max_new_tokens:\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: false\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 4e-4\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n\n# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)\ndevices: auto\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data\n# module require this. (type: Optional[Path], default: null)\ntokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)\nlogger_name: tensorboard\n\n# The random seed to use for reproducibility. (type: int, default: 42)\nseed: 42\n"
  },
  {
    "path": "config_hub/pretrain/tinystories.yaml",
    "content": "# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with\n# ``model_config``. (type: Optional[str], default: null)\nmodel_name: stories15M\n\n# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with\n# ``model_config``. (type: Optional[Config], default: null)\nmodel_config:\n  name: stories15M\n  hf_config: {}\n  scale_embeddings: false\n  block_size: 256\n  padded_vocab_size: 32000\n  n_layer: 6\n  n_head: 6\n  n_query_groups: 6\n  n_embd: 288\n  head_size: 48\n  rotary_percentage: 1.0\n  parallel_residual: false\n  bias: false\n  norm_class_name: RMSNorm\n  mlp_class_name: LLaMAMLP\n  intermediate_size: 768\n\n# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)\nout_dir: out/pretrain/stories15M\n\n# The precision to use for pretraining. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\". (type: Optional[str], default: null)\nprecision: bf16-mixed\n\n# Optional path to a checkpoint directory to initialize the model from.\n# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)\ninitial_checkpoint_dir:\n\n# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n# from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n# ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n# (type: Union[bool, Literal[\"auto\"], Path], default: False)\nresume: false\n\n# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.\ndata: TinyStories\n\n# Training-related arguments. See ``litgpt.args.TrainArgs`` for details\ntrain:\n  # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)\n  save_interval: 1000\n\n  # Number of iterations between logging calls (type: int, default: 1)\n  log_interval: 1\n\n  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)\n  global_batch_size: 512\n\n  # Number of samples per data-parallel rank (type: int, default: 4)\n  micro_batch_size: 128\n\n  # Number of iterations with learning rate warmup active (type: int, default: 2000)\n  lr_warmup_steps: 1000\n\n  # Number of epochs to train on (type: Optional[int], default: null)\n  epochs:\n\n  # Total number of tokens to train on (type: Optional[int], default: 3000000000000)\n  max_tokens: 9700000000 # original did 298,000 iters\n\n  # Limits the number of optimizer steps to run. (type: Optional[int], default: null)\n  max_steps:\n\n  # Limits the length of samples. Off by default (type: Optional[int], default: null)\n  max_seq_length: 256\n\n  # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)\n  tie_embeddings: true\n\n  #   (type: Optional[float], default: 1.0)\n  max_norm: 1.0\n\n  #   (type: float, default: 4e-05)\n  min_lr: 0.0\n\n# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details\neval:\n  # Number of optimizer steps between evaluation calls (type: int, default: 1000)\n  interval: 2000\n\n  # Number of tokens to generate (type: Optional[int], default: null)\n  max_new_tokens:\n\n  # Number of iterations (type: int, default: 100)\n  max_iters: 100\n\n  # Whether to evaluate on the validation set at the beginning of the training\n  initial_validation: false\n\n  # Whether to evaluate on the validation set at the end the training\n  final_validation: false\n\n# Optimizer-related arguments\noptimizer:\n  class_path: torch.optim.AdamW\n\n  init_args:\n    #   (type: float, default: 0.001)\n    lr: 0.0005\n\n    #   (type: float, default: 0.01)\n    weight_decay: 0.1\n\n    #   (type: tuple, default: (0.9,0.999))\n    betas:\n      - 0.9\n      - 0.95\n\n# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)\ndevices: auto\n\n# How many nodes to use. (type: int, default: 1)\nnum_nodes: 1\n\n# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data\n# module require this. (type: Optional[Path], default: null)\ntokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf\n\n# The name of the logger to send metrics to. (type: LoggerChoice, i.e. Literal['wandb', 'tensorboard', 'csv', 'mlflow', 'litlogger'], default: tensorboard)\nlogger_name: csv\n\n# The random seed to use for reproducibility. (type: int, default: 42)\nseed: 42\n"
  },
  {
    "path": "extensions/thunder/README.md",
    "content": "# Lightning Thunder: a source-to-source compiler for PyTorch\n\n[Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) makes PyTorch programs faster both on single accelerators or in distributed settings.\n\nThunder aims to be usable, understandable, and extensible and can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors.\n\nThis extension directory shows how Thunder can be used with LitGPT.\n\n> [!WARNING]\n> This document is an early-access development version that is currently only for internal use. We recommend users checking out the [Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) project directly, which provides more up-to-date usage information.\n\n\n&nbsp;\n## Thunder 👉👈 LitGPT: a short showcase\n\nTo try Lightning Thunder with your model simply `thunder.jit()` it.\n\n```python\nfrom litgpt import GPT\nimport thunder\nimport torch\n\n# Use only two layers to keep the traces shorter for the demonstration\nmodel = GPT.from_name(\"Llama-2-7b-hf\", n_layer=2).cuda()\nmodel = thunder.jit(model)\nx = torch.randint(model.max_seq_length, (2, 5), device=\"cuda\")\ny = model(x)  # forward, this may take a bit\n```\n\nThis will require some compilation time on the first forward call.\n\n### Traces\n\nThe JIT is will acquire a Python program (what we call a \"trace\") from the Python program (`GPT`, a `torch.nn.Module` in this example) that was given.\nThis process targets PyTorch operators (like `Tensor.view()`, `+`, `torch.nn.functional.scaled_dot_product_atttention()`) and optionally custom operators (more about that later).\n\nWe can visualize the thunder trace generated under the hood:\n\n```python\nforward_trace = thunder.last_traces(model)[-1].python()\nprint(forward_trace)\n```\n\n```python\n@torch.no_grad()\n@no_autocast()\ndef augmented_forward_fn(*args):\n  # args: \"Collection\"\n  t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, \\\n  t18, t19, = args\n  del args\n  t24 = torch.nn.functional.embedding(t0, t19, None, None, 2.0, False, False)  # t24: \"cuda:0 f32[2, 5, 4096]\"\n  t20 = torch_slice_prim_impl(t1, [0, 0], [5, 128], [1, 1])  # t20: \"cuda:0 f32[5, 128]\"\n  t21 = torch_slice_prim_impl(t2, [0, 0], [5, 128], [1, 1])  # t21: \"cuda:0 f32[5, 128]\"\n  t200 = torch.unsqueeze(t11, 0)  # t200: \"cuda:0 f32[1, 4096]\"\n  t201 = torch.unsqueeze(t200, 1)  # t201: \"cuda:0 f32[1, 1, 4096]\"\n  del t200\n  t33 = Tensor.expand(t201, (2, 5, 4096))  # t33: \"cuda:0 f32[2, 5, 4096]\"\n  del t201\n  t229 = torch.unsqueeze(t13, 0)  # t229: \"cuda:0 f32[1, 4096]\"\n  t230 = torch.unsqueeze(t229, 1)  # t230: \"cuda:0 f32[1, 1, 4096]\"\n  del t229\n  t84 = Tensor.expand(t230, (2, 5, 4096))  # t84: \"cuda:0 f32[2, 5, 4096]\"\n  del t230\n  t232 = torch.unsqueeze(t12, 0)  # t232: \"cuda:0 f32[1, 4096]\"\n  t233 = torch.unsqueeze(t232, 1)  # t233: \"cuda:0 f32[1, 1, 4096]\"\n  del t232\n  t104 = Tensor.expand(t233, (2, 5, 4096))  # t104: \"cuda:0 f32[2, 5, 4096]\"\n  del t233\n  t253 = torch.unsqueeze(t14, 0)  # t253: \"cuda:0 f32[1, 4096]\"\n  t254 = torch.unsqueeze(t253, 1)  # t254: \"cuda:0 f32[1, 1, 4096]\"\n  del t253\n  t155 = Tensor.expand(t254, (2, 5, 4096))  # t155: \"cuda:0 f32[2, 5, 4096]\"\n  del t254\n  t256 = torch.unsqueeze(t10, 0)  # t256: \"cuda:0 f32[1, 4096]\"\n  t257 = torch.unsqueeze(t256, 1)  # t257: \"cuda:0 f32[1, 1, 4096]\"\n  del t256\n  t175 = Tensor.expand(t257, (2, 5, 4096))  # t175: \"cuda:0 f32[2, 5, 4096]\"\n  del t257\n  t221 = torch.unsqueeze(t20, 0)  # t221: \"cuda:0 f32[1, 5, 128]\"\n  del t20\n  t222 = torch.unsqueeze(t221, 1)  # t222: \"cuda:0 f32[1, 1, 5, 128]\"\n  del t221\n  t49 = Tensor.expand(t222, (2, 32, 5, 128))  # t49: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t222\n  t224 = torch.unsqueeze(t21, 0)  # t224: \"cuda:0 f32[1, 5, 128]\"\n  del t21\n  t225 = torch.unsqueeze(t224, 1)  # t225: \"cuda:0 f32[1, 1, 5, 128]\"\n  del t224\n  t51 = Tensor.expand(t225, (2, 32, 5, 128))  # t51: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t225\n  [t30, t34] = nvFusion0(t24, t33)\n  t35 = torch.nn.functional.linear(t34, t3, None)  # t35: \"cuda:0 f32[2, 5, 12288]\"\n  t36 = torch.reshape(t35, (2, 5, 32, 3, 128))  # t36: \"cuda:0 f32[2, 5, 32, 3, 128]\"\n  del t35\n  t37 = torch.permute(t36, (0, 2, 3, 1, 4))  # t37: \"cuda:0 f32[2, 32, 3, 5, 128]\"\n  del t36\n  (t38, t39, t40) = torch.split(t37, (1, 1, 1), 2)\n  del t37\n  t41 = torch.reshape(t38, (2, 32, 5, 128))  # t41: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t38\n  t42 = torch.reshape(t39, (2, 32, 5, 128))  # t42: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t39\n  t43 = torch.reshape(t40, (2, 32, 5, 128))  # t43: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t40\n  t44 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t44: \"cuda:0 f32[2, 32, 5, 128]\"\n  t54 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t54: \"cuda:0 f32[2, 32, 5, 128]\"\n  t64 = torch_slice_prim_impl(t41, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1])  # t64: \"cuda:0 f32[2, 32, 5, 0]\"\n  del t41\n  t66 = torch_slice_prim_impl(t42, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1])  # t66: \"cuda:0 f32[2, 32, 5, 0]\"\n  del t42\n  t46 = torch_slice_prim_impl(t44, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1])  # t46: \"cuda:0 f32[2, 32, 5, 64]\"\n  t45 = torch_slice_prim_impl(t44, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1])  # t45: \"cuda:0 f32[2, 32, 5, 64]\"\n  t55 = torch_slice_prim_impl(t54, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1])  # t55: \"cuda:0 f32[2, 32, 5, 64]\"\n  t56 = torch_slice_prim_impl(t54, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1])  # t56: \"cuda:0 f32[2, 32, 5, 64]\"\n  [t47, t57] = nvFusion1(t46, t56)\n  del t46, t56\n  t48 = torch.cat((t47, t45), -1)  # t48: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t47, t45\n  t58 = torch.cat((t57, t55), -1)  # t58: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t57, t55\n  [t53, t63] = nvFusion2(t44, t48, t49, t51, t54, t58)\n  del t44, t48, t54, t58\n  t65 = torch.cat((t53, t64), -1)  # t65: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t53, t64\n  t67 = torch.cat((t63, t66), -1)  # t67: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t63, t66\n  (t68, t69, t70, t71) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t65, t67, t43, None, 0.0, True, 0.08838834764831843)\n  t72 = torch.permute(t68, (0, 2, 1, 3))  # t72: \"cuda:0 f32[2, 5, 32, 128]\"\n  t73 = torch.reshape(t72, (2, 5, 4096))  # t73: \"cuda:0 f32[2, 5, 4096]\"\n  del t72\n  t74 = torch.nn.functional.linear(t73, t15, None)  # t74: \"cuda:0 f32[2, 5, 4096]\"\n  [t75, t81, t85] = nvFusion3(t24, t74, t84)\n  del t74\n  t86 = torch.nn.functional.linear(t85, t5, None)  # t86: \"cuda:0 f32[2, 5, 11008]\"\n  t87 = torch.nn.functional.linear(t85, t7, None)  # t87: \"cuda:0 f32[2, 5, 11008]\"\n  [t93] = nvFusion4(t86, t87)\n  t94 = torch.nn.functional.linear(t93, t16, None)  # t94: \"cuda:0 f32[2, 5, 4096]\"\n  [t101, t105, t95] = nvFusion5(t104, t75, t94)\n  del t94\n  t106 = torch.nn.functional.linear(t105, t4, None)  # t106: \"cuda:0 f32[2, 5, 12288]\"\n  t107 = torch.reshape(t106, (2, 5, 32, 3, 128))  # t107: \"cuda:0 f32[2, 5, 32, 3, 128]\"\n  del t106\n  t108 = torch.permute(t107, (0, 2, 3, 1, 4))  # t108: \"cuda:0 f32[2, 32, 3, 5, 128]\"\n  del t107\n  (t109, t110, t111) = torch.split(t108, (1, 1, 1), 2)\n  del t108\n  t112 = torch.reshape(t109, (2, 32, 5, 128))  # t112: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t109\n  t113 = torch.reshape(t110, (2, 32, 5, 128))  # t113: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t110\n  t114 = torch.reshape(t111, (2, 32, 5, 128))  # t114: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t111\n  t135 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1])  # t135: \"cuda:0 f32[2, 32, 5, 0]\"\n  t137 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 0], [1, 1, 1, 1])  # t137: \"cuda:0 f32[2, 32, 5, 0]\"\n  t115 = torch_slice_prim_impl(t112, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t115: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t112\n  t125 = torch_slice_prim_impl(t113, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t125: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t113\n  t116 = torch_slice_prim_impl(t115, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1])  # t116: \"cuda:0 f32[2, 32, 5, 64]\"\n  t117 = torch_slice_prim_impl(t115, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1])  # t117: \"cuda:0 f32[2, 32, 5, 64]\"\n  t127 = torch_slice_prim_impl(t125, [0, 0, 0, 64], [2, 32, 5, 128], [1, 1, 1, 1])  # t127: \"cuda:0 f32[2, 32, 5, 64]\"\n  t126 = torch_slice_prim_impl(t125, [0, 0, 0, 0], [2, 32, 5, 64], [1, 1, 1, 1])  # t126: \"cuda:0 f32[2, 32, 5, 64]\"\n  [t118, t128] = nvFusion6(t117, t127)\n  del t117, t127\n  t129 = torch.cat((t128, t126), -1)  # t129: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t128, t126\n  t119 = torch.cat((t118, t116), -1)  # t119: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t118, t116\n  [t124, t134] = nvFusion7(t115, t119, t125, t129, t49, t51)\n  del t115, t119, t125, t129\n  t136 = torch.cat((t124, t135), -1)  # t136: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t124, t135\n  t138 = torch.cat((t134, t137), -1)  # t138: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t134, t137\n  (t139, t140, t141, t142) = sdpaex_grad_forward_scaled_dot_product_efficient_attention(t136, t138, t114, None, 0.0, True, 0.08838834764831843)\n  t143 = torch.permute(t139, (0, 2, 1, 3))  # t143: \"cuda:0 f32[2, 5, 32, 128]\"\n  t144 = torch.reshape(t143, (2, 5, 4096))  # t144: \"cuda:0 f32[2, 5, 4096]\"\n  del t143\n  t145 = torch.nn.functional.linear(t144, t17, None)  # t145: \"cuda:0 f32[2, 5, 4096]\"\n  [t146, t152, t156] = nvFusion8(t145, t155, t95)\n  del t145\n  t158 = torch.nn.functional.linear(t156, t8, None)  # t158: \"cuda:0 f32[2, 5, 11008]\"\n  t157 = torch.nn.functional.linear(t156, t6, None)  # t157: \"cuda:0 f32[2, 5, 11008]\"\n  [t164] = nvFusion9(t157, t158)\n  t165 = torch.nn.functional.linear(t164, t18, None)  # t165: \"cuda:0 f32[2, 5, 4096]\"\n  [t166, t172, t176] = nvFusion10(t146, t165, t175)\n  del t165\n  t177 = torch.nn.functional.linear(t176, t9, None)  # t177: \"cuda:0 f32[2, 5, 32000]\"\n  return {'output': t177, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19], 'flat_output': (t177,)}, ((t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95), (False, False, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2))\n```\n\nThis is a straight-lined version of `GPT.forward` that has been optimized. Since it's running on CUDA, the [NvFuser](https://github.com/NVIDIA/Fuser) executor has created regions (look for \"nvFusion\") that fuse multiple operators together.\n\nOperator fusion is very desirable with modern hardware and helps out in overhead-bound or device-bound settings by:\n- Launching less kernels, thus reducing the kernel launch overhead.\n- Reducing the number of memory accesses performed by reusing them in a fused operation\n- Minimizing host-device communications\n\nThunder also uses a multi-level intermediate representation. If we let it print all levels\n\n```python\nforward_trace = thunder.last_traces(model)[-1]\nprint(forward_trace)\n```\n\nWe can see as comments the primitives that compose the fusion regions. For instance, this is the region associated to [the `RMSNorm` implementation](https://github.com/Lightning-AI/litgpt/blob/9b6475dabf90c7acee506a026bd9fa86251835bf/litgpt/model.py#L409-L420)\n\n```python\n  [t146, t152, t156] = nvFusion8(t145, t155, t95)\n    # t146 = prims.add(t145, t95)  # t146: \"cuda:0 f32[2, 5, 4096]\"\n    # t147 = prims.mul(t146, t146)  # t147: \"cuda:0 f32[2, 5, 4096]\"\n    # t148 = prims.sum(t147, (2,))  # t148: \"cuda:0 f32[2, 5]\"\n    # t149 = prims.broadcast_in_dim(t148, [2, 5, 1], [0, 1])  # t149: \"cuda:0 f32[2, 5, 1]\"\n    # t150 = prims.div(t149, 4096.0)  # t150: \"cuda:0 f32[2, 5, 1]\"\n    # t151 = prims.add(t150, 1e-05)  # t151: \"cuda:0 f32[2, 5, 1]\"\n    # t152 = prims.rsqrt(t151)  # t152: \"cuda:0 f32[2, 5, 1]\"\n    # t153 = prims.broadcast_in_dim(t152, (2, 5, 4096), (0, 1, 2))  # t153: \"cuda:0 f32[2, 5, 4096]\"\n    # t154 = prims.mul(t146, t153)  # t154: \"cuda:0 f32[2, 5, 4096]\"\n    # t156 = prims.mul(t154, t155)  # t156: \"cuda:0 f32[2, 5, 4096]\"\n```\n\nSimilarly, we can visualize the backward trace:\n\n```python\nbackward_trace = thunder.last_backward_traces(model)[-1].python()\nprint(backward_trace)\n```\n\n```python\n@torch.no_grad()\n@no_autocast()\ndef backward_fn(saved_for_backward, cotangents):\n  # saved_for_backward: \"Collection\"\n  # cotangents: \"Collection\"\n  C0, C1, = saved_for_backward\n  clear_collection(saved_for_backward)\n  del saved_for_backward\n  t178, = cotangents\n  clear_collection(cotangents)\n  del cotangents\n  t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, \\\n  t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, \\\n  t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, \\\n  t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95, = C0\n  clear_collection(C0)\n  del C0\n  b1, b2, b41, b91, f101, f106, f40, f42, f51, f56, f6, f90, f92, i0, i23, i73, \\\n  = C1\n  clear_collection(C1)\n  del C1\n  t639 = torch.reshape(t178, (-1, 32000))  # t639: \"cuda:0 f32[10, 32000]\"\n  del t178\n  t643 = torch.permute(t639, (1, 0))  # t643: \"cuda:0 f32[32000, 10]\"\n  t644 = torch.reshape(t176, (-1, 4096))  # t644: \"cuda:0 f32[10, 4096]\"\n  del t176\n  t669 = torch.reshape(t164, (-1, 11008))  # t669: \"cuda:0 f32[10, 11008]\"\n  del t164\n  t686 = torch.reshape(t156, (-1, 4096))  # t686: \"cuda:0 f32[10, 4096]\"\n  del t156\n  t720 = torch.reshape(t144, (-1, 4096))  # t720: \"cuda:0 f32[10, 4096]\"\n  del t144\n  t776 = torch.reshape(t105, (-1, 4096))  # t776: \"cuda:0 f32[10, 4096]\"\n  del t105\n  t802 = torch.reshape(t93, (-1, 11008))  # t802: \"cuda:0 f32[10, 11008]\"\n  del t93\n  t819 = torch.reshape(t85, (-1, 4096))  # t819: \"cuda:0 f32[10, 4096]\"\n  del t85\n  t853 = torch.reshape(t73, (-1, 4096))  # t853: \"cuda:0 f32[10, 4096]\"\n  del t73\n  t911 = torch.reshape(t34, (-1, 4096))  # t911: \"cuda:0 f32[10, 4096]\"\n  del t34\n  t640 = torch.matmul(t639, t9)  # t640: \"cuda:0 f32[10, 4096]\"\n  del t639, t9\n  t645 = torch.matmul(t643, t644)  # t645: \"cuda:0 f32[32000, 4096]\"\n  del t643, t644\n  t641 = torch.reshape(t640, (2, 5, 4096))  # t641: \"cuda:0 f32[2, 5, 4096]\"\n  del t640\n  [t648, t663] = nvFusion0(f106, t166, t172, t175, t641)\n  del f106, t166, t172, t175, t641\n  t664 = torch.reshape(t663, (-1, 4096))  # t664: \"cuda:0 f32[10, 4096]\"\n  t668 = torch.permute(t664, (1, 0))  # t668: \"cuda:0 f32[4096, 10]\"\n  t665 = torch.matmul(t664, t18)  # t665: \"cuda:0 f32[10, 11008]\"\n  del t664, t18\n  t670 = torch.matmul(t668, t669)  # t670: \"cuda:0 f32[4096, 11008]\"\n  del t668, t669\n  t666 = torch.reshape(t665, (2, 5, 11008))  # t666: \"cuda:0 f32[2, 5, 11008]\"\n  del t665\n  [t672, t680] = nvFusion1(t157, t158, t666)\n  del t157, t158, t666\n  t681 = torch.reshape(t672, (-1, 11008))  # t681: \"cuda:0 f32[10, 11008]\"\n  del t672\n  t685 = torch.permute(t681, (1, 0))  # t685: \"cuda:0 f32[11008, 10]\"\n  t688 = torch.reshape(t680, (-1, 11008))  # t688: \"cuda:0 f32[10, 11008]\"\n  del t680\n  t692 = torch.permute(t688, (1, 0))  # t692: \"cuda:0 f32[11008, 10]\"\n  t689 = torch.matmul(t688, t6)  # t689: \"cuda:0 f32[10, 4096]\"\n  del t688, t6\n  t682 = torch.matmul(t681, t8)  # t682: \"cuda:0 f32[10, 4096]\"\n  del t681, t8\n  t694 = torch.matmul(t692, t686)  # t694: \"cuda:0 f32[11008, 4096]\"\n  del t692\n  t687 = torch.matmul(t685, t686)  # t687: \"cuda:0 f32[11008, 4096]\"\n  del t685, t686\n  t683 = torch.reshape(t682, (2, 5, 4096))  # t683: \"cuda:0 f32[2, 5, 4096]\"\n  del t682\n  t690 = torch.reshape(t689, (2, 5, 4096))  # t690: \"cuda:0 f32[2, 5, 4096]\"\n  del t689\n  [t698, t714] = nvFusion2(f101, t146, t152, t155, t663, t683, t690)\n  del f101, t146, t152, t155, t663, t683, t690\n  t715 = torch.reshape(t714, (-1, 4096))  # t715: \"cuda:0 f32[10, 4096]\"\n  t719 = torch.permute(t715, (1, 0))  # t719: \"cuda:0 f32[4096, 10]\"\n  t716 = torch.matmul(t715, t17)  # t716: \"cuda:0 f32[10, 4096]\"\n  del t715, t17\n  t721 = torch.matmul(t719, t720)  # t721: \"cuda:0 f32[4096, 4096]\"\n  del t719, t720\n  t717 = torch.reshape(t716, (2, 5, 4096))  # t717: \"cuda:0 f32[2, 5, 4096]\"\n  del t716\n  t722 = torch.reshape(t717, (2, 5, 32, 128))  # t722: \"cuda:0 f32[2, 5, 32, 128]\"\n  del t717\n  t723 = torch.permute(t722, (0, 2, 1, 3))  # t723: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t722\n  (t724, t725, t726, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t723, t136, t138, t114, None, t139, t140, t141, t142, f90, b91, scale=f92)\n  del t723, t136, t138, t114, t139, t140, t141, t142, f90, b91, f92\n  t765 = torch.reshape(t726, (2, 32, 1, 5, 128))  # t765: \"cuda:0 f32[2, 32, 1, 5, 128]\"\n  del t726\n  t727 = torch_slice_prim_impl(t725, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t727: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t725\n  t730 = torch_slice_prim_impl(t724, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t730: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t724\n  [t747, t764] = nvFusion3(t49, t51, t727, t730)\n  del t727, t730\n  t766 = torch.reshape(t747, (2, 32, 1, 5, 128))  # t766: \"cuda:0 f32[2, 32, 1, 5, 128]\"\n  del t747\n  t767 = torch.reshape(t764, (2, 32, 1, 5, 128))  # t767: \"cuda:0 f32[2, 32, 1, 5, 128]\"\n  del t764\n  t768 = torch.cat((t767, t766, t765), i73)  # t768: \"cuda:0 f32[2, 32, 3, 5, 128]\"\n  del t767, t766, t765, i73\n  t769 = torch.permute(t768, (0, 3, 1, 2, 4))  # t769: \"cuda:0 f32[2, 5, 32, 3, 128]\"\n  del t768\n  t770 = torch.reshape(t769, (2, 5, 12288))  # t770: \"cuda:0 f32[2, 5, 12288]\"\n  del t769\n  t771 = torch.reshape(t770, (-1, 12288))  # t771: \"cuda:0 f32[10, 12288]\"\n  del t770\n  t775 = torch.permute(t771, (1, 0))  # t775: \"cuda:0 f32[12288, 10]\"\n  t777 = torch.matmul(t775, t776)  # t777: \"cuda:0 f32[12288, 4096]\"\n  del t775, t776\n  t772 = torch.matmul(t771, t4)  # t772: \"cuda:0 f32[10, 4096]\"\n  del t771, t4\n  t773 = torch.reshape(t772, (2, 5, 4096))  # t773: \"cuda:0 f32[2, 5, 4096]\"\n  del t772\n  [t780, t796] = nvFusion4(f56, t101, t104, t714, t773, t95)\n  del f56, t101, t104, t714, t773, t95\n  t797 = torch.reshape(t796, (-1, 4096))  # t797: \"cuda:0 f32[10, 4096]\"\n  t801 = torch.permute(t797, (1, 0))  # t801: \"cuda:0 f32[4096, 10]\"\n  t798 = torch.matmul(t797, t16)  # t798: \"cuda:0 f32[10, 11008]\"\n  del t797, t16\n  t803 = torch.matmul(t801, t802)  # t803: \"cuda:0 f32[4096, 11008]\"\n  del t801, t802\n  t799 = torch.reshape(t798, (2, 5, 11008))  # t799: \"cuda:0 f32[2, 5, 11008]\"\n  del t798\n  [t805, t813] = nvFusion5(t799, t86, t87)\n  del t799, t86, t87\n  t814 = torch.reshape(t805, (-1, 11008))  # t814: \"cuda:0 f32[10, 11008]\"\n  del t805\n  t818 = torch.permute(t814, (1, 0))  # t818: \"cuda:0 f32[11008, 10]\"\n  t821 = torch.reshape(t813, (-1, 11008))  # t821: \"cuda:0 f32[10, 11008]\"\n  del t813\n  t825 = torch.permute(t821, (1, 0))  # t825: \"cuda:0 f32[11008, 10]\"\n  t822 = torch.matmul(t821, t5)  # t822: \"cuda:0 f32[10, 4096]\"\n  del t821, t5\n  t815 = torch.matmul(t814, t7)  # t815: \"cuda:0 f32[10, 4096]\"\n  del t814, t7\n  t827 = torch.matmul(t825, t819)  # t827: \"cuda:0 f32[11008, 4096]\"\n  del t825\n  t820 = torch.matmul(t818, t819)  # t820: \"cuda:0 f32[11008, 4096]\"\n  del t818, t819\n  t816 = torch.reshape(t815, (2, 5, 4096))  # t816: \"cuda:0 f32[2, 5, 4096]\"\n  del t815\n  t823 = torch.reshape(t822, (2, 5, 4096))  # t823: \"cuda:0 f32[2, 5, 4096]\"\n  del t822\n  [t831, t847] = nvFusion6(f51, t75, t796, t81, t816, t823, t84)\n  del f51, t75, t796, t81, t816, t823, t84\n  t848 = torch.reshape(t847, (-1, 4096))  # t848: \"cuda:0 f32[10, 4096]\"\n  t852 = torch.permute(t848, (1, 0))  # t852: \"cuda:0 f32[4096, 10]\"\n  t849 = torch.matmul(t848, t15)  # t849: \"cuda:0 f32[10, 4096]\"\n  del t848, t15\n  t854 = torch.matmul(t852, t853)  # t854: \"cuda:0 f32[4096, 4096]\"\n  del t852, t853\n  t850 = torch.reshape(t849, (2, 5, 4096))  # t850: \"cuda:0 f32[2, 5, 4096]\"\n  del t849\n  t855 = torch.reshape(t850, (2, 5, 32, 128))  # t855: \"cuda:0 f32[2, 5, 32, 128]\"\n  del t850\n  t856 = torch.permute(t855, (0, 2, 1, 3))  # t856: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t855\n  (t857, t858, t859, _) = sdpaex_scaled_dot_product_efficient_attention_backward(t856, t65, t67, t43, None, t68, t69, t70, t71, f40, b41, scale=f42)\n  del t856, t65, t67, t43, t68, t69, t70, t71, f40, b41, f42\n  t900 = torch.reshape(t859, (2, 32, 1, 5, 128))  # t900: \"cuda:0 f32[2, 32, 1, 5, 128]\"\n  del t859\n  t863 = torch_slice_prim_impl(t857, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t863: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t857\n  t860 = torch_slice_prim_impl(t858, [0, 0, 0, 0], [2, 32, 5, 128], [1, 1, 1, 1])  # t860: \"cuda:0 f32[2, 32, 5, 128]\"\n  del t858\n  [t882, t899] = nvFusion7(t49, t51, t860, t863)\n  del t49, t51, t860, t863\n  t902 = torch.reshape(t899, (2, 32, 1, 5, 128))  # t902: \"cuda:0 f32[2, 32, 1, 5, 128]\"\n  del t899\n  t901 = torch.reshape(t882, (2, 32, 1, 5, 128))  # t901: \"cuda:0 f32[2, 32, 1, 5, 128]\"\n  del t882\n  t903 = torch.cat((t902, t901, t900), i23)  # t903: \"cuda:0 f32[2, 32, 3, 5, 128]\"\n  del t902, t901, t900, i23\n  t904 = torch.permute(t903, (0, 3, 1, 2, 4))  # t904: \"cuda:0 f32[2, 5, 32, 3, 128]\"\n  del t903\n  t905 = torch.reshape(t904, (2, 5, 12288))  # t905: \"cuda:0 f32[2, 5, 12288]\"\n  del t904\n  t906 = torch.reshape(t905, (-1, 12288))  # t906: \"cuda:0 f32[10, 12288]\"\n  del t905\n  t910 = torch.permute(t906, (1, 0))  # t910: \"cuda:0 f32[12288, 10]\"\n  t907 = torch.matmul(t906, t3)  # t907: \"cuda:0 f32[10, 4096]\"\n  del t906, t3\n  t912 = torch.matmul(t910, t911)  # t912: \"cuda:0 f32[12288, 4096]\"\n  del t910, t911\n  t908 = torch.reshape(t907, (2, 5, 4096))  # t908: \"cuda:0 f32[2, 5, 4096]\"\n  del t907\n  [t915, t931] = nvFusion8(f6, t24, t30, t33, t847, t908)\n  del f6, t24, t30, t33, t847, t908\n  t932 = torch.torch.ops.aten.embedding_backward(t931, t0, i0, -1, b1, b2)  # t932: \"cuda:0 f32[32000, 4096]\"\n  del t931, t0, i0, b1, b2\n  return (None, None, None, t912, t777, t827, t694, t820, t687, t645, t648, t915, t780, t831, t698, t854, t803, t721, t670, t932)\n```\n\nThese traces are long, and require some familiarity with the model implementation to follow them, but they allow you to:\n- Inspect exactly what operations are run including their decompositions.\n- Inspect the sizes of tensors, their device, data type and conversions.\n- Apply transformations to the traces since the computations are completely decoupled from the data.\n- Inspect the backward operations generated for each forward operation to understand what autograd is doing.\n\n### Transforms\n\nTransforms are one of the core features of Thunder. For example, they enable easy data parallel distribution. That is replicated data parallelism (DDP) and fully-sharded data parallelism (FSDP).\n\nWe provide ready-to-use Fabric strategies that integrate Thunder DDP|FSDP. Under the hood, the code is quite straightforward:\n\n```python\nmodel = thunder.distributed.ddp(model)\n# or\n# model = thunder.distributed.fsdp(model)\n\nmodel = thunder.jit(model)\n```\n\nAfter applying the DDP transformation, the backward trace will include the expected all-reduce collectives:\n\n```python\n  p1022 = torch_all_reduce_prim_impl(t1021, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True, False)  # p1022: \"FUTURE cuda:0 f32[16797696]\"\n  ...\n  t1059 = torch_wait_prim_impl(p1025)  # t1059: \"cuda:0 f32[131072000]\"\n```\n\nWith `L.Fabric`, this is how to use them:\n\n```python\nfrom extensions.extensions.thunder.strategies import ThunderFSDPStrategy, ThunderDDPStrategy\n\n# fully-sharded data parallel\nstrategy = ThunderFSDPStrategy(\n    sharding_strategy=\"ZERO3\",\n    bucketing_strategy=\"BLOCK\",\n    executors=(\"sdpa\", \"torchcompile_cat\", \"nvfuser\", \"torch\"),\n    state_dict_type=\"full\",\n)\n\n# replicated data parallel\nstrategy = ThunderDDPStrategy(executors=(\"sdpa\", \"torchcompile_cat\", \"nvfuser\", \"torch\"))\n\nfabric = L.Fabric(devices=devices, strategy=strategy)\nfabric.launch()\nmodel = fabric.setup(model)  # JIT is called here\n```\n\nAnd in the case of FSDP all-gathers in forward and reduce-scatters in backward.\nMeaning that Thunder automatically introduced the necessary collective operations to support data parallelism.\n\n### Executors\n\nThunder allows you to define a priority list of executors that can map operators:\n\n```python\nimport thunder\n\nmodel = thunder.jit(\n    model,\n    executors=[\"sdpa\", \"torchcompile_cat\", \"nvfuser\", \"torch\"]\n)\n```\n\nNotice how `torch.compile` is a valid executor. This executor registers a few operators with improved performance so that you can utilize the fastest set of operator implementations possible.\n\n### Custom executors\n\nLightning Thunder provides extension points to integrate fast kernels for operators in your model without having to modify your implementation.\n\nFor instance, the [Unsloth project](https://github.com/unslothai/unsloth/) provides several Triton kernels that can be used with LitGPT:\n- Cross entropy loss\n- SwiGLU (part of `LLaMAMLP`)\n- RoPE\n\nThe [`unsloth` directory](unsloth) contains a [custom executor](unsloth/executor.py) that registers these operators for LitGPT.\nWe can enable this executor by passing it to the list of executors available. The order matters because we want to run its custom operators before\n`NvFuser` creates its fusion regions.\n\n```python\nimport thunder\n\nmodel = thunder.jit(\n    model,\n    executors=[\"sdpa\", \"unsloth\", \"torchcompile_cat\", \"nvfuser\", \"torch\"]\n)\n```\n\nDoing this, the model trace now includes the Unsloth kernel calls:\n\n```python\ndef augmented_forward_fn(*args):\n    ...\n    (t121, _, _, _, _, _) = unsloth_apply_rope(t120, t21, t22)\n    ...\n    (t189, t190) = unsloth_cross_entropy(t187, t188)\n    ...\n\ndef backward_fn(saved_for_backward, cotangents):\n    ...\n    t652 = unsloth_cross_entropy_backward(t651, t187, t188, t190)  # t652: \"cuda:0 f32[6, 320]\"\n    ...\n    t763 = unsloth_apply_rope_backward(t757, t21, t22, 1, 8, 4)  # t763: \"cuda:0 f32[2, 4, 3, 16]\"\n```\n\nWe provide a specific [pre-training script copy](pretrain.py) that uses this executor.\nGiven the Unsloth results below, these hand-written kernels do not seem to be worth it, showcasing the power of automated fusion compilers like [NvFuser](https://github.com/NVIDIA/Fuser).\n\n## Examples and benchmarks\n\n> [!WARNING]\n> Lightning Thunder is alpha and not ready for production runs. Feel free to try it out, expect a few bumps along the way.\n> We expect speed and memory usage to improve as we continue to develop it.\n\nWe provide a version of the main pre-training script [that integrates Thunder](pretrain.py) that uses TinyLlama, a 1.1B parameter LLM.\n\n| Setting              | Compiler | Executors                              | Devices | ms/iter @ step 10 | Memory (GB)   |\n|----------------------|----------|----------------------------------------|---------|-------------------|---------------|\n| Fully-sharded ZeRO 3 | Eager    | -                                      | 8       | 456.57            | 22.13         |\n| Fully-sharded ZeRO 3 | torch    | -                                      | 8       | Not supported     | Not supported |\n| Fully-sharded ZeRO 3 | Thunder  | sdpa, torchcompile                     | 8       | Not supported     | Not supported |\n| Fully-sharded ZeRO 3 | Thunder  | sdpa, torchcompile_cat, nvfuser, torch | 8       | 333.56            | 21.40         |\n|                      |          |                                        |         |                   |               |\n| Replicated           | Eager    | -                                      | 8       | 569.46            | 32.04         |\n| Replicated           | torch    | -                                      | 8       | Not supported     | Not supported |\n| Replicated           | Thunder  | sdpa, torchcompile                     | 8       | 426.44            | 22.19         |\n| Replicated           | Thunder  | sdpa, torchcompile_cat, nvfuser, torch | 8       | 356.01            | 27.42         |\n|                      |          |                                        |         |                   |               |\n| -                    | Eager    | -                                      | 1       | 447.65            | 29.84         |\n| -                    | torch    | -                                      | 1       | Not supported     | Not supported |\n| -                    | Thunder  | sdpa, torchcompile                     | 1       | 373.37            | 22.19         |\n| -                    | Thunder  | sdpa, torchcompile_cat, nvfuser, torch | 1       | 322.25            | 27.42         |\n|                      |          |                                        |         |                   |               |\n| Unsloth              | Thunder  | sdpa, torchcompile_cat, nvfuser, torch | 1       | 331.92            | 25.19         |\n\n<details>\n<summary>Reproduction details</summary>\n\nConfig:\n\n```yaml\nout_dir: out/pretrain-thunder\ndata: TinyStories\ntokenizer_dir: checkpoints/TinyLlama/TinyLlama-1.1B-Chat-v1.0\nlogger_name: csv\n```\n\nCommands:\n\n```bash\nlitgpt download --repo_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --tokenizer_only true\n\npython extensions/thunder/pretrain.py --config config.yaml --compiler null --train.global_batch_size 32\npython extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --train.global_batch_size 32\npython extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --train.global_batch_size 32\n\npython extensions/thunder/pretrain.py --config config.yaml --compiler null --strategy ddp\npython extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --strategy ddp\npython extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --strategy ddp\n\npython extensions/thunder/pretrain.py --config config.yaml --compiler null --devices 1\npython extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile]' --devices 1\npython extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, torchcompile_cat, nvfuser, torch]' --devices 1\n\npython extensions/thunder/pretrain.py --config config.yaml --executors '[sdpa, unsloth, torchcompile_cat, nvfuser, torch]' --devices 1\n```\n\n`--compiler torch` (`torch.compile` without `thunder`) is not include because it does not support compiling the `_FabricModule` due to this issue: https://github.com/pytorch/pytorch/issues/112787#issuecomment-1986827601\n\nThe CUDA devices are all NVIDIA A100-SXM4-40GB.\n\n```text\nPython version: 3.10.12 [GCC 11.4.0] (64-bit runtime)\nIs debug build: False\nCUDA used to build PyTorch: 12.1\nCUDA runtime version: 12.3.107\nNvidia driver version: 545.23.08\npytorch-triton==3.0.0+45fff310c8\ntorch==2.4.0.dev20240427+cu121\nlightning==2.3.0.dev20240328\nlightning-thunder==0.2.0.dev20240505\nnvfuser_cu121==0.2.3.dev20240428\n```\n\n</details>\n"
  },
  {
    "path": "extensions/thunder/__init__.py",
    "content": "import sys\nfrom pathlib import Path\n\n# support running without installing as a package, adding extensions to the Python path\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n"
  },
  {
    "path": "extensions/thunder/pretrain.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport math\nimport os\nimport pprint\nimport sys\nimport time\nfrom dataclasses import asdict\nfrom datetime import timedelta\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nimport torch.nn as nn\nfrom lightning.fabric.strategies import FSDPStrategy\nfrom lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops\nfrom torch.utils.data import DataLoader\nfrom torchmetrics.aggregation import RunningMean\nfrom typing_extensions import Literal\n\nfrom litgpt import Tokenizer\nfrom litgpt.args import EvalArgs, LogArgs, TrainArgs\nfrom litgpt.data import DataModule, TinyLlama\nfrom litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP, MultiheadLatentAttention\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.types import LoggerChoice\nfrom litgpt.utils import (\n    CLI,\n    CycleIterator,\n    capture_hparams,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    find_resume_path,\n    instantiate_torch_optimizer,\n    num_parameters,\n    parse_devices,\n    reset_parameters,\n    save_config,\n)\n\n# support running without installing as a package\nwd = Path(__file__).parent.resolve()\nsys.path.append(str(wd))\n\n\ndef forward_and_loss(model: nn.Module, input_ids: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\n    logits = model(input_ids)\n    # disable chunk_size to enable the unsloth cross entropy kernel\n    loss = chunked_cross_entropy(logits, targets, chunk_size=0)\n    return loss\n\n\ndef setup(\n    model_name: Optional[str] = None,\n    model_config: Optional[Config] = None,\n    out_dir: Path = Path(\"out/pretrain\"),\n    initial_checkpoint_dir: Optional[Path] = None,\n    resume: Union[bool, Literal[\"auto\"], Path] = False,\n    data: Optional[DataModule] = None,\n    train: TrainArgs = TrainArgs(\n        save_interval=1000,\n        log_interval=1,\n        global_batch_size=512,\n        micro_batch_size=4,\n        max_tokens=int(3e12),  # 3 trillion\n        max_norm=1.0,\n        min_lr=4e-5,\n        lr_warmup_steps=2000,\n        tie_embeddings=False,\n    ),\n    eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),\n    log: LogArgs = LogArgs(),\n    optimizer: Union[str, Dict] = \"AdamW\",\n    devices: Union[int, str] = \"auto\",\n    num_nodes: int = 1,\n    tokenizer_dir: Optional[Path] = None,\n    logger_name: LoggerChoice = \"tensorboard\",\n    seed: int = 42,\n    compiler: Optional[Literal[\"thunder\", \"torch\"]] = \"thunder\",\n    executors: Optional[List[str]] = (\"sdpa\", \"torchcompile\", \"nvfuser\", \"torch\"),\n    strategy: Literal[\"auto\", \"ddp\", \"fsdp\"] = \"fsdp\",\n):\n    \"\"\"Pretrain a model.\n\n    Arguments:\n        model_name: The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with\n            ``model_config``.\n        model_config: A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with\n            ``model_config``.\n        out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n            /teamspace/jobs/<job-name>/share.\n        initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.\n            Useful for continued pretraining. Mutually exclusive with ``resume``.\n        resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n            from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n            ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n        data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.\n        train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.\n        eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.\n        optimizer: An optimizer name (such as \"AdamW\") or config.\n        devices: How many devices/GPUs to use. Uses all GPUs by default.\n        num_nodes: How many nodes the code is being run on.\n        tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data\n            module require this.\n        logger_name: The name of the logger to send metrics to.\n        seed: The random seed to use for reproducibility.\n        compiler: If desired, the compiler/JIT to use.\n        executors: If using Thunder, the executors to enable.\n        strategy: If desired, the strategy to use.\n    \"\"\"\n    hparams = capture_hparams()\n    data = TinyLlama() if data is None else data\n    if model_config is not None and model_name is not None:\n        raise ValueError(\"Only one of `model_name` or `model_config` can be set.\")\n    elif model_config is None and model_name is None:\n        model_name = \"tiny-llama-1.1b\"\n    config = Config.from_name(model_name) if model_config is None else model_config\n    devices = parse_devices(devices)\n    out_dir = init_out_dir(out_dir)\n    # in case the dataset requires the Tokenizer\n    tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None\n\n    logger = choose_logger(\n        logger_name,\n        out_dir,\n        name=f\"pretrain-{config.name}\",\n        resume=bool(resume),\n        log_interval=train.log_interval,\n        log_args=asdict(log),\n    )\n\n    if devices * num_nodes > 1:\n        if compiler == \"thunder\":\n            if strategy == \"fsdp\":\n                from extensions.thunder.strategies import ThunderFSDPStrategy\n\n                strategy = ThunderFSDPStrategy(\n                    sharding_strategy=\"ZERO3\",\n                    bucketing_strategy=\"BLOCK\",\n                    state_dict_type=\"full\",\n                    jit=False,\n                )\n            elif strategy == \"ddp\":\n                from extensions.thunder.strategies import ThunderDDPStrategy\n\n                strategy = ThunderDDPStrategy(jit=False)\n        else:\n            if strategy == \"fsdp\":\n                strategy = FSDPStrategy(\n                    auto_wrap_policy={Block}, state_dict_type=\"full\", sharding_strategy=\"FULL_SHARD\"\n                )\n    else:\n        strategy = \"auto\"\n    fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=\"bf16-true\", loggers=[logger])\n    fabric.launch()\n\n    if compiler is not None:\n        global forward_and_loss\n        forward_and_loss = (\n            jit(forward_and_loss, executors) if compiler == \"thunder\" else torch.compile(forward_and_loss)\n        )\n\n    fabric.print(pprint.pformat(hparams))\n    if logger_name in (\"tensorboard\", \"wandb\", \"mlflow\"):\n        fabric.logger.log_hyperparams(hparams)\n\n    main(\n        fabric=fabric,\n        devices=devices,\n        num_nodes=num_nodes,\n        seed=seed,\n        initial_checkpoint_dir=initial_checkpoint_dir,\n        resume=resume,\n        config=config,\n        data=data,\n        out_dir=out_dir,\n        tokenizer_dir=tokenizer_dir,\n        tokenizer=tokenizer,\n        train=train,\n        eval=eval,\n        optimizer=optimizer,\n        compiler=compiler,\n    )\n\n\ndef main(\n    fabric: L.Fabric,\n    devices: int,\n    seed: int,\n    initial_checkpoint_dir: Optional[Path],\n    resume: Union[bool, Literal[\"auto\"], Path],\n    config: Config,\n    data: DataModule,\n    out_dir: Path,\n    tokenizer_dir: Optional[Path],\n    tokenizer: Optional[Tokenizer],\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    compiler: Optional[Literal[\"thunder\", \"torch\"]],\n    num_nodes: int = 1,\n) -> None:\n    validate_args(train, eval, initial_checkpoint_dir, resume)\n\n    if fabric.global_rank == 0:\n        out_dir.mkdir(parents=True, exist_ok=True)\n\n    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)\n\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n\n    initialize_weights(fabric, model, n_layer=config.n_layer, n_embd=config.n_embd)\n\n    if train.tie_embeddings:\n        model.transformer.wte.weight = model.lm_head.weight\n    if train.max_seq_length:\n        model.max_seq_length = train.max_seq_length\n\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\")\n    fabric.print(f\"Total parameters: {num_parameters(model):,}\")\n\n    model = fabric.setup(model)\n    if compiler == \"thunder\":\n        # avoid `Tensor.register_hook` which is unsupported\n        model._register_backward_hook = lambda *_: None\n    optimizer = instantiate_torch_optimizer(optimizer, model.parameters())\n    optimizer = fabric.setup_optimizers(optimizer)\n\n    train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n\n    if initial_checkpoint_dir:\n        fabric.load_raw(initial_checkpoint_dir / \"lit_model.pth\", model)\n\n    state = {\n        \"model\": model,\n        \"optimizer\": optimizer,\n        \"train_dataloader\": train_dataloader,\n        \"iter_num\": 0,\n        \"step_count\": 0,\n    }\n\n    resume = find_resume_path(resume, out_dir)\n    if resume:\n        fabric.print(f\"Resuming training from {resume}\")\n        fabric.load(resume, state)\n\n    train_time = time.perf_counter()\n    fit(\n        fabric=fabric,\n        devices=devices,\n        num_nodes=num_nodes,\n        state=state,\n        train_dataloader=train_dataloader,\n        val_dataloader=val_dataloader,\n        out_dir=out_dir,\n        tokenizer_dir=tokenizer_dir,\n        train=train,\n        eval=eval,\n        optimizer=optimizer,\n    )\n    fabric.print(f\"Training time: {(time.perf_counter() - train_time):.2f}s\")\n\n    # Save final checkpoint\n    save_checkpoint(fabric, state, tokenizer_dir, out_dir / \"final\" / \"lit_model.pth\")\n\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\")\n\n\ndef fit(\n    fabric: L.Fabric,\n    devices: int,\n    state: dict,\n    train_dataloader: DataLoader,\n    val_dataloader: DataLoader,\n    out_dir: Path,\n    tokenizer_dir: Optional[Path],\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    num_nodes: int = 1,\n) -> None:\n    model = state[\"model\"]\n    optimizer = state[\"optimizer\"]\n\n    validate(fabric, model, val_dataloader, max_iters=2)  # sanity check\n    throughput = ThroughputMonitor(fabric, window_size=5)\n\n    with torch.device(\"meta\"):\n        meta_model = GPT(model.config)\n        x = torch.randint(0, 1, (train.micro_batch_size, meta_model.max_seq_length))\n        model_fwd = lambda: meta_model(x)  # noqa: F821\n        model_loss = lambda y: chunked_cross_entropy(y, x, chunk_size=0)  # noqa: F821\n        measured_flops = measure_flops(meta_model, model_fwd, model_loss)\n        fabric.print(f\"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}\")\n        del meta_model, x\n\n    max_tokens_per_device = train.max_tokens // fabric.world_size\n    tokens_per_iter = train.micro_batch_size * model.max_seq_length\n    max_iters = max_tokens_per_device // tokens_per_iter\n    log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes)\n    initial_iter = state[\"iter_num\"]\n    train_iterator = CycleIterator(train_dataloader)\n\n    running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(\n        fabric.device\n    )\n    fabric.barrier()\n    total_t0 = time.perf_counter()\n    val_loss = \"n/a\"\n\n    warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader)\n\n    for train_data in train_iterator:\n        if state[\"iter_num\"] >= max_iters:\n            break\n\n        # determine and set the learning rate for this iteration\n        lr = get_lr(optimizer.defaults[\"lr\"], state[\"iter_num\"], warmup_iters, max_iters, train.min_lr)\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = lr\n\n        state[\"iter_num\"] += 1\n        iter_t0 = time.perf_counter()\n\n        input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long()\n        targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long()\n\n        is_accumulating = state[\"iter_num\"] % train.gradient_accumulation_iters(devices, num_nodes) != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            loss = forward_and_loss(model, input_ids, targets)\n            fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))\n\n        running_loss.update(loss.detach())\n\n        if not is_accumulating:\n            # THUNDER unsupported: https://github.com/Lightning-AI/lightning-thunder/issues/2357\n            # fabric.clip_gradients(model, optimizer, max_norm=train.max_norm)\n            optimizer.step()\n            optimizer.zero_grad()\n            state[\"step_count\"] += 1\n\n        if state[\"iter_num\"] % log_iter_interval == 0:\n            loss = running_loss.compute().item()  # expensive device-to-host synchronization\n            t1 = time.perf_counter()\n            throughput.update(\n                time=(t1 - total_t0),\n                flops=(measured_flops * log_iter_interval),\n                batches=state[\"iter_num\"],\n                samples=(state[\"iter_num\"] * train.micro_batch_size),\n                lengths=(state[\"iter_num\"] * train.micro_batch_size * model.max_seq_length),\n            )\n            metrics = {\n                \"loss\": loss,\n                \"iter\": state[\"iter_num\"],\n                \"step\": state[\"step_count\"],\n                \"epoch\": train_iterator.epoch,\n                \"iter_time\": t1 - iter_t0,\n                \"remaining_time\": (\n                    (t1 - total_t0) / (state[\"iter_num\"] - initial_iter) * (max_iters - state[\"iter_num\"])\n                ),\n                \"tokens\": state[\"iter_num\"] * train.micro_batch_size * model.max_seq_length,\n                \"total_tokens\": (state[\"iter_num\"] * train.micro_batch_size * model.max_seq_length * fabric.world_size),\n                \"learning_rate\": lr,\n            }\n            if isinstance(val_loss, float):\n                val_loss = f\"{val_loss:.3f}\"\n            fabric.print(\n                f\"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |\"\n                f\" loss train: {metrics['loss']:.3f},\"\n                f\" val: {val_loss} |\"\n                f\" iter time: {metrics['iter_time'] * 1000:.2f} ms\"\n                f\"{' (step)' if not is_accumulating else ''}\"\n                f\" remaining time: {timedelta(seconds=int(metrics['remaining_time']))!s}\"\n            )\n\n            throughput_metrics = throughput.compute()\n            metrics.update(throughput_metrics)\n            fabric.log_dict(metrics, step=state[\"iter_num\"] - 1)\n\n        if val_dataloader is not None and not is_accumulating and state[\"step_count\"] % eval.interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)\n            val_loss = val_loss.item()\n            td = time.perf_counter() - t0\n\n            fabric.print(f\"iter {state['iter_num']}: val loss {val_loss:.4f}, val time: {td * 1000:.2f} ms\")\n            metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n            fabric.log_dict(metrics, step=state[\"iter_num\"] - 1)\n            fabric.barrier()\n\n        if train.save_interval is not None and not is_accumulating and state[\"step_count\"] % train.save_interval == 0:\n            save_checkpoint(fabric, state, tokenizer_dir, out_dir / f\"step-{state['step_count']:08d}\" / \"lit_model.pth\")\n\n\n@torch.no_grad()\ndef validate(fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max_iters: int) -> torch.Tensor:\n    fabric.barrier()\n    fabric.print(\"Validating ...\")\n    model.eval()\n\n    losses = []\n    for k, batch in enumerate(val_dataloader):\n        if k >= max_iters:\n            break\n        input_ids = batch[:, 0 : model.max_seq_length].contiguous().long()\n        targets = batch[:, 1 : (model.max_seq_length + 1)].contiguous().long()\n        loss = forward_and_loss(model, input_ids, targets)\n        losses.append(loss)\n\n    val_loss = torch.stack(losses).mean()\n    model.train()\n    fabric.barrier()\n    return val_loss\n\n\ndef get_dataloaders(\n    fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, block_size: int\n) -> Tuple[DataLoader, DataLoader]:\n    data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=block_size)\n    with fabric.rank_zero_first():\n        data.prepare_data()\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    return train_dataloader, val_dataloader\n\n\n# learning rate decay scheduler (cosine with linear warmup)\ndef get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float:\n    # 1) linear warmup for warmup_iters steps\n    if it < warmup_iters:\n        return learning_rate * it / warmup_iters\n    # 2) if it > max_iters, return min learning rate\n    if it > max_iters:\n        return min_lr\n    # 3) in between, use cosine decay down to min learning rate\n    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)\n    assert 0 <= decay_ratio <= 1\n    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1\n    return min_lr + coeff * (learning_rate - min_lr)\n\n\ndef initialize_weights(fabric: L.Fabric, model: GPT, n_layer: int, n_embd: int) -> None:\n    \"\"\"GPT-NeoX weight initialization (https://arxiv.org/abs/2204.06745).\"\"\"\n    # Adapted from https://github.com/jzhang38/TinyLlama\n\n    def init_weights(module, std):\n        nn.init.normal_(module.weight, mean=0.0, std=std)\n        if getattr(module, \"bias\", None) is not None:\n            nn.init.zeros_(module.bias)\n\n    for mod in model.modules():\n        if isinstance(mod, (nn.Embedding, nn.Linear)):\n            mod.reset_parameters = partial(init_weights, mod, std=math.sqrt(2.0 / 5 / n_embd))\n\n    # need a separate loop because `mod.proj` below is a `nn.Linear` too\n    for mod in model.modules():\n        if isinstance(mod, (LLaMAMLP, CausalSelfAttention, MultiheadLatentAttention)):\n            mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer))\n\n    if not isinstance(fabric.strategy, FSDPStrategy):\n        reset_parameters(model)\n\n\ndef init_out_dir(out_dir: Path) -> Path:\n    if not out_dir.is_absolute() and \"LIGHTNING_ARTIFACTS_DIR\" in os.environ:\n        return Path(os.getenv(\"LIGHTNING_ARTIFACTS_DIR\")) / out_dir\n    return out_dir\n\n\ndef save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file):\n    model = state[\"model\"]\n    checkpoint_file.parent.mkdir(parents=True, exist_ok=True)\n    fabric.print(f\"Saving checkpoint to {str(checkpoint_file)!r}\")\n    fabric.save(checkpoint_file, state)\n    if fabric.global_rank == 0:\n        save_hyperparameters(setup, checkpoint_file.parent)\n        if tokenizer_dir is not None:\n            copy_config_files(tokenizer_dir, checkpoint_file.parent)\n        save_config(model.config, checkpoint_file.parent)\n\n\ndef validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None:\n    issues = []\n    unsupported = [(train, [\"max_steps\", \"epochs\"]), (eval, [\"max_new_tokens\"])]\n    for args, names in unsupported:\n        for name in names:\n            if getattr(args, name) is not None:\n                issues.append(f\"{__file__} doesn't support the {name!r} argument. This is set in {args}\")\n    required = [(train, [\"max_tokens\", \"max_norm\"])]\n    for args, names in required:\n        for name in names:\n            if getattr(args, name) is None:\n                issues.append(f\"{__file__} requires the {name!r} argument. This is set in {args}\")\n    if initial_checkpoint_dir and resume:\n        issues.append(\"Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one.\")\n    if issues:\n        raise ValueError(\"\\n\".join(issues))\n\n\ndef jit(fn: Callable, executors: List[str]) -> Any:\n    assert executors is not None\n    from unsloth.executor import unsloth_ex  # import for registration  # noqa: F401\n\n    import thunder\n\n    return thunder.jit(fn, executors=executors)\n\n\nif __name__ == \"__main__\":\n    torch.set_float32_matmul_precision(\"high\")\n\n    CLI(setup)\n"
  },
  {
    "path": "extensions/thunder/strategies/__init__.py",
    "content": "from .thunder_ddp import ThunderDDPStrategy  # noqa: F401\nfrom .thunder_fsdp import ThunderFSDPStrategy  # noqa: F401\n"
  },
  {
    "path": "extensions/thunder/strategies/thunder_ddp.py",
    "content": "\"\"\"Fabric Strategy to support Thunder DDP: To be upstreamed into Fabric eventually.\"\"\"\n\nfrom contextlib import nullcontext\nfrom datetime import timedelta\nfrom typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.distributed\nfrom lightning.fabric.accelerators.accelerator import Accelerator\nfrom lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout\nfrom lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment\nfrom lightning.fabric.plugins.io.checkpoint_io import CheckpointIO\nfrom lightning.fabric.plugins.precision import Precision\nfrom lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher\nfrom lightning.fabric.strategies.parallel import ParallelStrategy\nfrom lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl\nfrom lightning.fabric.utilities.distributed import (\n    ReduceOp,\n    _distributed_is_initialized,\n    _get_default_process_group_backend_for_device,\n    _init_dist_connection,\n    _sync_ddp_if_available,\n)\nfrom lightning.fabric.utilities.rank_zero import rank_zero_only\nfrom lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only\nfrom torch import Tensor\nfrom torch.nn import Module\nfrom typing_extensions import override\n\nfrom litgpt.constants import _THUNDER_AVAILABLE\n\nif TYPE_CHECKING:\n    from thunder import Executor\n\n\nclass ThunderDDPStrategy(ParallelStrategy):\n    def __init__(\n        self,\n        accelerator: Optional[Accelerator] = None,\n        parallel_devices: Optional[List[torch.device]] = None,\n        cluster_environment: Optional[ClusterEnvironment] = None,\n        checkpoint_io: Optional[CheckpointIO] = None,\n        precision: Optional[Precision] = None,\n        jit: bool = True,\n        executors: Optional[Tuple[Union[\"Executor\", str], ...]] = None,\n        process_group_backend: Optional[str] = None,\n        timeout: Optional[timedelta] = default_pg_timeout,\n        **kwargs: Any,\n    ):\n        r\"\"\"Strategy for Replicated Data Parallel provided by Lightning Thunder.\n\n        .. warning::  This is an :ref:`experimental <versioning:Experimental API>` feature.\n\n        Arguments:\n            jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually\n                jitting a function that includes the model.\n\n            executors: The list of Thunder executors to enable. They can be either string aliases for the executors\n                or the actual executor instances.\n\n            \\**kwargs: See available parameters in :func:`thunder.distributed.ddp`.\n\n        \"\"\"\n        if not _THUNDER_AVAILABLE:\n            raise ModuleNotFoundError(str(_THUNDER_AVAILABLE))\n        super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)\n        self.parallel_devices = parallel_devices\n        self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment\n\n        if not jit and executors is not None:\n            raise ValueError(f\"Passing executors={executors} doesn't have an effect with `jit={jit}`\")\n        self.jit = jit\n        self.executors = executors\n        self._num_nodes = 1\n        self._process_group_backend: Optional[str] = process_group_backend\n        self._timeout: Optional[timedelta] = timeout\n        self._backward_sync_control = _ThunderDataParalellBackwardSyncControl()\n        self._ddp_kwargs = kwargs\n\n    @property\n    @override\n    def root_device(self) -> torch.device:\n        assert self.parallel_devices is not None\n        return self.parallel_devices[self.local_rank]\n\n    @property\n    def num_nodes(self) -> int:\n        return self._num_nodes\n\n    @num_nodes.setter\n    def num_nodes(self, num_nodes: int) -> None:\n        # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks\n        self._num_nodes = num_nodes\n\n    @property\n    def num_processes(self) -> int:\n        return len(self.parallel_devices) if self.parallel_devices is not None else 0\n\n    @property\n    @override\n    def distributed_sampler_kwargs(self) -> Dict[str, Any]:\n        return {\"num_replicas\": self.num_nodes * self.num_processes, \"rank\": self.global_rank}\n\n    @override\n    def _configure_launcher(self) -> None:\n        assert self.cluster_environment is not None\n        if not self.cluster_environment.creates_processes_externally:\n            self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)\n\n    @property\n    def process_group_backend(self) -> Optional[str]:\n        return self._process_group_backend\n\n    @override\n    def _configure_launcher(self) -> None:\n        assert self.cluster_environment is not None\n        self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)\n\n    @override\n    def setup_environment(self) -> None:\n        super().setup_environment()\n        self._setup_distributed()\n\n    @override\n    def setup_module(self, module: Module) -> Module:\n        import thunder\n\n        if (cd := thunder.compile_data(module)) is not None:\n            # the module was already jitted\n            if thunder.compile_stats(module).last_traces is not None:\n                raise RuntimeError(\n                    \"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the\"\n                    \" DDP transform. Remove the `forward` call before `fabric.setup()`\"\n                )\n            assert cd.is_module  # sanity check\n            ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs)\n            # update the compile data state\n            cd.fn = ddp_module\n            cd.process_group_for_ddp = ddp_module.process_group_for_ddp\n            return module\n        else:\n            module = thunder.distributed.ddp(module, **self._ddp_kwargs)\n        if not self.jit:\n            return module\n        return thunder.jit(module, executors=self.executors)\n\n    @override\n    def module_to_device(self, module: Module) -> None:\n        module.to(self.root_device)\n\n    @override\n    def all_reduce(\n        self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = \"mean\"\n    ) -> Tensor:\n        if isinstance(tensor, Tensor):\n            return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)\n        return tensor\n\n    @override\n    def barrier(self, *args: Any, **kwargs: Any) -> None:\n        if not _distributed_is_initialized():\n            return\n        if torch.distributed.get_backend() == \"nccl\":\n            torch.distributed.barrier(device_ids=[self.root_device.index])\n        else:\n            torch.distributed.barrier()\n\n    @override\n    def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:\n        if not _distributed_is_initialized():\n            return obj\n\n        obj = [obj]\n        torch.distributed.broadcast_object_list(obj, src)\n        return obj[0]\n\n    def _setup_distributed(self) -> None:\n        self._set_world_ranks()\n        self._process_group_backend = self._get_process_group_backend()\n        assert self.cluster_environment is not None\n        _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)\n\n    def _get_process_group_backend(self) -> str:\n        return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)\n\n    def _set_world_ranks(self) -> None:\n        if self.cluster_environment is not None:\n            self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)\n            self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)\n        # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail\n        # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter\n        rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank\n\n\nclass _ThunderDataParalellBackwardSyncControl(_BackwardSyncControl):\n    def __init__(self):\n        self._enabled = False\n\n    @override\n    def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:\n        \"\"\"\n        In Thunder, we cannot use ``module.no_sync()`` because reduction happens at the end of the context manager.\n        It assumes that the user will reuse it across all gradient accumulation iterations:\n\n        .. code-block:: python\n\n            with model.no_sync():\n                for _ in range(len(gradient_accumulation_iters)):\n                    fwd()\n                    bwd()  # uses no-sync-backward trace\n                fwd()\n                bwd()  # uses regular-backward trace\n\n        However, Fabric is designed to the context manager every iteration:\n\n        .. code-block:: python\n\n            for i in range(iters):\n                is_accumulating = (i + 1) % gradient_accumulation_iters != 0\n                ctx = model.no_sync() if is_accumulating else nullcontext()\n                with ctx:\n                    fwd()\n                    bwd()\n\n        So we need to be smart about when to sync grads based on the ``enabled`` value.\n\n        More info in https://github.com/Lightning-AI/lit-thunder-LEGACY/issues/2085\n        \"\"\"\n        if not getattr(module, \"use_ddp\", False) and not getattr(module, \"use_fsdp\", False):\n            raise TypeError(\n                \"Blocking backward sync is only possible if the module passed to\"\n                f\" `{self.__class__.__name__}.no_backward_sync` is applied DDP or FSDP.\"\n                f\" Got: {module.__class__.__name__}.\"\n            )\n\n        from thunder.distributed import skip_data_parallel_grad_sync\n\n        previous, self._enabled = self._enabled, enabled\n        if enabled:\n            return skip_data_parallel_grad_sync()\n        if not enabled and previous:\n            return _SyncGradsContextManager(module)\n        return nullcontext()\n\n\nclass _SyncGradsContextManager:\n    def __init__(self, module: Module) -> None:\n        self._module = module\n\n    @override\n    def __enter__(self) -> None:\n        from thunder.distributed import _sync_grads\n\n        _sync_grads(self._module)\n\n    @override\n    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:\n        pass\n"
  },
  {
    "path": "extensions/thunder/strategies/thunder_fsdp.py",
    "content": "\"\"\"Fabric Strategy to support Thunder FSDP: To be upstreamed into Fabric eventually.\"\"\"\n\nimport shutil\nfrom contextlib import ExitStack, nullcontext\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Tuple, Union\n\nimport torch\nfrom lightning.fabric.accelerators.accelerator import Accelerator\nfrom lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment\nfrom lightning.fabric.plugins.io.checkpoint_io import CheckpointIO\nfrom lightning.fabric.plugins.precision import Precision\nfrom lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher\nfrom lightning.fabric.strategies.parallel import ParallelStrategy\nfrom lightning.fabric.strategies.strategy import TBroadcast, _apply_filter, _Sharded, _validate_keys_for_strict_loading\nfrom lightning.fabric.utilities.distributed import (\n    ReduceOp,\n    _distributed_is_initialized,\n    _get_default_process_group_backend_for_device,\n    _init_dist_connection,\n    _sync_ddp_if_available,\n)\nfrom lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2\nfrom lightning.fabric.utilities.load import _METADATA_FILENAME, _move_state_into\nfrom lightning.fabric.utilities.rank_zero import rank_zero_only\nfrom lightning.fabric.utilities.seed import reset_seed\nfrom lightning.fabric.utilities.types import _PATH, _Stateful\nfrom lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only\nfrom torch import Tensor\nfrom torch.nn import Module\nfrom torch.optim import Optimizer\nfrom typing_extensions import override\n\nfrom extensions.thunder.strategies.thunder_ddp import _ThunderDataParalellBackwardSyncControl\nfrom litgpt.constants import _THUNDER_AVAILABLE\n\nif TYPE_CHECKING:\n    from thunder import Executor\n    from thunder.distributed import FSDPBucketingStrategy, FSDPType\n    from thunder.distributed.checkpoint import StateDictOptions\n\n    _FSDP_TYPE = Union[FSDPType, Literal[\"ZERO2\", \"ZERO3\"]]\n    _BUCKETING_STRATEGY = Union[FSDPBucketingStrategy, Literal[\"NONE\", \"LAYER\", \"BLOCK\"]]\n\n\nclass ThunderFSDPStrategy(ParallelStrategy, _Sharded):\n    def __init__(\n        self,\n        accelerator: Optional[Accelerator] = None,\n        parallel_devices: Optional[List[torch.device]] = None,\n        cluster_environment: Optional[ClusterEnvironment] = None,\n        checkpoint_io: Optional[CheckpointIO] = None,\n        precision: Optional[Precision] = None,\n        jit: bool = True,\n        executors: Optional[Tuple[Union[\"Executor\", str], ...]] = None,\n        sharding_strategy: \"_FSDP_TYPE\" = \"ZERO3\",\n        bucketing_strategy: \"_BUCKETING_STRATEGY\" = \"NONE\",\n        state_dict_type: Literal[\"full\", \"sharded\"] = \"sharded\",\n        **kwargs: Any,\n    ):\n        r\"\"\"Strategy for Fully Sharded Data Parallel provided by Lightning Thunder.\n\n        .. warning::  This is an :ref:`experimental <versioning:Experimental API>` feature.\n\n        Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model\n        size, whilst using efficient communication to reduce overhead. In practice, this means we can remain\n        at parity with PyTorch DDP, whilst scaling our model sizes dramatically.\n\n        Arguments:\n            jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually\n                jitting a function that includes the model.\n\n            executors: The list of Thunder executors to enable. They can be either string aliases for the executors\n                or the actual executor instances.\n\n            sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination\n                of them:\n\n                - ``\"ZERO3\"``: Shards model parameters, gradients, and optimizer states (default).\n                - ``\"ZERO2\"``: Shards gradients and optimizer states only. Model parameters get replicated.\n\n                Also accepts a :class:`thunder.distributed.FSDPType` enum value.\n\n            bucketing_strategy: Enables combining the collective operations for sets of layers.\n\n                - ``\"NONE\"``: No bucketing (default).\n                - ``\"LAYER\"``: Create buckets per layer class.\n                - ``\"BLOCK\"``: Create buckets per layer block.\n\n                Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value.\n\n            state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.\n\n                - ``\"full\"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file\n                  (default).\n                - ``\"sharded\"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is\n                  a folder with as many files as the world size.\n\n            \\**kwargs: See available parameters in :func:`thunder.distributed.fsdp`.\n\n        \"\"\"\n        if not _TORCH_GREATER_EQUAL_2_2:\n            raise ImportError(\"Thunder's FSDP strategy requires PyTorch 2.2 or higher.\")\n        if not _THUNDER_AVAILABLE:\n            raise ModuleNotFoundError(str(_THUNDER_AVAILABLE))\n        super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision)\n        self.parallel_devices = parallel_devices\n        self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment\n        from thunder.distributed import FSDPBucketingStrategy, FSDPType\n\n        self.sharding_strategy = (\n            FSDPType[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy\n        )\n        self.bucketing_strategy = (\n            FSDPBucketingStrategy[bucketing_strategy.upper()]\n            if isinstance(bucketing_strategy, str)\n            else bucketing_strategy\n        )\n        if not jit and executors is not None:\n            raise ValueError(f\"Passing executors={executors} doesn't have an effect with `jit={jit}`\")\n        self.jit = jit\n        self.executors = executors\n        self._state_dict_type = state_dict_type\n        self._backward_sync_control = _ThunderDataParalellBackwardSyncControl()\n        self._fsdp_kwargs = kwargs\n\n    @property\n    @override\n    def root_device(self) -> torch.device:\n        assert self.parallel_devices is not None\n        return self.parallel_devices[self.local_rank]\n\n    @property\n    def num_nodes(self) -> int:\n        return 1\n\n    @property\n    def num_processes(self) -> int:\n        return len(self.parallel_devices) if self.parallel_devices is not None else 0\n\n    @property\n    @override\n    def distributed_sampler_kwargs(self) -> Dict[str, Any]:\n        return {\"num_replicas\": self.num_nodes * self.num_processes, \"rank\": self.global_rank}\n\n    @override\n    def _configure_launcher(self) -> None:\n        assert self.cluster_environment is not None\n        if not self.cluster_environment.creates_processes_externally:\n            self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)\n\n    @override\n    def setup_environment(self) -> None:\n        super().setup_environment()\n        self._setup_distributed()\n\n    @override\n    def setup_module(self, module: Module) -> Module:\n        import thunder\n\n        if (cd := thunder.compile_data(module)) is not None:\n            # the module was already jitted\n            if thunder.compile_stats(module).last_traces is not None:\n                raise RuntimeError(\n                    \"You already called `thunder.jit()` and generated an execution trace. It's too late to apply the\"\n                    \" FSDP transform. Remove the `forward` call before `fabric.setup()`\"\n                )\n            assert cd.is_module  # sanity check\n            fsdp_module = thunder.distributed.fsdp(\n                cd.fn,\n                device=self.root_device,\n                sharding_strategy=self.sharding_strategy,\n                bucketing_strategy=self.bucketing_strategy,\n                **self._fsdp_kwargs,\n            )\n            # update the compile data state\n            cd.fn = fsdp_module\n            cd.process_group_for_ddp = fsdp_module.process_group_for_ddp\n            return module\n        else:\n            module = thunder.distributed.fsdp(\n                module,\n                device=self.root_device,\n                sharding_strategy=self.sharding_strategy,\n                bucketing_strategy=self.bucketing_strategy,\n                **self._fsdp_kwargs,\n            )\n        if not self.jit:\n            return module\n        return thunder.jit(module, executors=self.executors)\n\n    @override\n    def module_to_device(self, module: Module) -> None:\n        pass\n\n    @override\n    def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:\n        precision_init_ctx = self.precision.module_init_context()\n        module_sharded_ctx = self.module_sharded_context()\n        stack = ExitStack()\n        if empty_init:\n            # Materialization happens in `setup`. When modules get wrapped by FSDP\n            stack.enter_context(torch.device(\"meta\"))\n        stack.enter_context(precision_init_ctx)\n        stack.enter_context(module_sharded_ctx)\n        return stack\n\n    @override\n    def module_sharded_context(self) -> ContextManager:\n        return nullcontext()\n\n    @override\n    def all_reduce(\n        self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = \"mean\"\n    ) -> Tensor:\n        if isinstance(tensor, Tensor):\n            return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)\n        return tensor\n\n    @override\n    def barrier(self, *args: Any, **kwargs: Any) -> None:\n        if not _distributed_is_initialized():\n            return\n        if torch.distributed.get_backend() == \"nccl\":\n            torch.distributed.barrier(device_ids=[self.root_device.index])\n        else:\n            torch.distributed.barrier()\n\n    @override\n    def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:\n        if not _distributed_is_initialized():\n            return obj\n\n        obj = [obj]\n        torch.distributed.broadcast_object_list(obj, src)\n        return obj[0]\n\n    @override\n    def clip_gradients_norm(\n        self,\n        module: Module,\n        optimizer: Optimizer,\n        max_norm: Union[float, int],\n        norm_type: Union[float, int] = 2.0,\n        error_if_nonfinite: bool = True,\n    ) -> Tensor:\n        raise NotImplementedError\n\n    @override\n    def save_checkpoint(\n        self,\n        path: _PATH,\n        state: Dict[str, Union[Module, Optimizer, Any]],\n        storage_options: Optional[Any] = None,\n        filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None,\n    ) -> None:\n        if storage_options is not None:\n            raise TypeError(\n                \"`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because\"\n                \" `FSDPStrategy` does not use the `CheckpointIO`.\"\n            )\n        if filter is not None:\n            raise NotImplementedError(\"Filtering checkpoint paths is not implemented\")\n\n        # broadcast the path from rank 0 to ensure all the states are saved in a common path\n        path = Path(self.broadcast(path))\n        if path.is_dir() and self._state_dict_type == \"full\" and not _is_sharded_checkpoint(path):\n            raise IsADirectoryError(f\"The checkpoint path exists and is a directory: {path}\")\n\n        from thunder.distributed.checkpoint import StateDictOptions, has_fsdp_modules, save\n\n        modules = [module for module in state.values() if has_fsdp_modules(module)]\n        if len(modules) == 0:\n            raise ValueError(\n                \"Could not find a FSDP model in the provided checkpoint state. Please provide the model as\"\n                \" part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure\"\n                \" you set up the model (and optimizers if any) through the strategy before saving the checkpoint.\"\n            )\n        if len(modules) > 1:\n            raise ValueError(\n                \"Found multiple FSDP models in the given state. Saving checkpoints with FSDP is\"\n                \" currently limited to a single model per checkpoint. To save multiple models, call the\"\n                \" save method for each model separately with a different path.\"\n            )\n\n        if self._state_dict_type == \"sharded\":\n            if _is_full_checkpoint(path):\n                path.unlink()\n            path.mkdir(parents=True, exist_ok=True)\n\n            options = StateDictOptions(full_state_dict=False, cpu_offload=True, rank0_only=False)\n            converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank)\n            save(converted_state, path)\n            if self.global_rank == 0:\n                torch.save(metadata, path / _METADATA_FILENAME)\n\n        elif self._state_dict_type == \"full\":\n            if _is_sharded_checkpoint(path):\n                shutil.rmtree(path)\n\n            options = StateDictOptions(full_state_dict=True, cpu_offload=True, rank0_only=True)\n            converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank)\n            converted_state.update(metadata)\n            if self.global_rank == 0:\n                torch.save(converted_state, path)\n        else:\n            raise ValueError(f\"Unknown state_dict_type: {self._state_dict_type}\")\n\n    @override\n    def load_checkpoint(\n        self,\n        path: _PATH,\n        state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,\n        strict: bool = True,\n    ) -> Dict[str, Any]:\n        if not state:\n            raise ValueError(\n                f\"Got `FSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least\"\n                \" a model instance to reload is required. Pass it in like so:\"\n                \" `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`\"\n            )\n        # broadcast the path from rank 0 to ensure all the states are loaded from a common path\n        path = Path(self.broadcast(path))\n\n        from thunder.distributed.checkpoint import StateDictOptions, has_fsdp_modules, load, load_model_state_dict\n\n        if isinstance(state, Module):\n            if not _is_full_checkpoint(path):\n                raise ValueError(\n                    \"Failed to load checkpoint directly into the model. The given path must be a single file\"\n                    f\" containing the full state dict: {path}\"\n                )\n            state_dict = torch.load(str(path), mmap=True, map_location=\"cpu\")\n            options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False)\n            load_model_state_dict(state_dict, _unwrap_tom(state), options, self.local_rank)\n            return {}\n\n        if isinstance(state, Optimizer):\n            raise NotImplementedError(\n                \"Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy.\"\n            )\n\n        modules = {key: module for key, module in state.items() if has_fsdp_modules(module)}\n        if len(modules) == 0:\n            raise ValueError(\n                \"Could not find a FSDP model in the provided checkpoint state. Please provide the model as\"\n                \" part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure\"\n                \" you set up the model (and optimizers if any) through the strategy before loading the checkpoint.\"\n            )\n        if len(modules) > 1:\n            raise ValueError(\n                \"Found multiple FSDP models in the given state. Loading checkpoints with FSDP is\"\n                \" currently limited to a single model per checkpoint. To load multiple models, call the\"\n                \" load method for each model separately with a different path.\"\n            )\n        optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)}\n        module_key, module = list(modules.items())[0]\n        module = _unwrap_tom(module)\n\n        if _is_sharded_checkpoint(path):\n            options = StateDictOptions(full_state_dict=False, cpu_offload=True, strict=strict, rank0_only=False)\n            # Load the DCP state dict, which requires a holder state dict\n            converted_state, _ = _get_state_dict(state, None, options, self.local_rank)\n            load(converted_state, path)\n            load_model_state_dict(converted_state[module_key], module, options, self.local_rank)\n\n            # Load metadata (anything not a module or optimizer)\n            metadata = torch.load(path / _METADATA_FILENAME)\n            requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()\n            _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict)\n            for key in requested_metadata_keys:\n                if key not in metadata:\n                    continue\n                state[key] = metadata.pop(key)\n            # return the remaining metadata that wasn't requested as part of `state`\n            return metadata\n\n        if _is_full_checkpoint(path):\n            options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False)\n            if not options.rank0_only or self.local_rank == 0:\n                map_location = \"cpu\" if options.cpu_offload else None\n                checkpoint = torch.load(str(path), mmap=True, map_location=map_location)\n                load_model_state_dict(checkpoint[module_key], module, options, self.local_rank)\n            else:\n                checkpoint = {}\n\n            requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()\n            _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)\n            # Load metadata (anything not a module or optimizer)\n            _move_state_into(source=checkpoint, destination=state, keys=requested_metadata_keys)\n            # return the remaining metadata that wasn't requested as part of `state`\n            return checkpoint\n\n        raise ValueError(\n            f\"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a\"\n            \" directory with FSDP checkpoint shards, or a single file with a full checkpoint.\"\n        )\n\n    def _setup_distributed(self) -> None:\n        reset_seed()\n        self._set_world_ranks()\n        process_group_backend = _get_default_process_group_backend_for_device(self.root_device)\n        assert self.cluster_environment is not None\n        _init_dist_connection(self.cluster_environment, process_group_backend)\n\n    def _set_world_ranks(self) -> None:\n        if self.cluster_environment is not None:\n            self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)\n            self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)\n        # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail\n        # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter\n        rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank\n\n\ndef _is_sharded_checkpoint(path: Path) -> bool:\n    \"\"\"A heuristic check to determine whether the path points to a directory with checkpoint shards.\"\"\"\n    return path.is_dir() and (path / _METADATA_FILENAME).is_file()\n\n\ndef _is_full_checkpoint(path: Path) -> bool:\n    return path.is_file()\n\n\ndef _get_state_dict(\n    state: Dict[str, Any],\n    filter: Optional[Dict[str, Callable[[str, Any], bool]]],\n    options: \"StateDictOptions\",\n    rank: int,\n) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n    from thunder.distributed.checkpoint import get_model_state_dict\n\n    # replace the modules and optimizer objects in the state with their local state dict\n    # and separate the user's metadata\n    converted_state: Dict[str, Any] = {}\n    metadata: Dict[str, Any] = {}\n    for key, obj in state.items():\n        converted: Any\n        if isinstance(obj, Module):\n            converted = get_model_state_dict(_unwrap_tom(obj), options, rank)\n            target_dict = converted_state\n        elif isinstance(obj, Optimizer):\n            # TODO: optimizer support\n            converted = obj.state_dict()\n            target_dict = converted_state\n        else:  # everything not a module or optimizer is considered metadata\n            converted = obj.state_dict() if isinstance(obj, _Stateful) else obj\n            target_dict = metadata\n        _apply_filter(key, filter or {}, converted, target_dict)\n\n    return converted_state, metadata\n\n\ndef _unwrap_tom(obj: object) -> object:\n    # TODO: this unwrap won't be required when Fabric's `_unwrap_objects` supports Thunder\n    from thunder import ThunderModule\n\n    if isinstance(obj, ThunderModule):\n        return obj._model\n    return obj\n"
  },
  {
    "path": "extensions/thunder/unsloth/__init__.py",
    "content": ""
  },
  {
    "path": "extensions/thunder/unsloth/executor.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport sys\nfrom pathlib import Path\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nimport litgpt.model\nfrom litgpt.constants import _THUNDER_AVAILABLE\nfrom litgpt.model import LLaMAMLP as OriginalLLaMAMLP\nfrom thunder.core.proxies import TensorProxy\nfrom thunder.core.transforms import get_grad, mean_backward, put_grads\nfrom thunder.extend import OperatorExecutor, register_executor\nfrom thunder.torch import ne, sum, true_divide\n\nif _THUNDER_AVAILABLE:\n    import thunder\n    import thunder.torch as ltorch\n\nsys.path.append(str(Path(__file__).parent))\n\nimport kernels\n\nunsloth_ex = OperatorExecutor(\"unsloth\", version=\"0.1\")\nregister_executor(unsloth_ex)\n\n\n\"\"\"\n====================\n Cross Entropy Loss\n====================\n\"\"\"\n\n\ndef unsloth_cross_entropy_meta(logits: TensorProxy, labels: TensorProxy) -> Tuple[TensorProxy, TensorProxy]:\n    return (\n        TensorProxy(\n            shape=(logits.shape[0],),\n            # the cross entropy kernel only supports float32\n            dtype=thunder.dtypes.float32,\n            device=logits.device,\n            requires_grad=logits.requires_grad,\n        ),\n        TensorProxy(shape=(logits.shape[0],), dtype=thunder.dtypes.float32, device=logits.device, requires_grad=False),\n    )\n\n\nunsloth_cross_entropy = unsloth_ex.register_operator(\n    \"unsloth_cross_entropy\", meta=unsloth_cross_entropy_meta, fn=kernels.cross_entropy_loss._cross_entropy_forward_impl\n)\n\n\ndef unsloth_cross_entropy_backward_impl(dlosses: Tensor, logits: Tensor, labels: Tensor, logsumexp: Tensor) -> Tensor:\n    # clone() because the kernel writes the grads in the logits\n    return kernels.cross_entropy_loss._cross_entropy_backward_impl(dlosses, logits.clone(), logsumexp, labels)\n\n\ndef unsloth_cross_entropy_backward_meta(\n    dlosses: TensorProxy, logits: TensorProxy, logsumexp: TensorProxy, labels: TensorProxy\n) -> TensorProxy:\n    return thunder.TensorProxy(like=logits)\n\n\nunsloth_cross_entropy_backward = unsloth_ex.register_operator(\n    \"unsloth_cross_entropy_backward\", meta=unsloth_cross_entropy_backward_meta, fn=unsloth_cross_entropy_backward_impl\n)\n\n\ndef unsloth_cross_entropy_checker(\n    logits: TensorProxy,\n    labels: TensorProxy,\n    weight: Optional[TensorProxy] = None,\n    size_average: Optional[bool] = None,\n    ignore_index: int = -100,\n    reduce: Optional[bool] = None,\n    reduction: str = \"mean\",\n    label_smoothing: float = 0.0,\n) -> bool:\n    return (\n        weight is None\n        and size_average is None\n        and reduce is None\n        and reduction in (\"none\", \"mean\")\n        and ignore_index == -100\n        and label_smoothing == 0.0\n        and logits.device.type == \"cuda\"\n        and labels.device.type == \"cuda\"\n    )\n\n\ndef cross_entropy_to_unsloth(\n    logits: TensorProxy,\n    labels: TensorProxy,\n    weight: Optional[TensorProxy] = None,\n    size_average: Optional[bool] = None,\n    ignore_index: int = -100,\n    reduce: Optional[bool] = None,\n    reduction: str = \"mean\",\n    label_smoothing: float = 0.0,\n) -> Tuple[TensorProxy, TensorProxy]:\n    loss, logsumexp = unsloth_cross_entropy(logits, labels)\n    if reduction == \"mean\":\n        # \"mean\" reduction is not part of the kernel\n        # TODO: this doesn't consider that all elements could be masked, causing a division by 0\n        n_items = sum(ne(labels, -100))\n        loss = true_divide(sum(loss), n_items)\n    elif reduction != \"none\":\n        raise NotImplementedError(reduction)\n    return loss, logsumexp\n\n\ndef unsloth_cross_entropy_grad(\n    logits: TensorProxy,\n    labels: TensorProxy,\n    weight: Optional[TensorProxy] = None,\n    size_average: Optional[bool] = None,\n    ignore_index: int = -100,\n    reduce: Optional[bool] = None,\n    reduction: str = \"mean\",\n    label_smoothing: float = 0.0,\n) -> TensorProxy:\n    loss, logsumexp = cross_entropy_to_unsloth(**locals())\n    grad = get_grad(loss)\n    if reduction == \"mean\":\n        grad = mean_backward(logsumexp.ndim, logsumexp.shape, (0,), grad)\n    logits_grad = unsloth_cross_entropy_backward(grad, logits, labels, logsumexp)\n    put_grads((logits,), (logits_grad,))\n    return loss\n\n\n# registers as cross entropy implementation, including the execution transform and now a grad transform\nunsloth_ex.register_implementation(\n    ltorch.cross_entropy,\n    checker=unsloth_cross_entropy_checker,\n    execution_transform=lambda *args: cross_entropy_to_unsloth(*args)[0],\n    grad_transform=unsloth_cross_entropy_grad,\n)\n\n\n\"\"\"\n=========\n RMSNorm\n=========\n\nThe RMSNorm kernel is not integrated because it's not numerically equal and it doesn't compute the gradient for the\nweight, just for the input.\n\"\"\"\n\n\n\"\"\"\n========\n SwiGLU\n========\n\"\"\"\n\n\ndef swiglu(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor:\n    return torch.nn.functional.silu(e) * g\n\n\nclass ThunderLLaMAMLP(OriginalLLaMAMLP):\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_fc_1 = self.fc_1(x)\n        x_fc_2 = self.fc_2(x)\n        x = swiglu(x_fc_1, x_fc_2)\n        return self.proj(x)\n\n\nlitgpt.model.LLaMAMLP = ThunderLLaMAMLP\n\n\ndef swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy:\n    return TensorProxy(like=e)\n\n\nlitgpt_swiglu = unsloth_ex.register_operator(\"litgpt_swiglu\", meta=swiglu_forward_meta, fn=swiglu, replaces=swiglu)\n\n\nunsloth_swiglu_forward = unsloth_ex.register_operator(\n    \"unsloth_swiglu_forward\", meta=swiglu_forward_meta, fn=lambda *args: kernels.swiglu_fg_kernel(*args)\n)\n\n\ndef unsloth_swiglu_backward_meta(DW: TensorProxy, e: TensorProxy, g: TensorProxy) -> Tuple[TensorProxy, TensorProxy]:\n    return TensorProxy(like=g), TensorProxy(like=e)\n\n\ndef unsloth_swiglu_backward_fn(DW: Tensor, e: Tensor, g: Tensor) -> Tuple[Tensor, Tuple]:\n    B, T, n_embd = e.shape\n    e = e.view(-1, n_embd)\n    g = g.view(-1, n_embd)\n    DW, e, g = kernels.swiglu_DWf_DW_dfg_kernel(DW, e, g)\n    e = e.view(B, T, n_embd)\n    g = g.view(B, T, n_embd)\n    return g, e\n\n\nunsloth_swiglu_backward = unsloth_ex.register_operator(\n    \"unsloth_swiglu_backward\", meta=unsloth_swiglu_backward_meta, fn=unsloth_swiglu_backward_fn\n)\n\n\ndef swiglu_to_unsloth_checker(e: TensorProxy, g: TensorProxy) -> bool:\n    return e.device.type == \"cuda\" and g.device.type == \"cuda\"\n\n\ndef unsloth_swiglu_grad(e: TensorProxy, g: TensorProxy) -> TensorProxy:\n    h = unsloth_swiglu_forward(**locals())\n    grad = get_grad(h)\n    e_grad, g_grad = unsloth_swiglu_backward(grad, e, g)\n    put_grads((e, g), (e_grad, g_grad))\n    return h\n\n\nunsloth_ex.register_implementation(\n    litgpt_swiglu,\n    checker=swiglu_to_unsloth_checker,\n    execution_transform=unsloth_swiglu_forward,\n    grad_transform=unsloth_swiglu_grad,\n)\n\n\n\"\"\"\n======\n RoPE\n======\n\"\"\"\n\n\ndef apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n    return TensorProxy(like=x)\n\n\napply_rope = unsloth_ex.register_operator(\n    \"litgpt_apply_rope\", like=apply_rope_meta, fn=litgpt.model.apply_rope, replaces=litgpt.model.apply_rope\n)\n\n\ndef unsloth_apply_rope_meta(\n    Q: TensorProxy, cos: TensorProxy, sin: TensorProxy\n) -> Tuple[TensorProxy, TensorProxy, TensorProxy, int, int, int]:\n    batch, n_heads, seq_len, head_dim = Q.shape\n    assert seq_len <= cos.shape[-2]\n    BLOCK_SIZE, num_warps = kernels.calculate_settings(head_dim // 2)\n    div, mod = divmod(n_heads, kernels.rope_embedding.ROPE_GROUP_SIZE)\n    n_groups = div + (mod != 0)\n    return TensorProxy(like=Q), cos, sin, n_groups, BLOCK_SIZE, num_warps\n\n\nunsloth_apply_rope = unsloth_ex.register_operator(\n    \"unsloth_apply_rope\", meta=unsloth_apply_rope_meta, fn=kernels._rope_embedding_forward_impl\n)\n\n\ndef unsloth_apply_rope_backward_meta(\n    dY: TensorProxy, cos: TensorProxy, sin: TensorProxy, n_groups: int, BLOCK_SIZE: int, num_warps: int\n) -> TensorProxy:\n    return TensorProxy(like=dY)\n\n\nunsloth_apply_rope_backward = unsloth_ex.register_operator(\n    \"unsloth_apply_rope_backward\", meta=unsloth_apply_rope_backward_meta, fn=kernels._rope_embedding_backward_impl\n)\n\n\ndef apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:\n    return len(x.shape) == 4 and x.device.type == \"cuda\" and cos.device.type == \"cuda\" and sin.device.type == \"cuda\"\n\n\ndef unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n    Q, cos, sin, n_groups, BLOCK_SIZE, num_warps = unsloth_apply_rope(x, cos, sin)\n    dY = get_grad(Q)\n    dX = unsloth_apply_rope_backward(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps)\n    put_grads((x,), (dX,))\n    return Q\n\n\nunsloth_ex.register_implementation(\n    apply_rope,\n    checker=apply_rope_to_unsloth_checker,\n    execution_transform=lambda *args: unsloth_apply_rope(*args)[0],\n    grad_transform=unsloth_apply_rope_grad,\n)\n"
  },
  {
    "path": "extensions/thunder/unsloth/kernels/__init__.py",
    "content": "from .cross_entropy_loss import _cross_entropy_backward_impl, _cross_entropy_forward_impl  # noqa: F401\nfrom .rope_embedding import ROPE_GROUP_SIZE, _rope_embedding_backward_impl, _rope_embedding_forward_impl  # noqa: F401\nfrom .swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel  # noqa: F401\nfrom .utils import calculate_settings  # noqa: F401\n"
  },
  {
    "path": "extensions/thunder/unsloth/kernels/cross_entropy_loss.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nfrom litgpt.constants import _TRITON_AVAILABLE\n\nfrom .utils import MAX_FUSED_SIZE, calculate_settings\n\nif _TRITON_AVAILABLE:\n    import triton\n    import triton.language as tl\n\n\n@triton.jit\ndef _cross_entropy_forward(\n    logits_ptr,\n    logits_row_stride,\n    loss_ptr,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]\n    Pi = exp(xi) / sum(exp(xi))\n    CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]\n         = -y [ x - log[sum(exp(x))] ]\n         = y * (log[sum(exp(x))] - x)\n    If y == 0: CE_i = 0\n    If y == 1: CE_i = logsumexp - x\n\n    logsumexp is also stable\n    Take    y =         log[sum(exp(x))]\n       exp(y) =             sum(exp(x))\n       exp(y) =             sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x\n       exp(y) =      exp(c)*sum(exp(x - c))\n           y  = log(exp(c)*sum(exp(x - c)))\n           y  = c + log[sum(exp(x - c))]\n    This means we can set c = max(x) to make sure\n    exp(x - c) always is exp(x - max(x)).\n    This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.\n    \"\"\"\n    row_idx = tl.program_id(0)\n    logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n    loss_ptr += row_idx\n    logsumexp_ptr += row_idx\n    labels_ptr += row_idx\n\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n\n    label_idx = tl.load(labels_ptr).to(tl.int32)\n    logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float(\"inf\")).to(tl.float32)\n    c = tl.max(logits, 0)\n    logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n    if label_idx != -100:\n        x = tl.load(logits_ptr + label_idx).to(tl.float32)\n        loss = logsumexp - x\n    else:\n        loss = 0.0\n    tl.store(logsumexp_ptr, logsumexp)\n    tl.store(loss_ptr, loss)\n\n\npass\n\n\n@triton.jit\ndef _chunked_cross_entropy_forward(\n    logits_ptr,\n    logits_row_stride,\n    loss_ptr,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    N_CHUNKS: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    256K vocab divided in 4 chunks\n\n    |-65536-| |-65536-| |-65536-| |-65536-|\n    |-------| |-------| |-------| |-------|\n    |-------| |-------| |-------| |-------|\n\n    If y == 0: CE_i = 0\n    If y == 1: CE_i = logsumexp - x\n\n    Notice we can do logsumexp for each chunk and then\n    logsumexp[chunk_sum(logsumexp)] == logsumexp\n\n    chunk_sum = log[chunk_sum(logsumexp)]\n              = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]\n              = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]\n              = log[sum(exp(a)) + ... + sum(exp(z))]\n              = logsumexp(x)\n\n    This means we can perform a logsumexp for each chunk, then do a\n    final logsumexp reduction!\n\n    Ie do: logsumexp(chunked_logsumexp) - x\n    \"\"\"\n    row_idx = tl.program_id(0)\n    chunk_idx = tl.program_id(1)\n    logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n    loss_ptr += row_idx\n    logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n    labels_ptr += row_idx\n\n    col_offsets = chunk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n\n    label_idx = tl.load(labels_ptr).to(tl.int32)\n    logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float(\"inf\")).to(tl.float32)\n    c = tl.max(logits, 0)\n    logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n    if chunk_idx == 0:\n        # logsumexp(chunked_logsumexp) - x\n        # Do the -x separately\n        if label_idx != -100:\n            x = tl.load(logits_ptr + label_idx).to(tl.float32)\n            loss = -1.0 * x\n        else:\n            loss = 0.0\n        tl.store(loss_ptr, loss)\n    pass\n    tl.store(logsumexp_ptr, logsumexp)\n\n\npass\n\n\n@triton.jit\ndef _cross_entropy_backward(\n    logits_ptr,\n    logits_row_stride,\n    dloss_ptr,\n    dloss_row_stride,\n    logsumexp_ptr,\n    labels_ptr,\n    VOCAB_SIZE: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    CE_i = -y log(P) = y * (log[sum(exp(x))] - x)\n    dC/dx = d/dx (y * log[sum(exp(x))] - x * y)\n\n    From https://en.wikipedia.org/wiki/LogSumExp\n    d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)\n\n    dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)\n    dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick\n    dC/dx = y * exp[x - logsumexp] - d/dx (x * y)\n\n    If y == 0: dC/dx = 0\n    If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1\n    If y == 1 and x != label: dC/dx     = exp[x - logsumexp]\n    \"\"\"\n    row_idx = tl.program_id(0)\n    block_idx = tl.program_id(1)\n\n    logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n    dloss_ptr += row_idx * dloss_row_stride\n    col_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < VOCAB_SIZE\n    label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n    if label_idx != -100:\n        dloss = tl.load(dloss_ptr)\n    else:\n        dloss = 0.0\n\n    x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float(\"inf\")).to(tl.float32)\n    logsumexp = tl.load(logsumexp_ptr + row_idx)\n    y = tl.exp(x - logsumexp)\n    y = tl.where(\n        col_offsets == label_idx,\n        y - 1.0,  # exp(x - logsumexp) - 1\n        y,  # exp(x - logsumexp)\n    )\n\n    # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.\n    tl.store(logits_ptr + col_offsets, dloss * y, mask=mask)\n\n\npass\n\n\ndef _cross_entropy_forward_impl(logits, labels):\n    n_rows, vocab_size = logits.shape\n\n    div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n    n_chunks = div + (mod != 0)\n    losses = torch.empty(n_rows, dtype=torch.float32, device=\"cuda\")\n\n    if n_chunks == 1:\n        # For small vocabs <= 65336 like Llama, Mistral\n        BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n        logsumexp = torch.empty(n_rows, dtype=torch.float32, device=\"cuda\")\n\n        _cross_entropy_forward[(n_rows,)](\n            logits,\n            logits.stride(0),\n            losses,\n            logsumexp,\n            labels,\n            VOCAB_SIZE=vocab_size,\n            BLOCK_SIZE=BLOCK_SIZE,\n            num_warps=num_warps,\n        )\n    else:\n        # For large vocabs > 65336 like Gemma 256K\n        logsumexp = torch.empty(\n            (\n                n_rows,\n                n_chunks,\n            ),\n            dtype=torch.float32,\n            device=\"cuda\",\n        )\n\n        _chunked_cross_entropy_forward[\n            (\n                n_rows,\n                n_chunks,\n            )\n        ](\n            logits,\n            logits.stride(0),\n            losses,\n            logsumexp,\n            labels,\n            VOCAB_SIZE=vocab_size,\n            N_CHUNKS=n_chunks,\n            BLOCK_SIZE=MAX_FUSED_SIZE,\n            num_warps=32,\n        )\n        # logsumexp(chunked_logsumexp) - x\n        # Do the -x separately\n        logsumexp = torch.logsumexp(logsumexp, dim=1)  # Row sum\n        losses += logsumexp\n        losses.masked_fill_(labels == -100, 0)  # Don't forget to mask padding out!\n\n    return losses, logsumexp\n\n\ndef _cross_entropy_backward_impl(dlosses, logits, logsumexp, labels):\n    n_rows, vocab_size = logits.shape\n\n    BLOCK_SIZE = 4096\n    div, mod = divmod(vocab_size, BLOCK_SIZE)\n    n_blocks = div + (mod != 0)\n\n    _cross_entropy_backward[\n        (\n            n_rows,\n            n_blocks,\n        )\n    ](\n        logits,\n        logits.stride(0),\n        dlosses,\n        dlosses.stride(0),\n        logsumexp,\n        labels,\n        VOCAB_SIZE=vocab_size,\n        BLOCK_SIZE=BLOCK_SIZE,\n        num_warps=8,\n    )\n    return logits\n"
  },
  {
    "path": "extensions/thunder/unsloth/kernels/rope_embedding.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom litgpt.constants import _TRITON_AVAILABLE\n\nfrom .utils import calculate_settings\n\nif _TRITON_AVAILABLE:\n    import triton\n    import triton.language as tl\n\nROPE_GROUP_SIZE = 4\n\n\n@triton.heuristics(\n    {\n        \"BACKWARD_PASS\": lambda args: args[\"BACKWARD_PASS\"],\n    }\n)\n@triton.jit\ndef _rope_embedding(\n    Q,\n    Q_row_stride,\n    cos,\n    cos_row_stride,\n    sin,\n    sin_row_stride,\n    seqlen,\n    head_dim: tl.constexpr,\n    n_heads: tl.constexpr,\n    BACKWARD_PASS: tl.constexpr,\n    BLOCK_SIZE: tl.constexpr,\n    ROPE_GROUP_SIZE: tl.constexpr = 4,\n):\n    \"\"\"\n    Calculates the RoPE Embedding quickly\n    RoPE is Q * cos + rotate_half(Q) * sin\n    See our blog post for more info\n    \"\"\"\n    row_position = tl.program_id(0)\n    group_head_position = tl.program_id(1)\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    half_head_dim = head_dim // 2\n    mask = col_offsets < half_head_dim\n\n    sin1 = tl.load(sin + (row_position % seqlen) * sin_row_stride + half_head_dim * 0 + col_offsets, mask=mask, other=0)\n    cos1 = tl.load(cos + (row_position % seqlen) * cos_row_stride + half_head_dim * 0 + col_offsets, mask=mask, other=0)\n\n    if BACKWARD_PASS:\n        # See our blog post for more info.\n        sin1 = -sin1\n    pass\n\n    # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8\n    head_start = group_head_position * ROPE_GROUP_SIZE\n    head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)\n\n    # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)\n    for k in range(head_start, head_end):\n        offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets\n        offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim\n\n        # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n        Q1 = tl.load(Q + offs_q1, mask=mask, other=0).to(sin1.dtype)\n        Q2 = tl.load(Q + offs_q2, mask=mask, other=0).to(sin1.dtype)\n\n        tl.store(Q + offs_q1, Q1 * cos1 - Q2 * sin1, mask=mask)\n        tl.store(Q + offs_q2, Q2 * cos1 + Q1 * sin1, mask=mask)\n    pass\n\n\npass\n\n\ndef _rope_embedding_forward_impl(Q, cos, sin):\n    Q = Q.transpose(1, 2).clone()\n    cos, sin = cos.squeeze(), sin.squeeze()\n    batch, seq_len, n_heads, head_dim = Q.shape\n    Q = Q.reshape(batch * seq_len, n_heads * head_dim)\n    n_rows, n_cols = Q.shape\n    assert seq_len <= cos.shape[0]\n\n    # [TODO] Changing blocksize to head_dim//2 seems to have\n    # some concurrency / un-deterministic issues.\n    BLOCK_SIZE, num_warps = calculate_settings(head_dim // 2)  # (head_dim//2)\n\n    # group_size = 4 # 4 or 8, too large group_size can hurt performance.\n    div, mod = divmod(n_heads, ROPE_GROUP_SIZE)\n    n_groups = div + (mod != 0)\n\n    _rope_embedding[\n        (\n            n_rows,\n            n_groups,\n        )\n    ](\n        Q,\n        Q.stride(0),\n        cos,\n        cos.stride(0),\n        sin,\n        sin.stride(0),\n        seq_len,\n        head_dim,\n        n_heads,\n        BACKWARD_PASS=False,\n        BLOCK_SIZE=BLOCK_SIZE,\n        num_warps=num_warps,\n    )\n    Q = Q.view(batch, seq_len, n_heads, head_dim)\n    Q = Q.transpose(1, 2)\n    return Q, cos, sin, n_groups, BLOCK_SIZE, num_warps\n\n\ndef _rope_embedding_backward_impl(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps):\n    dY = dY.transpose(1, 2)\n    batch, seq_len, n_heads, head_dim = dY.shape\n    dY = dY.reshape(batch * seq_len, n_heads * head_dim)\n    # Must be reshape not view\n    n_rows, n_cols = dY.shape\n\n    _rope_embedding[\n        (\n            n_rows,\n            n_groups,\n        )\n    ](\n        dY,\n        dY.stride(0),\n        cos,\n        cos.stride(0),\n        sin,\n        sin.stride(0),\n        seq_len,\n        head_dim,\n        n_heads,\n        BACKWARD_PASS=True,\n        BLOCK_SIZE=BLOCK_SIZE,\n        num_warps=num_warps,\n    )\n    dY = dY.view(batch, seq_len, n_heads, head_dim)\n    dY = dY.transpose(1, 2)\n    return dY\n"
  },
  {
    "path": "extensions/thunder/unsloth/kernels/swiglu.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nfrom litgpt.constants import _TRITON_AVAILABLE\n\nif _TRITON_AVAILABLE:\n    import triton\n    import triton.language as tl\n\n\n@triton.jit\ndef _fg_kernel(\n    e,\n    g,\n    h,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n):\n    block_idx = tl.program_id(0)\n    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask=mask, other=0)  # .to(tl.float32)\n\n    # f = e * sigmoid(e)\n    f_row = e_row * tl.sigmoid(e_row)  # e_row / (1 + tl.exp(-e_row))\n    f_row = f_row.to(g_row.dtype)  # Exact copy from HF\n    # h = f * g\n    h_row = f_row * g_row\n\n    # Store h\n    tl.store(h + offsets, h_row, mask=mask)\n\n\npass\n\n\ndef swiglu_fg_kernel(e, g):\n    batch, seq_len, hd = e.shape\n    n_elements = e.numel()\n    h = torch.empty((batch, seq_len, hd), dtype=e.dtype, device=\"cuda\")\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    _fg_kernel[grid](\n        e,\n        g,\n        h,\n        n_elements,\n        BLOCK_SIZE=1024,\n    )\n    return h\n\n\npass\n\n\n@triton.jit\ndef _DWf_DW_dfg_kernel(\n    DW,\n    e,\n    g,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,\n):\n    \"\"\"\n    e = e.float()\n    se = 1.0 / (1.0 + torch.exp(-e))\n    f = (se * e).to(dtype)\n    h = f * g\n    df = DW * f\n    dg = DW * g\n    de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)\n    \"\"\"\n    block_idx = tl.program_id(0)\n    offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    mask = offsets < n_elements\n\n    DW_row = tl.load(DW + offsets, mask=mask, other=0)  # .to(tl.float32)\n    e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)\n    g_row = tl.load(g + offsets, mask=mask, other=0)  # .to(tl.float32)\n\n    # e = e.float()\n    # se = 1.0 / (1.0 + torch.exp(-e))\n    se_row = tl.sigmoid(e_row)  # 1.0 / (1.0 + tl.exp(-e_row))\n    # f = (se * e).to(dtype)\n    f_row = se_row * e_row\n    f_row = f_row.to(DW_row.dtype)\n    # h = f * g\n    h_row = f_row * g_row\n    # df = DW * f\n    df_row = DW_row * f_row\n    # dg = DW * g\n    dg_row = DW_row * g_row\n    # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)\n    de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))\n    de_row = de_row.to(DW_row.dtype)\n\n    # Store derivatives in buffers\n    tl.store(DW + offsets, h_row, mask=mask)  # h  = f * g\n    tl.store(e + offsets, df_row, mask=mask)  # df = DW * f\n    tl.store(g + offsets, de_row, mask=mask)  # de\n\n\npass\n\n\ndef swiglu_DWf_DW_dfg_kernel(DW, e, g):\n    batch_seq_len, hd = e.shape\n    n_elements = e.numel()\n    grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n    _DWf_DW_dfg_kernel[grid](\n        DW,\n        e,\n        g,\n        n_elements,\n        BLOCK_SIZE=1024,\n    )\n    return DW, e, g\n\n\npass\n"
  },
  {
    "path": "extensions/thunder/unsloth/kernels/utils.py",
    "content": "# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom litgpt.constants import _TRITON_AVAILABLE\n\nif _TRITON_AVAILABLE:\n    import triton\n\nMAX_FUSED_SIZE = 65536  # 2**16\nnext_power_of_2 = triton.next_power_of_2\n\n\ndef calculate_settings(n):\n    BLOCK_SIZE = next_power_of_2(n)\n    if BLOCK_SIZE > MAX_FUSED_SIZE:\n        raise RuntimeError(\n            f\"Cannot launch Triton kernel since n = {n} exceeds the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\"\n        )\n    num_warps = 4\n    if BLOCK_SIZE >= 32768:\n        num_warps = 32\n    elif BLOCK_SIZE >= 8192:\n        num_warps = 16\n    elif BLOCK_SIZE >= 2048:\n        num_warps = 8\n    return BLOCK_SIZE, num_warps\n\n\npass\n"
  },
  {
    "path": "extensions/xla/README.md",
    "content": "# TPU support\n\nThis project utilizes [`Fabric`](https://lightning.ai/docs/fabric/stable), which supports TPUs via [PyTorch XLA](https://github.com/pytorch/xla).\n\n> [!NOTE]\n> This guide assumes that you have already set-up your [Google Cloud environment](https://cloud.google.com/run/docs/setup).\n\nTo set up a Google Cloud instance with a TPU v4 VM, run the following commands:\n\n```shell\ngcloud compute tpus tpu-vm create litgpt --version=tpu-vm-v4-base --accelerator-type=v4-8 --zone=us-central2-b\ngcloud compute tpus tpu-vm ssh litgpt --zone=us-central2-b\n```\n\nYou can also choose a different TPU type. To do so, change the `version`, `accelerator-type`, and `zone` arguments. Find all regions and zones [here](https://cloud.google.com/tpu/docs/regions-zones).\n\n<details>\n<summary>Multihost caveats</summary>\n\nTPU v4-8 uses a single host. SSH'ing into the machine and running commands manually will only work when using a single host (1 slice in the TPU pod).\nIn multi-host environments, such as larger TPU pod slices, it's necessary to launch all commands on all hosts simultaneously to avoid hangs.\nFor local development, it is advisable to upload a zip file containing all your current changes and execute it inside the VM from your personal computer:\n\n```shell\n# Zip the local directory, excluding large directories from the zip. You may want to keep them.\nzip -r local_changes.zip . -x  \".git/*\" \"checkpoints/*\" \"data/*\" \"out/*\"\n# Copy the .zip file to the TPU VM\ngcloud compute tpus tpu-vm scp --worker=all local_changes.zip \"litgpt:~\"\n# Unzip on each host\ngcloud compute tpus tpu-vm ssh litgpt --worker=all --command=\"cd ~; unzip -q -o local_changes.zip\"\n\n# Example of a typical workflow\ngcloud compute tpus tpu-vm ssh tmp --worker=all --command=\"cd ~; bash install_dependencies.sh\"\ngcloud compute tpus tpu-vm ssh tmp --worker=all --command=\"cd ~; bash prepare_checkpoints.sh\"\ngcloud compute tpus tpu-vm ssh tmp --worker=all --command=\"cd ~; bash run_desired_script.sh\"\n\n# This will allow you to kill all python processes on all workers\ngcloud compute tpus tpu-vm ssh tmp --worker=all --command=\"pkill -e python\"\n```\n\nNotice how the commands to install the environment and prepare checkpoints need to be run on all workers, since the filesystem\nfor each worker (host) is not shared.\n\nFor the rest of this tutorial, it will be assumed that it is being run on a single host for simplicity.\n\n</details>\n\nOnce inside the machine, clone the repository and install the dependencies:\n\n```shell\ngit clone https://github.com/Lightning-AI/litgpt\ncd litgpt\npip install .\n```\n\nInstall Optimized BLAS:\n\n```shell\nsudo apt update\nsudo apt install libopenblas-dev\n```\n\nSince LitGPT requires a torch version newer than torch 2.0.0, manually install nightly builds of torch and torch_xla:\n\n```shell\npip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl\npip install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl\n```\n\nWhile computations will run by default using the new PjRT runtime, it is recommended to set the following environment variables:\n\n```shell\nexport ALLOW_MULTIPLE_LIBTPU_LOAD=1\nexport PJRT_DEVICE=TPU\n```\n\n> [!NOTE]\n> An extensive guide on setup and available options can be found [here](https://cloud.google.com/tpu/docs/v4-users-guide).\n\nSince a new machine was created, you may need to download pretrained weights.\nThey can be copied to the machine using `gcloud compute tpus tpu-vm scp`, or you can follow the steps described in our [downloading guide](../../tutorials/download_model_weights.md).\n\nIt is also recommended to set up a persistent disk from which to load checkpoints.\nFollow [this guide](https://cloud.google.com/tpu/docs/setup-persistent-disk#setting_up_a_tpu_vm_and_a_persistent_disk) to do so.\nRead-write disks are not supported in multihost VM setups, so persistent disks cannot be used to save checkpoints in that case.\nPersistent disks can still be useful in read-only mode to load pretrained weights before finetuning or inference.\nIn multihost settings, FSDP will save checkpoint shards per host and consolidate them into a single checkpoint.\nFor safekeeping, it is recommended to upload the consolidated checkpoints to a Google Cloud bucket.\nAlternatively, you can use the `scp` command to transfer these checkpoints from the TPU VM periodically, although this is not implemented in our scripts.\n\n## Inference\n\nThis project provides custom versions of the regular recipes to run with XLA in the `xla` directory.\nTo generate text, use the following command:\n\n```shell\npython3 xla/generate/base.py --prompt \"Hello, my name is\" --num_samples 3\n```\n\nFor the first generation, this command will take around 17 seconds as XLA needs to compile the graph.\nSubsequent generations will take around 2 seconds.\n\n## Fine-tuning\n\nTo get started fine-tuning Falcon 7B with adapter, run the following command:\n\n```shell\npython3 xla/scripts/prepare_alpaca.py --checkpoint_dir checkpoints/tiiuae/falcon-7b\n\npython3 xla/finetune/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true\n```\n\n<details>\n<summary>Multihost caveats</summary>\n\nThis script is configured to save \"full\" checkpoints, which isn't possible on multihost TPU VMs.\nHere's how you can consolidate them together into a single one after training with `state_dict_type=\"sharded\"`:\n\n```shell\npath_to_shards=\"out/adapter/alpaca/lit_model_adapter_finetuned\"\nmkdir -p $path_to_shards\nworkers=4  # 4 hosts\nfor ((i = 0; i < workers; i++)); do\n  # aggregate all shards locally\n  gcloud compute tpus tpu-vm scp --worker=$i \"litgpt:${path_to_shards}/*\" \"${path_to_shards}/\" --zone us-central2-b\ndone\n# copy all shards to all workers\ngcloud compute tpus tpu-vm scp --worker=all ${path_to_shards}/* \"litgpt:${path_to_shards}/\" --zone us-central2-b\n# consolidate the shards in each worker\ngcloud compute tpus tpu-vm ssh tmp --worker=all --command=\"python -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts --ckpt_prefix ${path_to_shards}/checkpoint --ckpt_suffix '_rank-*-of-*.pth' --save_path ${path_to_shards}.pth\" --zone us-central2-b\n```\n\n</details>\n\nSince the TPU VM host RAM is limited (200 GB), we implement a technique to sequentially load and shard the checkpoint that can be enabled by\nsetting `reduce_cpu_memory_usage_during_load = True`. This is necessary to load falcon-40b.\n\nTo generate text with the adapter fine-tuned model weights, use the following command:\n\n```shell\npython3 xla/generate/adapter.py --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --adapter_path out/adapter/alpaca/lit_model_adapter_finetuned.pth\n```\n\n> **Warning**\n> Remember to delete your instance when you are done.\n>\n> ```shell\n> gcloud compute tpus tpu-vm delete litgpt --zone=us-central2-b\n> ```\n\n## Computational Performance\n\nUsing the [adapter finetuning script](finetune/adapter.py) and XLA's FSDP implementation, a 49.57% MFU was achieved with Falcon 7B on a v4-32 (micro batch size 7), and a 39.67% MFU was achieved with Falcon 40B on a v4-512 (micro batch size 3) at a fixed 1034 maximum sequence length.\n\nSince the TPU VM host has limited system memory (RAM) compared to device memory (HBM), specific techniques were implemented to limit peak RAM usage when loading the model and pretrained weights before sharding, as well as when saving sharded checkpoints.\nA v4 chip has 32 GiB HBM, so with 4 devices per host (4 * 32 = 128 GiB HBM), each host has 188 GiB RAM, which is shared across the devices.\nTherefore, any RAM allocation over 188/4 = 47 GiB would exceed the host's RAM capacity.\nA ~24B parameter model on CPU (with half precision) would be the largest possible model under this setup without the techniques used in our scripts.\n"
  },
  {
    "path": "extensions/xla/__init__",
    "content": "import sys\nfrom pathlib import Path\n\n# support running without installing as a package, adding extensions to the Python path\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n"
  },
  {
    "path": "extensions/xla/finetune/__init__",
    "content": ""
  },
  {
    "path": "extensions/xla/finetune/adapter.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Dict, List, Tuple\n\nimport lightning as L\nimport torch\nimport torch_xla.core.xla_model as xm\nfrom lightning.fabric.accelerators import XLAAccelerator\nfrom lightning.fabric.loggers import CSVLogger\nfrom lightning.fabric.strategies import XLAFSDPStrategy\nfrom lightning.fabric.utilities import ThroughputMonitor, measure_flops\n\nfrom litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import check_valid_checkpoint_dir, chunked_cross_entropy, estimate_flops, lazy_load, num_parameters\n\n# support running without installing as a package\nwd = Path(__file__).parents[3].resolve()\nsys.path.append(str(wd))\n\nfrom xla.generate.base import generate  # noqa: E402\nfrom xla.scripts.prepare_alpaca import generate_prompt  # noqa: E402\nfrom xla.utils import rank_print, sequential_load_and_fsdp_wrap  # noqa: E402\n\neval_interval = 200\nsave_interval = 200\neval_iters = 100\neval_max_new_tokens = 100\nlog_interval = 1\ndevices = XLAAccelerator.auto_device_count()\n# the state of very large models will not fit on the system RAM, this flag can alleviate it by loading it on each rank\n# sequentially\nreduce_cpu_memory_usage_during_load = False\n\n# Hyperparameters\nlearning_rate = 3e-3\nbatch_size = 4\nmicro_batch_size = batch_size\ngradient_accumulation_iters = batch_size // micro_batch_size\nassert gradient_accumulation_iters > 0\nepoch_size = 50000  # train dataset size\nnum_epochs = 5\nmax_iters = num_epochs * (epoch_size // micro_batch_size) // devices\nweight_decay = 0.02\nwarmup_steps = 2 * (epoch_size // micro_batch_size) // devices // gradient_accumulation_iters  # 2 epochs\n\nhparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith(\"_\")}\n\n\ndef setup(\n    *,\n    data_dir: Path = Path(\"data/alpaca\"),\n    checkpoint_dir: Path = Path(\"checkpoints/tiiuae/falcon-7b\"),\n    out_dir: Path = Path(\"out/adapter/alpaca\"),\n    precision: str = \"bf16-true\",\n) -> None:\n    if devices > 1:\n        strategy = XLAFSDPStrategy(\n            auto_wrap_policy={Block},\n            activation_checkpointing_policy={Block},\n            state_dict_type=\"full\",  # change to \"sharded\" in multi-host environments where the filesystem is not shared\n            sequential_save=True,\n        )\n    else:\n        strategy = \"auto\"\n    logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)\n    fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger)\n    rank_print(fabric, hparams)\n    fabric.launch(main, data_dir, checkpoint_dir, out_dir)\n\n\ndef main(fabric: L.Fabric, data_dir: Path, checkpoint_dir: Path, out_dir: Path) -> None:\n    check_valid_checkpoint_dir(checkpoint_dir)\n\n    fabric.seed_everything(1337)  # same seed for every process to init model (FSDP)\n\n    if fabric.global_rank == 0:\n        os.makedirs(out_dir, exist_ok=True)\n\n    train_data = torch.load(data_dir / \"train.pt\")\n    val_data = torch.load(data_dir / \"test.pt\")\n\n    config = Config.from_name(name=checkpoint_dir.name, adapter_start_layer=0)\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    rank_print(fabric, f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\")\n\n    if reduce_cpu_memory_usage_during_load:\n        model = sequential_load_and_fsdp_wrap(fabric, lambda: GPT(config), checkpoint_path)\n    else:\n        with fabric.init_module(empty_init=False):\n            model = GPT(config)\n        checkpoint = lazy_load(checkpoint_path)\n        # strict=False because missing keys due to adapter weights not contained in state dict\n        model.load_state_dict(checkpoint, strict=False)\n\n    model = fabric.setup_module(model)\n    # mark as trainable only after sharding due to https://github.com/pytorch/xla/pull/5484\n    mark_only_adapter_as_trainable(model)\n    # these are not correct in the sharding case\n    rank_print(fabric, f\"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}\")\n    rank_print(fabric, f\"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}\")\n\n    trainable_params = [p for p in model.parameters() if p.requires_grad]\n    optimizer = torch.optim.SGD(trainable_params, lr=learning_rate)\n    optimizer = fabric.setup_optimizers(optimizer)\n\n    fabric.seed_everything(1337 + fabric.global_rank)\n\n    train_time = time.perf_counter()\n    train(fabric, model, optimizer, train_data, val_data, checkpoint_dir, out_dir)\n    rank_print(fabric, f\"Training time: {(time.perf_counter() - train_time):.2f}s\")\n\n    # Save the final checkpoint at the end of training\n    save_path = out_dir / \"lit_model_adapter_finetuned.pth\"\n    save_adapter_checkpoint(fabric, model, save_path)\n\n\ndef train(\n    fabric: L.Fabric,\n    model: GPT,\n    optimizer: torch.optim.Optimizer,\n    train_data: List[Dict],\n    val_data: List[Dict],\n    checkpoint_dir: Path,\n    out_dir: Path,\n) -> None:\n    tokenizer = Tokenizer(checkpoint_dir)\n    longest_seq_length = get_longest_seq_length(train_data)\n    model.max_seq_length = longest_seq_length\n    # to avoid recompilation, this script is configured to pad batches to the `longest_seq_length`\n    fabric.print(\n        f\"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is\"\n        f\" {model.max_seq_length} and context length is {model.config.block_size}\"\n    )\n\n    with torch.device(\"meta\"):\n        meta_model = GPT(model.config)\n        mark_only_adapter_as_trainable(meta_model)\n        # \"estimated\" is not as precise as \"measured\". Estimated is optimistic but widely used in the wild.\n        # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,\n        # consider passing `flops_per_batch=estimated_flops` instead\n        estimated_flops = estimate_flops(meta_model, training=True) * micro_batch_size\n        rank_print(fabric, f\"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}\")\n        # this assumes that all samples have a fixed length equal to the longest sequence length\n        # which is most likely false during finetuning\n        x = torch.randint(0, 1, (micro_batch_size, longest_seq_length))\n        forward_fn = lambda: meta_model(x)  # noqa: F821\n        loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0)  # noqa: F821\n        measured_flops = measure_flops(meta_model, forward_fn, loss_fn)\n        rank_print(fabric, f\"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}\")\n        del meta_model, x\n\n    throughput = ThroughputMonitor(fabric, window_size=50)\n    step_count = 0\n    total_t0 = time.perf_counter()\n\n    xm.mark_step()\n    for iter_num in range(1, max_iters + 1):\n        if step_count <= warmup_steps:\n            # linear warmup\n            lr = learning_rate * step_count / warmup_steps\n            for param_group in optimizer.param_groups:\n                param_group[\"lr\"] = lr\n\n        iter_t0 = time.perf_counter()\n\n        input_ids, targets = get_batch(fabric, train_data, longest_seq_length)\n\n        is_accumulating = iter_num % gradient_accumulation_iters != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids, lm_head_chunk_size=128)\n            xm.mark_step()\n            # shift the targets such that output n predicts token n+1\n            logits[-1] = logits[-1][..., :-1, :]\n            loss = chunked_cross_entropy(logits, targets[..., 1:])\n            fabric.backward(loss / gradient_accumulation_iters)\n        xm.mark_step()\n\n        if not is_accumulating:\n            optimizer.step()\n            optimizer.zero_grad()\n            step_count += 1\n        else:\n            xm.mark_step()\n\n        if iter_num % log_interval == 0:\n            t1 = time.perf_counter()\n            throughput.update(\n                time=t1 - total_t0,\n                batches=iter_num,\n                samples=iter_num * micro_batch_size,\n                lengths=iter_num * micro_batch_size * longest_seq_length,\n                flops=measured_flops * log_interval,\n            )\n            throughput.compute_and_log(step=iter_num)\n            rank_print(\n                fabric,\n                f\"iter {iter_num} step {step_count}:\"\n                # uncomment to print the loss. this will considerably slow down the iteration times\n                # + f\" loss {loss.item():.4f},\"\n                + f\" iter time: {(t1 - iter_t0) * 1000:.2f}ms\"\n                + (\" (optimizer.step)\" if not is_accumulating else \"\"),\n            )\n\n        if not is_accumulating and step_count % eval_interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_data, tokenizer, longest_seq_length)\n            t1 = time.perf_counter() - t0\n            rank_print(fabric, f\"step {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f}ms\")\n            fabric.barrier()\n        if not is_accumulating and step_count % save_interval == 0:\n            checkpoint_path = out_dir / f\"iter-{iter_num:06d}-ckpt.pth\"\n            save_adapter_checkpoint(fabric, model, checkpoint_path)\n\n\n# xla does not support `inference_mode`: RuntimeError: Cannot set version_counter for inference tensor\n@torch.no_grad()\ndef validate(\n    fabric: L.Fabric, model: GPT, val_data: List[Dict], tokenizer: Tokenizer, longest_seq_length: int\n) -> torch.Tensor:\n    rank_print(fabric, \"Validating ...\")\n    model.eval()\n    losses = torch.zeros(eval_iters)\n    xm.mark_step()\n    for k in range(eval_iters):\n        input_ids, targets = get_batch(fabric, val_data, longest_seq_length)\n        logits = model(input_ids)\n        xm.mark_step()\n        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)\n    val_loss = losses.mean()\n\n    # produce an example:\n    instruction = \"Recommend a movie for me to watch during the weekend and explain the reason.\"\n    rank_print(fabric, instruction)\n    sample = {\"instruction\": instruction, \"input\": \"\"}\n    prompt = generate_prompt(sample)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    with fabric.init_tensor():\n        # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n        model.set_kv_cache(batch_size=1)\n    output = generate(model, encoded, max_returned_tokens=len(encoded) + eval_max_new_tokens, temperature=0.8)\n    model.clear_kv_cache()\n    output = tokenizer.decode(output)\n    rank_print(fabric, output)\n\n    model.train()\n    return val_loss\n\n\ndef get_batch(fabric: L.Fabric, data: List[Dict], longest_seq_length: int) -> Tuple[torch.Tensor, torch.Tensor]:\n    ix = torch.randint(len(data), (micro_batch_size,))\n\n    input_ids = [data[i][\"input_ids\"].type(torch.int64) for i in ix]\n    labels = [data[i][\"labels\"].type(torch.int64) for i in ix]\n\n    def pad_right(x, pad_id):\n        # pad right using a fixed longest sequence length to avoid recompilation\n        n = longest_seq_length - len(x)\n        return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))\n\n    x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])\n    y = torch.stack([pad_right(x, pad_id=-1) for x in labels])\n\n    x, y = fabric.to_device((x, y))\n    return x, y\n\n\ndef get_longest_seq_length(data: List[Dict]) -> int:\n    # find out the minimum max_seq_length required during fine-tuning (saves memory!)\n    return max(len(d[\"input_ids\"]) for d in data)\n\n\ndef save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:\n    rank_print(fabric, f\"Saving adapter weights to {str(file_path)!r}\")\n    fabric.save(file_path, {\"model\": model}, filter={\"model\": adapter_filter})\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n\n    CLI(setup)\n"
  },
  {
    "path": "extensions/xla/generate/__init__",
    "content": ""
  },
  {
    "path": "extensions/xla/generate/adapter.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional\n\nimport lightning as L\nfrom lightning.fabric.accelerators import XLAAccelerator\nfrom lightning.fabric.strategies import XLAFSDPStrategy\n\nfrom litgpt import Tokenizer\nfrom litgpt.adapter import GPT, Block, Config\nfrom litgpt.prompts import Alpaca\nfrom litgpt.utils import check_valid_checkpoint_dir, lazy_load\n\n# support running without installing as a package\nwd = Path(__file__).parents[3].resolve()\nsys.path.append(str(wd))\n\nfrom xla.generate.base import generate  # noqa: E402\nfrom xla.utils import rank_print  # noqa: E402\n\n\ndef setup(\n    prompt: str = \"What food do llamas eat?\",\n    *,\n    input: str = \"\",\n    sys_prompt: Optional[str] = None,\n    adapter_path: Path = Path(\"out/adapter/alpaca/lit_model_adapter_finetuned.pth\"),\n    checkpoint_dir: Path = Path(\"checkpoints/tiiuae/falcon-7b\"),\n    max_new_tokens: int = 100,\n    top_k: Optional[int] = 50,\n    temperature: float = 0.8,\n    precision: str = \"bf16-true\",\n) -> None:\n    \"\"\"Generates a response based on a given instruction and an optional input.\n    This script will only work with checkpoints from the instruction-tuned Adapter model.\n    See `xla/finetune/adapter.py`.\n\n    Args:\n        prompt: The prompt/instruction (Alpaca style).\n        input: Optional input (Alpaca style).\n        sys_prompt: Optional system prompt.\n        adapter_path: Path to the checkpoint with trained adapter weights, which are the output of\n            `xla/finetune/adapter.py`.\n        checkpoint_dir: The path to the checkpoint folder with pretrained model weights.\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        precision: Indicates the Fabric precision setting to use.\n    \"\"\"\n    devices = XLAAccelerator.auto_device_count()\n    strategy = XLAFSDPStrategy(auto_wrap_policy={Block}) if devices > 1 else \"auto\"\n    fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)\n    fabric.launch(main, prompt, input, sys_prompt, adapter_path, checkpoint_dir, max_new_tokens, top_k, temperature)\n\n\ndef main(\n    fabric: L.Fabric,\n    prompt: str,\n    input: str,\n    sys_prompt: Optional[str],\n    adapter_path: Path,\n    checkpoint_dir: Path,\n    max_new_tokens: int,\n    top_k: Optional[int],\n    temperature: float,\n) -> None:\n    check_valid_checkpoint_dir(checkpoint_dir)\n\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\", adapter_start_layer=0)\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n\n    rank_print(fabric, f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n    rank_print(fabric, f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    t0 = time.perf_counter()\n    checkpoint = lazy_load(checkpoint_path)\n    adapter_checkpoint = lazy_load(adapter_path)\n    checkpoint.update(adapter_checkpoint.get(\"model\", adapter_checkpoint))\n    model.load_state_dict(checkpoint)\n    rank_print(fabric, f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    model.eval()\n    model = fabric.setup_module(model)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    # TODO: Load prompt style from checkpoint and apply it here\n    prompt_style = Alpaca()\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    with fabric.init_tensor():\n        # set the max_seq_length to limit the memory usage to what we need\n        model.max_seq_length = max_returned_tokens\n        # enable the kv cache\n        model.set_kv_cache(batch_size=1)\n\n    t0 = time.perf_counter()\n    y = generate(\n        model,\n        encoded,\n        max_returned_tokens,\n        max_seq_length=max_returned_tokens,\n        temperature=temperature,\n        top_k=top_k,\n        eos_id=tokenizer.eos_id,\n    )\n    t = time.perf_counter() - t0\n\n    output = tokenizer.decode(y)\n    output = output.split(\"### Response:\")[1] if \"### Response:\" in output else output\n    output = output.strip()\n    fabric.print(output)\n\n    tokens_generated = y.size(0) - prompt_length\n    rank_print(\n        fabric, f\"\\n\\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n    )\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n\n    CLI(setup)\n"
  },
  {
    "path": "extensions/xla/generate/base.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional\n\nimport lightning as L\nimport torch\nimport torch_xla.core.xla_model as xm\nfrom lightning.fabric.accelerators import XLAAccelerator\nfrom lightning.fabric.strategies import XLAFSDPStrategy\n\nfrom litgpt import GPT, Config, Tokenizer\nfrom litgpt.model import Block\nfrom litgpt.utils import check_valid_checkpoint_dir, lazy_load\n\n# support running without installing as a package\nwd = Path(__file__).parents[3].resolve()\nsys.path.append(str(wd))\n\nfrom xla.utils import rank_print  # noqa: E402\n\n\n# xla does not support `inference_mode`: RuntimeError: Cannot set version_counter for inference tensor\n@torch.no_grad()\ndef generate(\n    model: GPT,\n    idx: torch.Tensor,\n    max_returned_tokens: int,\n    *,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    eos_id: Optional[int] = None,\n) -> torch.Tensor:\n    \"\"\"Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n\n    The implementation of this function is modified from A. Karpathy's nanoGPT.\n\n    Args:\n        model: The model to use.\n        idx: Tensor of shape (T) with indices of the prompt sequence.\n        max_returned_tokens: The maximum number of tokens to return (given plus generated).\n        temperature: Scales the predicted logits by 1 / temperature.\n        top_k: If specified, only sample among the tokens with the k highest probabilities.\n        eos_id: If specified, stop generating any more token once the <eos> token is triggered.\n    \"\"\"\n    T = idx.size(0)\n    assert max_returned_tokens > T\n    if model.max_seq_length < max_returned_tokens - 1:\n        # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a\n        # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do\n        # not support it to avoid negatively impacting the overall speed\n        raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n\n    device, dtype = idx.device, idx.dtype\n    # create an empty tensor of the expected final shape and fill in the current tokens\n    empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)\n    empty[:T] = idx\n    idx = empty\n    # TODO: FSDP has an internal broadcasting issue, so we are forced to have this be of length 1 until it's fixed\n    input_pos = torch.tensor([0], device=device)\n\n    xm.mark_step()\n\n    # generate up to a fixed number of tokens\n    for _ in range(max_returned_tokens):\n        x = idx.index_select(0, input_pos).view(1, -1)\n\n        # forward\n        logits = model(x, input_pos)\n        logits = logits[0, -1] / temperature\n\n        # optionally crop the logits to only the top k options\n        if top_k is not None:\n            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n            logits = torch.where(logits < v[[-1]], -float(\"Inf\"), logits)\n\n        probs = torch.nn.functional.softmax(logits, dim=-1)\n        idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)\n\n        # advance\n        input_pos = input_pos[-1:] + 1\n\n        xm.mark_step()\n\n        # concatenate the new generation\n        idx = idx.index_copy(0, input_pos, idx_next)\n\n        # if <eos> token is triggered, return the output (stop generation)\n        if idx_next == eos_id:\n            return idx[:input_pos]  # include the EOS token\n\n    return idx\n\n\ndef setup(\n    prompt: str = \"What food do llamas eat?\",\n    *,\n    num_samples: int = 1,\n    max_new_tokens: int = 100,\n    top_k: Optional[int] = 50,\n    temperature: float = 0.8,\n    checkpoint_dir: Path = Path(\"checkpoints/tiiuae/falcon-7b\"),\n    precision: str = \"bf16-true\",\n) -> None:\n    \"\"\"Generates text samples based on a pre-trained model and tokenizer.\n\n    Args:\n        prompt: The prompt string to use for generating the samples.\n        num_samples: The number of text samples to generate.\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        checkpoint_dir: The checkpoint directory to load.\n        precision: Indicates the Fabric precision setting to use.\n    \"\"\"\n    devices = XLAAccelerator.auto_device_count()\n    strategy = XLAFSDPStrategy(auto_wrap_policy={Block}) if devices > 1 else \"auto\"\n    fabric = L.Fabric(devices=devices, precision=precision, strategy=strategy)\n    fabric.launch(main, prompt, num_samples, max_new_tokens, top_k, temperature, checkpoint_dir)\n\n\ndef main(\n    fabric: L.Fabric,\n    prompt: str,\n    num_samples: int,\n    max_new_tokens: int,\n    top_k: Optional[int],\n    temperature: float,\n    checkpoint_dir: Path,\n) -> None:\n    check_valid_checkpoint_dir(checkpoint_dir)\n\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n\n    rank_print(fabric, f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n    rank_print(fabric, f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    t0 = time.perf_counter()\n    checkpoint = lazy_load(checkpoint_path)\n    model.load_state_dict(checkpoint.get(\"model\", checkpoint))\n    rank_print(fabric, f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    model.eval()\n    model = fabric.setup_module(model)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    with fabric.init_tensor():\n        # set the max_seq_length to limit the memory usage to what we need\n        model.max_seq_length = max_returned_tokens\n\n    L.seed_everything(1234)\n    for i in range(num_samples):\n        with fabric.init_tensor():\n            # enable the kv cache\n            model.set_kv_cache(batch_size=1)\n\n        t0 = time.perf_counter()\n        y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)\n        t = time.perf_counter() - t0\n\n        fabric.print(tokenizer.decode(y))\n        tokens_generated = y.size(0) - prompt_length\n        rank_print(\n            fabric,\n            f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\",\n            file=sys.stderr,\n        )\n\n\nif __name__ == \"__main__\":\n    from jsonargparse import CLI\n\n    CLI(setup)\n"
  },
  {
    "path": "extensions/xla/scripts/__init__",
    "content": ""
  },
  {
    "path": "extensions/xla/scripts/prepare_alpaca.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\"\"\"Implementation derived from https://github.com/tloen/alpaca-lora\"\"\"\n\nimport json\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\nimport yaml\nfrom lightning_utilities.core.imports import RequirementCache\nfrom torch.utils.data import random_split\nfrom tqdm import tqdm\n\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import CLI\n\n\ndef prepare(\n    destination_path: Path = Path(\"data/alpaca\"),\n    checkpoint_dir: Path = Path(\"checkpoints/stabilityai/stablelm-base-alpha-3b\"),\n    val_split_fraction: float = 0.03865,  # to get exactly 2000 validation samples,\n    seed: int = 42,\n    mask_inputs: bool = False,  # as in alpaca-lora\n    data_file_name: str = \"alpaca_data_cleaned_archive.json\",\n    data_file_url: str = \"https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json\",\n    ignore_index: int = -100,\n    max_seq_length: Optional[int] = None,\n) -> None:\n    \"\"\"Prepare the Alpaca dataset for instruction tuning.\n\n    The output is a training and test dataset saved as `train.pt` and `test.pt`,\n    which stores the preprocessed and tokenized prompts and labels.\n    \"\"\"\n    if max_seq_length is None:\n        with open(checkpoint_dir / \"model_config.yaml\", encoding=\"utf-8\") as file:\n            config = yaml.safe_load(file)\n            max_seq_length = config[\"block_size\"]\n\n    destination_path.mkdir(parents=True, exist_ok=True)\n    data_file_path = destination_path / data_file_name\n    print(\"Loading data file...\")\n    download_if_missing(data_file_path, data_file_url)\n    with open(data_file_path, encoding=\"utf-8\") as file:\n        data = json.load(file)\n\n    print(\"Loading tokenizer...\")\n    tokenizer = Tokenizer(checkpoint_dir)\n\n    # Partition the dataset into train and test\n    train_set, test_set = random_split(\n        data, [1.0 - val_split_fraction, val_split_fraction], generator=torch.Generator().manual_seed(seed)\n    )\n    train_set, test_set = list(train_set), list(test_set)\n\n    print(f\"train has {len(train_set):,} samples\")\n    print(f\"test has {len(test_set):,} samples\")\n\n    print(\"Processing train split ...\")\n    train_set = [\n        prepare_sample(\n            example=sample,\n            tokenizer=tokenizer,\n            max_length=max_seq_length,\n            mask_inputs=mask_inputs,\n            ignore_index=ignore_index,\n        )\n        for sample in tqdm(train_set)\n    ]\n    torch.save(train_set, destination_path / \"train.pt\")\n\n    print(\"Processing test split ...\")\n    test_set = [\n        prepare_sample(\n            example=sample,\n            tokenizer=tokenizer,\n            max_length=max_seq_length,\n            mask_inputs=mask_inputs,\n            ignore_index=ignore_index,\n        )\n        for sample in tqdm(test_set)\n    ]\n    torch.save(test_set, destination_path / \"test.pt\")\n\n\ndef download_if_missing(file_path: Path, file_url: str) -> None:\n    \"\"\"Downloads the raw json data file and saves it in the given destination.\"\"\"\n    if file_path.exists() and file_path.stat().st_size > 0:\n        return\n    requests_available = RequirementCache(\"requests\")\n    if not requests_available:\n        raise ModuleNotFoundError(str(requests_available))\n    import requests\n\n    with open(file_path, \"w\", encoding=\"utf-8\") as f:\n        f.write(requests.get(file_url).text)\n\n\ndef prepare_sample(example: dict, tokenizer: Tokenizer, max_length: int, mask_inputs: bool, ignore_index: int) -> dict:\n    \"\"\"Processes a single sample.\n\n    Each sample in the dataset consists of:\n    - instruction: A string describing the task\n    - input: A string holding a special input value for the instruction.\n        This only applies to some samples, and in others this is empty.\n    - output: The response string\n\n    This function processes this data to produce a prompt text and a label for\n    supervised training. The prompt text is formed as a single message including both\n    the instruction and the input. The label/target is the same message but with the\n    response attached.\n\n    Finally, both the prompt and the label get tokenized. If desired, all tokens\n    in the label that correspond to the original input prompt get masked out (default).\n    \"\"\"\n    full_prompt = generate_prompt(example)\n    full_prompt_and_response = full_prompt + example[\"output\"]\n    encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length)\n    encoded_full_prompt_and_response = tokenizer.encode(full_prompt_and_response, eos=True, max_length=max_length)\n\n    # The labels are the full prompt with response, but with the prompt masked out\n    labels = encoded_full_prompt_and_response.clone()\n    if mask_inputs:\n        labels[: len(encoded_full_prompt)] = ignore_index\n\n    return {**example, \"input_ids\": encoded_full_prompt_and_response, \"labels\": labels}\n\n\ndef generate_prompt(example: dict) -> str:\n    \"\"\"Generates a standardized message to prompt the model with an instruction, optional input and a\n    'response' field.\"\"\"\n\n    if example[\"input\"]:\n        return (\n            \"Below is an instruction that describes a task, paired with an input that provides further context. \"\n            \"Write a response that appropriately completes the request.\\n\\n\"\n            f\"### Instruction:\\n{example['instruction']}\\n\\n### Input:\\n{example['input']}\\n\\n### Response:\"\n        )\n    return (\n        \"Below is an instruction that describes a task. \"\n        \"Write a response that appropriately completes the request.\\n\\n\"\n        f\"### Instruction:\\n{example['instruction']}\\n\\n### Response:\"\n    )\n\n\nif __name__ == \"__main__\":\n    CLI(prepare)\n"
  },
  {
    "path": "extensions/xla/utils.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport itertools\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any, Callable\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.strategies.xla_fsdp import XLAFSDPStrategy, _activation_checkpointing_auto_wrapper\nfrom lightning_utilities.core.rank_zero import rank_prefixed_message\n\nfrom litgpt import GPT\n\n\ndef rank_print(fabric: L.Fabric, message: object, *, flush: bool = True, **kwargs: Any) -> None:\n    if fabric.local_rank == 0:\n        message = str(message)\n        # let each host print, but only on rank 0\n        message = rank_prefixed_message(message, fabric.global_rank)\n        # TPU VM will only print when the script finishes if `flush=False`\n        print(message, flush=flush, **kwargs)\n\n\ndef materialize_parameters(module: torch.nn.Module, device: torch.device) -> None:\n    for module_name, module in module.named_modules():\n        if any(\n            param.is_meta for param in itertools.chain(module.parameters(recurse=False), module.buffers(recurse=False))\n        ):\n            module.to_empty(device=device, recurse=False)\n            module.reset_parameters()\n\n\ndef sequential_load_and_fsdp_wrap(\n    fabric: L.Fabric, get_model: Callable[[], GPT], checkpoint_path: Path\n) -> torch.nn.Module:\n    assert fabric._launched\n    # similar logic could be implemented for regular FSDP, but this implementation is specific to XLAFSDP\n    assert isinstance(fabric.strategy, XLAFSDPStrategy)\n\n    with fabric.init_module(empty_init=False), torch.device(\"meta\"):\n        model = get_model()\n\n    # TODO: this could be made faster by broadcasting in separate process groups for each host\n    if fabric.local_rank == 0:\n        # load the full checkpoint on a single rank to limit the system memory usage\n        state_dict = torch.load(checkpoint_path, map_location=\"cpu\", mmap=False)  # mmap=True hangs\n    else:\n        # XLA cannot broadcast different number of tensors or different shapes in each rank. To get around this\n        # limitation, we need to load the checkpoint on meta device to get the correct number of tensors and materialize\n        # them as necessary\n        state_dict = torch.load(checkpoint_path, map_location=\"meta\", mmap=False)\n\n    fsdp_kwargs = fabric.strategy._parse_fsdp_kwargs()\n    if \"auto_wrapper_callable\" in fsdp_kwargs:\n        # includes activation checkpointing if configured\n        wrap = fsdp_kwargs.pop(\"auto_wrapper_callable\")\n    else:\n        wrap = partial(_activation_checkpointing_auto_wrapper, set())\n    fsdp_kwargs.pop(\"auto_wrap_policy\", None)  # this needs to be removed or else root wrapping would error\n\n    for i, block in enumerate(model.transformer.h):\n        rank_print(fabric, f\"Broadcasting transformer block {i}\")\n        # get the relevant piece of the state dict\n        to_load = {}\n        for param_name, _ in block.named_parameters():\n            if (key := f\"transformer.h.{i}.{param_name}\") not in state_dict:\n                continue\n            param = state_dict.pop(key)\n            if not param.is_meta:\n                to_load[param_name] = param\n            else:\n                # materialize this parameter for broadcast to work\n                to_load[param_name] = torch.empty_like(param, device=\"cpu\")\n\n        to_load = fabric.broadcast(to_load)\n\n        rank_print(fabric, f\"Loading transformer block {i}\")\n        keys = block.load_state_dict(to_load, strict=False, assign=True)\n        assert not keys.unexpected_keys\n\n        # materialize any leftover meta parameters, regular FSDP does it automatically\n        materialize_parameters(block, torch.device(\"cpu\"))  # init on CPU, FSDP will shard and move it\n\n        # XLA FSDP only supports fp32 parameters. If the checkpoint had a different dtype, this needs to be converted\n        # since we are loading with assign=True\n        block = block.to(torch.float32)\n\n        # shard the block\n        rank_print(fabric, f\"Wrapping transformer block {i}\")\n        wrapped_block = wrap(block, **fsdp_kwargs)\n        model.transformer.h[i] = wrapped_block\n\n    # load the rest of the state_dict, this assumes that all keys need to be loaded\n    # an alternative technique would be to do load the rest of the state dict at once, but we want to materialize\n    # and move the params to the xla device to reduce the system memory usage\n    for key in list(state_dict):\n        rank_print(fabric, f\"Loading {key}\")\n        param = state_dict.pop(key)\n        if param.is_meta:\n            # materialize this parameter for broadcast to work\n            param = torch.empty_like(param, device=\"cpu\")\n        param = fabric.broadcast(param)\n        param = param.to(device=fabric.device, dtype=torch.float32)\n        keys = model.load_state_dict({key: param}, strict=False, assign=True)\n        assert not keys.unexpected_keys\n    assert not state_dict\n\n    # materialize any leftover meta parameters, regular FSDP does it automatically\n    rank_print(fabric, \"Materializing leftover parameters\")\n    materialize_parameters(model, fabric.device)\n\n    return model\n"
  },
  {
    "path": "litgpt/__init__.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport logging\nimport re\n\nfrom litgpt.api import LLM\nfrom litgpt.config import Config\nfrom litgpt.model import GPT  # needs to be imported before config\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632\npattern = re.compile(\".*Profiler function .* will be ignored\")\nlogging.getLogger(\"torch._dynamo.variables.torch\").addFilter(lambda record: not pattern.search(record.getMessage()))\n\n# Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint\nlogging.getLogger(\"torch.distributed.fsdp._optim_utils\").disabled = True\nlogging.getLogger(\"torch.distributed.fsdp._debug_utils\").disabled = True\n\n__all__ = [\"LLM\", \"GPT\", \"Config\", \"PromptStyle\", \"Tokenizer\"]\n"
  },
  {
    "path": "litgpt/__main__.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport warnings\n\nimport torch\nfrom jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options\n\nfrom litgpt.chat.base import main as chat_fn\nfrom litgpt.deploy.serve import run_server as serve_fn\nfrom litgpt.eval.evaluate import convert_and_evaluate as evaluate_fn\nfrom litgpt.finetune.adapter import setup as finetune_adapter_fn\nfrom litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn\nfrom litgpt.finetune.full import setup as finetune_full_fn\nfrom litgpt.finetune.lora import setup as finetune_lora_fn\nfrom litgpt.generate.adapter import main as generate_adapter_fn\nfrom litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn\nfrom litgpt.generate.base import main as generate_base_fn\nfrom litgpt.generate.full import main as generate_full_fn\nfrom litgpt.generate.sequentially import main as generate_sequentially_fn\nfrom litgpt.generate.speculative_decoding import main as generate_speculatively_fn\nfrom litgpt.generate.tp import main as generate_tp_fn\nfrom litgpt.parser_config import parser_commands\nfrom litgpt.pretrain import setup as pretrain_fn\nfrom litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint as convert_hf_checkpoint_fn\nfrom litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint as convert_lit_checkpoint_fn\nfrom litgpt.scripts.convert_pretrained_checkpoint import (\n    convert_pretrained_checkpoint as convert_pretrained_checkpoint_fn,\n)\nfrom litgpt.scripts.download import download_from_hub as download_fn\nfrom litgpt.scripts.merge_lora import merge_lora as merge_lora_fn\n\nPARSER_DATA = {\n    \"download\": download_fn,\n    \"chat\": chat_fn,\n    \"finetune\": finetune_lora_fn,\n    \"finetune_lora\": finetune_lora_fn,\n    \"finetune_full\": finetune_full_fn,\n    \"finetune_adapter\": finetune_adapter_fn,\n    \"finetune_adapter_v2\": finetune_adapter_v2_fn,\n    \"pretrain\": pretrain_fn,\n    \"generate\": generate_base_fn,\n    \"generate_full\": generate_full_fn,\n    \"generate_adapter\": generate_adapter_fn,\n    \"generate_adapter_v2\": generate_adapter_v2_fn,\n    \"generate_sequentially\": generate_sequentially_fn,\n    \"generate_speculatively\": generate_speculatively_fn,\n    \"generate_tp\": generate_tp_fn,\n    \"convert_to_litgpt\": convert_hf_checkpoint_fn,\n    \"convert_from_litgpt\": convert_lit_checkpoint_fn,\n    \"convert_pretrained_checkpoint\": convert_pretrained_checkpoint_fn,\n    \"merge_lora\": merge_lora_fn,\n    \"evaluate\": evaluate_fn,\n    \"serve\": serve_fn,\n}\n\n\ndef _check_commands():\n    assert set(parser_commands()) == set(PARSER_DATA.keys()), (\n        \"PARSER_DATA has to be kept in sync with litgpt.parser_config.parser_commands()\"\n    )\n\n\ndef main() -> None:\n    _check_commands()\n\n    set_docstring_parse_options(attribute_docstrings=True)\n    set_config_read_mode(urls_enabled=True)\n\n    # PyTorch bug that raises a false-positive warning\n    # More info: https://github.com/Lightning-AI/litgpt/issues/1561\n    warning_message = r\"The epoch parameter in `scheduler.step\\(\\)` was not necessary and is being deprecated.*\"\n\n    warnings.filterwarnings(\n        action=\"ignore\", message=warning_message, category=UserWarning, module=r\".*torch\\.optim\\.lr_scheduler.*\"\n    )\n\n    torch.set_float32_matmul_precision(\"high\")\n    CLI(PARSER_DATA)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "litgpt/adapter.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\"\"\"Implementation of the paper:\n\nLLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention\nhttps://arxiv.org/abs/2303.16199\n\nPort for LitGPT\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tuple\n\nimport torch\nimport torch.nn as nn\nfrom typing_extensions import Self\n\nfrom litgpt.config import Config as BaseConfig\nfrom litgpt.model import GPT as BaseModel\nfrom litgpt.model import Block as BaseBlock\nfrom litgpt.model import CausalSelfAttention as BaseCausalSelfAttention\n\n\n@dataclass\nclass Config(BaseConfig):\n    adapter_prompt_length: int = 10\n    adapter_start_layer: int = 2\n\n\nclass GPT(BaseModel):\n    # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.max_seq_length = self.config.block_size\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def _init_weights(self, module: nn.Module) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.\"\"\"\n        super()._init_weights(module)\n        if isinstance(module, CausalSelfAttention):\n            module.reset_parameters()\n\n\nclass Block(BaseBlock):\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__(config, block_idx)\n        self.attn = CausalSelfAttention(config, block_idx)\n\n\nclass CausalSelfAttention(BaseCausalSelfAttention):\n    \"\"\"A modification of `litgpt.model.CausalSelfAttention` that adds the attention\n    over the adaption prompt.\"\"\"\n\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__(config, block_idx)\n        if block_idx >= config.adapter_start_layer:\n            # adapter embedding layer\n            self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)\n            # gate for adaption\n            self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1))\n            # kv cache for inference\n            self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None\n\n    def scaled_dot_product_attention(\n        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        y = super().scaled_dot_product_attention(q, k, v, mask)\n        if self.block_idx < self.config.adapter_start_layer:\n            return y\n\n        aT = self.config.adapter_prompt_length\n        if self.adapter_kv_cache is not None:\n            # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av\n            # are the same every call\n            ak, av = self.adapter_kv_cache\n        else:\n            prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)\n            aqkv = self.qkv(prefix)\n            q_per_kv = self.config.n_head // self.config.n_query_groups\n            aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)\n            aqkv = aqkv.permute(0, 2, 3, 1, 4)\n            _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2)\n            if self.config.n_query_groups != 1:\n                # for MHA this is a no-op\n                ak = ak.repeat_interleave(q_per_kv, dim=2)\n                av = av.repeat_interleave(q_per_kv, dim=2)\n            ak = ak.view(1, -1, aT, self.config.head_size)  # (1, nh_ak, aT, hs)\n            av = av.view(1, -1, aT, self.config.head_size)  # (1, nh_av, aT, hs)\n            self.adapter_kv_cache = (ak, av)\n\n        T = q.size(2)\n        amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)\n        ay = super().scaled_dot_product_attention(q, ak, av, amask)\n        return y + self.gating_factor * ay\n\n    def reset_parameters(self) -> None:\n        if hasattr(self, \"gating_factor\"):\n            torch.nn.init.zeros_(self.gating_factor)\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with older checkpoints.\"\"\"\n        if (key := prefix + \"gating_factor\") in state_dict and state_dict[key].size(1) == self.config.n_head:\n            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\ndef mark_only_adapter_as_trainable(model: GPT) -> None:\n    \"\"\"Sets `requires_grad=False` for all non-adapter weights.\"\"\"\n    for name, param in model.named_parameters():\n        param.requires_grad = adapter_filter(name, param)\n\n\ndef adapter_filter(key: str, value: Any) -> bool:\n    return \"adapter_wte\" in key or \"gating_factor\" in key\n"
  },
  {
    "path": "litgpt/adapter_v2.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\"\"\"Implementation of the paper:\n\nLLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model\nhttps://arxiv.org/abs/2304.15010\n\nPort for LitGPT\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Type\n\nimport torch\nimport torch.nn as nn\nfrom typing_extensions import Self\n\nimport litgpt\nfrom litgpt.adapter import GPT as BaseModel\nfrom litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention\nfrom litgpt.adapter import Config as BaseConfig\nfrom litgpt.model import Block as BaseBlock\nfrom litgpt.scripts.convert_hf_checkpoint import qkv_reassemble\nfrom litgpt.utils import map_old_state_dict_weights\n\n\n@dataclass\nclass Config(BaseConfig):\n    @property\n    def mlp_class(self) -> Type:\n        return getattr(litgpt.adapter_v2, self.mlp_class_name)\n\n\ndef adapter_filter(key: str, value: Any) -> bool:\n    adapter_substrings = (\n        # regular adapter v1 parameters\n        \"adapter_wte\",\n        \"gating_factor\",\n        # adapter v2: new bias and scale used in Linear\n        \"adapter_scale\",\n        \"adapter_bias\",\n        # adapter v2: Norm parameters are now trainable\n        \"norm_1\",\n        \"norm_2\",\n        \"ln_f\",\n    )\n    return any(s in key for s in adapter_substrings)\n\n\nclass AdapterV2Linear(torch.nn.Module):\n    def __init__(self, in_features: int, out_features: int, **kwargs) -> None:\n        super().__init__()\n        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)\n        self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False)\n        self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.adapter_scale * (self.linear(x) + self.adapter_bias)\n\n    def reset_parameters(self) -> None:\n        nn.init.zeros_(self.adapter_bias)\n        nn.init.ones_(self.adapter_scale)\n\n\nclass GPT(BaseModel):\n    # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.max_seq_length = self.config.block_size\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def _init_weights(self, module: nn.Module) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.\"\"\"\n        super()._init_weights(module)\n        if isinstance(module, AdapterV2Linear):\n            module.reset_parameters()\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\"lm_head.weight\": \"lm_head.linear.weight\", \"lm_head.bias\": \"lm_head.linear.bias\"}\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass Block(BaseBlock):\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__(config, block_idx)\n        self.attn = CausalSelfAttention(config, block_idx)\n        self.mlp = config.mlp_class(config)\n\n\nclass CausalSelfAttention(BaseCausalSelfAttention):\n    \"\"\"A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class\"\"\"\n\n    # Copy&paste from :class:`model.CausalSelfAttention`\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__(config, block_idx)\n        # key, query, value projections for all heads, but in a batch\n        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size\n        self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)\n        # output projection\n        self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base and/or legacy checkpoints.\"\"\"\n        mapping = {\n            \"qkv.weight\": \"qkv.linear.weight\",\n            \"qkv.bias\": \"qkv.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        # For compatibility with older checkpoints\n        if (key := prefix + \"gating_factor\") in state_dict and state_dict[key].size(1) == self.config.n_head:\n            state_dict[key] = state_dict[key].permute(0, 2, 1, 3)\n\n        for attr in (\"weight\", \"bias\"):\n            legacy_key = f\"{prefix}attn.linear.{attr}\"\n            current_key = f\"{prefix}qkv.linear.{attr}\"\n            if legacy_key in state_dict:\n                state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)\n\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass GptNeoxMLP(litgpt.model.GptNeoxMLP):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias)\n        self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias)\n        self.config = config\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc.weight\": \"fc.linear.weight\",\n            \"fc.bias\": \"fc.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass LLaMAMLP(litgpt.model.LLaMAMLP):\n    def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:\n        nn.Module.__init__(self)\n        self.intermediate_size = intermediate_size or config.intermediate_size\n        self.fc_1 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias)\n        self.fc_2 = AdapterV2Linear(config.n_embd, self.intermediate_size, bias=config.bias)\n        self.proj = AdapterV2Linear(self.intermediate_size, config.n_embd, bias=config.bias)\n        self.config = config\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc_1.weight\": \"fc_1.linear.weight\",\n            \"fc_1.bias\": \"fc_1.linear.bias\",\n            \"fc_2.weight\": \"fc_2.linear.weight\",\n            \"fc_2.bias\": \"fc_2.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass GemmaMLP(LLaMAMLP):\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_fc_1 = self.fc_1(x)\n        x_fc_2 = self.fc_2(x)\n        x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2\n        return self.proj(x)\n\n\nclass LLaMAMoE(litgpt.model.LLaMAMoE):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False)\n        self.experts = nn.ModuleList(\n            LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)\n        )\n        self.config = config\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\"gate.weight\": \"gate.linear.weight\"}\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\ndef mark_only_adapter_v2_as_trainable(model: GPT) -> None:\n    \"\"\"Sets requires_grad=False for all non-adapter weights\"\"\"\n    for name, param in model.named_parameters():\n        param.requires_grad = adapter_filter(name, param)\n"
  },
  {
    "path": "litgpt/api.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n#\n# This file implements the LitGPT Python API\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Any, Callable, List, Literal, Optional, Tuple, Union\n\nimport lightning as L\nimport numpy as np\nimport torch\nfrom lightning.fabric.accelerators import CUDAAccelerator\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom tqdm import tqdm\n\nfrom litgpt.chat.base import generate as stream_generate_fn\nfrom litgpt.config import Config, name_to_config\nfrom litgpt.generate.base import generate as generate_fn\nfrom litgpt.generate.sequentially import sequential\nfrom litgpt.generate.tp import tensor_parallel\nfrom litgpt.model import GPT\nfrom litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style, save_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import (\n    auto_download_checkpoint,\n    check_file_size_on_cpu_and_warn,\n    check_nvlink_connectivity,\n    chunked_cross_entropy,\n    copy_config_files,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n    load_checkpoint,\n    save_config,\n)\n\n\nclass LLM(torch.nn.Module):\n    def __init__(\n        self,\n        model: GPT,\n        preprocessor=None,\n        prompt_style: PromptStyle = None,\n        devices: Union[int, List[int]] = None,\n        config: Config = None,\n        checkpoint_dir: Path = None,\n        fabric: L.Fabric = None,\n        generate_strategy: Optional[Literal[\"sequential\", \"tensor_parallel\"]] = None,\n        kv_cache_initialized: bool = False,\n        fixed_kv_cache_size: Union[int, Literal[\"max_model_supported\"], None] = None,\n    ) -> None:\n        super().__init__()\n        self.model = model\n        self.preprocessor = preprocessor\n        self.devices = devices\n        self.prompt_style = prompt_style\n        self.config = config\n        self.checkpoint_dir = checkpoint_dir\n        self.fabric = fabric\n        self.generate_strategy = generate_strategy\n        self.kv_cache_initialized = kv_cache_initialized\n        self.fixed_kv_cache_size = fixed_kv_cache_size\n        self.prev_generated_seq_length = 0\n\n    \"\"\"\n    LLM model class for inference, pretraining, and finetuning.\n\n    Example:\n        from litgpt.api import LLM\n\n        llm = LLM.load(\"microsoft/phi-2\")\n        text = llm.generate(\"What do Llamas eat?\", top_k=1)\n        print(text)\n    \"\"\"\n\n    @property\n    def tokenizer(self):\n        return self.preprocessor.tokenizer\n\n    def state_dict(self, destination=None, prefix=\"\", keep_vars=False):\n        return self.model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)\n\n    def load_state_dict(self, state_dict, strict=True):\n        return self.model.load_state_dict(state_dict, strict=strict)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        target_ids: Optional[torch.Tensor] = None,\n        loss_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        logits = self.model(input_ids)\n        if target_ids is not None:\n            if loss_fn is None:\n                loss_fn = chunked_cross_entropy\n            loss = loss_fn(logits[..., :-1, :], target_ids[..., 1:])\n            return logits, loss\n        else:\n            return logits\n\n    def trainer_setup(self, trainer_ckpt: Optional[Path] = None) -> None:\n        \"\"\"Initializes the model checkpoint for PyTorch Lightning Trainer contexts\"\"\"\n        self.model = GPT(self.config)\n\n        if trainer_ckpt is not None:\n            # strip the object name key from the state_dict\n            state_dict = torch.load(trainer_ckpt, weights_only=True)[\"state_dict\"]\n            first_key = next(iter(state_dict))\n            prefix = first_key.split(\".\")[0] + \".\"\n            keys_to_modify = [key for key in state_dict if key.startswith(prefix)]\n            for key in keys_to_modify:\n                new_key = key.replace(prefix, \"\", 1)\n                state_dict[new_key] = state_dict.pop(key)\n\n            self.load_state_dict(state_dict, strict=True)\n\n        elif self.checkpoint_dir is not None:\n            state_dict = torch.load(self.checkpoint_dir / \"lit_model.pth\", weights_only=False)\n            self.load_state_dict(state_dict, strict=False)\n\n        else:\n            raise ValueError(\n                \"No checkpoint found. Either provide a valid path via `trainer_ckpt` \"\n                \"or ensure that `self.checkpoint_dir` points to a folder containing a `lit_model.pth` weight file.\"\n            )\n\n    def save(self, out_dir: Optional[Path] = None, prompt_style: Optional[PromptStyle] = None) -> None:\n        out_dir = Path(out_dir)\n        save_path = out_dir / \"lit_model.pth\"\n        save_path.parent.mkdir(parents=True, exist_ok=True)\n\n        if prompt_style is None:\n            prompt_style = PromptStyle.from_config(self.config)\n        if self.fabric is None:\n            torch.save(self.state_dict(), save_path)\n        else:\n            self.fabric.save(save_path, self.state_dict())\n\n        if self.fabric is None or self.fabric.global_rank == 0:\n            # If initialization a model with random weights, the checkpoint dir can be none\n            if self.checkpoint_dir is not None:\n                copy_config_files(Path(self.checkpoint_dir), save_path.parent)\n            else:\n                save_config(self.config, out_dir)\n\n            save_prompt_style(prompt_style, save_path.parent)\n\n    @classmethod\n    def load(\n        cls,\n        model: str,\n        init: Optional[Literal[\"pretrained\", \"random\"]] = \"pretrained\",\n        tokenizer_dir: Optional[Path] = None,\n        access_token: Optional[str] = None,\n        distribute: Optional[Literal[\"auto\"]] = \"auto\",\n    ) -> \"LLM\":\n        \"\"\"\n        Loads the LLM from a local directory or model hub.\n\n        Arguments\n            model: A local path to a directory containing the model weights or a valid model name.\n               You can get a list of valid model names via the `litgpt download list` command line argument.\n            init: If \"pretrained\" (default), downloads the model from the HF Hub if a local model can't be found at the `model`\n                directory name; otherwise loads the model from the local directory.\n                If \"random\", initializes the `model` with random weights.\n            tokenizer_dir: An optional tokenizer directory if `model` is not a checkpoint directory, or if a user\n                wants to use a different tokenizer instead.\n            access_token: Optional API token to access models with restrictions when using `init=\"pretrained\"`.\n            distribute: If \"auto\" (default), initializes the model on a single GPU if available and otherwise on the CPU.\n                To have more control over the model distribution strategy and utilize multiple GPUs, you can set\n                `llm = LLM.load(..., distribute=None)` and call `llm.distribute(...)` manually.\n        \"\"\"\n\n        allowed_init = {\"pretrained\", \"random\"}\n\n        if init == \"pretrained\":\n            checkpoint_dir = auto_download_checkpoint(\n                model_name=model, access_token=access_token, ignore_tokenizer_files=tokenizer_dir is not None\n            )\n            config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n        elif init == \"random\":\n            checkpoint_dir = None\n            try:\n                config = Config.from_name(model)\n            except ValueError:\n                print(f\"Model name {model} is not supported.\\n\")\n                available_models = \"\\n\".join(sorted(name_to_config))\n                print(f\"Available values:\\n{available_models}\")\n                return\n\n        else:\n            raise ValueError(f\"Invalid init option: {init}. Must be one of {allowed_init}\")\n\n        torch.set_float32_matmul_precision(\"high\")\n\n        if tokenizer_dir is not None:\n            tokenizer_dir = extend_checkpoint_dir(Path(tokenizer_dir))\n            tokenizer = Tokenizer(tokenizer_dir)\n        elif checkpoint_dir is not None:\n            tokenizer = Tokenizer(checkpoint_dir)\n        else:\n            raise ValueError(\"Provide a path to a tokenizer directory via the `tokenizer_dir` setting.\")\n\n        if checkpoint_dir is not None:\n            prompt_style = (\n                load_prompt_style(checkpoint_dir)\n                if has_prompt_style(checkpoint_dir)\n                else PromptStyle.from_config(config)\n            )\n        else:\n            prompt_style = PromptStyle.from_config(config)\n\n        if distribute == \"auto\":\n            if torch.cuda.is_available():\n                accelerator = \"cuda\"\n            elif torch.backends.mps.is_available():\n                accelerator = \"mps\"\n            else:\n                accelerator = \"cpu\"\n\n            fabric = L.Fabric(\n                accelerator=accelerator,\n                devices=1,\n                precision=get_default_supported_precision(training=False),\n            )\n\n            with fabric.init_module(empty_init=False):\n                model = GPT(config)\n            model.eval()\n            preprocessor = Preprocessor(tokenizer, device=fabric.device)\n\n            if checkpoint_dir is not None:\n                checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n                check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)\n                load_checkpoint(fabric, model, checkpoint_path)\n\n            model = fabric.setup_module(model)\n\n        else:\n            preprocessor = Preprocessor(tokenizer, device=\"cuda\" if torch.cuda.is_available() else \"cpu\")\n            model = None\n            fabric = None\n\n        return cls(\n            model=model,\n            preprocessor=preprocessor,\n            prompt_style=prompt_style,\n            config=config,\n            checkpoint_dir=checkpoint_dir,\n            fabric=fabric,\n            generate_strategy=None,\n            kv_cache_initialized=False,\n            fixed_kv_cache_size=False,\n        )\n\n    def distribute(\n        self,\n        accelerator: Literal[\"cpu\", \"cuda\", \"auto\"] = \"auto\",\n        devices: Union[int, Literal[\"auto\"]] = \"auto\",\n        precision: Optional[Any] = None,\n        quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n        generate_strategy: Optional[Literal[\"sequential\", \"tensor_parallel\"]] = None,\n        fixed_kv_cache_size: Union[int, Literal[\"max_model_supported\"], None] = None,\n    ) -> None:\n        \"\"\"\n        Moves the model onto specified devices for single-GPU or multi-GPU inference\n\n        accelerator: Which device type to load the model on (\"cpu\", \"gpu\", \"mps\", \"cuda\", or \"auto\")\n        devices: The number of devices (1, 2, etc.) or \"auto\", which uses all available devices\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        precision: Indicates the Fabric precision setting to use.\n            For instance, \"32-true\", \"16-mixed\", \"16-true\", \"bf16-mixed\", \"bf16-true\".\n            For more details, see https://lightning.ai/docs/fabric/stable/api/fabric_args.html#precision\n        generate_strategy: Whether to use a sequential model generation strategy. The \"sequential\" settings allows running\n            models that wouldn't fit in a single card by partitioning the transformer blocks across\n            all devices and running them sequentially. Sequential generation may be slower but allows using larger models.\n            Note that sequential generation sets `fixed_kv_cache_size=\"max_model_supported\"`. You can set it to a lower integer\n            value, `fixed_kv_cache_size=256` to reduce memory. The `fixed_kv_cache_size` value determines the maximum number\n            of tokens that can be returned via `llm.generate(...)`.\n        fixed_kv_cache_size: If set to an integer value or \"max_model_supported\" is set, the kv-cache won't be resized dynamically\n            during `llm.generate` calls. Use this setting if you plan to compile the model or use `generate_strategy=\"sequential`.\n            Note that the chosen `fixed_kv_cache_size` value determines the maximum number of tokens that can be returned in `llm.generate(...)`.\n        \"\"\"\n\n        if self.checkpoint_dir is None:\n            raise NotImplementedError(\n                \"The LLM was initialized with init='random' but .distribute() \"\n                \"currently only supports pretrained weights.\"\n            )\n\n        allowed_accelerators = {\"cpu\", \"gpu\", \"cuda\", \"mps\", \"auto\"}\n        if accelerator not in allowed_accelerators:\n            raise ValueError(f\"Invalid accelerator: {accelerator}. Must be one of {allowed_accelerators}.\")\n\n        if accelerator == \"auto\":\n            if torch.cuda.is_available():\n                accelerator = \"cuda\"\n            elif torch.backends.mps.is_available():\n                accelerator = \"mps\"\n            else:\n                accelerator = \"cpu\"\n\n        if generate_strategy in (\"sequential\", \"tensor_parallel\") and accelerator not in (\"cuda\", \"gpu\"):\n            raise NotImplementedError(\n                f\"generate_strategy='{generate_strategy}' is only supported for accelerator='cuda'|'gpu'.\"\n            )\n\n        if devices == \"auto\":\n            if generate_strategy in (\"sequential\", \"tensor_parallel\"):\n                total_devices = CUDAAccelerator.auto_device_count()\n            else:\n                total_devices = 1\n        elif isinstance(devices, int) and accelerator == \"cuda\":\n            use_devices = calculate_number_of_devices(devices)\n            total_devices = CUDAAccelerator.auto_device_count()\n            if use_devices > total_devices:\n                raise ValueError(\n                    f\"You selected more devices ({use_devices}) than available in your system ({total_devices}).\"\n                )\n            else:\n                total_devices = use_devices\n\n            if total_devices > 1 and generate_strategy not in (\"sequential\", \"tensor_parallel\"):\n                raise NotImplementedError(\n                    \"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'.\"\n                )\n        elif accelerator == \"cpu\" or accelerator == \"mps\":\n            total_devices = 1\n\n        else:\n            raise ValueError(f\"devices argument must be an integer or 'auto', got {devices}\")\n\n        print(f\"Using {total_devices} device(s)\", file=sys.stderr)\n\n        if precision is None:\n            precision = get_default_supported_precision(training=False)\n\n        print(\"Precision set\", file=sys.stderr)\n\n        plugins = None\n        if quantize is not None and quantize.startswith(\"bnb.\"):\n            if \"mixed\" in precision:\n                raise ValueError(\"The combination of quantization and mixed precision is not supported.\")\n            dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n            plugins = BitsandbytesPrecision(quantize[4:], dtype)\n            precision = None\n\n        # set \"ddp\" as the strategy for the launching functionality, but there's no data-parallelism\n        if generate_strategy != \"tensor_parallel\":\n            fabric = L.Fabric(\n                accelerator=accelerator,\n                devices=1,  # Otherwise sequential wouldn't work, see litgpt/generate/sequentially.py\n                # devices=devices,\n                precision=precision,\n                plugins=plugins,\n            )\n        else:\n            fabric = L.Fabric(\n                accelerator=accelerator, devices=total_devices, strategy=\"ddp\", precision=precision, plugins=plugins\n            )\n            if torch.cuda.is_available() and fabric.accelerator.auto_device_count() > 1:\n                check_nvlink_connectivity(fabric)\n                fabric.launch()\n\n        print(\"Fabric launched\", file=sys.stderr)\n\n        self.kv_cache_initialized = False\n        if generate_strategy is None:\n            with fabric.init_module(empty_init=(total_devices > 1)):\n                model = GPT(self.config)\n            model.eval()\n\n            if self.checkpoint_dir is not None:\n                load_checkpoint(fabric, model, self.checkpoint_dir / \"lit_model.pth\")\n\n            model = fabric.setup_module(model)\n\n            if fixed_kv_cache_size is not None:\n                if fixed_kv_cache_size is None or fixed_kv_cache_size == \"max_model_supported\":\n                    kv_cache_size = model.max_seq_length\n                else:\n                    kv_cache_size = fixed_kv_cache_size\n                model.set_kv_cache(batch_size=1, max_seq_length=kv_cache_size, device=fabric.device)\n                self.kv_cache_initialized = True\n                self.fixed_kv_cache_size = fixed_kv_cache_size\n\n        elif generate_strategy in (\"sequential\", \"tensor_parallel\"):\n            with fabric.init_tensor(), torch.device(\"meta\"):\n                model = GPT(self.config)\n            model.eval()\n\n            if generate_strategy == \"sequential\":\n                state_dict = torch.load(\n                    str(self.checkpoint_dir / \"lit_model.pth\"), mmap=True, map_location=\"cpu\", weights_only=False\n                )\n                model.load_state_dict(state_dict, assign=True)\n                model = fabric.setup_module(model, move_to_device=False)\n\n                if fixed_kv_cache_size is None:\n                    fixed_kv_cache_size = \"max_model_supported\"\n                if fixed_kv_cache_size == \"max_model_supported\":\n                    kv_cache_size = model.max_seq_length\n                else:\n                    kv_cache_size = fixed_kv_cache_size\n\n                model = sequential(model, fabric.device, kv_cache_size, total_devices)\n                self.fixed_kv_cache_size = fixed_kv_cache_size\n\n            elif generate_strategy == \"tensor_parallel\":\n                if fabric.global_rank == 0:\n                    pbar = tqdm(total=fabric.world_size, desc=\"Loading model weights\")\n                for rank in range(fabric.world_size):\n                    if fabric.global_rank == rank:\n                        state_dict = torch.load(\n                            str(self.checkpoint_dir / \"lit_model.pth\"),\n                            mmap=True,\n                            map_location=\"cpu\",\n                            weights_only=False,\n                        )\n                        model.load_state_dict(state_dict, assign=True)\n\n                        # cannot use `.setup_module` because it will wrap with DDP\n                        model = fabric._precision.convert_module(model)\n                        model = tensor_parallel(fabric, model)\n\n                        with fabric.init_tensor():\n                            if fixed_kv_cache_size is None:\n                                fixed_kv_cache_size = \"max_model_supported\"\n                            if fixed_kv_cache_size == \"max_model_supported\":\n                                kv_cache_size = model.max_seq_length\n                            else:\n                                kv_cache_size = fixed_kv_cache_size\n                            model.max_seq_length = kv_cache_size\n                            # the rope cache which is on meta device\n                            model.cos, model.sin = model.rope_cache()\n                            # enable the kv cache\n                            model.set_kv_cache(batch_size=1)\n                        model.eval()\n                        model = fabric.to_device(model)\n\n                    fabric.barrier()\n                    if fabric.global_rank == 0:\n                        pbar.update(1)\n\n                if fabric.global_rank == 0:\n                    pbar.close()\n\n            self.kv_cache_initialized = True\n\n        else:\n            raise ValueError(f\"Unsupported generate_strategy: {generate_strategy}\")\n\n        self.model = model\n        self.fabric = fabric\n        self.preprocessor.device = fabric.device\n\n    @torch.inference_mode()\n    def generate(\n        self,\n        prompt: str,\n        sys_prompt: Optional[str] = None,\n        max_new_tokens: int = 50,\n        temperature: float = 1.0,\n        top_k: Optional[int] = None,\n        top_p: float = 1.0,\n        return_as_token_ids: bool = False,\n        stream: bool = False,\n    ) -> Union[str, torch.Tensor]:\n        \"\"\"\n        Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n\n        Arguments:\n            model: The model to use.\n            prompt: The prompt string to use for generating the samples.\n            sys_prompt: The system prompt string to use for generating the samples.\n                The system prompt allows the user to provide additional instructions to shape all responses by providing additional context, behavioral guidelines, style, and constraints.\n            max_new_tokens: The maximum number of new tokens to return.\n            temperature: Scales the predicted logits by 1 / temperature.\n            top_k: If specified, only sample among the tokens with the k highest probabilities.\n            top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n                In top-p sampling, the next token is sampled from the highest probability tokens\n                whose cumulative probability exceeds the threshold `top_p`. When specified,\n                it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n                to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n                It can be used in conjunction with `top_k` and `temperature` with the following order\n                of application:\n\n                1. `top_k` sampling\n                2. `temperature` scaling\n                3. `top_p` sampling\n\n                For more details, see https://arxiv.org/abs/1904.09751\n                or https://huyenchip.com/2024/01/16/sampling.html#top_p\n            return_as_token_ids: If True, returns the token IDs as a torch.Tensor. Otherwise, returns the decoded text as a string.\n            stream: If True, returns a generator that yields tokens as they are generated.\n                At the moment, this setting is slower and may use more memory than the non-streaming version.\n                We plan to resolve this in the future.\n        \"\"\"\n        if self.model is None:\n            raise AttributeError(\n                \"The model is not initialized yet; use the .distribute() \"\n                \"or .trainer_setup() method to initialize the model.\"\n            )\n        input_ids = self._text_to_token_ids(prompt, sys_prompt)\n        prompt_length = input_ids.size(0)\n        max_returned_tokens = prompt_length + max_new_tokens\n\n        if not self.kv_cache_initialized:\n            if self.fabric is not None:\n                device = self.fabric.device\n            else:\n                device = self.preprocessor.device\n            self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=device)\n            self.kv_cache_initialized = True\n\n        # Dynamically grow the kv cache size if necessary\n        if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens:\n            tmp_device = self.model.mask_cache.device\n            self.model.clear_kv_cache()\n            self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)\n\n        else:\n            for block in self.model.transformer.h:\n                block.attn.kv_cache.reset_parameters()\n\n        self.prev_generated_seq_length = max_returned_tokens\n        self.model.eval()\n\n        def iterator():\n            outputs = stream_generate_fn(\n                model=self.model,\n                prompt=input_ids,\n                max_returned_tokens=max_returned_tokens,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n                stop_tokens=([self.preprocessor.tokenizer.eos_id],),\n            )\n            if return_as_token_ids:\n                yield from outputs\n            else:\n                for output in outputs:\n                    yield self.preprocessor.decode(output)\n            return\n\n        if stream:\n            outputs = iterator()\n        else:\n            outputs = generate_fn(\n                model=self.model,\n                prompt=input_ids,\n                max_returned_tokens=max_returned_tokens,\n                temperature=temperature,\n                top_k=top_k,\n                top_p=top_p,\n                eos_id=self.preprocessor.tokenizer.eos_id,\n                include_prompt=False,\n            )\n\n        if stream:\n            return outputs\n        elif return_as_token_ids:\n            return outputs\n        else:\n            return self.preprocessor.decode(outputs)\n\n    def _text_to_token_ids(self, prompt: str, sys_prompt: Optional[str] = None) -> torch.Tensor:\n        \"\"\"Utility method to convert a prompt text to token IDs\"\"\"\n        prompt = self.prompt_style.apply(prompt, sys_prompt=sys_prompt)\n        input_ids = self.preprocessor.encode(prompt)\n        return input_ids\n\n    def benchmark(self, num_iterations=1, **kwargs):\n        \"\"\"\n        A wrapper around the .generate() method to calculate runtime performance.\n\n        Arguments:\n        num_iterations: How often the `.generate()` call is repeated.\n        kwargs: Keyword arguments that are passed to the .generate() method.\n        \"\"\"\n        benchmark_dict = {}\n\n        for i in range(num_iterations):\n            time_to_first_token = None\n            t0 = time.perf_counter()\n            outputs = self.generate(**kwargs)\n\n            if kwargs.get(\"stream\", False):\n                gen_outputs = []\n                for e in outputs:\n                    if time_to_first_token is None:\n                        t1 = time.perf_counter()\n                        time_to_first_token = t1 - t0\n                    gen_outputs.append(e)\n                outputs = \"\".join(gen_outputs)\n            else:\n                outputs = self.generate(\n                    **kwargs,\n                )\n            benchmark_dict.setdefault(\"Seconds total\", []).append(time.perf_counter() - t0)\n\n            benchmark_dict.setdefault(\"Seconds to first token\", []).append(time_to_first_token)\n            tokens_generated = self.preprocessor.encode(outputs).size(0)\n            benchmark_dict.setdefault(\"Tokens generated\", []).append(tokens_generated)\n            benchmark_dict.setdefault(\"Inference speed in tokens/sec\", []).append(\n                benchmark_dict[\"Tokens generated\"][-1] / benchmark_dict[\"Seconds total\"][-1]\n            )\n            if self.fabric is not None and self.fabric.device.type == \"cuda\":\n                benchmark_dict.setdefault(\"Total GPU memory allocated in GB\", []).append(\n                    torch.cuda.max_memory_allocated() / 1e9\n                )\n\n        return outputs, benchmark_dict\n\n\nclass Preprocessor:\n    \"\"\"\n    Preprocessor class for tokenization and de-tokenization.\n    \"\"\"\n\n    def __init__(self, tokenizer: Tokenizer, device: str = \"cpu\") -> None:\n        self.tokenizer = tokenizer\n        self.device = device\n\n    def encode(self, text: str) -> torch.Tensor:\n        return self.tokenizer.encode(text, device=self.device)\n\n    def decode(self, token_ids: torch.Tensor) -> str:\n        return self.tokenizer.decode(token_ids)\n\n\ndef calculate_number_of_devices(devices):\n    \"\"\"\n    Utility function to calculate the number of devices.\n    \"\"\"\n    num_devices = devices if isinstance(devices, int) else len(devices) if isinstance(devices, list) else 0\n    return num_devices\n\n\ndef benchmark_dict_to_markdown_table(data):\n    \"\"\"\n    Converts .benchmark() outputs to a markdown table\n    \"\"\"\n    markdown_table = (\n        \"| Metric                              | Mean                        | Std Dev                     |\\n\"\n    )\n    markdown_table += (\n        \"|-------------------------------------|-----------------------------|-----------------------------|\\n\"\n    )\n\n    for key, values in data.items():\n        mean_value = np.mean(values)\n        std_dev_value = np.std(values, ddof=1)\n\n        formatted_mean = f\"{mean_value:.2f}\"\n        formatted_std_dev = f\"{std_dev_value:.2f}\"\n\n        markdown_table += f\"| {key.ljust(35)} | {formatted_mean.ljust(27)} | {formatted_std_dev.ljust(27)} |\\n\"\n\n    return markdown_table\n\n\ndef pull_request_benchmark_util(model_name=\"microsoft/phi-2\", num_iterations=6):\n    def print_table(header, data):\n        print(f\"\\n### {header}\\n\")\n        markdown_table = (\n            f\"| Metric                               | First Iteration | \"\n            f\"Iter 2-{num_iterations} Mean     | Iter 2-{num_iterations} Standard Dev.  |\\n\"\n            f\"|--------------------------------------|-----------------|\"\n            f\"-------------------|-------------------------|\\n\"\n        )\n\n        for key, value in data.items():\n            first_iteration = f\"{value[0]:.2f}\" if value[0] is not None else \"N/A\"\n            clean_values = [v for v in value[1:] if v is not None]\n\n            if clean_values:\n                mean_value = np.mean(clean_values)\n                std_dev_value = np.std(clean_values, ddof=1)\n                mean_str = f\"{mean_value:.2f}\"\n                std_dev_str = f\"{std_dev_value:.2f}\"\n            else:\n                mean_str = \"N/A\"\n                std_dev_str = \"N/A\"\n\n            markdown_table += f\"| {key:<36} | {first_iteration:<15} | {mean_str:<17} | {std_dev_str:<23} |\\n\"\n        print(markdown_table)\n\n    import subprocess\n\n    try:\n        g_hash = subprocess.run(\n            [\"git\", \"rev-parse\", \"--short\", \"HEAD\"], capture_output=True, text=True, check=True\n        ).stdout.strip()\n        print(f\"Git Commit Hash: {g_hash}\")\n    except subprocess.CalledProcessError:\n        print(\"Git Commit Hash: N/A\")\n    print(f\"PyTorch version: {torch.__version__}\")\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    print(f\"Device: {device}\\n\")\n\n    # 1st table\n    llm = LLM.load(\n        model=model_name,\n    )\n    text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt=\"What do llamas eat?\", top_k=1)\n    print_table(f\"Defaults ({model_name}), 1st time\", bench_d)\n    del llm\n\n    # 2nd table\n    llm = LLM.load(\n        model=model_name,\n    )\n    text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt=\"What do llamas eat?\", top_k=1)\n    print_table(f\"Defaults ({model_name}), 2nd time\", bench_d)\n    del llm\n\n    # 3rd table\n    llm = LLM.load(\n        model=model_name,\n    )\n    text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt=\"What do llamas eat?\", top_k=1, stream=True)\n    print_table(\"stream=True\", bench_d)\n    del llm\n\n    # 4th table\n    llm = LLM.load(model=model_name, distribute=None)\n    llm.distribute(fixed_kv_cache_size=500)\n\n    text, bench_d = llm.benchmark(num_iterations=num_iterations, prompt=\"What do llamas eat?\", top_k=1, stream=True)\n    print_table(\"stream=True + fixed_kv_cache=500\", bench_d)\n"
  },
  {
    "path": "litgpt/args.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport math\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Dict, Optional, Union\n\n\n@dataclass\nclass TrainArgs:\n    \"\"\"Training-related arguments\"\"\"\n\n    save_interval: Optional[int] = 1000\n    \"\"\"Number of optimizer steps between saving checkpoints\"\"\"\n    log_interval: int = 1\n    \"\"\"Number of iterations between logging calls\"\"\"\n    global_batch_size: int = 64\n    \"\"\"Number of samples between optimizer steps across data-parallel ranks\"\"\"\n    micro_batch_size: int = 4\n    \"\"\"Number of samples per data-parallel rank\"\"\"\n    lr_warmup_steps: Optional[int] = 100\n    \"\"\"Number of iterations with learning rate warmup active\"\"\"\n    lr_warmup_fraction: Optional[float] = None\n    \"\"\"The fraction of an epoch to use for learning rate warmup\"\"\"\n    epochs: Optional[int] = None\n    \"\"\"Number of epochs to train on\"\"\"\n    # TODO: `pretrain` is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs?\n    max_tokens: Optional[int] = None\n    \"\"\"Total number of tokens to train on\"\"\"\n    max_steps: Optional[int] = None\n    \"\"\"Limits the number of optimizer steps to run\"\"\"\n    max_time: Optional[float] = None\n    \"\"\"Limits the number of seconds to train for\"\"\"\n    max_seq_length: Optional[int] = None\n    \"\"\"Limits the length of samples\"\"\"\n    tie_embeddings: Optional[bool] = None\n    \"\"\"Whether to tie the embedding weights with the language modeling head weights\"\"\"\n\n    # Optimization args\n    max_norm: Optional[float] = None\n    min_lr: float = 6e-5\n\n    def __post_init__(self) -> None:\n        if self.lr_warmup_fraction and self.lr_warmup_steps:\n            raise ValueError(\n                \"Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_steps`. Choose one.\"\n            )\n        if self.lr_warmup_fraction and not (0 <= self.lr_warmup_fraction <= 1):\n            raise ValueError(\"`--train.lr_warmup_fraction` must be between 0 and 1.\")\n\n        if self.lr_warmup_steps and self.max_steps and (self.lr_warmup_steps >= self.max_steps):\n            warnings.warn(\n                \"`--train.lr_warmup_steps` should be less than `--train.max_steps`.\"\n                f\" Got {self.lr_warmup_steps} lr_warmup_steps and {self.max_steps} max_steps.\",\n                UserWarning,\n            )\n\n    def gradient_accumulation_iters(self, devices: int, num_nodes: int = 1) -> int:\n        \"\"\"Number of iterations between gradient synchronizations\"\"\"\n        gradient_accumulation_iters = self.batch_size(devices, num_nodes) // self.micro_batch_size\n        assert gradient_accumulation_iters > 0\n        return gradient_accumulation_iters\n\n    def batch_size(self, devices: int, num_nodes: int = 1) -> int:\n        \"\"\"Number of samples between optimizer steps per data-parallel rank\"\"\"\n        batch_size = self.global_batch_size // (devices * num_nodes)\n        assert batch_size > 0\n        return batch_size\n\n    def warmup_iters(self, devices: int, num_nodes: int, max_iters: int, train_dataloader) -> int:\n        \"\"\"Number of iterations to warm up the learning rate.\"\"\"\n        if self.lr_warmup_fraction:\n            return min(max_iters, math.ceil(self.lr_warmup_fraction * len(train_dataloader)))\n        if self.lr_warmup_steps:\n            return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices, num_nodes))\n        return 0\n\n\n@dataclass\nclass EvalArgs:\n    \"\"\"Evaluation-related arguments\"\"\"\n\n    interval: int = 600\n    \"\"\"Number of optimizer steps between evaluation calls\"\"\"\n    max_new_tokens: Optional[int] = None\n    \"\"\"Number of tokens to generate\"\"\"\n    max_iters: int = 100\n    \"\"\"Number of iterations\"\"\"\n    initial_validation: bool = False\n    \"\"\"Whether to evaluate on the validation set at the beginning of the training\"\"\"\n    final_validation: bool = True\n    \"\"\"Whether to evaluate on the validation set at the end of the training\"\"\"\n    evaluate_example: Union[str, int] = \"first\"\n    \"\"\"How to pick an example instruction to evaluate periodically during training.\n       Can be \"first\", \"random\", or an integer index to pick a specific example.\"\"\"\n\n\n@dataclass\nclass LogArgs:\n    \"\"\"Logging-related arguments. Different loggers use different fields.\"\"\"\n\n    # === WandB Fields ===\n    project: Optional[str] = None\n    \"\"\"WandB project name\"\"\"\n    run: Optional[str] = None\n    \"\"\"WandB run name (defaults to generated name)\"\"\"\n    group: Optional[str] = None\n    \"\"\"WandB group name\"\"\"\n\n    # === LitLogger Fields (Lightning.ai) ===\n    teamspace: Optional[str] = None\n    \"\"\"Teamspace name where charts and artifacts will appear\"\"\"\n    metadata: Optional[Dict] = None\n    \"\"\"Extra metadata to associate with the experiment as tags\"\"\"\n    log_model: bool = False\n    \"\"\"If True, automatically log model checkpoints as artifacts\"\"\"\n    save_logs: bool = True\n    \"\"\"If True, capture and upload terminal logs\"\"\"\n    checkpoint_name: Optional[str] = None\n    \"\"\"Override the base name for logged checkpoints\"\"\"\n"
  },
  {
    "path": "litgpt/chat/__init__.py",
    "content": ""
  },
  {
    "path": "litgpt/chat/base.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Iterator, List, Literal, Optional, Tuple\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\n\nfrom litgpt.config import Config\nfrom litgpt.model import GPT\nfrom litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style\nfrom litgpt.scripts.merge_lora import merge_lora\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import (\n    auto_download_checkpoint,\n    check_file_size_on_cpu_and_warn,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n    load_checkpoint,\n)\n\n\n@torch.inference_mode()\ndef generate(\n    model: GPT,\n    prompt: torch.Tensor,\n    max_returned_tokens: int,\n    *,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: float = 1.0,\n    stop_tokens: Tuple[List[int], ...] = (),\n) -> Iterator[torch.Tensor]:\n    \"\"\"Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as possible.\n\n    Arguments:\n        model: The model to use.\n        prompt: Tensor of shape (T) with indices of the prompt sequence.\n        max_returned_tokens: The maximum number of tokens to return (given plus generated).\n        temperature: Scales the predicted logits by 1 / temperature\n        top_k: If specified, only sample among the tokens with the k highest probabilities.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        stop_tokens: If specified, stop generating any more token once one of this list is generated.\n    \"\"\"\n    from litgpt.generate.base import generate_fn\n\n    return generate_fn(\n        include_prompt=False,\n        include_eos=False,\n        model=model,\n        prompt=prompt,\n        max_returned_tokens=max_returned_tokens,\n        temperature=temperature,\n        top_k=top_k,\n        top_p=top_p,\n        stop_tokens=stop_tokens,\n    )\n\n\ndef process_prompt(\n    prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens\n):\n    prompt = prompt_style.apply(prompt=prompt)\n    encoded_prompt = tokenizer.encode(prompt, device=fabric.device)\n\n    if max_new_tokens is None:\n        max_returned_tokens = model.max_seq_length\n    else:\n        first_turn = model.mask_cache is None\n        max_returned_tokens = encoded_prompt.size(0) + max_new_tokens\n        if first_turn or max_returned_tokens > model.max_seq_length:\n            model.max_seq_length = max_returned_tokens\n            model.set_kv_cache(batch_size=1, device=fabric.device)\n\n    y: Iterator[torch.Tensor] = generate(\n        model,\n        encoded_prompt,\n        max_returned_tokens,\n        temperature=temperature,\n        top_k=top_k,\n        top_p=top_p,\n        stop_tokens=stop_tokens,\n    )\n    token_generator: Iterator[str] = tokenizer.decode_stream(y, device=fabric.device)\n\n    fabric.print(\">> Reply: \", end=\"\")\n\n    t0 = time.perf_counter()\n\n    tokens_generated = 0\n    for tok in token_generator:\n        tokens_generated += 1\n        fabric.print(tok, end=\"\", flush=True)\n\n    t = time.perf_counter() - t0\n\n    for block in model.transformer.h:\n        block.attn.kv_cache.reset_parameters()\n    fabric.print(\n        f\"\\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec, {tokens_generated} tokens\",\n        file=sys.stderr,\n    )\n    fabric.print()\n\n\ndef interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens):\n    while True:\n        try:\n            if not multiline:\n                prompt = input(\">> Prompt: \")\n            else:\n                print(\">> Prompt: (Type '!submit' on a new line to end input).\")\n                prompt_lines = []\n                while True:\n                    line = input()\n                    if line.strip().lower() in (\"!submit\", \"!quit\", \"!exit\"):\n                        break\n                    prompt_lines.append(line)\n                prompt = \"\\n\".join(prompt_lines)\n\n        except KeyboardInterrupt:\n            break\n\n        prompt = prompt.strip()\n        if not prompt or prompt.lower() in (\"!quit\", \"!exit\"):\n            break\n\n        process_prompt(\n            prompt, model, tokenizer, prompt_style, fabric, temperature, max_new_tokens, top_k, top_p, stop_tokens\n        )\n\n\n@torch.inference_mode()\ndef main(\n    checkpoint_dir: Path,\n    *,\n    max_new_tokens: int = 50,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n    precision: Optional[str] = None,\n    compile: bool = False,\n    multiline: bool = False,\n    access_token: Optional[str] = None,\n) -> None:\n    \"\"\"Chat with a model.\n\n    Args:\n        checkpoint_dir: A local path to a directory containing the model weights or a valid model name.\n            You can get a list of valid model names via the `litgpt download list` command line argument.\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        precision: Indicates the Fabric precision setting to use.\n        compile: Whether to use compilation to speed up token generation. Will increase startup time.\n        multiline: Whether to support multiline input prompts.\n        access_token: Optional API token to access models with restrictions.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    precision = precision or get_default_supported_precision(training=False)\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)\n\n    # Merge if this is a raw LoRA checkpoint\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    if (checkpoint_dir / \"lit_model.pth.lora\").is_file() and not checkpoint_path.is_file():\n        print(\"Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.\")\n        merge_lora(checkpoint_dir)\n\n    if not checkpoint_path.is_file():\n        checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n        checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n\n    check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n        if compile:\n            print(\n                \"IMPORTANT: with enabled compilation the KV-cache size is determined by model's maximum context size, which leads to \"\n                \"a higher memory consumption. In case of an OOM error, try to set `--compile=False`.\"\n            )\n            model.set_kv_cache(batch_size=1)\n    load_checkpoint(fabric, model, checkpoint_path)\n    model.eval()\n\n    if compile:\n        torch._dynamo.config.automatic_dynamic_shapes = True\n        torch._inductor.config.triton.unique_kernel_names = True\n        torch._inductor.config.coordinate_descent_tuning = True\n        global next_token\n        next_token = torch.compile(next_token, mode=\"reduce-overhead\", dynamic=True)\n\n    model = fabric.setup_module(model)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    prompt_style = (\n        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n    )\n    stop_tokens = prompt_style.stop_tokens(tokenizer)\n\n    if multiline:\n        exit_instruction = \"To exit, enter '!quit' or '!exit' on an empty prompt and press 'Enter'.\"\n    else:\n        exit_instruction = \"To exit, press 'Enter' on an empty prompt.\"\n\n    print(f\"Now chatting with {config.name}.\\n{exit_instruction}\\n\")\n    L.seed_everything(1234)\n\n    interact(\n        multiline=multiline,\n        model=model,\n        tokenizer=tokenizer,\n        prompt_style=prompt_style,\n        fabric=fabric,\n        temperature=temperature,\n        max_new_tokens=(None if compile else max_new_tokens),\n        top_k=top_k,\n        top_p=top_p,\n        stop_tokens=stop_tokens,\n    )\n\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"\\nMemory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/config.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom copy import deepcopy\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any, List, Literal, Optional, Type, Union\n\nimport yaml\nfrom typing_extensions import Self\n\n\ndef find_multiple(n: int, k: int) -> int:\n    \"\"\"Utility function for finding the nearest value to n which is a multiple of k.\n\n    NOTE: We define this function in this module rather than `litgpt.utils` so that users can import\n    this file to do configuration manipulations in Python environments which do not include all the dependencies\n    demanded by `litgpt.utils`.\n    \"\"\"\n    assert k > 0\n    if n % k == 0:\n        return n\n    return n + k - (n % k)\n\n\n@dataclass\nclass Config:\n    name: str = \"\"\n    hf_config: dict = field(default_factory=dict)\n    # General size parameters\n    block_size: int = 4096\n    n_layer: int = 16\n    n_embd: int = 4096\n    vocab_size: int = 50254\n    padding_multiple: int = 512\n    padded_vocab_size: Optional[int] = None\n    # Transformer block (structure, normalizations)\n    norm_class_name: Literal[\"LayerNorm\", \"RMSNorm\"] = \"LayerNorm\"\n    norm_eps: float = 1e-5\n    norm_qk: bool = False\n    norm_qk_type: Literal[\"default\", \"olmo2\"] = \"default\"\n    post_attention_norm: bool = False\n    post_mlp_norm: bool = False\n    parallel_residual: bool = True\n    shared_attention_norm: bool = False\n    # Transformer block (self-attention)\n    n_head: int = 32\n    head_size: Optional[int] = None\n    # to use multi-head attention (MHA), set this to `n_head` (default)\n    # to use multi-query attention (MQA), set this to 1\n    # to use grouped-query attention (GQA), set this to a value in between\n    # Example with `n_head=4`\n    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐\n    # │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │\n    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘\n    #   │    │    │    │         │        │                 │\n    # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐\n    # │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │\n    # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘\n    #   │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐\n    # ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐\n    # │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │\n    # └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘\n    # ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶\n    #         MHA                    GQA                   MQA\n    #   n_query_groups=4       n_query_groups=2      n_query_groups=1\n    #\n    # credit https://arxiv.org/pdf/2305.13245.pdf\n    n_query_groups: Optional[int] = None\n    attn_bias: bool = False\n    attention_scores_scalar: Optional[int] = None\n    # If `sliding_window_size` is given, sliding window attention with this\n    # size is used in layers where `sliding_window_indices` has a 1. The\n    # default is all 1, so that sliding window attention is used in all\n    # layers. If `len(sliding_window_indices) > n_layer`, we only use the\n    # initial part.\n    sliding_window_size: Optional[int] = None\n    sliding_window_indices: Optional[List[int]] = None\n    # if `attention_logit_softcapping` is used, cannot use optimized\n    # `torch.nn.functional.scaled_dot_product_attention` (which implements\n    # Flash attention), may result in higher memory and runtime footprint.\n    attention_logit_softcapping: Optional[float] = None\n    # Rotary position embedding (RoPE)\n    rope_base: int = 10000\n    rotary_percentage: float = 0.25\n    rope_condense_ratio: int = 1\n    rope_adjustments: Optional[dict] = None\n    rope_interleave: bool = False\n    # Transformer block (MLP)\n    intermediate_size: Optional[int] = None\n    moe_intermediate_size: Optional[int] = None\n    bias: bool = True\n    mlp_class_name: Literal[\"GptNeoxMLP\", \"LLaMAMLP\", \"GemmaMLP\", \"LLaMAMoE\"] = \"GptNeoxMLP\"\n    gelu_approximate: str = \"none\"\n    n_expert: int = 0\n    n_shared_expert: Optional[int] = None\n    n_expert_groups: Optional[int] = None\n    n_topk_groups: Optional[int] = None\n    n_topk_scores_per_group: Optional[int] = None\n    n_expert_per_token: int = 0\n    first_k_dense_replace: Optional[int] = None\n    routed_scaling_factor: float = 1.0\n    norm_topk_prob: bool = False\n    # GPT before/after blocks\n    scale_embeddings: bool = False\n    lm_head_bias: bool = False\n    final_logit_softcapping: Optional[float] = None\n    norm_1: bool = True\n    norm_2: bool = True\n    latent_attention: Optional[dict] = None\n    # The base period of the RoPE embeddings for local attention.\n    # If not provided, `rope_base` will be used for both local and global attention.\n    rope_local_base_freq: Optional[float] = None\n    # If provided, must have `>= n_layer` entries, either 0 or 1. For 0,\n    # `rope_base` is used, for 1 `rope_local_base_freq` is used. If\n    # `len(rope_indices) > n_layer`, we only use the initial part.\n    rope_indices: Optional[List[int]] = None\n\n    def __post_init__(self):\n        if not self.name:\n            self.name = self.hf_config.get(\"name\", self.name)\n\n        if self.head_size is None:\n            assert self.n_embd % self.n_head == 0\n            self.head_size = self.n_embd // self.n_head\n\n        # vocab size should be a power of 2 to be optimal on hardware. compute the closest value\n        if self.padded_vocab_size is None:\n            self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)\n        else:\n            # vocab size shouldn't be larger than padded vocab size\n            self.vocab_size = min(self.vocab_size, self.padded_vocab_size)\n\n        # compute the number of query groups\n        if self.n_query_groups is not None:\n            assert self.n_head % self.n_query_groups == 0\n        else:\n            self.n_query_groups = self.n_head\n\n        # compute the intermediate size for MLP if not set\n        if self.intermediate_size is None:\n            if self.mlp_class_name == \"LLaMAMLP\":\n                raise ValueError(f\"The config {self.name!r}, needs to set the `intermediate_size`\")\n            self.intermediate_size = 4 * self.n_embd\n\n        self.rope_n_elem = int(self.rotary_percentage * self.head_size)\n\n        if self.sliding_window_size is not None:\n            self.sliding_window_indices = check_indicator_and_length(\n                self.sliding_window_indices,\n                name=\"sliding_window_indices\",\n                required_length=self.n_layer,\n            )\n\n        if self.rope_local_base_freq is not None:\n            self.rope_indices = check_indicator_and_length(\n                self.rope_indices,\n                name=\"rope_indices\",\n                required_length=self.n_layer,\n            )\n\n        if self.latent_attention is not None:\n            self.q_lora_rank = self.latent_attention.get(\"q_lora_rank\")\n            self.kv_lora_rank = self.latent_attention.get(\"kv_lora_rank\")\n            self.qk_rope_head_dim = self.latent_attention.get(\"qk_rope_head_dim\")\n            self.qk_nope_head_dim = self.latent_attention.get(\"qk_nope_head_dim\")\n            self.v_head_dim = self.latent_attention.get(\"v_head_dim\")\n            assert (\n                self.q_lora_rank\n                and self.kv_lora_rank\n                and self.qk_rope_head_dim\n                and self.qk_nope_head_dim\n                and self.v_head_dim\n            ) is not None\n            assert self.n_head == self.n_query_groups, \"Latent attention does not support MQA/GQA\"\n            self.qk_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim\n            self.rope_n_elem = self.qk_rope_head_dim\n        if self.first_k_dense_replace is not None:\n            assert self.mlp_class_name == \"LLaMAMoE\"\n        if self.n_expert_groups is not None:\n            assert self.n_expert % self.n_expert_groups == 0 and self.n_expert_groups > 1\n            assert self.n_topk_groups is not None\n            experts_per_group = self.n_expert // self.n_expert_groups\n            assert self.n_topk_scores_per_group is not None and self.n_topk_scores_per_group <= experts_per_group\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:\n        if name not in name_to_config:\n            # search through all `config['hf_config']['name']`\n            try:\n                conf_dict = next(\n                    config\n                    for config in configs\n                    if name == config[\"hf_config\"][\"name\"]\n                    or config[\"hf_config\"][\"org\"] + \"/\" + config[\"hf_config\"][\"name\"] == name\n                )\n            except StopIteration:\n                raise ValueError(f\"{name!r} is not a supported config name\")\n        else:\n            conf_dict = name_to_config[name]\n\n        conf_dict = conf_dict.copy()\n        conf_dict.update(kwargs)\n        return cls(**conf_dict)\n\n    @classmethod\n    def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:\n        with open(path, encoding=\"utf-8\") as fp:\n            file_kwargs = yaml.safe_load(fp)\n            if file_kwargs is None:\n                raise ValueError(f\"{path} is empty which is likely unexpected.\")\n        file_kwargs.update(kwargs)\n        return cls(**file_kwargs)\n\n    @classmethod\n    def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:\n        \"\"\"Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`.\"\"\"\n        if (config_path := path / \"model_config.yaml\").is_file():\n            return cls.from_file(config_path, **kwargs)\n        if (model_name := path.name) in name_to_config:\n            return cls.from_name(model_name, **kwargs)\n        raise FileNotFoundError(f\"For {str(path)!r} neither 'model_config.yaml' nor matching config exists.\")\n\n    @property\n    def mlp_class(self) -> Type:\n        # `self.mlp_class_name` cannot be the type to keep the config serializable\n        import litgpt.model\n\n        return getattr(litgpt.model, self.mlp_class_name)\n\n    @property\n    def norm_class(self) -> Type:\n        # `self.norm_class_name` cannot be the type to keep the config serializable\n\n        from functools import partial\n\n        import torch  # Torch import is lazy to make config loading faster\n\n        if self.norm_class_name == \"RMSNorm\":\n            from litgpt.model import RMSNorm\n\n            return partial(RMSNorm, add_unit_offset=\"Gemma\" in self.name)\n\n        if self.norm_class_name == \"LayerNorm\" and \"OLMo\" in self.name:\n            # this makes it equivalent to `torch.nn.functional.layer_norm`\n            # that is used by OLMo\n            # Table 5 caption in the OLMo paper shows this - https://aclanthology.org/2024.acl-long.841\n            return partial(torch.nn.LayerNorm, elementwise_affine=False)\n\n        return getattr(torch.nn, self.norm_class_name)\n\n\ndef check_indicator_and_length(\n    params: Optional[List[int]],\n    name: str,\n    required_length: int,\n    use_initial_part: bool = True,\n    def_val: int = 1,\n) -> List[int]:\n    if params is None:\n        return [def_val] * required_length\n    if len(params) != required_length:\n        if use_initial_part and len(params) > required_length:\n            params = params[:required_length]\n        else:\n            raise ValueError(f\"{name} = {params}, must have length {required_length}\")\n    if not set(params).issubset({0, 1}):\n        raise ValueError(f\"{name} = {params}, must only contain 0 and 1\")\n    return params\n\n\n########################\n# Stability AI StableLM\n########################\nconfigs = [\n    # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json\n    dict(name=\"stablelm-base-alpha-3b\", hf_config=dict(org=\"stabilityai\", name=\"stablelm-base-alpha-3b\")),\n    # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json\n    dict(\n        name=\"stablelm-base-alpha-7b\",\n        hf_config=dict(org=\"stabilityai\", name=\"stablelm-base-alpha-7b\"),\n        n_head=48,\n        n_embd=6144,\n        padding_multiple=256,\n    ),\n    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json\n    dict(name=\"stablelm-tuned-alpha-3b\", hf_config=dict(org=\"stabilityai\", name=\"stablelm-tuned-alpha-3b\"), n_head=32),\n    # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json\n    dict(\n        name=\"stablelm-tuned-alpha-7b\",\n        hf_config=dict(org=\"stabilityai\", name=\"stablelm-tuned-alpha-7b\"),\n        n_head=48,\n        n_embd=6144,\n        padding_multiple=256,\n    ),\n    # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json\n    dict(\n        name=\"stablelm-3b-4e1t\",\n        hf_config=dict(org=\"stabilityai\", name=\"stablelm-3b-4e1t\"),\n        padded_vocab_size=50304,\n        n_layer=32,\n        n_head=32,\n        n_embd=2560,\n        parallel_residual=False,\n        bias=False,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=6912,\n    ),\n    # https://huggingface.co/stabilityai/stablelm-zephyr-3b/blob/main/config.json\n    dict(\n        name=\"stablelm-zephyr-3b\",\n        hf_config=dict(org=\"stabilityai\", name=\"stablelm-zephyr-3b\"),\n        padded_vocab_size=50304,\n        n_layer=32,\n        n_head=32,\n        n_embd=2560,\n        parallel_residual=False,\n        bias=False,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=6912,\n    ),\n]\n\n\n##########################\n# Stability AI StableCode\n##########################\nstablecode = [\n    # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json\n    dict(\n        name=\"stablecode-completion-alpha-3b\",\n        hf_config=dict(org=\"stabilityai\", name=\"stablecode-completion-alpha-3b\"),\n        block_size=16384,\n        vocab_size=49152,\n        n_layer=32,\n        n_embd=2560,\n    ),\n    # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json\n    dict(\n        name=\"stablecode-completion-alpha-3b-4k\",\n        hf_config=dict(org=\"stabilityai\", name=\"stablecode-completion-alpha-3b-4k\"),\n        vocab_size=49152,\n        n_layer=32,\n        n_embd=2560,\n    ),\n    # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json\n    dict(\n        name=\"stablecode-instruct-alpha-3b\",\n        hf_config=dict(org=\"stabilityai\", name=\"stablecode-instruct-alpha-3b\"),\n        vocab_size=49152,\n        n_layer=32,\n        n_embd=2560,\n    ),\n    # https://huggingface.co/stabilityai/stable-code-3b/blob/main/config.json\n    dict(\n        name=\"stable-code-3b\",\n        hf_config=dict(org=\"stabilityai\", name=\"stable-code-3b\"),\n        padded_vocab_size=50304,\n        n_layer=32,\n        n_embd=2560,\n        block_size=16384,\n        parallel_residual=False,\n        bias=False,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=6912,\n    ),\n]\nconfigs.extend(stablecode)\n\n\n####################\n# EleutherAI Pythia\n####################\npythia = [\n    # https://huggingface.co/EleutherAI/pythia-14m/blob/main/config.json\n    dict(\n        name=\"pythia-14m\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-14m\"),\n        block_size=512,\n        n_layer=6,\n        n_embd=128,\n        n_head=4,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json\n    dict(\n        name=\"pythia-31m\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-31m\"),\n        block_size=1024,\n        n_layer=6,\n        n_embd=256,\n        n_head=8,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json\n    dict(\n        name=\"pythia-70m\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-70m\"),\n        block_size=2048,\n        n_layer=6,\n        n_embd=512,\n        n_head=8,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json\n    dict(\n        name=\"pythia-160m\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-160m\"),\n        block_size=2048,\n        n_layer=12,\n        n_embd=768,\n        n_head=12,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json\n    dict(\n        name=\"pythia-410m\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-410m\"),\n        block_size=2048,\n        n_layer=24,\n        n_embd=1024,\n        n_head=16,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json\n    dict(\n        name=\"pythia-1b\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-1b\"),\n        block_size=2048,\n        n_embd=2048,\n        n_head=8,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json\n    dict(\n        name=\"pythia-1.4b\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-1.4b\"),\n        block_size=2048,\n        n_layer=24,\n        n_embd=2048,\n        n_head=16,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json\n    dict(\n        name=\"pythia-2.8b\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-2.8b\"),\n        block_size=2048,\n        n_layer=32,\n        n_embd=2560,\n        padding_multiple=128,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json\n    dict(\n        name=\"pythia-6.9b\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-6.9b\"),\n        block_size=2048,\n        n_layer=32,\n        padding_multiple=256,\n    ),\n    # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json\n    dict(\n        name=\"pythia-12b\",\n        hf_config=dict(org=\"EleutherAI\", name=\"pythia-12b\"),\n        block_size=2048,\n        n_layer=36,\n        n_embd=5120,\n        n_head=40,\n    ),\n]\nconfigs.extend(pythia)\nfor c in pythia:\n    # \"pythia-14m\" and \"pythia-31m\" don't have deduped version\n    if c[\"name\"] in (\"pythia-14m\", \"pythia-31m\"):\n        continue\n    copy = deepcopy(c)\n    copy[\"name\"] = f\"{c['name']}-deduped\"\n    copy[\"hf_config\"][\"name\"] = f\"{c['hf_config']['name']}-deduped\"\n    configs.append(copy)\n\n\n#################\n# TII UAE Falcon\n#################\nfalcon = [\n    # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json\n    dict(\n        name=\"falcon-7b{}\",\n        hf_config=dict(org=\"tiiuae\", name=\"falcon-7b{}\"),\n        block_size=2048,\n        vocab_size=65024,\n        padded_vocab_size=65024,\n        n_layer=32,\n        n_head=71,\n        n_embd=4544,\n        rotary_percentage=1.0,\n        n_query_groups=1,\n        bias=False,\n        # this is not in the config, but in the original model implementation, only for this config\n        shared_attention_norm=True,\n    ),\n    # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json\n    dict(\n        name=\"falcon-40b{}\",\n        hf_config=dict(org=\"tiiuae\", name=\"falcon-40b{}\"),\n        block_size=2048,\n        vocab_size=65024,\n        padded_vocab_size=65024,\n        n_layer=60,\n        n_head=128,\n        n_embd=8192,\n        rotary_percentage=1.0,\n        n_query_groups=8,\n        bias=False,\n    ),\n]\nfor c in falcon:\n    for kind in (\"\", \"-instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n# https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json\nfalcon180b = dict(\n    name=\"falcon-180B{}\",\n    hf_config=dict(org=\"tiiuae\", name=\"falcon-180B{}\"),\n    block_size=2048,\n    vocab_size=65024,\n    padded_vocab_size=65024,\n    n_layer=80,\n    n_head=232,\n    n_embd=14848,\n    rotary_percentage=1.0,\n    n_query_groups=8,\n    bias=False,\n)\n\nfor kind in (\"\", \"-chat\"):\n    copy = deepcopy(falcon180b)\n    copy[\"name\"] = falcon180b[\"name\"].format(kind)\n    copy[\"hf_config\"][\"name\"] = falcon180b[\"hf_config\"][\"name\"].format(kind)\n    configs.append(copy)\n\nfalcon3 = [\n    # https://huggingface.co/tiiuae/Falcon3-1B-Base/blob/main/config.json\n    dict(\n        name=\"Falcon3-1B{}\",\n        hf_config=dict(org=\"tiiuae\", name=\"Falcon3-1B{}\"),\n        block_size=4096,\n        vocab_size=131072,\n        padded_vocab_size=131072,\n        n_layer=18,\n        n_head=8,\n        n_query_groups=4,\n        n_embd=2048,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        rope_base=1000042,\n        norm_eps=1e-6,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8192,\n    ),\n    # https://huggingface.co/tiiuae/Falcon3-3B-Base/blob/main/config.json\n    dict(\n        name=\"Falcon3-3B{}\",\n        hf_config=dict(org=\"tiiuae\", name=\"Falcon3-3B{}\"),\n        block_size=32768,\n        vocab_size=131072,\n        padded_vocab_size=131072,\n        n_layer=22,\n        n_head=12,\n        n_query_groups=4,\n        n_embd=3072,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        rope_base=1000042,\n        norm_eps=1e-6,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=9216,\n    ),\n    # https://huggingface.co/tiiuae/Falcon3-7B-Base/blob/main/config.json\n    dict(\n        name=\"Falcon3-7B{}\",\n        hf_config=dict(org=\"tiiuae\", name=\"Falcon3-7B{}\"),\n        block_size=32768,\n        vocab_size=131072,\n        padded_vocab_size=131072,\n        n_layer=28,\n        n_head=12,\n        n_query_groups=4,\n        n_embd=3072,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        rope_base=1000042,\n        norm_eps=1e-6,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=23040,\n    ),\n    # https://huggingface.co/tiiuae/Falcon3-10B-Base/blob/main/config.json\n    dict(\n        name=\"Falcon3-10B{}\",\n        hf_config=dict(org=\"tiiuae\", name=\"Falcon3-10B{}\"),\n        block_size=32768,\n        vocab_size=131072,\n        padded_vocab_size=131072,\n        n_layer=40,\n        n_head=12,\n        n_query_groups=4,\n        n_embd=3072,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        rope_base=1000042,\n        norm_eps=1e-6,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=23040,\n    ),\n]\nfor c in falcon3:\n    for kind in (\"-Base\", \"-Instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n\n#############################\n# OpenLM Research Open LLaMA\n#############################\nopen_LLaMA = [\n    # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json\n    dict(\n        name=\"open_llama_3b\",\n        hf_config=dict(org=\"openlm-research\", name=\"open_llama_3b\"),\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=26,\n        n_embd=3200,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-6,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8640,\n    ),\n    # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json\n    dict(\n        name=\"open_llama_7b\",\n        hf_config=dict(org=\"openlm-research\", name=\"open_llama_7b\"),\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-6,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json\n    dict(\n        name=\"open_llama_13b\",\n        hf_config=dict(org=\"openlm-research\", name=\"open_llama_13b\"),\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-6,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n]\nconfigs.extend(open_LLaMA)\n\n###############\n# Meta LLaMA 2\n###############\nllama_2 = [\n    # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json\n    dict(\n        name=\"Llama-2-7b{}-hf\",\n        hf_config=dict(org=\"meta-llama\", name=\"Llama-2-7b{}-hf\"),\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json\n    dict(\n        name=\"Llama-2-13b{}-hf\",\n        hf_config=dict(org=\"meta-llama\", name=\"Llama-2-13b{}-hf\"),\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n    # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json\n    dict(\n        name=\"Llama-2-70b{}-hf\",\n        hf_config=dict(org=\"meta-llama\", name=\"Llama-2-70b{}-hf\"),\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n    ),\n]\nfor c in llama_2:\n    for kind in (\"\", \"-chat\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n\n###############\n# Meta LLaMA 3\n###############\nllama_3 = [\n    # https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json\n    dict(\n        name=\"Llama-3-8B{}\",\n        hf_config=dict(org=\"meta-llama\", name=\"Meta-Llama-3-8B{}\"),\n        block_size=8192,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=32,\n        n_head=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n        rope_base=500000,\n    ),\n    # https://huggingface.co/meta-llama/Meta-Llama-3.1-8B/blob/main/config.json\n    dict(\n        name=\"Llama-3.1-8B{}\",\n        hf_config=dict(org=\"meta-llama\", name=\"Meta-Llama-3.1-8B{}\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=32,\n        n_head=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n        rope_base=500000,\n        rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n    # https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json\n    dict(\n        name=\"Llama-3-70B{}\",\n        hf_config=dict(org=\"meta-llama\", name=\"Meta-Llama-3-70B{}\"),\n        block_size=8192,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=500000,\n    ),\n    # https://huggingface.co/meta-llama/Meta-Llama-3.1-70B/blob/main/config.json\n    dict(\n        name=\"Llama-3.1-70B{}\",\n        hf_config=dict(org=\"meta-llama\", name=\"Meta-Llama-3.1-70B{}\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=500000,\n        rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n    # https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/main/config.json\n    dict(\n        name=\"Llama-3.1-405B{}\",\n        hf_config=dict(org=\"meta-llama\", name=\"Meta-Llama-3.1-405B{}\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=126,\n        n_head=128,\n        n_embd=16384,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=53248,\n        rope_base=500000,\n        rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n    # https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json\n    dict(\n        name=\"Llama-3.2-1B{}\",\n        hf_config=dict(org=\"meta-llama\", name=\"Llama-3.2-1B{}\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=16,\n        n_embd=2048,\n        n_head=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8192,\n        rope_base=500000,\n        rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n    # https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json\n    dict(\n        name=\"Llama-3.2-3B{}\",\n        hf_config=dict(org=\"meta-llama\", name=\"Llama-3.2-3B{}\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=28,\n        n_embd=3072,\n        n_head=24,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8192,\n        rope_base=500000,\n        rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n    # https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct/blob/main/config.json\n    dict(\n        name=\"Llama-3.3-70B-Instruct\",\n        hf_config=dict(org=\"meta-llama\", name=\"Llama-3.3-70B-Instruct\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=500000,\n        rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n]\nfor c in llama_3:\n    if c[\"name\"] == \"Llama-3.3-70B-Instruct\":\n        configs.append(c)\n        continue\n    for kind in (\"\", \"-Instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n#########################\n# NVIDIA Llama Nemotron\n#########################\nconfigs.append(\n    dict(\n        name=\"Llama-3.1-Nemotron-70B-Instruct-HF\",\n        hf_config=dict(org=\"nvidia\", name=\"Llama-3.1-Nemotron-70B-Instruct-HF\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=500000,\n        rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n)\n\n#################\n# Allen AI OLMo\n#################\nolmo = [\n    # https://huggingface.co/allenai/OLMo-1B-hf/blob/main/config.json\n    dict(\n        name=\"OLMo-1B-hf\",\n        hf_config=dict(org=\"allenai\", name=\"OLMo-1B-hf\"),\n        vocab_size=50280,\n        padded_vocab_size=50304,\n        block_size=2048,\n        n_embd=2048,\n        n_layer=16,\n        n_head=16,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"LayerNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8192,\n    ),\n    # https://huggingface.co/allenai/OLMo-7B-hf/blob/main/config.json\n    dict(\n        name=\"OLMo-7B-hf\",\n        hf_config=dict(org=\"allenai\", name=\"OLMo-7B-hf\"),\n        vocab_size=50280,\n        padded_vocab_size=50304,\n        block_size=2048,\n        n_layer=32,\n        n_head=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"LayerNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    # https://huggingface.co/allenai/OLMo-7B-Instruct-hf/blob/main/config.json\n    dict(\n        name=\"OLMo-7B-Instruct-hf\",\n        hf_config=dict(org=\"allenai\", name=\"OLMo-7B-Instruct-hf\"),\n        vocab_size=50280,\n        padded_vocab_size=50304,\n        block_size=2048,\n        n_layer=32,\n        n_head=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"LayerNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n]\n\nconfigs.extend(olmo)\n\nolmo2 = [\n    # https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json\n    dict(\n        name=\"OLMo-2-1124-7B{}\",\n        hf_config=dict(org=\"allenai\", name=\"OLMo-2-1124-7B{}\"),\n        vocab_size=100278,\n        padded_vocab_size=100352,\n        block_size=4096,\n        n_embd=4096,\n        n_layer=32,\n        n_head=32,\n        n_query_groups=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        norm_eps=1e-06,\n        intermediate_size=11008,\n        rope_base=500000,\n        norm_qk=True,\n        post_mlp_norm=True,\n        norm_1=False,\n        norm_2=False,\n        norm_qk_type=\"olmo2\",\n        post_attention_norm=True,\n    ),\n    # https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json\n    dict(\n        name=\"OLMo-2-1124-13B{}\",\n        hf_config=dict(org=\"allenai\", name=\"OLMo-2-1124-13B{}\"),\n        vocab_size=100278,\n        padded_vocab_size=100352,\n        block_size=4096,\n        n_embd=5120,\n        n_layer=40,\n        n_head=40,\n        n_query_groups=40,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        norm_eps=1e-06,\n        intermediate_size=13824,\n        rope_base=500000,\n        norm_qk=True,\n        post_mlp_norm=True,\n        norm_1=False,\n        norm_2=False,\n        norm_qk_type=\"olmo2\",\n        post_attention_norm=True,\n    ),\n]\n\nfor c in olmo2:\n    for kind in (\"\", \"-SFT\", \"-DPO\", \"-Instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n###############\n# Google Gemma\n###############\ngemma = [\n    # https://huggingface.co/google/gemma-2b/blob/main/config.json\n    dict(\n        name=\"Gemma-2b\",\n        hf_config=dict(org=\"google\", name=\"gemma-2b\"),\n        scale_embeddings=True,\n        vocab_size=256000,\n        padding_multiple=64,\n        n_embd=2048,\n        n_layer=18,\n        n_head=8,\n        n_query_groups=1,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        intermediate_size=16384,\n    ),\n    # https://huggingface.co/google/gemma-7b/blob/main/config.json\n    dict(\n        name=\"Gemma-7b\",\n        hf_config=dict(org=\"google\", name=\"gemma-7b\"),\n        scale_embeddings=True,\n        vocab_size=256000,\n        padding_multiple=64,\n        n_embd=3072,\n        n_layer=28,\n        n_head=16,\n        head_size=256,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        intermediate_size=24576,\n    ),\n    # https://huggingface.co/google/gemma-2-2b/blob/main/config.json\n    dict(\n        name=\"Gemma-2-2b\",\n        hf_config=dict(org=\"google\", name=\"gemma-2-2b\"),\n        scale_embeddings=True,\n        attention_scores_scalar=256,\n        vocab_size=256000,\n        block_size=8192,\n        sliding_window_size=4096,\n        # only layer with idx 0, 2, 4, ... have sliding window attention\n        sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(26)],\n        intermediate_size=9216,\n        n_embd=2304,\n        n_layer=26,\n        n_head=8,\n        n_query_groups=4,\n        head_size=256,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        post_attention_norm=True,\n        post_mlp_norm=True,\n        attention_logit_softcapping=50.0,\n        final_logit_softcapping=30.0,\n    ),\n    # https://huggingface.co/google/gemma-2-9b/blob/main/config.json\n    dict(\n        name=\"Gemma-2-9b\",\n        hf_config=dict(org=\"google\", name=\"gemma-2-9b\"),\n        scale_embeddings=True,\n        attention_scores_scalar=256,\n        vocab_size=256000,\n        block_size=8192,\n        sliding_window_size=4096,\n        # only layer with idx 0, 2, 4, ... have sliding window attention\n        sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(42)],\n        intermediate_size=14336,\n        n_embd=3584,\n        n_layer=42,\n        n_head=16,\n        n_query_groups=8,\n        head_size=256,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        post_attention_norm=True,\n        post_mlp_norm=True,\n        attention_logit_softcapping=50.0,\n        final_logit_softcapping=30.0,\n    ),\n    # https://huggingface.co/google/gemma-2-27b/blob/main/config.json\n    dict(\n        name=\"Gemma-2-27b\",\n        hf_config=dict(org=\"google\", name=\"gemma-2-27b\"),\n        scale_embeddings=True,\n        # In Gemma 2 27B attention scores are scaled not by `sqrt(head_size)` (11.31),\n        # but by `sqrt(n_emb // n_head)` = sqrt(4608 // 32) = 12\n        attention_scores_scalar=144,\n        vocab_size=256000,\n        block_size=8192,\n        sliding_window_size=4096,\n        # only layer with idx 0, 2, 4, ... have sliding window attention\n        sliding_window_indices=[1 if i % 2 == 0 else 0 for i in range(46)],\n        intermediate_size=36864,\n        n_embd=4608,\n        n_layer=46,\n        n_head=32,\n        n_query_groups=16,\n        head_size=128,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        post_attention_norm=True,\n        post_mlp_norm=True,\n        attention_logit_softcapping=50.0,\n        final_logit_softcapping=30.0,\n    ),\n]\nconfigs.extend(gemma)\nfor c in gemma:\n    copy = deepcopy(c)\n    copy[\"name\"] = f\"{c['name']}-it\"\n    copy[\"hf_config\"][\"name\"] = f\"{c['hf_config']['name']}-it\"\n    configs.append(copy)\n\n##################\n# Google Gemma 3\n##################\ngemma3 = [\n    # https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json\n    dict(\n        name=\"Gemma-3-1b-it\",\n        hf_config=dict(org=\"google\", name=\"gemma-3-1b-it\"),\n        scale_embeddings=True,\n        attention_scores_scalar=256,\n        vocab_size=262144,\n        block_size=131072,\n        sliding_window_size=512,\n        # 5 local layers for every global layer\n        sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],\n        intermediate_size=6912,\n        n_embd=1152,\n        n_layer=26,\n        n_head=4,\n        n_query_groups=1,\n        head_size=256,\n        rotary_percentage=1.0,\n        rope_adjustments=None,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        post_attention_norm=True,\n        post_mlp_norm=True,\n        norm_qk=True,\n        rope_base=1000000,\n        rope_local_base_freq=10000,\n        # 5 local layers for every global layer\n        rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],\n    ),\n    # https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json\n    dict(\n        name=\"Gemma-3-4b-it\",\n        hf_config=dict(org=\"google\", name=\"gemma-3-4b-it\"),\n        scale_embeddings=True,\n        attention_scores_scalar=256,\n        vocab_size=262144,\n        block_size=131072,\n        sliding_window_size=1024,\n        # 5 local layers for every global layer\n        sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],\n        intermediate_size=10240,\n        n_embd=2560,\n        n_layer=34,\n        n_head=8,\n        n_query_groups=4,\n        head_size=256,\n        rotary_percentage=1.0,\n        rope_adjustments=dict(factor=8.0),\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        post_attention_norm=True,\n        post_mlp_norm=True,\n        norm_qk=True,\n        rope_base=1000000,\n        rope_local_base_freq=10000,\n        # 5 local layers for every global layer\n        rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],\n    ),\n    # https://huggingface.co/google/gemma-3-12b-it/blob/main/config.json\n    dict(\n        name=\"Gemma-3-12b-it\",\n        hf_config=dict(org=\"google\", name=\"gemma-3-12b-it\"),\n        scale_embeddings=True,\n        attention_scores_scalar=256,\n        vocab_size=262144,\n        block_size=131072,\n        sliding_window_size=1024,\n        # 5 local layers for every global layer\n        sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)],\n        intermediate_size=15360,\n        n_embd=3840,\n        n_layer=48,\n        n_head=16,\n        n_query_groups=8,\n        head_size=256,\n        rotary_percentage=1.0,\n        rope_adjustments=dict(factor=8.0),\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        post_attention_norm=True,\n        post_mlp_norm=True,\n        norm_qk=True,\n        rope_base=1000000,\n        rope_local_base_freq=10000,\n        # 5 local layers for every global layer\n        rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)],\n    ),\n    # https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json\n    dict(\n        name=\"Gemma-3-27b-it\",\n        hf_config=dict(org=\"google\", name=\"gemma-3-27b-it\"),\n        scale_embeddings=True,\n        attention_scores_scalar=168,\n        vocab_size=262144,\n        block_size=131072,\n        sliding_window_size=1024,\n        # 5 local layers for every global layer\n        sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)],\n        intermediate_size=21504,\n        n_embd=5376,\n        n_layer=62,\n        n_head=32,\n        n_query_groups=16,\n        head_size=128,\n        rotary_percentage=1.0,\n        rope_adjustments=dict(factor=8.0),\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        post_attention_norm=True,\n        post_mlp_norm=True,\n        norm_qk=True,\n        rope_base=1000000,\n        rope_local_base_freq=10000,\n        # 5 local layers for every global layer\n        rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(62)],\n    ),\n]\nconfigs.extend(gemma3)\n\n##################\n# Google CodeGemma\n##################\ncodegemma = [\n    # https://huggingface.co/google/codegemma-7b-it/blob/main/config.json\n    dict(\n        name=\"CodeGemma-7b-it\",\n        hf_config=dict(org=\"google\", name=\"codegemma-7b-it\"),\n        scale_embeddings=True,\n        vocab_size=256000,\n        padding_multiple=64,\n        n_embd=3072,\n        n_layer=28,\n        n_head=16,\n        head_size=256,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"GemmaMLP\",\n        gelu_approximate=\"tanh\",\n        intermediate_size=24576,\n    ),\n]\nconfigs.extend(codegemma)\n\n\n##########################\n# Stability AI FreeWilly2\n##########################\nfreewilly_2 = [\n    # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json\n    dict(\n        name=\"FreeWilly2\",\n        hf_config=dict(org=\"stabilityai\", name=\"FreeWilly2\"),\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n    )\n]\nconfigs.extend(freewilly_2)\n\n\n##################\n# Meta Code Llama\n##################\ncode_llama = [\n    # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-7b-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-7b-hf\"),\n        block_size=16384,\n        vocab_size=32016,\n        padding_multiple=16,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-13b-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-13b-hf\"),\n        block_size=16384,\n        vocab_size=32016,\n        padding_multiple=16,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-34b-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-34b-hf\"),\n        block_size=16384,\n        vocab_size=32000,\n        padded_vocab_size=32000,\n        n_layer=48,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=22016,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-70b-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-70b-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-70b-hf\"),\n        block_size=16384,\n        vocab_size=32016,\n        padding_multiple=16,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-7b-Python-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-7b-Python-hf\"),\n        block_size=16384,\n        vocab_size=32000,\n        padded_vocab_size=32000,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-13b-Python-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-13b-Python-hf\"),\n        block_size=16384,\n        vocab_size=32000,\n        padded_vocab_size=32000,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-34b-Python-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-34b-Python-hf\"),\n        block_size=16384,\n        vocab_size=32000,\n        padded_vocab_size=32000,\n        n_layer=48,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=22016,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-70b-Python-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-70b-Python-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-70b-Python-hf\"),\n        block_size=16384,\n        vocab_size=32016,\n        padding_multiple=16,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-7b-Instruct-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-7b-Instruct-hf\"),\n        block_size=16384,\n        vocab_size=32016,\n        padding_multiple=16,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-13b-Instruct-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-13b-Instruct-hf\"),\n        block_size=2048,\n        vocab_size=32016,\n        padding_multiple=16,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-34b-Instruct-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-34b-Instruct-hf\"),\n        block_size=16384,\n        vocab_size=32000,\n        padded_vocab_size=32000,\n        n_layer=48,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=22016,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf/blob/main/config.json\n    dict(\n        name=\"CodeLlama-70b-Instruct-hf\",\n        hf_config=dict(org=\"codellama\", name=\"CodeLlama-70b-Instruct-hf\"),\n        block_size=16384,\n        # 32016 is an added token, so not reported in vocab_size\n        # https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf/blob/main/tokenizer_config.json\n        vocab_size=32015,\n        padding_multiple=16,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=1000000,\n    ),\n]\nconfigs.extend(code_llama)\n\n\n########################\n# garage-bAInd Platypus\n########################\nplatypus = [\n    # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json\n    dict(\n        name=\"Platypus-30B\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Platypus-30B\"),\n        block_size=2048,\n        padded_vocab_size=32000,\n        n_layer=60,\n        n_head=52,\n        n_embd=6656,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-06,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=17920,\n    ),\n    # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json\n    dict(\n        name=\"Platypus2-7B\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Platypus2-7B\"),\n        padded_vocab_size=32000,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n    ),\n    # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json\n    dict(\n        name=\"Platypus2-13B\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Platypus2-13B\"),\n        padded_vocab_size=32000,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n    # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json\n    dict(\n        name=\"Platypus2-70B\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Platypus2-70B\"),\n        padded_vocab_size=32000,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n    ),\n    # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json\n    dict(\n        name=\"Camel-Platypus2-13B\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Camel-Platypus2-13B\"),\n        padded_vocab_size=32000,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n    # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json\n    dict(\n        name=\"Camel-Platypus2-70B\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Camel-Platypus2-70B\"),\n        padded_vocab_size=32000,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n    ),\n    # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json\n    dict(\n        name=\"Stable-Platypus2-13B\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Stable-Platypus2-13B\"),\n        padded_vocab_size=32000,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n    ),\n    # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json\n    dict(\n        name=\"Platypus2-70B-instruct\",\n        hf_config=dict(org=\"garage-bAInd\", name=\"Platypus2-70B-instruct\"),\n        padded_vocab_size=32000,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n    ),\n]\nconfigs.extend(platypus)\n\n\n##################################\n# togethercomputer LLaMA-2-7B-32K\n##################################\ntogether_llama2_32k = [\n    # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json\n    dict(\n        name=\"LLaMA-2-7B-32K\",\n        hf_config=dict(org=\"togethercomputer\", name=\"LLaMA-2-7B-32K\"),\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        rope_condense_ratio=8,\n    )\n]\nconfigs.extend(together_llama2_32k)\n\n\n################\n# Microsoft Phi\n################\nphi = [\n    # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json\n    dict(\n        name=\"phi-1_5\",\n        hf_config=dict(org=\"microsoft\", name=\"phi-1_5\"),\n        vocab_size=50257,\n        padded_vocab_size=51200,\n        block_size=2048,\n        n_embd=2048,\n        n_layer=24,\n        rotary_percentage=0.5,  # 32 / (n_embd / n_head) = 32 / 64\n        shared_attention_norm=True,\n        lm_head_bias=True,\n        gelu_approximate=\"tanh\",\n    ),\n    # https://huggingface.co/microsoft/phi-2/blob/main/config.json\n    dict(\n        name=\"phi-2\",\n        hf_config=dict(org=\"microsoft\", name=\"phi-2\"),\n        vocab_size=50257,\n        padded_vocab_size=51200,\n        block_size=2048,\n        n_embd=2560,\n        n_layer=32,\n        rotary_percentage=0.4,  # 32 / (n_embd / n_head) = 32 / 80\n        shared_attention_norm=True,\n        lm_head_bias=True,\n        gelu_approximate=\"tanh\",\n    ),\n    # https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json\n    dict(\n        name=\"Phi-3-mini-4k-instruct\",\n        hf_config=dict(org=\"microsoft\", name=\"Phi-3-mini-4k-instruct\"),\n        vocab_size=32000,\n        padded_vocab_size=32064,\n        block_size=4096,\n        n_embd=3072,\n        n_layer=32,\n        rotary_percentage=1.0,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=8192,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n        sliding_window_size=2048,\n    ),\n    # https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/main/config.json\n    dict(\n        name=\"Phi-3-mini-128k-instruct\",\n        hf_config=dict(org=\"microsoft\", name=\"Phi-3-mini-128k-instruct\"),\n        vocab_size=32000,\n        padded_vocab_size=32064,\n        block_size=131072,\n        n_embd=3072,\n        n_layer=32,\n        rotary_percentage=1.0,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=8192,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n        sliding_window_size=262145,\n    ),\n    # https://huggingface.co/microsoft/Phi-3.5-mini-instruct/blob/main/config.json\n    dict(\n        name=\"Phi-3.5-mini-instruct\",\n        hf_config=dict(org=\"microsoft\", name=\"Phi-3.5-mini-instruct\"),\n        vocab_size=32000,\n        padded_vocab_size=32064,\n        block_size=4096,\n        n_embd=3072,\n        n_layer=32,\n        rotary_percentage=1.0,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=8192,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n    ),\n    # https://huggingface.co/microsoft/phi-4/blob/main/config.json\n    dict(\n        name=\"phi-4\",\n        hf_config=dict(org=\"microsoft\", name=\"phi-4\"),\n        vocab_size=100352,\n        padded_vocab_size=100352,\n        block_size=16384,\n        n_embd=5120,\n        n_layer=40,\n        n_head=40,\n        n_query_groups=10,\n        rotary_percentage=1.0,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=17920,\n        rope_base=250000,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n    ),\n    # https://huggingface.co/microsoft/Phi-4-reasoning/blob/main/config.json\n    dict(\n        name=\"Phi-4-reasoning\",\n        hf_config=dict(org=\"microsoft\", name=\"Phi-4-reasoning\"),\n        vocab_size=100352,\n        padded_vocab_size=100352,\n        block_size=32768,\n        n_embd=5120,\n        n_layer=40,\n        n_head=40,\n        n_query_groups=10,\n        rotary_percentage=1.0,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=17920,\n        rope_base=500000,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n    ),\n    # https://huggingface.co/microsoft/Phi-4-reasoning-plus/blob/main/config.json\n    dict(\n        name=\"Phi-4-reasoning-plus\",\n        hf_config=dict(org=\"microsoft\", name=\"Phi-4-reasoning-plus\"),\n        vocab_size=100352,\n        padded_vocab_size=100352,\n        block_size=32768,\n        n_embd=5120,\n        n_layer=40,\n        n_head=40,\n        n_query_groups=10,\n        rotary_percentage=1.0,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=17920,\n        rope_base=500000,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n    ),\n    # https://huggingface.co/microsoft/Phi-4-mini-instruct/blob/main/config.json\n    dict(\n        name=\"Phi-4-mini-instruct\",\n        hf_config=dict(org=\"microsoft\", name=\"Phi-4-mini-instruct\"),\n        vocab_size=200019,\n        padded_vocab_size=200064,\n        block_size=131072,\n        n_embd=3072,\n        n_layer=32,\n        n_head=24,\n        n_query_groups=8,\n        rotary_percentage=0.75,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=8192,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n        sliding_window_size=262145,\n    ),\n    # https://huggingface.co/microsoft/Phi-4-mini-reasoning/blob/main/config.json\n    dict(\n        name=\"Phi-4-mini-reasoning\",\n        hf_config=dict(org=\"microsoft\", name=\"Phi-4-mini-reasoning\"),\n        vocab_size=200019,\n        padded_vocab_size=200064,\n        block_size=131072,\n        n_embd=3072,\n        n_layer=32,\n        n_head=24,\n        n_query_groups=8,\n        rotary_percentage=0.75,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        intermediate_size=8192,\n        mlp_class_name=\"LLaMAMLP\",\n        parallel_residual=False,\n        sliding_window_size=262145,\n    ),\n]\nconfigs.extend(phi)\n\n\n#############\n# Mistral AI\n#############\n\nconfigs.append(\n    # https://huggingface.co/mistralai/mathstral-7B-v0.1/blob/main/config.json\n    dict(\n        name=\"Mathstral-7B-v0.1\",\n        hf_config=dict(org=\"mistralai\", name=\"mathstral-7B-v0.1\"),\n        padded_vocab_size=32768,\n        block_size=32768,\n        n_layer=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n        sliding_window_size=4096,\n    )\n)\n\nmistral = [\n    # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json\n    dict(\n        name=\"Mistral-7B-{}v0.1\",\n        hf_config=dict(org=\"mistralai\", name=\"Mistral-7B-{}v0.1\"),\n        padded_vocab_size=32000,\n        block_size=4096,  # should be 32768 but sliding window attention is not implemented\n        n_layer=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n        sliding_window_size=4096,\n    ),\n    # https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json\n    dict(\n        name=\"Mixtral-8x7B-{}v0.1\",\n        hf_config=dict(org=\"mistralai\", name=\"Mixtral-8x7B-{}v0.1\"),\n        padded_vocab_size=32000,\n        block_size=32768,\n        n_layer=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=14336,\n        rope_base=1000000,\n        n_expert=8,\n        n_expert_per_token=2,\n    ),\n    # https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1/blob/main/config.json\n    dict(\n        name=\"Mixtral-8x22B-{}v0.1\",\n        hf_config=dict(org=\"mistralai\", name=\"Mixtral-8x22B-{}v0.1\"),\n        padded_vocab_size=32768,\n        block_size=65536,\n        n_layer=56,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=16384,\n        n_head=48,\n        n_embd=6144,\n        rope_base=1000000,\n        n_expert=8,\n        n_expert_per_token=2,\n    ),\n]\nfor c in mistral:\n    for kind in (\"\", \"Instruct-\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\nconfigs.append(\n    # https://huggingface.co/unsloth/mistral-7b-v0.2/blob/main/config.json\n    dict(\n        name=\"Mistral-7B-v0.2\",\n        hf_config=dict(org=\"unsloth\", name=\"Mistral-7B-v0.2\"),\n        padded_vocab_size=32000,\n        block_size=32768,\n        n_layer=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n    )\n)\nconfigs.append(\n    # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json\n    dict(\n        name=\"Mistral-7B-Instruct-v0.2\",\n        hf_config=dict(org=\"mistralai\", name=\"Mistral-7B-Instruct-v0.2\"),\n        padded_vocab_size=32000,\n        block_size=32768,\n        n_layer=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n    )\n)\nconfigs.append(\n    # https://huggingface.co/mistralai/Mistral-7B-v0.3/blob/main/config.json\n    dict(\n        name=\"Mistral-7B-v0.3\",\n        hf_config=dict(org=\"mistralai\", name=\"Mistral-7B-v0.3\"),\n        padded_vocab_size=32768,\n        block_size=32768,\n        n_layer=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n    )\n)\nconfigs.append(\n    # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/blob/main/config.json\n    dict(\n        name=\"Mistral-7B-Instruct-v0.3\",\n        hf_config=dict(org=\"mistralai\", name=\"Mistral-7B-Instruct-v0.3\"),\n        padded_vocab_size=32768,\n        block_size=32768,\n        n_layer=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n    )\n)\nconfigs.append(\n    # https://huggingface.co/mistralai/Mistral-Large-Instruct-2407/blob/main/config.json\n    dict(\n        name=\"Mistral-Large-Instruct-2407\",\n        hf_config=dict(org=\"mistralai\", name=\"Mistral-Large-Instruct-2407\"),\n        padded_vocab_size=32768,\n        block_size=32768,\n        n_layer=88,\n        n_head=96,\n        n_embd=12288,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n    )\n)\nconfigs.append(\n    # https://huggingface.co/mistralai/Mistral-Large-Instruct-2411/blob/main/config.json\n    dict(\n        name=\"Mistral-Large-Instruct-2411\",\n        hf_config=dict(org=\"mistralai\", name=\"Mistral-Large-Instruct-2411\"),\n        padded_vocab_size=32768,\n        block_size=32768,\n        n_layer=88,\n        n_head=96,\n        n_embd=12288,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        norm_eps=1e-05,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n    )\n)\n\n\n############\n# TinyLlama\n############\ntiny_llama = [\n    dict(\n        name=\"tiny-llama-1.1b{}\",\n        hf_config=dict(org=\"TinyLlama\", name=\"TinyLlama-1.1B{}\"),\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=22,\n        n_head=32,\n        n_embd=2048,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",  # original TinyLlama use FusedRMSNorm\n        norm_eps=1e-5,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=5632,\n        n_query_groups=4,\n    )\n]\nfor c in tiny_llama:\n    for kind, hf_postfix in ((\"\", \"-intermediate-step-1431k-3T\"), (\"-chat\", \"-Chat-v1.0\")):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(hf_postfix)\n        configs.append(copy)\n\n\n############\n# MicroLlama\n############\nmicro_llama = [\n    dict(\n        name=\"micro-llama-300M\",\n        hf_config=dict(org=\"keeeeenw\", name=\"MicroLlama\"),\n        block_size=2048,\n        vocab_size=32000,\n        padding_multiple=64,\n        n_layer=12,\n        n_head=16,\n        n_embd=1024,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",  # original TinyLlama and MicroLlama use FusedRMSNorm\n        norm_eps=1e-5,\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=5632,\n        n_query_groups=4,\n    )\n]\nconfigs.extend(micro_llama)\n\n\n##########################\n# Trelis Function Calling\n##########################\nllama_2_function_calling = [\n    # https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2/blob/main/config.json\n    dict(\n        name=\"Llama-2-7b-chat-hf-function-calling-v2\",\n        hf_config=dict(org=\"Trelis\", name=\"Llama-2-7b-chat-hf-function-calling-v2\"),\n        padding_multiple=64,\n        n_layer=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        norm_eps=1e-6,\n        block_size=4096,\n        vocab_size=32000,\n        n_head=32,\n        n_embd=4096,\n        rope_base=10000,\n    )\n]\n\nconfigs.extend(llama_2_function_calling)\n\n##########\n# Qwen2.5\n##########\nqwen_2_5 = [\n    # https://huggingface.co/Qwen/Qwen2.5-0.5B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-0.5B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-0.5B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=24,\n        n_head=14,\n        n_embd=896,\n        n_query_groups=2,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=4864,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-1.5B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-1.5B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-1.5B{}\"),\n        block_size=131072,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=28,\n        n_head=12,\n        n_embd=1536,\n        n_query_groups=2,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8960,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-3B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-3B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-3B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=36,\n        n_head=16,\n        n_embd=2048,\n        n_query_groups=2,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-7B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-7B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-7B{}\"),\n        block_size=131072,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=28,\n        n_head=28,\n        n_embd=3584,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=18944,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-14B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-14B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-14B{}\"),\n        block_size=131072,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=48,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n        norm_eps=1e-5,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-32B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-32B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-32B{}\"),\n        block_size=131072,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=64,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=27648,\n        norm_eps=1e-5,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-72B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-72B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-72B{}\"),\n        block_size=131072,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=29568,\n        norm_eps=1e-5,\n        rope_base=1000000,\n    ),\n]\n\nqwen_2_5_coder = [\n    # https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Coder-0.5B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Coder-0.5B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=24,\n        n_head=14,\n        n_embd=896,\n        n_query_groups=2,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=4864,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-Coder-1.5B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Coder-1.5B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Coder-1.5B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=28,\n        n_head=12,\n        n_embd=1536,\n        n_query_groups=2,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8960,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-Coder-3B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Coder-3B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Coder-3B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=36,\n        n_head=16,\n        n_embd=2048,\n        n_query_groups=2,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-Coder-7B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Coder-7B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Coder-7B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=28,\n        n_head=28,\n        n_embd=3584,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=18944,\n        norm_eps=1e-6,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-Coder-14B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Coder-14B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Coder-14B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=48,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n        norm_eps=1e-5,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-Coder-32B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Coder-32B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Coder-32B{}\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=64,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=27648,\n        norm_eps=1e-5,\n        rope_base=1000000,\n    ),\n]\n\nqwen_2_5.extend(qwen_2_5_coder)\n\nqwen_2_5_math = [\n    # https://huggingface.co/Qwen/Qwen2.5-Math-1.5B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Math-1.5B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Math-1.5B{}\"),\n        block_size=4096,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=28,\n        n_head=12,\n        n_embd=1536,\n        n_query_groups=2,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8960,\n        norm_eps=1e-6,\n        rope_base=10000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-Math-7B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Math-7B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Math-7B{}\"),\n        block_size=4096,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=28,\n        n_head=28,\n        n_embd=3584,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=18944,\n        norm_eps=1e-6,\n        rope_base=10000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-Math-72B/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-Math-72B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-Math-72B{}\"),\n        block_size=4096,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=29568,\n        norm_eps=1e-5,\n        rope_base=10000,\n    ),\n]\n\nqwen_2_5.extend(qwen_2_5_math)\n\nfor c in qwen_2_5:\n    for kind in (\"\", \"-Instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\nqwen_2_5_1m = [\n    # https://huggingface.co/Qwen/Qwen2.5-7B-Instruct-1M/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-7B-Instruct-1M\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-7B-Instruct-1M\"),\n        block_size=1010000,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=28,\n        n_head=28,\n        n_embd=3584,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=18944,\n        norm_eps=1e-5,\n        rope_base=10000000,\n    ),\n    # https://huggingface.co/Qwen/Qwen2.5-14B-Instruct-1M/blob/main/config.json\n    dict(\n        name=\"Qwen2.5-14B-Instruct-1M\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen2.5-14B-Instruct-1M\"),\n        block_size=1010000,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=48,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=13824,\n        norm_eps=1e-5,\n        rope_base=10000000,\n    ),\n]\n\nconfigs.extend(qwen_2_5_1m)\n\n##########\n# QwQ\n##########\nqwq = [\n    # https://huggingface.co/Qwen/QwQ-32B/blob/main/config.json\n    dict(\n        name=\"QwQ-32B\",\n        hf_config=dict(org=\"Qwen\", name=\"QwQ-32B\"),\n        block_size=131072,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=64,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=27648,\n        norm_eps=1e-5,\n        rope_base=1000000,\n    ),\n    # https://huggingface.co/Qwen/QwQ-32B-Preview/blob/main/config.json\n    dict(\n        name=\"QwQ-32B-Preview\",\n        hf_config=dict(org=\"Qwen\", name=\"QwQ-32B-Preview\"),\n        block_size=32768,\n        vocab_size=151643,\n        padded_vocab_size=152064,\n        n_layer=64,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        attn_bias=True,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=27648,\n        norm_eps=1e-5,\n        rope_base=1000000,\n    ),\n]\n\nconfigs.extend(qwq)\n\n##########\n# Qwen3\n##########\nqwen_3 = [\n    # https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json\n    dict(\n        name=\"Qwen3-0.6B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-0.6B{}\"),\n        block_size=40960,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=28,\n        n_head=16,\n        n_embd=1024,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=3072,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        head_size=128,\n        norm_qk=True,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/config.json\n    dict(\n        name=\"Qwen3-1.7B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-1.7B{}\"),\n        block_size=40960,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=28,\n        n_head=16,\n        n_embd=2048,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=6144,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        norm_qk=True,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-4B/blob/main/config.json\n    dict(\n        name=\"Qwen3-4B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-4B{}\"),\n        block_size=40960,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=36,\n        n_head=32,\n        n_embd=2560,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=9728,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        head_size=128,\n        norm_qk=True,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-8B/blob/main/config.json\n    dict(\n        name=\"Qwen3-8B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-8B{}\"),\n        block_size=40960,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=36,\n        n_head=32,\n        n_embd=4096,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=12288,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        norm_qk=True,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-14B/blob/main/config.json\n    dict(\n        name=\"Qwen3-14B{}\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-14B{}\"),\n        block_size=40960,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=40,\n        n_head=40,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=17408,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        norm_qk=True,\n    ),\n]\nfor c in qwen_3:\n    for kind in (\"\", \"-Base\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\nqwen_3_32b = [\n    # https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json\n    dict(\n        name=\"Qwen3-32B\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-32B\"),\n        block_size=40960,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=64,\n        n_head=64,\n        n_embd=5120,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=25600,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        head_size=128,\n        norm_qk=True,\n    ),\n]\nconfigs.extend(qwen_3_32b)\n\nqwen_3_moe = [\n    # https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/config.json\n    dict(\n        name=\"Qwen3-30B-A3B\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-30B-A3B\"),\n        block_size=40960,\n        head_size=128,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=48,\n        n_head=32,\n        n_embd=2048,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=6144,\n        moe_intermediate_size=768,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        norm_qk=True,\n        n_expert=128,\n        n_expert_per_token=8,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-30B-A3B-Base/blob/main/config.json\n    dict(\n        name=\"Qwen3-30B-A3B-Base\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-30B-A3B-Base\"),\n        block_size=40960,\n        head_size=128,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=48,\n        n_head=32,\n        n_embd=2048,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=6144,\n        moe_intermediate_size=768,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        norm_qk=True,\n        n_expert=128,\n        n_expert_per_token=8,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json\n    dict(\n        name=\"Qwen3-235B-A22B\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-235B-A22B\"),\n        block_size=40960,\n        head_size=128,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=94,\n        n_head=64,\n        n_embd=4096,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=12288,\n        moe_intermediate_size=1536,\n        norm_eps=1e-6,\n        rope_base=1000000,\n        norm_qk=True,\n        n_expert=128,\n        n_expert_per_token=8,\n    ),\n]\nconfigs.extend(qwen_3_moe)\n\nqwen_3_2507_thinking_instruct = [\n    # https://huggingface.co/Qwen/Qwen3-235B-A22B-Thinking-2507/blob/main/config.json\n    dict(\n        name=\"Qwen3-235B-A22B-{}-2507\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-235B-A22B-{}-2507\"),\n        block_size=262144,\n        head_size=128,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=94,\n        n_head=64,\n        n_embd=4096,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=12288,\n        moe_intermediate_size=1536,\n        norm_eps=1e-6,\n        rope_base=5000000,\n        norm_qk=True,\n        n_expert=128,\n        n_expert_per_token=8,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-30B-A3B-Thinking-2507/blob/main/config.json\n    dict(\n        name=\"Qwen3-30B-A3B-{}-2507\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-30B-A3B-{}-2507\"),\n        block_size=262144,\n        head_size=128,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=48,\n        n_head=32,\n        n_embd=2048,\n        n_query_groups=4,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=6144,\n        moe_intermediate_size=768,\n        norm_eps=1e-6,\n        rope_base=10000000,\n        norm_qk=True,\n        n_expert=128,\n        n_expert_per_token=8,\n    ),\n    # https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507/blob/main/config.json\n    dict(\n        name=\"Qwen3-4B-{}-2507\",\n        hf_config=dict(org=\"Qwen\", name=\"Qwen3-4B-{}-2507\"),\n        block_size=262144,\n        vocab_size=151643,\n        padded_vocab_size=151936,\n        n_layer=36,\n        n_head=32,\n        n_embd=2560,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=9728,\n        norm_eps=1e-6,\n        rope_base=5000000,\n        head_size=128,\n        norm_qk=True,\n    ),\n]\n\nfor c in qwen_3_2507_thinking_instruct:\n    for kind in (\"Thinking\", \"Instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n#############\n# Salamandra\n#############\nsalamandra = [\n    # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json\n    dict(\n        name=\"salamandra-2b{}\",\n        hf_config=dict(org=\"BSC-LT\", name=\"salamandra-2b{}\"),\n        block_size=8192,\n        vocab_size=256000,\n        padded_vocab_size=256000,\n        n_layer=24,\n        n_head=16,\n        n_embd=2048,\n        n_query_groups=16,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=5440,\n        norm_eps=1e-5,\n        rope_base=10000,\n    ),\n    # https://huggingface.co/BSC-LT/salamandra-7b-instruct/blob/main/config.json\n    dict(\n        name=\"salamandra-7b{}\",\n        hf_config=dict(org=\"BSC-LT\", name=\"salamandra-7b{}\"),\n        block_size=8192,\n        vocab_size=256000,\n        padded_vocab_size=256000,\n        n_layer=32,\n        n_head=32,\n        n_embd=4096,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=11008,\n        norm_eps=1e-6,\n        rope_base=10000,\n    ),\n]\n\nfor c in salamandra:\n    for kind in (\"\", \"-instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n\n###############\n# SmolLM2\n###############\nsmollm2 = [\n    # https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json\n    dict(\n        name=\"SmolLM2-135M{}\",\n        hf_config=dict(org=\"HuggingFaceTB\", name=\"SmolLM2-135M{}\"),\n        block_size=8192,\n        vocab_size=49152,\n        padded_vocab_size=49152,\n        n_layer=30,\n        n_head=9,\n        n_embd=576,\n        n_query_groups=3,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=1536,\n        rope_base=100000,\n        norm_eps=1e-5,\n    ),\n    # https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json\n    dict(\n        name=\"SmolLM2-360M{}\",\n        hf_config=dict(org=\"HuggingFaceTB\", name=\"SmolLM2-360M{}\"),\n        block_size=8192,\n        vocab_size=49152,\n        padded_vocab_size=49152,\n        n_layer=32,\n        n_head=15,\n        n_embd=960,\n        n_query_groups=5,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=2560,\n        rope_base=100000,\n        norm_eps=1e-5,\n    ),\n    # https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json\n    dict(\n        name=\"SmolLM2-1.7B{}\",\n        hf_config=dict(org=\"HuggingFaceTB\", name=\"SmolLM2-1.7B{}\"),\n        block_size=8192,\n        vocab_size=49152,\n        padded_vocab_size=49152,\n        n_layer=24,\n        n_head=32,\n        n_embd=2048,\n        n_query_groups=32,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=8192,\n        rope_base=130000,\n        norm_eps=1e-5,\n    ),\n]\n\nfor c in smollm2:\n    for kind in (\"\", \"-Instruct\"):\n        copy = deepcopy(c)\n        copy[\"name\"] = c[\"name\"].format(kind)\n        copy[\"hf_config\"][\"name\"] = c[\"hf_config\"][\"name\"].format(kind)\n        configs.append(copy)\n\n###############\n# DeepSeek R1 Distill\n###############\n\nr1_distill_llama = [\n    # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/blob/main/config.json\n    dict(\n        name=\"R1-Distill-Llama-8B\",\n        hf_config=dict(org=\"deepseek-ai\", name=\"DeepSeek-R1-Distill-Llama-8B\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=32,\n        n_head=32,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=14336,\n        rope_base=500000,\n        rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n    # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/blob/main/config.json\n    dict(\n        name=\"R1-Distill-Llama-70B\",\n        hf_config=dict(org=\"deepseek-ai\", name=\"DeepSeek-R1-Distill-Llama-70B\"),\n        block_size=131072,\n        vocab_size=128000,\n        padded_vocab_size=128256,\n        n_layer=80,\n        n_head=64,\n        n_embd=8192,\n        n_query_groups=8,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=28672,\n        rope_base=500000,\n        rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192),\n    ),\n]\n\nconfigs.extend(r1_distill_llama)\n\nname_to_config = {config[\"name\"]: config for config in configs}\n"
  },
  {
    "path": "litgpt/constants.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\"\"\"Centralized package availability constants for optional dependencies.\"\"\"\n\nfrom lightning_utilities.core.imports import RequirementCache\n\n# Logger-related constants\n_SUPPORTED_LOGGERS: tuple[str, ...] = (\"csv\", \"tensorboard\", \"wandb\", \"mlflow\", \"litlogger\")\n\n# Logger-related optional dependencies\n_LITLOGGER_AVAILABLE = RequirementCache(\"litlogger>=0.1.7\")\n_TENSORBOARD_AVAILABLE = RequirementCache(\"tensorboard\")\n_WANDB_AVAILABLE = RequirementCache(\"wandb\")\n_MLFLOW_AVAILABLE = RequirementCache(\"mlflow\")\n_MLFLOW_SKINNY_AVAILABLE = RequirementCache(\"mlflow-skinny\")\n\n# PyTorch version-specific constants\n_TORCH_EQUAL_2_7 = RequirementCache(\"torch>=2.7.0,<2.8\")\n_TORCH_EQUAL_2_8 = RequirementCache(\"torch>=2.8.0,<2.9\")\n\n# Other optional dependencies\n_REQUESTS_AVAILABLE = RequirementCache(\"requests\")\n_THUNDER_AVAILABLE = RequirementCache(\"thunder\")\n_TRITON_AVAILABLE = RequirementCache(\"triton\")\n_BITANDBYTES_AVAILABLE = RequirementCache(\"bitsandbytes\")\n_BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0 = RequirementCache(\"bitsandbytes != 0.42.0\")\n_LITDATA_AVAILABLE = RequirementCache(\"litdata\")\n_LITSERVE_AVAILABLE = RequirementCache(\"litserve\")\n_JINJA2_AVAILABLE = RequirementCache(\"jinja2\")\n_SAFETENSORS_AVAILABLE = RequirementCache(\"safetensors\")\n_HF_TRANSFER_AVAILABLE = RequirementCache(\"hf_transfer\")\n"
  },
  {
    "path": "litgpt/data/__init__.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom litgpt.data.alpaca import Alpaca\nfrom litgpt.data.alpaca_2k import Alpaca2k\nfrom litgpt.data.alpaca_gpt4 import AlpacaGPT4\nfrom litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn\nfrom litgpt.data.deita import Deita\nfrom litgpt.data.flan import FLAN\nfrom litgpt.data.json_data import JSON\nfrom litgpt.data.lima import LIMA\nfrom litgpt.data.lit_data import LitData\nfrom litgpt.data.longform import LongForm\nfrom litgpt.data.microllama import MicroLlama\nfrom litgpt.data.openwebtext import OpenWebText\nfrom litgpt.data.text_files import TextFiles\nfrom litgpt.data.tinyllama import TinyLlama\nfrom litgpt.data.tinystories import TinyStories\n\n__all__ = [\n    \"Alpaca\",\n    \"Alpaca2k\",\n    \"AlpacaGPT4\",\n    \"Deita\",\n    \"FLAN\",\n    \"JSON\",\n    \"LIMA\",\n    \"LitData\",\n    \"DataModule\",\n    \"LongForm\",\n    \"OpenWebText\",\n    \"SFTDataset\",\n    \"TextFiles\",\n    \"TinyLlama\",\n    \"TinyStories\",\n    \"MicroLlama\",\n    \"get_sft_collate_fn\",\n]\n"
  },
  {
    "path": "litgpt/data/alpaca.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\"\"\"Implementation derived from https://github.com/tloen/alpaca-lora\"\"\"\n\nimport json\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nimport torch\nfrom torch.utils.data import DataLoader, random_split\n\nfrom litgpt.constants import _REQUESTS_AVAILABLE\nfrom litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n_URL = \"https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json\"\n\n\n@dataclass\nclass Alpaca(DataModule):\n    \"\"\"Alpaca data module for supervised finetuning.\"\"\"\n\n    mask_prompt: bool = False\n    \"\"\"Whether to mask the prompt section from the label (with ``ignore_index``).\"\"\"\n    val_split_fraction: float = 0.03865  # to get exactly 2000 validation samples,\n    \"\"\"The fraction of the dataset to use for the validation dataset. The rest is used for training.\"\"\"\n    prompt_style: Union[str, PromptStyle] = \"alpaca\"\n    \"\"\"The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.\"\"\"\n    ignore_index: int = -100\n    \"\"\"The index to use for elements to be ignored in the label.\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for creating the train/val splits and shuffling the dataset.\"\"\"\n    num_workers: int = 4\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n    download_dir: Path = Path(\"./data/alpaca\")\n    \"\"\"The directory in which the downloaded dataset gets saved.\"\"\"\n    file_url: str = field(repr=False, default=_URL)\n    \"\"\"The URL from where to download the dataset.\"\"\"\n    file_name: str = field(repr=False, default=\"alpaca_data_cleaned_archive.json\")\n    \"\"\"The name of the dataset file to download.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n    train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n    test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n\n    def __post_init__(self) -> None:\n        super().__init__()\n        if isinstance(self.prompt_style, str):\n            self.prompt_style = PromptStyle.from_name(self.prompt_style)\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = -1 if max_seq_length is None else max_seq_length\n\n    def prepare_data(self) -> None:\n        self.download_dir.mkdir(parents=True, exist_ok=True)\n        download_if_missing(self.download_dir / self.file_name, self.file_url)\n\n    def setup(self, stage: str = \"\") -> None:\n        with open(self.download_dir / self.file_name, encoding=\"utf-8\") as file:\n            data = json.load(file)\n\n        # Partition the dataset into train and test\n        train_data, test_data = random_split(\n            data,\n            [1.0 - self.val_split_fraction, self.val_split_fraction],\n            generator=torch.Generator().manual_seed(self.seed),\n        )\n        train_data, test_data = list(train_data), list(test_data)\n\n        self.train_dataset = SFTDataset(\n            data=train_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n        self.test_dataset = SFTDataset(\n            data=test_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n\n    def train_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=True,\n            generator=torch.Generator().manual_seed(self.seed),\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n    def val_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n\ndef download_if_missing(file_path: Path, file_url: str, mode: str = \"w\", stream: bool = False) -> None:\n    \"\"\"Downloads the raw json data file and saves it in the given destination.\"\"\"\n    if file_path.exists() and file_path.stat().st_size > 0:\n        return\n    if not _REQUESTS_AVAILABLE:\n        raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE))\n    import requests\n\n    response = requests.get(file_url, stream=stream)\n    with open(file_path, mode, encoding=None if mode == \"wb\" else \"utf-8\") as f:\n        if stream:\n            # credit: https://github.com/karpathy/llama2.c/blob/b3c4b6/tinystories.py#L25-L38\n            from tqdm import tqdm\n\n            pbar = tqdm(\n                desc=str(file_path),\n                total=int(response.headers.get(\"content-length\", 0)),\n                unit=\"iB\",\n                unit_scale=True,\n                unit_divisor=1024,\n            )\n            for data in response.iter_content(chunk_size=1024):\n                size = f.write(data)\n                pbar.update(size)\n            pbar.close()\n        else:\n            f.write(response.text)\n"
  },
  {
    "path": "litgpt/data/alpaca_2k.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\n\nfrom litgpt.data.alpaca import Alpaca\nfrom litgpt.data.base import SFTDataset\n\n\n@dataclass\nclass Alpaca2k(Alpaca):\n    \"\"\"Alpaca2k data module for supervised finetuning.\"\"\"\n\n    val_split_fraction: float = 0.05  # to get exactly 100 validation samples,\n    \"\"\"The fraction of the dataset to use for the validation dataset. The rest is used for training.\"\"\"\n    download_dir: Path = Path(\"./data/alpaca2k\")\n    \"\"\"The directory in which the downloaded datasetgets saved.\"\"\"\n    repo_id: str = field(repr=False, default=\"mhenrichsen/alpaca_2k_test\")\n    \"\"\"The URL from where to download the dataset.\"\"\"\n    file_name: str = field(repr=False, default=\"alpaca2k_data_cleaned_archive.json\")\n    \"\"\"The name of the dataset file to download.\"\"\"\n\n    def prepare_data(self) -> None:\n        from datasets import load_dataset\n\n        load_dataset(self.repo_id, cache_dir=self.download_dir)\n\n    def setup(self, stage: str = \"\") -> None:\n        from datasets import load_dataset\n\n        dataset = load_dataset(self.repo_id, cache_dir=self.download_dir)\n\n        train_validation_split = dataset[\"train\"].train_test_split(test_size=self.val_split_fraction, seed=self.seed)\n        train_data = train_validation_split[\"train\"]\n        test_data = train_validation_split[\"test\"]\n\n        self.train_dataset = SFTDataset(\n            data=train_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n        self.test_dataset = SFTDataset(\n            data=test_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n"
  },
  {
    "path": "litgpt/data/alpaca_gpt4.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\n\nfrom litgpt.data.alpaca import Alpaca\n\n_URL = \"https://raw.githubusercontent.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/main/data/alpaca_gpt4_data.json\"\n\n\n@dataclass\nclass AlpacaGPT4(Alpaca):\n    \"\"\"AlpacaGPT4 data module for supervised finetuning.\"\"\"\n\n    val_split_fraction: float = 0.03847  # to get exactly 2000 test samples,\n    \"\"\"The fraction of the dataset to use for the validation dataset. The rest is used for training.\"\"\"\n    download_dir: Path = Path(\"./data/alpacagpt4\")\n    \"\"\"The directory in which the downloaded datasetgets saved.\"\"\"\n    file_url: str = field(repr=False, default=_URL)\n    \"\"\"The URL from where to download the dataset.\"\"\"\n    file_name: str = field(repr=False, default=\"alpacagpt4_data_cleaned_archive.json\")\n    \"\"\"The name of the dataset file to download.\"\"\"\n"
  },
  {
    "path": "litgpt/data/base.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom abc import abstractmethod\nfrom functools import partial\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nfrom lightning import LightningDataModule\nfrom torch import Tensor\nfrom torch.utils.data import Dataset\n\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n\nclass DataModule(LightningDataModule):\n    \"\"\"Base class for all data modules in LitGPT.\"\"\"\n\n    @abstractmethod\n    def connect(\n        self,\n        tokenizer: Optional[Tokenizer] = None,\n        batch_size: int = 1,\n        max_seq_length: Optional[int] = None,\n        **kwargs,\n    ) -> None:\n        \"\"\"All settings that can't be determined at the time of instantiation need to be passed through here\n        before any dataloaders can be accessed.\n        \"\"\"\n\n    def setup(self, stage: str = \"\") -> None:\n        # Stub is to redefine the default signature, because the concept of 'stage' does not exist in LitGPT\n        pass\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}()\"\n\n\nclass SFTDataset(Dataset):\n    \"\"\"An in-memory dataset for supervised finetuning with `input_ids` and `labels`.\n\n    Args:\n        data: A list of samples (dicts). The target/label must be stored under the key 'output' and the instruction\n            or other data can be stored under any key as long as it is compatible with the given prompt template.\n        tokenizer: The tokenizer to use. Should match the one that was used to pretrain the model.\n        prompt_style: The style to apply to prompts. See `litgpt.prompts` for a list of available styles.\n        max_seq_length: Truncate sequences that are longer than this value. By default, no truncation is applied.\n        mask_prompt: Whether to mask the prompt section from the label (with ``ignore_index``).\n        ignore_index: The index to use for elements to be ignored in the label.\n        transform: An optional transform to apply to the sample before it gets tokenized. Use this to rename the\n            keys in the dataset to the expected 'instruction' and 'output' keys.\n\n    Returns a dict with two keys:\n        input_ids: The encoded prompt + response\n        labels: Same as input_ids, unless ``mask_prompt=True`` in which case the 'prompt' part is replaced with\n            the ``ignore_index``.\n    \"\"\"\n\n    def __init__(\n        self,\n        data: List[Dict[str, str]],\n        tokenizer: Tokenizer,\n        prompt_style: Union[str, PromptStyle],\n        max_seq_length: int = -1,\n        mask_prompt: bool = True,\n        ignore_index: int = -100,\n        transform: Optional[Callable[[Any], Any]] = None,\n    ) -> None:\n        self.data = data\n        self.tokenizer = tokenizer\n        self.prompt_style = (\n            prompt_style if isinstance(prompt_style, PromptStyle) else PromptStyle.from_name(prompt_style)\n        )\n        self.max_seq_length = max_seq_length\n        self.mask_prompt = mask_prompt\n        self.ignore_index = ignore_index\n        self.transform = transform\n\n    def __len__(self) -> int:\n        return len(self.data)\n\n    def __getitem__(self, idx: int) -> Dict[str, Union[Tensor, Dict[str, int]]]:\n        example = self.data[idx]\n        if self.transform is not None:\n            example = self.transform(example)\n        prompt = self.prompt_style.apply(prompt=example[\"instruction\"], **example)\n        encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length)\n        encoded_response = self.tokenizer.encode(example[\"output\"], bos=False, eos=True, max_length=self.max_seq_length)\n        encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64)\n        if self.max_seq_length > 0:  # do not slice off last token when self.max_seq_length = -1\n            encoded_prompt_and_response = encoded_prompt_and_response[: self.max_seq_length]\n\n        # The labels are the full prompt with response, but with the prompt masked out\n        labels = encoded_prompt_and_response.clone()\n        if self.mask_prompt:\n            labels[: len(encoded_prompt)] = self.ignore_index\n\n        raw_token_count = len(self.tokenizer.encode(example[\"instruction\"], max_length=self.max_seq_length)) + len(\n            encoded_response\n        )\n\n        return {\n            \"input_ids\": encoded_prompt_and_response,\n            \"labels\": labels,\n            \"token_counts\": {\n                \"raw\": raw_token_count,\n                \"raw_plus_prompt_template\": len(encoded_prompt_and_response),\n            },\n        }\n\n\ndef get_sft_collate_fn(max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100):\n    \"\"\"Returns the collate function for supervised finetuning (needed in the DataLoader).\n\n    The collate function gets a list of dicts with keys `input_ids` and `labels`.\n    It returns a dict with batched `input_ids` and `labels`. Also pads short sequences to the longest element in\n    the batch. Optionally truncates all sequences to the specified maximum length.\n    \"\"\"\n    return partial(_sft_collate_fn, max_seq_length=max_seq_length, pad_id=pad_id, ignore_index=ignore_index)\n\n\ndef _sft_collate_fn(\n    samples: List[Dict[str, Tensor]], max_seq_length: int = -1, pad_id: int = 0, ignore_index: int = -100\n) -> Dict[str, Tensor]:\n    batched = {}\n    for key in (\"input_ids\", \"labels\"):\n        pad_value = pad_id if key == \"input_ids\" else ignore_index\n\n        # Pad right based on the longest sequence\n        batched[key] = torch.nn.utils.rnn.pad_sequence(\n            [sample[key] for sample in samples], batch_first=True, padding_value=pad_value\n        )\n\n        # Truncate if needed\n        if max_seq_length > 0:\n            batched[key] = batched[key][:, :max_seq_length]\n\n    batched[\"token_counts\"] = {}\n    batched[\"token_counts\"][\"raw\"] = torch.tensor(  # Token count without padding and without prompt template\n        [sample[\"token_counts\"][\"raw\"] for sample in samples], dtype=torch.int64\n    ).unsqueeze(1)\n    batched[\"token_counts\"][\"raw_plus_prompt_template\"] = (\n        torch.tensor(  # Token count without padding but with prompt template\n            [sample[\"token_counts\"][\"raw_plus_prompt_template\"] for sample in samples], dtype=torch.int64\n        ).unsqueeze(1)\n    )\n\n    return batched\n"
  },
  {
    "path": "litgpt/data/deita.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\"\"\"Implementation derived from https://github.com/tloen/alpaca-lora\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import List, Optional, Union\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import DataModule, SFTDataset, get_sft_collate_fn\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass Deita(DataModule):\n    \"\"\"Deita data module for supervised finetuning.\"\"\"\n\n    mask_prompt: bool = False\n    \"\"\"Whether to mask the prompt section from the label (with ``ignore_index``).\"\"\"\n    prompt_style: Union[str, PromptStyle] = \"alpaca\"\n    \"\"\"The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.\"\"\"\n    ignore_index: int = -100\n    \"\"\"The index to use for elements to be ignored in the label.\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for shuffling the dataset.\"\"\"\n    num_workers: int = 4\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n    include_multiturn_conversations: bool = False\n    \"\"\"Whether to include multi-turn conversations in the dataset.\"\"\"\n    download_dir: Path = Path(\"./data/deita\")\n    \"\"\"The directory in which the downloaded dataset gets saved.\"\"\"\n    repo_id: str = \"HuggingFaceH4/deita-10k-v0-sft\"\n    \"\"\"The repo from where the data is downloaded\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n    train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n    test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n\n    def __post_init__(self) -> None:\n        super().__init__()\n        if isinstance(self.prompt_style, str):\n            self.prompt_style = PromptStyle.from_name(self.prompt_style)\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = -1 if max_seq_length is None else max_seq_length\n\n    def prepare_data(self) -> None:\n        from datasets import load_dataset\n\n        load_dataset(self.repo_id, split=[\"train_sft\", \"test_sft\"], cache_dir=self.download_dir)\n\n    def setup(self, stage: str = \"\") -> None:\n        from datasets import load_dataset\n\n        dataset = load_dataset(self.repo_id, split=[\"train_sft\", \"test_sft\"])\n        train_data = format_dataset(dataset[0], self.include_multiturn_conversations)\n        test_data = format_dataset(dataset[1], self.include_multiturn_conversations)\n\n        self.train_dataset = SFTDataset(\n            data=train_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n        self.test_dataset = SFTDataset(\n            data=test_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n\n    def train_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=True,\n            generator=torch.Generator().manual_seed(self.seed),\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n    def val_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n\ndef format_dataset(dataset: List[dict], include_multi_turn_conversations: bool) -> List[dict]:\n    formatted = []\n\n    for entry in dataset:\n        convo = entry[\"messages\"]\n        if include_multi_turn_conversations:\n            for i in range(0, len(convo) - 1, 2):\n                formatted.append({\"instruction\": convo[i][\"content\"], \"input\": \"\", \"output\": convo[i + 1][\"content\"]})\n        else:\n            formatted.append({\"instruction\": convo[0][\"content\"], \"input\": \"\", \"output\": convo[1][\"content\"]})\n\n    return formatted\n"
  },
  {
    "path": "litgpt/data/flan.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport json\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Set, Union\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import DataModule, SFTDataset, get_sft_collate_fn\nfrom litgpt.data.alpaca import download_if_missing\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n_URL = \"https://huggingface.co/datasets/Muennighoff/flan/resolve/main\"\n\n\n# TODO: Including all subsets, FLAN is too large to be loaded in memory. Switch the implementation to cache\n#   on disk or use Lightning Data\n@dataclass\nclass FLAN(DataModule):\n    \"\"\"FLAN data module for supervised finetuning.\"\"\"\n\n    mask_prompt: bool = False\n    \"\"\"Whether to mask the prompt section from the label (with ``ignore_index``).\"\"\"\n    prompt_style: Union[str, PromptStyle] = \"flan\"\n    \"\"\"The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.\"\"\"\n    ignore_index: int = -100\n    \"\"\"The index to use for elements to be ignored in the label.\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for shuffling the dataset.\"\"\"\n    num_workers: int = 4\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n    download_dir: Path = Path(\"./data/flan\")\n    \"\"\"The directory in which the downloaded dataset gets saved.\"\"\"\n    url: str = _URL\n    \"\"\"The URL from where to download the dataset.\"\"\"\n    subsets: Optional[str] = None\n    \"\"\"A comma separated list of subsets to use. If None, all subsets are used.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n    train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n    test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n\n    def __post_init__(self):\n        super().__init__()\n        if isinstance(self.prompt_style, str):\n            self.prompt_style = PromptStyle.from_name(self.prompt_style)\n\n        supported_subsets = _supported_subsets()\n        if self.subsets is not None:\n            self.subsets = self.subsets.split(\",\")\n            for subset in self.subsets:\n                if subset not in supported_subsets:\n                    raise ValueError(f\"{subset} not in {supported_subsets}\")\n        else:\n            self.subsets = list(supported_subsets)\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = -1 if max_seq_length is None else max_seq_length\n\n    def prepare_data(self) -> None:\n        self.download_dir.mkdir(parents=True, exist_ok=True)\n        for subset in self.subsets:\n            for split in (\"train\", \"test\"):\n                data_file_path = self.download_dir / f\"{subset}_{split}.jsonl\"\n                data_file_url = f\"{self.url}/{split}/{subset}_{split}.jsonl\"\n                download_if_missing(data_file_path, data_file_url)\n\n    def train_dataloader(self):\n        return self._dataloader(\"train\")\n\n    def val_dataloader(self):\n        return self._dataloader(\"test\")\n\n    def _dataloader(self, split: str) -> DataLoader:\n        data = []\n        for subset in self.subsets:\n            data_file_path = self.download_dir / f\"{subset}_{split}.jsonl\"\n            data.extend(load_jsonl(data_file_path))\n\n        dataset = SFTDataset(\n            data=data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n            transform=_transform,\n        )\n        return DataLoader(\n            dataset=dataset,\n            batch_size=self.batch_size,\n            shuffle=(split == \"train\"),\n            generator=torch.Generator().manual_seed(self.seed),\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n\ndef load_jsonl(filename: Path) -> List[Dict[str, str]]:\n    data = []\n    with open(filename, encoding=\"utf-8\") as f:\n        for line in f:\n            data.append(json.loads(line))\n    return data\n\n\ndef _transform(item: dict) -> dict:\n    item[\"instruction\"] = item.pop(\"inputs\")\n    item[\"output\"] = item.pop(\"targets\")\n    return item\n\n\ndef _supported_subsets() -> Set[str]:\n    return {\n        \"aeslc_10templates\",\n        \"ag_news_subset_10templates\",\n        \"anli_r1_10templates\",\n        \"anli_r2_10templates\",\n        \"anli_r3_10templates\",\n        \"arc_challenge_10templates\",\n        \"arc_easy_10templates\",\n        \"bool_q_10templates\",\n        \"cb_10templates\",\n        \"cnn_dailymail_10templates\",\n        \"cola_10templates\",\n        \"common_gen_10templates\",\n        \"copa_10templates\",\n        \"coqa_10templates\",\n        \"cosmos_qa_10templates\",\n        \"dart_10templates\",\n        \"definite_pronoun_resolution_10templates\",\n        \"drop_10templates\",\n        \"e2e_nlg_10templates\",\n        \"fix_punct_10templates\",\n        \"gigaword_10templates\",\n        \"glue_mrpc_10templates\",\n        \"glue_qqp_10templates\",\n        \"hellaswag_10templates\",\n        \"imdb_reviews_10templates\",\n        \"math_dataset_10templates\",\n        \"mnli_matched_10templates\",\n        \"mnli_mismatched_10templates\",\n        \"multi_news_10templates\",\n        \"multirc_10templates\",\n        \"natural_questions_10templates\",\n        \"openbookqa_10templates\",\n        \"opinion_abstracts_idebate_10templates\",\n        \"opinion_abstracts_rotten_tomatoes_10templates\",\n        \"para_crawl_enes_10templates\",\n        \"paws_wiki_10templates\",\n        \"piqa_10templates\",\n        \"qnli_10templates\",\n        \"quac_10templates\",\n        \"record_10templates\",\n        \"rte_10templates\",\n        \"samsum_10templates\",\n        \"sentiment140_10templates\",\n        \"snli_10templates\",\n        \"squad_v1_10templates\",\n        \"squad_v2_10templates\",\n        \"sst2_10templates\",\n        \"story_cloze_10templates\",\n        \"stsb_10templates\",\n        \"trec_10templates\",\n        \"trivia_qa_10templates\",\n        \"true_case_10templates\",\n        \"web_nlg_en_10templates\",\n        \"wic_10templates\",\n        \"wiki_lingua_english_en_10templates\",\n        \"wmt14_enfr_10templates\",\n        \"wmt16_translate_csen_10templates\",\n        \"wmt16_translate_deen_10templates\",\n        \"wmt16_translate_fien_10templates\",\n        \"wmt16_translate_roen_10templates\",\n        \"wmt16_translate_ruen_10templates\",\n        \"wmt16_translate_tren_10templates\",\n        \"wnli_10templates\",\n        \"word_segment_10templates\",\n        \"wsc_10templates\",\n        \"yelp_polarity_reviews_10templates\",\n    }\n"
  },
  {
    "path": "litgpt/data/json_data.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport json\nimport warnings\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Any, Optional, Tuple, Union\n\nimport torch\nfrom torch.utils.data import DataLoader, random_split\n\nfrom litgpt.data import DataModule, SFTDataset, get_sft_collate_fn\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass JSON(DataModule):\n    \"\"\"Loads JSON or JSONL data for supervised finetuning.\"\"\"\n\n    json_path: Path\n    \"\"\"A path to a JSON file or a directory with `train.json` and `val.json` containing the data.\n    The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output',\n    and can optionally have a key 'input' (see Alpaca).\"\"\"\n    mask_prompt: bool = False\n    \"\"\"Whether to mask the prompt section from the label (with ``ignore_index``).\"\"\"\n    val_split_fraction: Optional[float] = None\n    \"\"\"The fraction of the dataset to use for the validation dataset. The rest is used for training.\n    Only applies if you passed in a single file to `json_path`.\"\"\"\n    prompt_style: Union[str, PromptStyle] = \"alpaca\"\n    \"\"\"The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.\"\"\"\n    ignore_index: int = -100\n    \"\"\"The index to use for elements to be ignored in the label.\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for creating the train/val splits and shuffling the dataset.\"\"\"\n    num_workers: int = 4\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n    train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n    val_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n\n    def __post_init__(self):\n        super().__init__()\n        if self.json_path.is_file() and self.val_split_fraction is None:\n            self.val_split_fraction = 0.05\n            warnings.warn(\n                \"The `json_path` points to a single file and `val_split_fraction` was not set. \"\n                \"Defaulting to `val_split_fraction=0.05`. Set `val_split_fraction` explicitly \"\n                \"to use a different split percentage.\",\n                UserWarning,\n                stacklevel=2,\n            )\n        if self.json_path.is_dir() and self.val_split_fraction is not None:\n            raise ValueError(\n                \"If `json_path` is a directory, it must contain 'train.json' and 'val.json' files and\"\n                f\" hence `val_split_fraction` should not be set. Got `{self.val_split_fraction=}`.\"\n            )\n        if not self.json_path.exists():\n            raise FileNotFoundError(\n                \"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files,\"\n                f\" but '{self.json_path!s}' does not exist.\"\n            )\n        if isinstance(self.prompt_style, str):\n            self.prompt_style = PromptStyle.from_name(self.prompt_style)\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = -1 if max_seq_length is None else max_seq_length\n\n    def setup(self, stage: str = \"\") -> None:\n        train_data, test_data = self.get_splits()\n\n        self.train_dataset = SFTDataset(\n            data=train_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n        self.test_dataset = SFTDataset(\n            data=test_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n\n    def train_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=True,\n            generator=torch.Generator().manual_seed(self.seed),\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n    def val_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n    def get_splits(self) -> Tuple:\n        # A single file (gets split into train and test)\n        if self.json_path.is_file():\n            data = load_split(self.json_path)\n\n            # Partition the dataset into train and test\n            train_data, test_data = random_split(\n                data,\n                [1.0 - self.val_split_fraction, self.val_split_fraction],\n                generator=torch.Generator().manual_seed(self.seed),\n            )\n            return train_data, test_data\n\n        # A directory containing train.json and val.json\n        if (train_file := self.find_split(\"train\")) and (val_file := self.find_split(\"val\")):\n            train_data = load_split(train_file)\n            test_data = load_split(val_file)\n            return train_data, test_data\n\n        raise FileNotFoundError(\n            \"The `json_path` must be a file or a directory containing 'train.json' and 'val.json' files.\"\n        )\n\n    def find_split(self, split_name: str) -> Optional[Path]:\n        for suffix in (\".json\", \".jsonl\"):\n            if (file := self.json_path / f\"{split_name}{suffix}\").is_file():\n                return file\n        return None\n\n\ndef load_split(json_path: Path) -> Any:\n    if json_path.suffix == \".json\":\n        with open(json_path, encoding=\"utf-8\") as file:\n            return json.load(file)\n    if json_path.suffix == \".jsonl\":\n        with open(json_path, encoding=\"utf-8\") as file:\n            return [json.loads(line) for line in file]\n    else:\n        raise ValueError(f\"Unsupported file format: {json_path.suffix}. Expected `.json` or `.jsonl`.\")\n"
  },
  {
    "path": "litgpt/data/lima.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\"\"\"Implementation derived from https://github.com/tloen/alpaca-lora\"\"\"\n\nimport os\nfrom dataclasses import dataclass, field\nfrom typing import List, Optional, Union\n\nimport torch\nfrom torch.utils.data import DataLoader, random_split\n\nfrom litgpt.data import DataModule, SFTDataset, get_sft_collate_fn\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass LIMA(DataModule):\n    \"\"\"LIMA data module for supervised finetuning.\"\"\"\n\n    mask_prompt: bool = False\n    \"\"\"Whether to mask the prompt section from the label (with ``ignore_index``).\"\"\"\n    val_split_fraction: float = 0.1\n    \"\"\"The fraction of the dataset to use for the validation dataset. The rest is used for training.\"\"\"\n    prompt_style: Union[str, PromptStyle] = \"alpaca\"\n    \"\"\"The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.\"\"\"\n    ignore_index: int = -100\n    \"\"\"The index to use for elements to be ignored in the label.\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for creating the train/val splits and shuffling the dataset.\"\"\"\n    num_workers: int = 4\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n    include_multiturn_conversations: bool = False\n    \"\"\"Whether to include multi-turn conversations in the dataset.\"\"\"\n    repo_id: str = \"GAIR/lima\"\n    \"\"\"The Hugging Face dataset repository ID from where to download the data.\"\"\"\n    access_token: Optional[str] = field(repr=False, default=os.getenv(\"HF_TOKEN\"))\n    \"\"\"The Hugging Face API token to use for authentication. Can also be set through the\n    `HF_TOKEN` environment variable.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n    train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n    test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n\n    def __post_init__(self):\n        super().__init__()\n        if self.access_token is None:\n            raise ValueError(\n                \"LIMA requires authentication, please set the `HF_TOKEN=your_token` environment\"\n                \" variable or pass --access_token=your_token. You can find your token by visiting\"\n                \" https://huggingface.co/settings/tokens\"\n            )\n        if isinstance(self.prompt_style, str):\n            self.prompt_style = PromptStyle.from_name(self.prompt_style)\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = -1 if max_seq_length is None else max_seq_length\n\n    def prepare_data(self) -> None:\n        from datasets import load_dataset\n\n        load_dataset(self.repo_id, token=self.access_token)\n\n    def setup(self, stage: str = \"\") -> None:\n        from datasets import load_dataset\n\n        dataset = load_dataset(self.repo_id, token=self.access_token)\n        data = format_dataset(dataset[\"train\"], self.include_multiturn_conversations)\n\n        # Partition the dataset into train and test\n        train_data, test_data = random_split(\n            data,\n            [1.0 - self.val_split_fraction, self.val_split_fraction],\n            generator=torch.Generator().manual_seed(self.seed),\n        )\n        train_data, test_data = list(train_data), list(test_data)\n\n        self.train_dataset = SFTDataset(\n            data=train_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n        self.test_dataset = SFTDataset(\n            data=test_data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n        )\n\n    def train_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=True,\n            generator=torch.Generator().manual_seed(self.seed),\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n    def val_dataloader(self) -> DataLoader:\n        return DataLoader(\n            self.test_dataset,\n            batch_size=self.batch_size,\n            shuffle=False,\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n\ndef format_dataset(dataset_partition: dict, include_multi_turn_conversations: bool) -> List[dict]:\n    formatted_ds = []\n\n    for entry in dataset_partition:\n        convo = entry[\"conversations\"]\n        if include_multi_turn_conversations:\n            for i in range(0, len(convo) - 1, 2):\n                formatted_ds.append({\"instruction\": convo[i], \"input\": \"\", \"output\": convo[i + 1]})\n        else:\n            formatted_ds.append({\"instruction\": convo[0], \"input\": \"\", \"output\": convo[1]})\n\n    return formatted_ds\n"
  },
  {
    "path": "litgpt/data/lit_data.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Tuple, Union\n\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import DataModule\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass LitData(DataModule):\n    \"\"\"Loads data using LitData's StreamingDataset given a path to a folder of preprocessed data (chunks).\"\"\"\n\n    data_path: Union[str, Path] = Path(\"data/\")\n    \"\"\"The path to the data directory containing the preprocessed chunks for the streaming dataset\n    The path can also be a remote path (e.g., s3://). See also ``split_names`` if this path contains subfolders\n    for training- and validation splits.\"\"\"\n    split_names: Optional[Tuple[str, str]] = None\n    \"\"\"Optional tuple for names of subfolders for training and validation under ``data_path``. If not provided,\n    all data under data_path will be used for training, and the validation dataloader will be identical to the\n    train dataloader.\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for shuffling the dataset.\"\"\"\n    num_workers: int = 8\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n\n    batch_size: int = field(init=False, repr=False, default=1)\n    seq_length: int = field(init=False, repr=False, default=2048)\n\n    def __post_init__(self) -> None:\n        super().__init__()\n        if self.split_names is not None and len(self.split_names) != 2:\n            raise ValueError(\"If provided `split_names` must be a tuple of two strings, for example: ('train', 'val').\")\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.batch_size = batch_size\n        self.seq_length = max_seq_length + 1  # Increase by one because we need the next token as well\n\n    def train_dataloader(self) -> DataLoader:\n        input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path)\n        return self._dataloader(input_dir=input_dir, train=True)\n\n    def val_dataloader(self) -> DataLoader:\n        input_dir = os.path.join(self.data_path, self.split_names[1]) if self.split_names else str(self.data_path)\n        return self._dataloader(input_dir=input_dir, train=False)\n\n    def _dataloader(self, input_dir: str, train: bool):\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        dataset = StreamingDataset(\n            input_dir=input_dir,\n            item_loader=TokensLoader(block_size=self.seq_length),\n            shuffle=train,\n            seed=self.seed,\n        )\n        dataloader = StreamingDataLoader(\n            dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return dataloader\n"
  },
  {
    "path": "litgpt/data/longform.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport json\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import DataModule, SFTDataset, get_sft_collate_fn\nfrom litgpt.data.alpaca import download_if_missing\nfrom litgpt.prompts import PromptStyle\nfrom litgpt.tokenizer import Tokenizer\n\n_URL = \"https://raw.githubusercontent.com/akoksal/LongForm/main/dataset\"\n\n\n@dataclass\nclass LongForm(DataModule):\n    \"\"\"LongForm data module for supervised finetuning.\"\"\"\n\n    mask_prompt: bool = False\n    \"\"\"Whether to mask the prompt section from the label (with ``ignore_index``).\"\"\"\n    prompt_style: Union[str, PromptStyle] = \"longform\"\n    \"\"\"The style to apply to instruction prompts. See `litgpt.prompts` for a list of available styles.\"\"\"\n    ignore_index: int = -100\n    \"\"\"The index to use for elements to be ignored in the label.\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for shuffling the dataset.\"\"\"\n    num_workers: int = 4\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n    download_dir: Path = Path(\"./data/longform\")\n    \"\"\"The directory in which the downloaded dataset gets saved.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n    train_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n    test_dataset: Optional[SFTDataset] = field(default=None, init=False, repr=False)\n\n    def __post_init__(self) -> None:\n        super().__init__()\n        if isinstance(self.prompt_style, str):\n            self.prompt_style = PromptStyle.from_name(self.prompt_style)\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = -1 if max_seq_length is None else max_seq_length\n\n    def prepare_data(self) -> None:\n        self.download_dir.mkdir(parents=True, exist_ok=True)\n        download_if_missing(self.download_dir / \"train.json\", f\"{_URL}/train.json\")\n        download_if_missing(self.download_dir / \"val.json\", f\"{_URL}/val.json\")\n\n    def train_dataloader(self):\n        return self._dataloader(\"train\")\n\n    def val_dataloader(self):\n        return self._dataloader(\"val\")\n\n    def _dataloader(self, split: str) -> DataLoader:\n        with open(self.download_dir / f\"{split}.json\", encoding=\"utf-8\") as file:\n            data = json.load(file)\n\n        dataset = SFTDataset(\n            data=data,\n            tokenizer=self.tokenizer,\n            prompt_style=self.prompt_style,\n            max_seq_length=self.max_seq_length,\n            mask_prompt=self.mask_prompt,\n            ignore_index=self.ignore_index,\n            transform=_transform,\n        )\n        return DataLoader(\n            dataset=dataset,\n            batch_size=self.batch_size,\n            shuffle=(split == \"train\"),\n            generator=torch.Generator().manual_seed(self.seed),\n            num_workers=self.num_workers,\n            collate_fn=get_sft_collate_fn(max_seq_length=self.max_seq_length, ignore_index=self.ignore_index),\n        )\n\n\ndef _transform(item: dict) -> dict:\n    item[\"instruction\"] = item.pop(\"input\")\n    return item\n"
  },
  {
    "path": "litgpt/data/microllama.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Union\n\nfrom litgpt.data.tinyllama import TinyLlama\n\n\n@dataclass\nclass MicroLlama(TinyLlama):\n    \"\"\"The MicroLlama data module is composed of only SlimPajama data.\"\"\"\n\n    def __init__(self, data_path: Union[str, Path] = Path(\"data/\"), seed: int = 42, num_workers: int = 8):\n        super().__init__(data_path=data_path, seed=seed, num_workers=num_workers, use_starcoder=False)\n"
  },
  {
    "path": "litgpt/data/openwebtext.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nfrom dataclasses import dataclass, field\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import DataModule\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass OpenWebText(DataModule):\n    \"\"\"The OpenWebText data module for pretraining.\"\"\"\n\n    data_path: Union[str, Path] = Path(\"data/openwebtext\")\n    \"\"\"The path to the data directory, containing two folders 'train' and 'val'\n    which are the output of the preprocessing step. The path can also be a remote path (e.g., s3://).\"\"\"\n    val_split_fraction: float = 0.0005\n    \"\"\"The fraction of data that should be put aside for validation.\"\"\"\n    seed: int = 42\n    \"\"\"The seed to use for shuffling the training data.\"\"\"\n    num_workers: int = 8\n    \"\"\"The number of workers to use for the dataloaders.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, repr=False, init=False)\n    batch_size: int = field(default=1, repr=False, init=False)\n    seq_length: int = field(default=2048, repr=False, init=False)\n\n    def __post_init__(self) -> None:\n        super().__init__()\n        # Could be a remote path (s3://) or a local path\n        self.data_path_train = str(self.data_path).rstrip(\"/\") + \"/train\"\n        self.data_path_val = str(self.data_path).rstrip(\"/\") + \"/val\"\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = 2048\n    ) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.seq_length = max_seq_length + 1  # Increase by one because we need the next token as well\n\n    def prepare_data(self) -> None:\n        from datasets import Dataset, load_dataset\n        from litdata import optimize\n\n        if str(self.data_path).startswith(\"s3://\"):\n            print(f\"The OpenWebText data path points to an S3 location: {self.data_path}. Skipping preprocessing.\")\n            return\n\n        if Path(self.data_path_train).is_dir() and Path(self.data_path_val).is_dir():\n            print(f\"Found OpenWebText train and val dir: {self.data_path}. Skipping preprocessing.\")\n            return\n\n        dataset = load_dataset(\"openwebtext\", num_proc=(os.cpu_count() // 2), trust_remote_code=True)\n\n        # Split the data in training and validation\n        split_dataset = dataset[\"train\"].train_test_split(\n            test_size=self.val_split_fraction, seed=self.seed, shuffle=True\n        )\n        split_dataset[\"val\"] = split_dataset.pop(\"test\")  # rename the test split to val\n\n        def tokenize(data: Dataset, index: int):\n            yield self.tokenizer.encode(data[index][\"text\"], eos=True)\n\n        optimize(\n            fn=partial(tokenize, split_dataset[\"train\"]),\n            inputs=list(range(len(split_dataset[\"train\"]))),\n            output_dir=self.data_path_train,\n            num_workers=min(64, os.cpu_count() - 1),\n            chunk_bytes=\"200MB\",\n        )\n        optimize(\n            fn=partial(tokenize, split_dataset[\"val\"]),\n            inputs=list(range(len(split_dataset[\"val\"]))),\n            output_dir=self.data_path_val,\n            num_workers=min(8, os.cpu_count() - 1),\n            chunk_bytes=\"200MB\",\n        )\n\n    def train_dataloader(self) -> DataLoader:\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        train_dataset = StreamingDataset(\n            input_dir=self.data_path_train,\n            item_loader=TokensLoader(block_size=self.seq_length),\n            shuffle=True,\n        )\n        train_dataloader = StreamingDataLoader(\n            train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return train_dataloader\n\n    def val_dataloader(self) -> DataLoader:\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        val_dataset = StreamingDataset(\n            input_dir=self.data_path_val,\n            item_loader=TokensLoader(block_size=self.seq_length),\n            shuffle=True,\n        )\n        val_dataloader = StreamingDataLoader(\n            val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return val_dataloader\n"
  },
  {
    "path": "litgpt/data/prepare_slimpajama.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport json\nimport os\nimport time\nfrom pathlib import Path\n\nfrom litgpt.data.prepare_starcoder import DataChunkRecipe\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import CLI, extend_checkpoint_dir\n\n\nclass SlimPajamaDataRecipe(DataChunkRecipe):\n    is_generator = True\n\n    def __init__(self, tokenizer: Tokenizer, chunk_size: int):\n        super().__init__(chunk_size)\n        self.tokenizer = tokenizer\n\n    def prepare_structure(self, input_dir):\n        files = Path(input_dir).rglob(\"*.zst\")\n        return [str(file) for file in files]\n\n    def prepare_item(self, filepath):\n        import zstandard as zstd\n\n        with zstd.open(open(filepath, \"rb\"), \"rt\", encoding=\"utf-8\") as f:\n            for row in f:\n                text = json.loads(row)[\"text\"]\n                if json.loads(row)[\"meta\"][\"redpajama_set_name\"] == \"RedPajamaGithub\":\n                    continue  # exclude the GitHub data since it overlaps with starcoder\n                text_ids = self.tokenizer.encode(string=text, bos=False, eos=True)\n                yield text_ids\n\n\ndef prepare(\n    input_dir: Path = Path(\"data/SlimPajama-627B/train\"),\n    output_dir: Path = Path(\"data/slimpajama/train\"),\n    tokenizer_path: Path = Path(\"checkpoints/Llama-2-7b-hf/\"),\n    chunk_size: int = (2049 * 16384),\n    fast_dev_run: bool = False,\n) -> None:\n    from litdata.processing.data_processor import DataProcessor\n\n    tokenizer_path = extend_checkpoint_dir(tokenizer_path)\n    tokenizer = Tokenizer(tokenizer_path)\n    data_recipe = SlimPajamaDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size)\n    data_processor = DataProcessor(\n        input_dir=str(input_dir),\n        output_dir=str(output_dir),\n        fast_dev_run=fast_dev_run,\n        num_workers=os.cpu_count(),\n        num_downloaders=1,\n    )\n\n    start_time = time.time()\n    data_processor.run(data_recipe)\n    elapsed_time = time.time() - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    CLI(prepare)\n"
  },
  {
    "path": "litgpt/data/prepare_starcoder.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport time\nimport traceback\nfrom pathlib import Path\n\nfrom litgpt.constants import _LITDATA_AVAILABLE\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import CLI, extend_checkpoint_dir\n\nif _LITDATA_AVAILABLE:\n    from litdata.processing.data_processor import DataChunkRecipe\nelse:\n    DataChunkRecipe = object\n\n\nclass StarcoderDataRecipe(DataChunkRecipe):\n    is_generator = True\n\n    def __init__(self, tokenizer: Tokenizer, chunk_size: int):\n        super().__init__(chunk_size)\n        self.tokenizer = tokenizer\n\n    def prepare_structure(self, input_dir):\n        files = Path(input_dir).rglob(\"*.parquet\")\n        return [str(file) for file in files]\n\n    def prepare_item(self, item_metadata):\n        import pyarrow.parquet as pq\n\n        filepath = item_metadata\n        start = time.time()\n\n        try:\n            parquet_file = pq.ParquetFile(filepath)\n            # reduce RAM usage\n            for batch in parquet_file.iter_batches(batch_size=8192, columns=[\"content\"]):\n                for text in batch.to_pandas()[\"content\"]:\n                    yield self.tokenizer.encode(text, bos=False, eos=True)\n\n        except Exception:\n            print(traceback.format_exc())\n            print(f\"Error reading {filepath}\")\n            return\n\n        parquet_file.close()\n        end = time.time()\n        print(f\"Took {end - start:.2f} seconds total\", filepath)\n\n\ndef prepare(\n    input_dir: Path = Path(\"data/starcoderdata\"),\n    output_dir: Path = Path(\"data/starcoder\"),\n    tokenizer_path: Path = Path(\"checkpoints/Llama-2-7b-hf/\"),\n    chunk_size: int = (2049 * 8192),\n    fast_dev_run: bool = False,\n) -> None:\n    from litdata.processing.data_processor import DataProcessor\n\n    tokenizer_path = extend_checkpoint_dir(tokenizer_path)\n    tokenizer = Tokenizer(tokenizer_path)\n    data_recipe = StarcoderDataRecipe(tokenizer=tokenizer, chunk_size=chunk_size)\n    data_processor = DataProcessor(\n        input_dir=str(input_dir),\n        output_dir=str(output_dir),\n        fast_dev_run=fast_dev_run,\n        num_workers=os.cpu_count(),\n        num_downloaders=1,\n    )\n\n    start_time = time.time()\n    data_processor.run(data_recipe)\n    elapsed_time = time.time() - start_time\n    print(f\"Time taken: {elapsed_time:.2f} seconds\")\n\n\nif __name__ == \"__main__\":\n    CLI(prepare)\n"
  },
  {
    "path": "litgpt/data/text_files.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport glob\nimport os\nfrom dataclasses import dataclass, field\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Optional\n\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import DataModule\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass TextFiles(DataModule):\n    \"\"\"The TextFile data module used for pretraining.\n\n    Reads in text data from plaintext files contained in a data folder\n    and provides training and validation dataloaders that return batches of tokens.\n    Every sample is set to a fixed length.\n    \"\"\"\n\n    train_data_path: Path\n    \"\"\"The path to the data directory used for training that contains .txt files\"\"\"\n    val_data_path: Optional[Path] = None\n    \"\"\"The path to the data directory used for validation that\n    contains .txt files. Splits off data for validation from the\n    training set if None.\"\"\"\n    seed: int = 42\n    \"\"\"The seed to use for shuffling the dataset.\"\"\"\n    num_workers: int = 4\n    \"\"\"The number of workers to use for data loading.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n\n    def __post_init__(self) -> None:\n        super().__init__()\n        self.out_path_train = self.train_data_path / \"train\"\n        if self.val_data_path is None:\n            self.out_path_val = self.train_data_path / \"val\"\n        else:\n            self.out_path_val = Path(self.val_data_path) / \"val\"\n\n    def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = max_seq_length + 1  # Increase by one because we need the next token as well\n\n    def prepare_data(self) -> None:\n        from litdata import optimize\n        from litdata.streaming import TokensLoader\n\n        train_files = sorted(glob.glob(str(self.train_data_path / \"*.txt\")))\n        assert len(train_files) > 0, f\"No .txt files found in train data {train_files}\"\n\n        if self.val_data_path is not None:\n            self.val_data_path = Path(self.val_data_path)\n            val_files = sorted(glob.glob(str(self.val_data_path / \"*.txt\")))\n            assert len(val_files) > 0, f\"No .txt files found in validation data {val_files}\"\n        # train/test split. let's use only shard 0 for test split, rest train\n        else:\n            assert len(train_files) > 1, f\"Expected at least two .txt files in {train_files}\"\n            val_files, *train_files = train_files\n            val_files = [val_files]\n\n        # It's ok to use almost all CPUs here because this runs in a single process\n        num_workers = os.cpu_count() - 1\n        use_workers = min(num_workers, len(train_files))\n        if not Path(self.out_path_train).is_dir():\n            validate_tokenizer(self.tokenizer)\n            optimize(\n                fn=partial(tokenize, tokenizer=self.tokenizer),\n                inputs=train_files,\n                output_dir=str(self.out_path_train),\n                num_workers=use_workers,\n                chunk_bytes=\"50MB\",\n                item_loader=TokensLoader(block_size=self.max_seq_length),\n            )\n        else:\n            print(\n                f\"\\nWarning: Preprocessed training data found in {self.out_path_train}.\"\n                \" For efficiency, reprocessing is skipped. If your text input has changed since\"\n                \" the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger\"\n                f\" reprocessing: `rm -rf {self.out_path_train}`\\n\"\n            )\n        use_workers = min(num_workers, len(val_files))\n        if not Path(self.out_path_val).is_dir():\n            validate_tokenizer(self.tokenizer)\n            optimize(\n                fn=partial(tokenize, tokenizer=self.tokenizer),\n                inputs=val_files,\n                output_dir=str(self.out_path_val),\n                num_workers=use_workers,\n                chunk_bytes=\"50MB\",\n                item_loader=TokensLoader(block_size=self.max_seq_length),\n            )\n        else:\n            print(\n                f\"\\nWarning: Preprocessed validation data found in {self.out_path_val}.\"\n                \" For efficiency, reprocessing is skipped. If your text input has changed since\"\n                \" the last `litgpt pretrain` command, remove the preprocessed file(s) to trigger\"\n                f\" reprocessing: `rm -rf {self.out_path_val}`\\n\"\n            )\n\n    def train_dataloader(self) -> DataLoader:\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        train_dataset = StreamingDataset(\n            input_dir=str(self.out_path_train),\n            item_loader=TokensLoader(block_size=self.max_seq_length),\n            shuffle=True,\n        )\n\n        train_dataloader = StreamingDataLoader(\n            train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return train_dataloader\n\n    def val_dataloader(self) -> DataLoader:\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        val_dataset = StreamingDataset(\n            input_dir=str(self.out_path_val),\n            item_loader=TokensLoader(block_size=self.max_seq_length),\n            shuffle=True,\n        )\n        val_dataloader = StreamingDataLoader(\n            val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return val_dataloader\n\n\ndef tokenize(filename: str, tokenizer: Tokenizer):\n    with open(filename, encoding=\"utf-8\") as file:\n        text = file.read()\n    text = text.strip()\n    yield tokenizer.encode(text, bos=True, eos=False)\n\n\ndef validate_tokenizer(tokenizer: Tokenizer) -> None:\n    if tokenizer is None:\n        raise ValueError(\n            \"Tokenizer is None. If you are using this data module via `litgpt pretrain`, \"\n            \"please provide a valid `--tokenizer_dir` path.\"\n        )\n"
  },
  {
    "path": "litgpt/data/tinyllama.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import DataModule\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass TinyLlama(DataModule):\n    \"\"\"The TinyLlama data module is composed of a mix of SlimPajama and Starcoder data.\n\n    Provides training and validation streaming dataloaders that return batches of tokens.\n    \"\"\"\n\n    data_path: Union[str, Path] = Path(\"data/\")\n    \"\"\"The path to the data directory, containing two folders 'slimpajama' and 'starcoder'\n    which are the output of the preprocessing step done in advance. See the `tutorial/pretrain_tinyllama.md`\n    for instructions. The path can also be a remote path (e.g., s3://).\"\"\"\n    seed: int = 42\n    \"\"\"The random seed for shuffling the dataset.\"\"\"\n    num_workers: int = 8\n    \"\"\"How many DataLoader processes to use for loading.\"\"\"\n    use_starcoder: bool = True\n    \"\"\"Toggle for using Starcoder data.\"\"\"\n\n    batch_size: int = field(init=False, repr=False, default=1)\n    seq_length: int = field(init=False, repr=False, default=2048)\n\n    def __post_init__(self):\n        super().__init__()\n        # Could be a remote path (s3://) or a local path\n        self.slimpajama_train = str(self.data_path).rstrip(\"/\") + \"/slimpajama/train\"\n        self.slimpajama_val = str(self.data_path).rstrip(\"/\") + \"/slimpajama/val\"\n        self.required_paths = [self.slimpajama_train, self.slimpajama_val]\n\n        if self.use_starcoder:\n            self.starcoder_train = str(self.data_path).rstrip(\"/\") + \"/starcoder\"\n            self.required_paths += [self.starcoder_train]\n\n    def connect(\n        self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None\n    ) -> None:\n        self.batch_size = batch_size\n        self.seq_length = max_seq_length + 1  # Increase by one because we need the next token as well\n\n    def prepare_data(self) -> None:\n        for path in self.required_paths:\n            if not path.startswith(\"s3://\") and not Path(path).is_dir():\n                raise FileNotFoundError(\n                    \"The data path for TinyLlama is expected to be the directory containing these subdirectories:\"\n                    f\" `slimpajama/train`, `slimpajama/val`, `starcoder`. The directory {path} does not exist.\"\n                    \" Set it via `--data.data_path=...`\"\n                )\n\n    def train_dataloader(self) -> DataLoader:\n        from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader\n\n        slim_train_data = StreamingDataset(\n            input_dir=self.slimpajama_train,\n            item_loader=TokensLoader(block_size=self.seq_length),\n            shuffle=True,\n            drop_last=True,\n        )\n        train_data = slim_train_data\n\n        if self.use_starcoder:\n            train_datasets = [\n                slim_train_data,\n                StreamingDataset(\n                    input_dir=self.starcoder_train,\n                    item_loader=TokensLoader(block_size=self.seq_length),\n                    shuffle=True,\n                    drop_last=True,\n                ),\n            ]\n\n            # Mix SlimPajama data and Starcoder data with these proportions:\n            weights = (0.693584, 0.306416)\n            train_data = CombinedStreamingDataset(\n                datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False\n            )\n\n        train_dataloader = StreamingDataLoader(\n            train_data, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return train_dataloader\n\n    def val_dataloader(self) -> DataLoader:\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        val_dataset = StreamingDataset(\n            input_dir=self.slimpajama_val,\n            item_loader=TokensLoader(block_size=self.seq_length),\n            shuffle=True,\n        )\n        val_dataloader = StreamingDataLoader(\n            val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return val_dataloader\n"
  },
  {
    "path": "litgpt/data/tinystories.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport glob\nimport json\nimport os\nfrom dataclasses import dataclass, field\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Optional\n\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\n\nfrom litgpt.data import DataModule\nfrom litgpt.data.alpaca import download_if_missing\nfrom litgpt.data.text_files import validate_tokenizer\nfrom litgpt.tokenizer import Tokenizer\n\n\n@dataclass\nclass TinyStories(DataModule):\n    \"\"\"The TinyStories data module: https://huggingface.co/datasets/roneneldan/TinyStories\n\n    Provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length.\n    \"\"\"\n\n    data_path: Path = Path(\"data/tinystories\")\n    \"\"\"The path to the data directory, containing two folders 'train' and 'val'\n    which are the output of the preprocessing step.\"\"\"\n    seed: int = 42\n    \"\"\"The seed to use for shuffling the dataset.\"\"\"\n    num_workers: int = 8\n    \"\"\"The number of workers to use for the dataloaders.\"\"\"\n\n    tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)\n    batch_size: int = field(default=1, init=False, repr=False)\n    max_seq_length: int = field(default=-1, init=False, repr=False)\n\n    def __post_init__(self) -> None:\n        super().__init__()\n        self.data_path_train = self.data_path / \"train\"\n        self.data_path_val = self.data_path / \"val\"\n\n    def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None:\n        self.tokenizer = tokenizer\n        self.batch_size = batch_size\n        self.max_seq_length = max_seq_length + 1  # Increase by one because we need the next token as well\n\n    def prepare_data(self) -> None:\n        from litdata import TokensLoader, optimize\n\n        download(self.data_path)\n\n        files = sorted(glob.glob(str(self.data_path / \"TinyStories_all_data\" / \"*.json\")))\n        assert len(files) > 0, f\"No json files found in {files}\"\n        assert len(files) > 1, f\"Expected at least two json files in {files}\"\n        # train/test split. let's use only shard 0 for test split, rest train\n        val_file, *train_files = files\n        num_workers = os.cpu_count() - 1\n\n        if not Path(self.data_path_train).is_dir():\n            validate_tokenizer(self.tokenizer)\n            optimize(\n                fn=partial(tokenize, tokenizer=self.tokenizer),\n                inputs=train_files,\n                output_dir=str(self.data_path_train),\n                num_workers=num_workers,\n                chunk_bytes=\"200MB\",\n                item_loader=TokensLoader(),\n            )\n        if not Path(self.data_path_val).is_dir():\n            validate_tokenizer(self.tokenizer)\n            optimize(\n                fn=partial(tokenize, tokenizer=self.tokenizer),\n                inputs=[val_file],\n                output_dir=str(self.data_path_val),\n                num_workers=1,  # there's only 1 file\n                chunk_bytes=\"200MB\",\n                item_loader=TokensLoader(),\n            )\n\n    def train_dataloader(self) -> DataLoader:\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        train_dataset = StreamingDataset(\n            input_dir=str(self.data_path_train),\n            item_loader=TokensLoader(block_size=self.max_seq_length),\n            shuffle=True,\n        )\n        train_dataloader = StreamingDataLoader(\n            train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return train_dataloader\n\n    def val_dataloader(self) -> DataLoader:\n        from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader\n\n        val_dataset = StreamingDataset(\n            input_dir=str(self.data_path_val),\n            item_loader=TokensLoader(block_size=self.max_seq_length),\n            shuffle=True,\n        )\n        val_dataloader = StreamingDataLoader(\n            val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True\n        )\n        return val_dataloader\n\n\ndef tokenize(filename: str, tokenizer: Tokenizer):\n    with open(filename, encoding=\"utf-8\") as f:\n        data = json.load(f)\n    global_rank = int(os.environ[\"DATA_OPTIMIZER_GLOBAL_RANK\"])\n    num_workers = int(os.environ[\"DATA_OPTIMIZER_NUM_WORKERS\"])\n    local_rank = global_rank % num_workers\n    for example in tqdm(data, position=local_rank):\n        text = example[\"story\"]\n        text = text.strip()  # get rid of leading/trailing whitespace\n        tokens = tokenizer.encode(text, bos=True, eos=False)  # encode the text, use BOS\n        yield tokens\n\n\n_URL = \"https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz\"\n\n\ndef download(data_dir: Path):\n    data_dir.mkdir(exist_ok=True, parents=True)\n\n    data_tar = data_dir / \"TinyStories_all_data.tar.gz\"\n    data_dir = data_dir / \"TinyStories_all_data\"\n    shard_filenames = sorted(glob.glob(str(data_dir / \"*.json\")))\n    if shard_filenames:\n        print(f\"{data_dir} already exists, skipping unpacking...\")\n        return\n\n    # download the TinyStories dataset, unless it's already downloaded\n    download_if_missing(data_tar, _URL, stream=True, mode=\"wb\")\n\n    # unpack the tar.gz file into all the data shards (json files)\n    data_dir.mkdir(exist_ok=False)\n    tar_command = f\"tar -xzf {data_tar} -C {data_dir}\"\n    print(tar_command)\n    os.system(tar_command)\n    shard_filenames = sorted(glob.glob(str(data_dir / \"*.json\")))\n    print(f\"Number of shards: {len(shard_filenames)}\")\n"
  },
  {
    "path": "litgpt/deploy/__init__.py",
    "content": ""
  },
  {
    "path": "litgpt/deploy/serve.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport json\nimport sys\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Any, Dict, Literal, Optional\n\nimport torch\n\nfrom litgpt.api import LLM\nfrom litgpt.constants import _JINJA2_AVAILABLE, _LITSERVE_AVAILABLE\nfrom litgpt.utils import auto_download_checkpoint\n\nif _LITSERVE_AVAILABLE:\n    from litserve import LitAPI, LitServer\n    from litserve.specs.openai import ChatCompletionRequest, OpenAISpec\nelse:\n    LitAPI, LitServer = object, object\n\n\nclass BaseLitAPI(LitAPI):\n    def __init__(\n        self,\n        checkpoint_dir: Path,\n        quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n        precision: Optional[str] = None,\n        temperature: float = 0.8,\n        top_k: int = 50,\n        top_p: float = 1.0,\n        max_new_tokens: int = 50,\n        devices: int = 1,\n        api_path: Optional[str] = None,\n        generate_strategy: Optional[Literal[\"sequential\", \"tensor_parallel\"]] = None,\n    ) -> None:\n        if not _LITSERVE_AVAILABLE:\n            raise ImportError(str(_LITSERVE_AVAILABLE))\n\n        super().__init__(api_path=api_path)\n\n        self.checkpoint_dir = checkpoint_dir\n        self.quantize = quantize\n        self.precision = precision\n        self.temperature = temperature\n        self.top_k = top_k\n        self.max_new_tokens = max_new_tokens\n        self.top_p = top_p\n        self.devices = devices\n        self.generate_strategy = generate_strategy\n\n    def setup(self, device: str) -> None:\n        if \":\" in device:\n            accelerator, device = device.split(\":\")\n            device = f\"[{int(device)}]\"\n        else:\n            accelerator = device\n            device = 1\n\n        print(\"Initializing model...\", file=sys.stderr)\n        self.llm = LLM.load(model=self.checkpoint_dir, distribute=None)\n\n        self.llm.distribute(\n            devices=self.devices,\n            accelerator=accelerator,\n            quantize=self.quantize,\n            precision=self.precision,\n            generate_strategy=self.generate_strategy\n            or (\"sequential\" if self.devices is not None and self.devices > 1 else None),\n        )\n        print(\"Model successfully initialized.\", file=sys.stderr)\n\n    def decode_request(self, request: Dict[str, Any]) -> Any:\n        prompt = str(request[\"prompt\"])\n        return prompt\n\n\nclass SimpleLitAPI(BaseLitAPI):\n    def __init__(\n        self,\n        checkpoint_dir: Path,\n        quantize: Optional[str] = None,\n        precision: Optional[str] = None,\n        temperature: float = 0.8,\n        top_k: int = 50,\n        top_p: float = 1.0,\n        max_new_tokens: int = 50,\n        devices: int = 1,\n        api_path: Optional[str] = None,\n        generate_strategy: Optional[str] = None,\n    ):\n        super().__init__(\n            checkpoint_dir,\n            quantize,\n            precision,\n            temperature,\n            top_k,\n            top_p,\n            max_new_tokens,\n            devices,\n            api_path=api_path,\n            generate_strategy=generate_strategy,\n        )\n\n    def setup(self, device: str):\n        super().setup(device)\n\n    def predict(self, inputs: str) -> Any:\n        output = self.llm.generate(\n            inputs,\n            temperature=self.temperature,\n            top_k=self.top_k,\n            top_p=self.top_p,\n            max_new_tokens=self.max_new_tokens,\n        )\n        return output\n\n    def encode_response(self, output: str) -> Dict[str, Any]:\n        # Convert the model output to a response payload.\n        return {\"output\": output}\n\n\nclass StreamLitAPI(BaseLitAPI):\n    def __init__(\n        self,\n        checkpoint_dir: Path,\n        quantize: Optional[str] = None,\n        precision: Optional[str] = None,\n        temperature: float = 0.8,\n        top_k: int = 50,\n        top_p: float = 1.0,\n        max_new_tokens: int = 50,\n        devices: int = 1,\n        api_path: Optional[str] = None,\n        generate_strategy: Optional[str] = None,\n    ):\n        super().__init__(\n            checkpoint_dir,\n            quantize,\n            precision,\n            temperature,\n            top_k,\n            top_p,\n            max_new_tokens,\n            devices,\n            api_path=api_path,\n            generate_strategy=generate_strategy,\n        )\n\n    def setup(self, device: str):\n        super().setup(device)\n\n    def predict(self, inputs: torch.Tensor) -> Any:\n        yield from self.llm.generate(\n            inputs,\n            temperature=self.temperature,\n            top_k=self.top_k,\n            top_p=self.top_p,\n            max_new_tokens=self.max_new_tokens,\n            stream=True,\n        )\n\n    def encode_response(self, output):\n        for out in output:\n            yield {\"output\": out}\n\n\nclass OpenAISpecLitAPI(BaseLitAPI):\n    def __init__(\n        self,\n        checkpoint_dir: Path,\n        quantize: Optional[str] = None,\n        precision: Optional[str] = None,\n        temperature: float = 0.8,\n        top_k: int = 50,\n        top_p: float = 1.0,\n        max_new_tokens: int = 50,\n        devices: int = 1,\n        api_path: Optional[str] = None,\n        generate_strategy: Optional[str] = None,\n    ):\n        super().__init__(\n            checkpoint_dir,\n            quantize,\n            precision,\n            temperature,\n            top_k,\n            top_p,\n            max_new_tokens,\n            devices,\n            api_path=api_path,\n            generate_strategy=generate_strategy,\n        )\n\n    def setup(self, device: str):\n        super().setup(device)\n        if not _JINJA2_AVAILABLE:\n            raise ImportError(str(_JINJA2_AVAILABLE))\n        from jinja2 import Template\n\n        config_path = self.checkpoint_dir / \"tokenizer_config.json\"\n        if not config_path.is_file():\n            raise FileNotFoundError(f\"Tokenizer config file not found at {config_path}\")\n\n        with open(config_path, encoding=\"utf-8\") as fp:\n            config = json.load(fp)\n            chat_template = config.get(\"chat_template\", None)\n            if chat_template is None:\n                print(\"The tokenizer config does not contain chat_template, falling back to a default.\")\n                chat_template = \"{% for m in messages %}{{ m.role }}: {{ m.content }}\\n{% endfor %}Assistant: \"\n            self.chat_template = chat_template\n\n        self.template = Template(self.chat_template)\n\n    def decode_request(self, request: \"ChatCompletionRequest\") -> Any:\n        # Apply chat template to request messages\n        return self.template.render(messages=request.messages)\n\n    def predict(self, inputs: str, context: dict) -> Any:\n        # Extract parameters from context with fallback to instance attributes\n        temperature = context.get(\"temperature\") or self.temperature\n        top_p = context.get(\"top_p\", self.top_p) or self.top_p\n        max_new_tokens = context.get(\"max_completion_tokens\") or self.max_new_tokens\n\n        # Run the model on the input and return the output.\n        yield from self.llm.generate(\n            inputs,\n            temperature=temperature,\n            top_k=self.top_k,\n            top_p=top_p,\n            max_new_tokens=max_new_tokens,\n            stream=True,\n        )\n\n\ndef run_server(\n    checkpoint_dir: Path,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n    precision: Optional[str] = None,\n    temperature: float = 0.8,\n    top_k: int = 50,\n    top_p: float = 1.0,\n    max_new_tokens: int = 50,\n    devices: int = 1,\n    accelerator: str = \"auto\",\n    port: int = 8000,\n    stream: bool = False,\n    openai_spec: bool = False,\n    access_token: Optional[str] = None,\n    api_path: Optional[str] = \"/predict\",\n    timeout: int = 30,\n    generate_strategy: Optional[Literal[\"sequential\", \"tensor_parallel\"]] = None,\n) -> None:\n    \"\"\"Serve a LitGPT model using LitServe.\n\n    Evaluate a model with the LM Evaluation Harness.\n\n    Arguments:\n        checkpoint_dir: The checkpoint directory to load the model from.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        precision: Optional precision setting to instantiate the model weights in. By default, this will\n            automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.\n        temperature: Temperature setting for the text generation. Value above 1 increase randomness.\n            Values below 1 decrease randomness.\n        top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel\n            generated text but can also lead to more incoherent texts.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        max_new_tokens: The number of generation steps to take.\n        devices: How many devices/GPUs to use.\n        accelerator: The type of accelerator to use. For example, \"auto\", \"cuda\", \"cpu\", or \"mps\".\n            The \"auto\" setting (default) chooses a GPU if available, and otherwise uses a CPU.\n        port: The network port number on which the model is configured to be served.\n        stream: Whether to stream the responses.\n        openai_spec: Whether to use the OpenAISpec and enable OpenAI-compatible API endpoints. When True, the server will provide\n            `/v1/chat/completions` endpoints that work with the OpenAI SDK and other OpenAI-compatible clients,\n            making it easy to integrate with existing applications that use the OpenAI API.\n        access_token: Optional API token to access models with restrictions.\n        api_path: The custom API path for the endpoint (e.g., \"/my_api/classify\").\n        timeout: Request timeout in seconds. Defaults to 30.\n        generate_strategy: The generation strategy to use. The \"sequential\" strategy (default for devices > 1)\n            allows running models that wouldn't fit in a single card by partitioning the transformer blocks across\n            all devices and running them sequentially. \"tensor_parallel\" shards the model using tensor parallelism.\n            If None (default for devices = 1), the model is not distributed.\n    \"\"\"\n    checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n    pprint(locals())\n\n    api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI\n\n    server = LitServer(\n        api_class(\n            checkpoint_dir=checkpoint_dir,\n            quantize=quantize,\n            precision=precision,\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            max_new_tokens=max_new_tokens,\n            devices=devices,\n            api_path=api_path,\n            generate_strategy=generate_strategy,\n        ),\n        spec=OpenAISpec() if openai_spec else None,\n        accelerator=accelerator,\n        devices=1,\n        stream=stream,\n        timeout=timeout,\n    )\n\n    server.run(port=port, generate_client_file=False)\n"
  },
  {
    "path": "litgpt/eval/evaluate.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport json\nimport os\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Optional, Union\n\nimport torch\n\nfrom litgpt.scripts.convert_lit_checkpoint import convert_lit_checkpoint\nfrom litgpt.utils import auto_download_checkpoint, copy_config_files\n\n\ndef prepare_results(results, save_filepath, print_results=True):\n    from lm_eval.utils import make_table\n\n    if print_results:\n        print(make_table(results))\n        if \"groups\" in results:\n            print(make_table(results, \"groups\"))\n\n    json_result = json.dumps(results, indent=2, ensure_ascii=False, default=str)\n    save_filepath.open(\"w\", encoding=\"utf-8\").write(json_result)\n\n\ndef convert_and_evaluate(\n    checkpoint_dir: Path,\n    tasks: Optional[str] = None,\n    out_dir: Optional[Path] = None,\n    force_conversion: bool = False,\n    num_fewshot: Optional[int] = None,\n    batch_size: Union[int, str] = 1,\n    device: Optional[str] = None,\n    dtype: Optional[Union[str, torch.dtype]] = None,\n    limit: Optional[float] = None,\n    seed: int = 1234,\n    save_filepath: Optional[Path] = None,\n    access_token: Optional[str] = None,\n) -> None:\n    \"\"\"Evaluate a model with the LM Evaluation Harness.\n\n    Arguments:\n        checkpoint_dir: Directory where the `lit_model.pth` and tokenizer files are located.\n        out_dir: Directory in which to save the converted checkpoints for evaluation.\n            Saves to `checkpoint_dir`/evaluate by default.\n        force_conversion: Set to `True` to reconvert the model and override\n            an existing model.pth from a previous evaluation call.\n        tasks: CSV of task names to evaluate. Example: \"hellaswag,truthfulqa_mc2,mmlu\"\n        num_fewshot: Number of examples in few-shot context.\n        batch_size: Batch size configuration as positive integer value (default: 1),\n            \"auto\", in the format 'auto:N', where 'auto:4' recomputes the batch size 4 times.\n        device: Device to use for evaluation, for example, \"cuda\" or \"cuda:0\".\n        limit: Limit on number of examples per task.\n        seed: Random seed.\n        save_filepath: The file where the results will be saved.\n            Saves to `out_dir/results.json` by default.\n        access_token: Optional API token to access models with restrictions.\n    \"\"\"\n    if tasks is None:\n        from lm_eval.tasks import TaskManager\n\n        taskm = TaskManager()\n        print(\"\\n\".join(taskm.task_index.keys()))\n        print(\n            \"\\n\\nTo evaluate multiple tasks, you can chain the task names \"\n            \"listed above via a comma-separated list.\"\n            \"\\nFor example: `--tasks 'hellaswag,truthfulqa_mc2,mmlu'`. \"\n            \"\\nTo search for a specific task, use `litgpt evaluate list | grep task_name`.\"\n        )\n        return\n\n    checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n    pprint(locals())\n\n    if not (isinstance(batch_size, int) and batch_size > 0) and not (\n        isinstance(batch_size, str) and batch_size.startswith(\"auto\")\n    ):\n        raise ValueError(\"batch_size must be a positive integer, 'auto', or in the format 'auto:N'.\")\n\n    from lm_eval import evaluator\n\n    if device is None:\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n    if out_dir is None:\n        out_dir = checkpoint_dir / \"evaluate\"\n    else:\n        out_dir = Path(out_dir)\n    out_dir.mkdir(parents=True, exist_ok=True)\n\n    save_filepath = out_dir / Path(\"results.json\") if save_filepath is None else Path(save_filepath)\n\n    model_path = out_dir / \"pytorch_model.bin\"\n    if not model_path.exists() or force_conversion:\n        copy_config_files(source_dir=checkpoint_dir, out_dir=out_dir)\n        convert_lit_checkpoint(checkpoint_dir=checkpoint_dir, output_dir=out_dir)\n\n        # Hack: LitGPT's conversion doesn't save a pickle file that is compatible to be loaded with\n        # `torch.load(..., weights_only=True)`, which is a requirement in HFLM.\n        # So we're `torch.load`-ing and `torch.save`-ing it again to work around this.\n        state_dict = torch.load(out_dir / \"model.pth\")\n        torch.save(state_dict, model_path)\n        os.remove(out_dir / \"model.pth\")\n\n    from lm_eval.models.huggingface import HFLM\n\n    model = HFLM(pretrained=str(out_dir.resolve()), device=device, batch_size=batch_size, dtype=dtype)\n\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n    results = evaluator.simple_evaluate(\n        model=model,\n        tasks=tasks.split(\",\"),\n        num_fewshot=num_fewshot,\n        batch_size=batch_size,\n        device=device,\n        limit=limit,\n        random_seed=seed,\n        numpy_random_seed=seed,\n        torch_random_seed=seed,\n    )\n    prepare_results(results, save_filepath)\n"
  },
  {
    "path": "litgpt/finetune/__init__.py",
    "content": ""
  },
  {
    "path": "litgpt/finetune/adapter.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport dataclasses\nimport math\nimport os\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom lightning.fabric.strategies import FSDPStrategy\nfrom lightning.fabric.utilities import ThroughputMonitor\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torchmetrics import RunningMean\n\nfrom litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable\nfrom litgpt.args import EvalArgs, LogArgs, TrainArgs\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.data import Alpaca, DataModule\nfrom litgpt.generate.base import generate\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.prompts import save_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.types import LoggerChoice\nfrom litgpt.utils import (\n    CycleIterator,\n    auto_download_checkpoint,\n    check_nvlink_connectivity,\n    check_valid_checkpoint_dir,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    create_finetuning_performance_report,\n    get_default_supported_precision,\n    init_out_dir,\n    instantiate_bnb_optimizer,\n    instantiate_torch_optimizer,\n    load_checkpoint,\n    num_parameters,\n    parse_devices,\n    select_sft_generate_example,\n)\n\n\ndef setup(\n    checkpoint_dir: Path,\n    out_dir: Path = Path(\"out/finetune/adapter\"),\n    precision: Optional[str] = None,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8-training\"]] = None,\n    devices: Union[int, str] = 1,\n    num_nodes: int = 1,\n    data: Optional[DataModule] = None,\n    train: TrainArgs = TrainArgs(\n        save_interval=1000,\n        log_interval=1,\n        global_batch_size=16,\n        micro_batch_size=1,\n        lr_warmup_steps=100,\n        epochs=5,\n        max_seq_length=None,\n    ),\n    eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),\n    log: LogArgs = LogArgs(),\n    optimizer: Union[str, Dict] = \"AdamW\",\n    logger_name: LoggerChoice = \"csv\",\n    seed: int = 1337,\n    access_token: Optional[str] = None,\n) -> None:\n    \"\"\"Finetune a model using the Adapter method.\n\n    Arguments:\n        checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.\n        out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n            /teamspace/jobs/<job-name>/share.\n        precision: The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\".\n        quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.\n        devices: How many devices/GPUs to use.\n        num_nodes: How many nodes the code is being run on.\n        data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\n        train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.\n        eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.\n        optimizer: An optimizer name (such as \"AdamW\") or config.\n        logger_name: The name of the logger to send metrics to.\n        seed: The random seed to use for reproducibility.\n        access_token: Optional API token to access models with restrictions.\n    \"\"\"\n    checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n    pprint(locals())\n    data = Alpaca() if data is None else data\n    devices = parse_devices(devices)\n    out_dir = init_out_dir(out_dir)\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    precision = precision or get_default_supported_precision(training=True)\n    logger = choose_logger(\n        logger_name,\n        out_dir,\n        name=f\"finetune-{config.name}\",\n        log_interval=train.log_interval,\n        log_args=dataclasses.asdict(log),\n    )\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    if devices * num_nodes > 1:\n        if quantize:\n            raise NotImplementedError(\n                \"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1\"\n                \" when using the --quantize flag.\"\n            )\n        strategy = FSDPStrategy(\n            auto_wrap_policy={Block},\n            activation_checkpointing_policy={Block},\n            state_dict_type=\"full\",\n            limit_all_gathers=True,\n            cpu_offload=False,\n        )\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(\n        devices=devices,\n        num_nodes=num_nodes,\n        strategy=strategy,\n        precision=precision,\n        loggers=logger,\n        plugins=plugins,\n    )\n\n    if torch.cuda.is_available() and devices > 1:\n        check_nvlink_connectivity(fabric)\n\n    fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)\n\n\ndef main(\n    fabric: L.Fabric,\n    devices: int,\n    seed: int,\n    config: Config,\n    data: DataModule,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    num_nodes: int = 1,\n) -> None:\n    validate_args(train, eval)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)\n    steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)\n    lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float(\"inf\")))\n\n    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)\n\n    if fabric.global_rank == 0:\n        os.makedirs(out_dir, exist_ok=True)\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    with fabric.init_module(empty_init=(fabric.world_size > 1)):\n        model = GPT(config)\n    mark_only_adapter_as_trainable(model)\n\n    fabric.print(f\"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}\")\n    fabric.print(f\"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}\")\n\n    model = fabric.setup_module(model)\n    if isinstance(fabric.strategy.precision, BitsandbytesPrecision):\n        optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())\n\n        from bitsandbytes.nn import StableEmbedding\n\n        old_embedding = model.transformer.wte\n        model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)\n        with torch.no_grad():\n            model.transformer.wte.weight.copy_(old_embedding.weight)\n        model.transformer.wte = model.transformer.wte.to(\n            device=old_embedding.weight.device, dtype=old_embedding.weight.dtype\n        )\n    else:\n        optimizer = instantiate_torch_optimizer(optimizer, model.parameters())\n\n    optimizer = fabric.setup_optimizers(optimizer)\n    scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)\n\n    # strict=False because missing keys due to Adapter weights not contained in state dict\n    load_checkpoint(fabric, model, checkpoint_path, strict=False)\n\n    train_time = time.perf_counter()\n    token_counts = fit(\n        fabric=fabric,\n        model=model,\n        optimizer=optimizer,\n        scheduler=scheduler,\n        train_dataloader=train_dataloader,\n        val_dataloader=val_dataloader,\n        devices=devices,\n        num_nodes=num_nodes,\n        checkpoint_dir=checkpoint_dir,\n        out_dir=out_dir,\n        train=train,\n        eval=eval,\n        data=data,\n    )\n    training_time = time.perf_counter() - train_time\n    output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)\n    fabric.print(output)\n\n    # Final evaluation\n    if eval.final_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n        fabric.log_dict(metrics)\n        fabric.print(f\"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}\")\n\n    # Save the final Adapter checkpoint at the end of training\n    save_path = out_dir / \"final\" / \"lit_model.pth.adapter\"\n    save_path.parent.mkdir(parents=True, exist_ok=True)\n    save_adapter_checkpoint(fabric, model, save_path)\n    if fabric.global_rank == 0:\n        # Copy checkpoint files from original checkpoint dir\n        copy_config_files(checkpoint_dir, save_path.parent)\n        save_hyperparameters(setup, save_path.parent)\n        save_prompt_style(data.prompt_style, save_path.parent)\n\n\ndef fit(\n    fabric: L.Fabric,\n    model: GPT,\n    optimizer: torch.optim.Optimizer,\n    scheduler: torch.optim.lr_scheduler,\n    train_dataloader: DataLoader,\n    val_dataloader: DataLoader,\n    devices: int,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    data: DataModule,\n    num_nodes: int = 1,\n) -> None:\n    tokenizer = Tokenizer(checkpoint_dir)\n    longest_seq_length, longest_seq_ix = get_longest_seq_length(\n        ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])\n    )\n    model.max_seq_length = min(longest_seq_length, train.max_seq_length or float(\"inf\"))\n    fabric.print(\n        f\"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is\"\n        f\" {model.max_seq_length} and context length is {model.config.block_size}\"\n    )\n\n    if eval.initial_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        val_loss = f\"{val_loss:.3f}\"\n    else:\n        fabric.print(\"Verifying settings ...\")\n        validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False)  # sanity check\n        val_loss = \"n/a\"\n\n    train_iterator = CycleIterator(train_dataloader)\n    throughput = ThroughputMonitor(fabric, window_size=50)\n    running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(\n        fabric.device\n    )\n    max_steps = train.max_steps or float(\"inf\")\n    step_count = 0\n    iter_num = 0\n    total_lengths = 0\n    total_t0 = time.perf_counter()\n\n    token_counts = {\n        \"raw_tokens\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template_and_padding\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n    }\n\n    while step_count < max_steps:\n        iter_num += 1\n        iter_t0 = time.perf_counter()\n        batch = next(train_iterator)\n        if train_iterator.epoch >= train.epochs:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n\n        is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids, lm_head_chunk_size=128)\n            # shift the targets such that output n predicts token n+1\n            logits[-1] = logits[-1][..., :-1, :]\n            loss = chunked_cross_entropy(logits, targets[..., 1:])\n            fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))\n\n        running_loss.update(loss.detach())\n\n        if not is_accumulating:\n            optimizer.step()\n            optimizer.zero_grad()\n            scheduler.step()\n            step_count += 1\n\n        token_counts[\"raw_tokens\"] += batch[\"token_counts\"][\"raw\"].sum().item()\n        token_counts[\"raw_tokens_plus_prompt_template\"] += (\n            batch[\"token_counts\"][\"raw_plus_prompt_template\"].sum().item()\n        )\n        token_counts[\"raw_tokens_plus_prompt_template_and_padding\"] += input_ids.numel()\n\n        total_lengths += input_ids.numel()\n        if iter_num % train.log_interval == 0:\n            loss = running_loss.compute().item()  # expensive device-to-host synchronization\n            t1 = time.perf_counter()\n            throughput.update(\n                time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths\n            )\n            throughput.compute_and_log(step=iter_num)\n            metrics = {\n                \"loss\": loss,\n                \"iter\": iter_num,\n                \"step\": step_count,\n                \"epoch\": train_iterator.epoch,\n                \"iter_time\": t1 - iter_t0,\n                \"tokens\": token_counts[\"raw_tokens_plus_prompt_template\"],\n                \"total_tokens\": token_counts[\"raw_tokens_plus_prompt_template\"] * fabric.world_size,\n                \"learning_rate\": scheduler.get_last_lr()[0],\n            }\n            if isinstance(val_loss, torch.Tensor):\n                val_loss = f\"{val_loss:.3f}\"\n            fabric.print(\n                f\"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |\"\n                f\" loss train: {metrics['loss']:.3f},\"\n                f\" val: {val_loss} |\"\n                f\" iter time: {metrics['iter_time'] * 1000:.2f} ms\"\n                f\"{' (step)' if not is_accumulating else ''}\"\n            )\n            fabric.log_dict(metrics, step=iter_num)\n\n        if not is_accumulating and step_count % eval.interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader, eval)\n            generate_example(fabric, model, tokenizer, eval, data)\n            t1 = time.perf_counter() - t0\n\n            val_loss_tensor = val_loss.detach().clone().to(fabric.device)\n            val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)\n\n            fabric.all_reduce(val_loss_tensor, reduce_op=\"mean\")\n            fabric.all_reduce(val_time_tensor, reduce_op=\"mean\")\n\n            fabric.print(\n                f\"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms\"\n            )\n            metrics = {\"val_loss\": val_loss_tensor, \"val_ppl\": math.exp(val_loss_tensor)}\n            fabric.log_dict(metrics, step=iter_num)\n            fabric.barrier()\n\n        if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:\n            checkpoint_file = out_dir / f\"step-{step_count:06d}\" / \"lit_model.pth.adapter\"\n            checkpoint_file.parent.mkdir(parents=True, exist_ok=True)\n            save_adapter_checkpoint(fabric, model, checkpoint_file)\n            if fabric.global_rank == 0:\n                copy_config_files(checkpoint_dir, checkpoint_file.parent)\n                save_hyperparameters(setup, checkpoint_file.parent)\n                save_prompt_style(data.prompt_style, checkpoint_file.parent)\n\n    total_token_counts = {}\n    for key in token_counts:\n        total = fabric.all_reduce(token_counts[key], reduce_op=\"sum\")\n        total_token_counts[key] = total.item()\n\n    return total_token_counts\n\n\n# FSDP has issues with `inference_mode`\n@torch.no_grad()\ndef validate(\n    fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True\n) -> torch.Tensor:\n    if verbose:\n        fabric.print(\"Validating ...\")\n    model.eval()\n    losses = torch.zeros(min(len(val_dataloader), eval.max_iters))\n    for k, batch in enumerate(val_dataloader):\n        if k >= eval.max_iters:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n        logits = model(input_ids)\n        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)\n\n    val_loss = losses.mean()\n    model.train()\n    return val_loss\n\n\n# the adapter \"kv cache\" cannot be initialized under `inference_mode`\n@torch.no_grad()\ndef generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):\n    instruction = select_sft_generate_example(eval, data)\n    fabric.print(instruction)\n    prompt = data.prompt_style.apply(instruction)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    model.eval()\n\n    with fabric.init_tensor():\n        # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n        model.set_kv_cache(batch_size=1)\n\n    max_returned_tokens = len(encoded) + eval.max_new_tokens\n\n    if max_returned_tokens < model.max_seq_length:\n        with fabric.init_tensor():\n            # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n            model.set_kv_cache(batch_size=1)\n        output = generate(\n            model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id\n        )\n        model.clear_kv_cache()\n        model.train()\n        output = tokenizer.decode(output)\n        fabric.print(f\"{output}\\n\")\n    else:\n        print(\n            f\"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) \"\n            f\"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. \"\n            f\"The model's supported context size (post-training) is {model.config.block_size}.\"\n        )\n\n\ndef get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):\n    # linear warmup followed by cosine annealing\n    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))\n    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])\n\n\ndef get_dataloaders(\n    fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs\n) -> Tuple[DataLoader, DataLoader]:\n    data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)\n    with fabric.rank_zero_first():\n        data.prepare_data()\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n    return train_dataloader, val_dataloader\n\n\ndef get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:\n    # find out the minimum max_seq_length required during fine-tuning (saves memory!)\n    lengths = [len(d[\"input_ids\"]) for d in data]\n    longest_seq_length = max(lengths)\n    longest_seq_ix = lengths.index(longest_seq_length)\n    return longest_seq_length, longest_seq_ix\n\n\ndef save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:\n    fabric.print(f\"Saving adapter weights to {str(file_path)!r}\")\n    fabric.save(file_path, {\"model\": model}, filter={\"model\": adapter_filter})\n\n\ndef validate_args(train: TrainArgs, eval: EvalArgs) -> None:\n    issues = []\n    unsupported = [(train, [\"max_tokens\", \"max_norm\", \"tie_embeddings\", \"lr_warmup_fraction\"])]\n    for args, names in unsupported:\n        for name in names:\n            if getattr(args, name) is not None:\n                issues.append(f\"{__file__} doesn't support the {name!r} argument. This is set in {args}\")\n    required = [(train, [\"epochs\"]), (eval, [\"max_new_tokens\"])]\n    for args, names in required:\n        for name in names:\n            if getattr(args, name) is None:\n                issues.append(f\"{__file__} requires the {name!r} argument. This is set in {args}\")\n    if not train.epochs and not train.max_steps:\n        issues.append(f\"{__file__} requires either epochs or max_steps to be set. This is set in {train}\")\n    if issues:\n        raise ValueError(\"\\n\".join(issues))\n"
  },
  {
    "path": "litgpt/finetune/adapter_v2.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport dataclasses\nimport math\nimport os\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom lightning.fabric.strategies import FSDPStrategy\nfrom lightning.fabric.utilities import ThroughputMonitor\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torchmetrics import RunningMean\n\nfrom litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable\nfrom litgpt.args import EvalArgs, LogArgs, TrainArgs\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.data import Alpaca, DataModule\nfrom litgpt.generate.base import generate\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.prompts import save_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.types import LoggerChoice\nfrom litgpt.utils import (\n    CycleIterator,\n    auto_download_checkpoint,\n    check_nvlink_connectivity,\n    check_valid_checkpoint_dir,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    create_finetuning_performance_report,\n    get_default_supported_precision,\n    init_out_dir,\n    instantiate_bnb_optimizer,\n    instantiate_torch_optimizer,\n    load_checkpoint,\n    load_checkpoint_update,\n    num_parameters,\n    parse_devices,\n    select_sft_generate_example,\n)\n\n\ndef setup(\n    checkpoint_dir: Path,\n    out_dir: Path = Path(\"out/finetune/adapter-v2\"),\n    precision: Optional[str] = None,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8-training\"]] = None,\n    devices: Union[int, str] = 1,\n    num_nodes: int = 1,\n    resume: Optional[bool] = False,\n    data: Optional[DataModule] = None,\n    train: TrainArgs = TrainArgs(\n        save_interval=1000,\n        log_interval=1,\n        global_batch_size=16,\n        micro_batch_size=1,\n        lr_warmup_steps=100,\n        epochs=5,\n        max_seq_length=None,\n    ),\n    eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),\n    log: LogArgs = LogArgs(),\n    optimizer: Union[str, Dict] = \"AdamW\",\n    logger_name: LoggerChoice = \"csv\",\n    seed: int = 1337,\n    access_token: Optional[str] = None,\n) -> None:\n    \"\"\"Finetune a model using the Adapter V2 method.\n\n    Arguments:\n        checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.\n        out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n            /teamspace/jobs/<job-name>/share.\n        precision: The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\".\n        quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.\n        devices: How many devices/GPUs to use.\n        num_nodes: How many nodes the code is being run on.\n        data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\n        train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.\n        eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.\n        optimizer: An optimizer name (such as \"AdamW\") or config.\n        logger_name: The name of the logger to send metrics to.\n        seed: The random seed to use for reproducibility.\n        access_token: Optional API token to access models with restrictions.\n    \"\"\"\n    checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n    pprint(locals())\n    data = Alpaca() if data is None else data\n    devices = parse_devices(devices)\n    out_dir = init_out_dir(out_dir)\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    precision = precision or get_default_supported_precision(training=True)\n    logger = choose_logger(\n        logger_name,\n        out_dir,\n        name=f\"finetune-{config.name}\",\n        log_interval=train.log_interval,\n        log_args=dataclasses.asdict(log),\n    )\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    if devices * num_nodes > 1:\n        if quantize:\n            raise NotImplementedError(\n                \"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1\"\n                \" when using the --quantize flag.\"\n            )\n        strategy = FSDPStrategy(\n            auto_wrap_policy={Block},\n            activation_checkpointing_policy={Block},\n            state_dict_type=\"full\",\n            limit_all_gathers=True,\n            cpu_offload=False,\n        )\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(\n        devices=devices,\n        num_nodes=num_nodes,\n        strategy=strategy,\n        precision=precision,\n        loggers=logger,\n        plugins=plugins,\n    )\n\n    if torch.cuda.is_available() and devices > 1:\n        check_nvlink_connectivity(fabric)\n\n    fabric.launch(main, devices, seed, config, data, resume, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)\n\n\ndef main(\n    fabric: L.Fabric,\n    devices: int,\n    seed: int,\n    config: Config,\n    data: DataModule,\n    resume: bool,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    num_nodes: int = 1,\n) -> None:\n    validate_args(train, eval)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)\n    steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)\n    lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float(\"inf\")))\n\n    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)\n\n    if fabric.global_rank == 0:\n        os.makedirs(out_dir, exist_ok=True)\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    with fabric.init_module(empty_init=(fabric.world_size > 1)):\n        model = GPT(config)\n    mark_only_adapter_v2_as_trainable(model)\n\n    fabric.print(f\"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}\")\n    fabric.print(f\"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}\")\n\n    model = fabric.setup_module(model)\n    if isinstance(fabric.strategy.precision, BitsandbytesPrecision):\n        optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())\n\n        from bitsandbytes.nn import StableEmbedding\n\n        old_embedding = model.transformer.wte\n        model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)\n        with torch.no_grad():\n            model.transformer.wte.weight.copy_(old_embedding.weight)\n        model.transformer.wte = model.transformer.wte.to(\n            device=old_embedding.weight.device, dtype=old_embedding.weight.dtype\n        )\n    else:\n        optimizer = instantiate_torch_optimizer(optimizer, model.parameters())\n\n    optimizer = fabric.setup_optimizers(optimizer)\n    scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)\n    if resume:\n        # Finding last trace of adapter training\n        try:\n            resume = max(out_dir.rglob(\"step-*/*.pth.adapter_v2\"), key=(lambda p: int(p.parent.name.split(\"-\")[1])))\n            fabric.print(f\"Resuming training from {resume}\")\n            load_checkpoint_update(fabric, resume, model, checkpoint_path, strict=False)\n            resume = True\n        except ValueError:\n            fabric.print(\"No previous adapter found. Finetune from start.\")\n            resume = False\n            load_checkpoint(fabric, model, checkpoint_path, strict=False)\n    else:\n        # strict=False because missing keys due to Adapter weights not contained in state dict\n        load_checkpoint(fabric, model, checkpoint_path, strict=False)\n\n    mark_only_adapter_v2_as_trainable(model)\n\n    train_time = time.perf_counter()\n    token_counts = fit(\n        fabric=fabric,\n        model=model,\n        optimizer=optimizer,\n        scheduler=scheduler,\n        train_dataloader=train_dataloader,\n        val_dataloader=val_dataloader,\n        devices=devices,\n        resume=resume,\n        num_nodes=num_nodes,\n        checkpoint_dir=checkpoint_dir,\n        out_dir=out_dir,\n        train=train,\n        eval=eval,\n        data=data,\n    )\n    training_time = time.perf_counter() - train_time\n    output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)\n    fabric.print(output)\n\n    # Final evaluation\n    if eval.final_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n        fabric.log_dict(metrics)\n        fabric.print(f\"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}\")\n\n    # Save the final Adapter checkpoint at the end of training\n    save_path = out_dir / \"final\" / \"lit_model.pth.adapter_v2\"\n    save_path.parent.mkdir(parents=True, exist_ok=True)\n    save_adapter_v2_checkpoint(fabric, model, save_path)\n    if fabric.global_rank == 0:\n        # Copy checkpoint files from original checkpoint dir\n        copy_config_files(checkpoint_dir, save_path.parent)\n        save_hyperparameters(setup, save_path.parent)\n        save_prompt_style(data.prompt_style, save_path.parent)\n\n\ndef fit(\n    fabric: L.Fabric,\n    model: GPT,\n    optimizer: torch.optim.Optimizer,\n    scheduler: torch.optim.lr_scheduler,\n    train_dataloader: DataLoader,\n    val_dataloader: DataLoader,\n    devices: int,\n    resume: bool,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    data: DataModule,\n    num_nodes: int = 1,\n) -> None:\n    tokenizer = Tokenizer(checkpoint_dir)\n    longest_seq_length, longest_seq_ix = get_longest_seq_length(\n        ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])\n    )\n    model.max_seq_length = min(longest_seq_length, train.max_seq_length or float(\"inf\"))\n    fabric.print(\n        f\"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is\"\n        f\" {model.max_seq_length} and context length is {model.config.block_size}\"\n    )\n\n    if eval.initial_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        val_loss = f\"{val_loss:.3f}\"\n    else:\n        fabric.print(\"Verifying settings ...\")\n        validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False)  # sanity check\n        val_loss = \"n/a\"\n\n    train_iterator = CycleIterator(train_dataloader)\n    throughput = ThroughputMonitor(fabric, window_size=50)\n    running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(\n        fabric.device\n    )\n    max_steps = train.max_steps or float(\"inf\")\n    step_count = 0\n    iter_num = 0\n    total_lengths = 0\n    total_t0 = time.perf_counter()\n\n    token_counts = {\n        \"raw_tokens\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template_and_padding\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n    }\n\n    if not resume:\n        try:\n            iter_match = max(out_dir.rglob(\"step-*/*.pth.adapter_v2\"), key=lambda p: int(p.parent.name.split(\"-\")[1]))\n            step_count = int(iter_match.parent.name.split(\"-\")[1]) if iter_match else 0\n        except ValueError:\n            step_count = 0\n\n    fabric.print(f\"Starting at step count {step_count}\")\n    while step_count < max_steps and train_iterator.epoch < train.epochs:\n        iter_num += 1\n        iter_t0 = time.perf_counter()\n        batch = next(train_iterator)\n        if train_iterator.epoch >= train.epochs:\n            break\n\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n\n        is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids, lm_head_chunk_size=128)\n            # shift the targets such that output n predicts token n+1\n            logits[-1] = logits[-1][..., :-1, :]\n            loss = chunked_cross_entropy(logits, targets[..., 1:])\n            fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))\n\n        running_loss.update(loss.detach())\n\n        if not is_accumulating:\n            optimizer.step()\n            optimizer.zero_grad()\n            scheduler.step()\n            step_count += 1\n\n        token_counts[\"raw_tokens\"] += batch[\"token_counts\"][\"raw\"].sum().item()\n        token_counts[\"raw_tokens_plus_prompt_template\"] += (\n            batch[\"token_counts\"][\"raw_plus_prompt_template\"].sum().item()\n        )\n        token_counts[\"raw_tokens_plus_prompt_template_and_padding\"] += input_ids.numel()\n\n        total_lengths += input_ids.numel()\n        if iter_num % train.log_interval == 0:\n            loss = running_loss.compute().item()  # expensive device-to-host synchronization\n            t1 = time.perf_counter()\n            throughput.update(\n                time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths\n            )\n            throughput.compute_and_log(step=iter_num)\n            metrics = {\n                \"loss\": loss,\n                \"iter\": iter_num,\n                \"step\": step_count,\n                \"epoch\": train_iterator.epoch,\n                \"iter_time\": t1 - iter_t0,\n                \"tokens\": token_counts[\"raw_tokens_plus_prompt_template\"],\n                \"total_tokens\": token_counts[\"raw_tokens_plus_prompt_template\"] * fabric.world_size,\n                \"learning_rate\": scheduler.get_last_lr()[0],\n            }\n            if isinstance(val_loss, torch.Tensor):\n                val_loss = f\"{val_loss:.3f}\"\n            fabric.print(\n                f\"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |\"\n                f\" loss train: {metrics['loss']:.3f},\"\n                f\" val: {val_loss} |\"\n                f\" iter time: {metrics['iter_time'] * 1000:.2f} ms\"\n                f\"{' (step)' if not is_accumulating else ''}\"\n            )\n            fabric.log_dict(metrics, step=iter_num)\n\n        if not is_accumulating and step_count % eval.interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader, eval)\n            generate_example(fabric, model, tokenizer, eval, data)\n            t1 = time.perf_counter() - t0\n\n            val_loss_tensor = val_loss.detach().clone().to(fabric.device)\n            val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)\n\n            fabric.all_reduce(val_loss_tensor, reduce_op=\"mean\")\n            fabric.all_reduce(val_time_tensor, reduce_op=\"mean\")\n\n            fabric.print(\n                f\"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms\"\n            )\n            metrics = {\"val_loss\": val_loss_tensor, \"val_ppl\": math.exp(val_loss_tensor)}\n            fabric.log_dict(metrics, step=iter_num)\n            fabric.barrier()\n\n        if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:\n            checkpoint_file = out_dir / f\"step-{step_count:06d}\" / \"lit_model.pth.adapter_v2\"\n            checkpoint_file.parent.mkdir(parents=True, exist_ok=True)\n            save_adapter_v2_checkpoint(fabric, model, checkpoint_file)\n            if fabric.global_rank == 0:\n                copy_config_files(checkpoint_dir, checkpoint_file.parent)\n                save_hyperparameters(setup, checkpoint_file.parent)\n                save_prompt_style(data.prompt_style, checkpoint_file.parent)\n\n    total_token_counts = {}\n    for key in token_counts:\n        total = fabric.all_reduce(token_counts[key], reduce_op=\"sum\")\n        total_token_counts[key] = total.item()\n\n    return total_token_counts\n\n\n# FSDP has issues with `inference_mode`\n@torch.no_grad()\ndef validate(\n    fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True\n) -> torch.Tensor:\n    if verbose:\n        fabric.print(\"Validating ...\")\n    model.eval()\n    losses = torch.zeros(min(len(val_dataloader), eval.max_iters))\n    for k, batch in enumerate(val_dataloader):\n        if k >= eval.max_iters:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n        logits = model(input_ids)\n        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)\n\n    val_loss = losses.mean()\n    model.train()\n    return val_loss\n\n\n# the adapter \"kv cache\" cannot be initialized under `inference_mode`\n@torch.no_grad()\ndef generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):\n    instruction = select_sft_generate_example(eval, data)\n    fabric.print(instruction)\n    prompt = data.prompt_style.apply(instruction)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    model.eval()\n\n    max_returned_tokens = len(encoded) + eval.max_new_tokens\n\n    if max_returned_tokens < model.max_seq_length:\n        with fabric.init_tensor():\n            # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n            model.set_kv_cache(batch_size=1)\n        output = generate(\n            model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id\n        )\n        model.clear_kv_cache()\n        model.train()\n        output = tokenizer.decode(output)\n        fabric.print(f\"{output}\\n\")\n    else:\n        print(\n            f\"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) \"\n            f\"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. \"\n            f\"The model's supported context size (post-training) is {model.config.block_size}.\"\n        )\n\n\ndef get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):\n    # linear warmup followed by cosine annealing\n    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))\n    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])\n\n\ndef get_dataloaders(\n    fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs\n) -> Tuple[DataLoader, DataLoader]:\n    data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)\n    with fabric.rank_zero_first():\n        data.prepare_data()\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n    return train_dataloader, val_dataloader\n\n\ndef get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:\n    # find out the minimum max_seq_length required during fine-tuning (saves memory!)\n    lengths = [len(d[\"input_ids\"]) for d in data]\n    longest_seq_length = max(lengths)\n    longest_seq_ix = lengths.index(longest_seq_length)\n    return longest_seq_length, longest_seq_ix\n\n\ndef save_adapter_v2_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:\n    fabric.print(f\"Saving adapter v2 weights to {str(file_path)!r}\")\n    fabric.save(file_path, {\"model\": model}, filter={\"model\": adapter_filter})\n\n\ndef validate_args(train: TrainArgs, eval: EvalArgs) -> None:\n    issues = []\n    unsupported = [(train, [\"max_tokens\", \"max_norm\", \"tie_embeddings\", \"lr_warmup_fraction\"])]\n    for args, names in unsupported:\n        for name in names:\n            if getattr(args, name) is not None:\n                issues.append(f\"{__file__} doesn't support the {name!r} argument. This is set in {args}\")\n    required = [(train, [\"epochs\"]), (eval, [\"max_new_tokens\"])]\n    for args, names in required:\n        for name in names:\n            if getattr(args, name) is None:\n                issues.append(f\"{__file__} requires the {name!r} argument. This is set in {args}\")\n    if not train.epochs and not train.max_steps:\n        issues.append(f\"{__file__} requires either epochs or max_steps to be set. This is set in {train}\")\n    if issues:\n        raise ValueError(\"\\n\".join(issues))\n"
  },
  {
    "path": "litgpt/finetune/full.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport dataclasses\nimport math\nimport os\nimport time\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.strategies import FSDPStrategy\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torchmetrics import RunningMean\n\nfrom litgpt.args import EvalArgs, LogArgs, TrainArgs\nfrom litgpt.data import Alpaca, DataModule\nfrom litgpt.generate.base import generate\nfrom litgpt.model import GPT, Block, Config\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.prompts import save_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.types import LoggerChoice\nfrom litgpt.utils import (\n    CycleIterator,\n    auto_download_checkpoint,\n    check_nvlink_connectivity,\n    check_valid_checkpoint_dir,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    create_finetuning_performance_report,\n    find_resume_path,\n    get_default_supported_precision,\n    init_out_dir,\n    instantiate_torch_optimizer,\n    load_checkpoint,\n    num_parameters,\n    parse_devices,\n    select_sft_generate_example,\n)\n\n\ndef setup(\n    checkpoint_dir: Path,\n    out_dir: Path = Path(\"out/finetune/full\"),\n    precision: Optional[str] = None,\n    devices: Union[int, str] = 1,\n    num_nodes: int = 1,\n    resume: Union[bool, Literal[\"auto\"], Path] = False,\n    data: Optional[DataModule] = None,\n    train: TrainArgs = TrainArgs(\n        save_interval=1000,\n        log_interval=1,\n        global_batch_size=16,\n        micro_batch_size=1,\n        lr_warmup_steps=100,\n        epochs=5,\n        max_seq_length=None,\n    ),\n    eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),\n    log: LogArgs = LogArgs(),\n    optimizer: Union[str, Dict] = \"AdamW\",\n    logger_name: LoggerChoice = \"csv\",\n    seed: int = 1337,\n    access_token: Optional[str] = None,\n) -> None:\n    \"\"\"Finetune a model.\n\n    Arguments:\n        checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.\n        out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n            /teamspace/jobs/<job-name>/share.\n        precision: The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\".\n        devices: How many devices/GPUs to use\n        num_nodes: How many nodes the code is being run on.\n        resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n            from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n            ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n        data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\n        train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.\n        eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.\n        optimizer: An optimizer name (such as \"AdamW\") or config.\n        logger_name: The name of the logger to send metrics to.\n        seed: The random seed to use for reproducibility.\n        access_token: Optional API token to access models with restrictions.\n    \"\"\"\n    checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n    pprint(locals())\n    data = Alpaca() if data is None else data\n    devices = parse_devices(devices)\n    out_dir = init_out_dir(out_dir)\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    precision = precision or get_default_supported_precision(training=True)\n    logger = choose_logger(\n        logger_name,\n        out_dir,\n        name=f\"finetune-{config.name}\",\n        resume=bool(resume),\n        log_interval=train.log_interval,\n        log_args=dataclasses.asdict(log),\n    )\n\n    if devices * num_nodes > 1:\n        strategy = FSDPStrategy(\n            auto_wrap_policy={Block},\n            activation_checkpointing_policy={Block},\n            state_dict_type=\"full\",\n            limit_all_gathers=True,\n            cpu_offload=False,\n        )\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=logger)\n\n    if torch.cuda.is_available() and devices > 1:\n        check_nvlink_connectivity(fabric)\n\n    fabric.launch(main, devices, resume, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)\n\n\ndef main(\n    fabric: L.Fabric,\n    devices: int,\n    resume: Union[bool, Literal[\"auto\"], Path],\n    seed: int,\n    config: Config,\n    data: DataModule,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    num_nodes: int = 1,\n) -> None:\n    validate_args(train, eval)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)\n    steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)\n    lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float(\"inf\")))\n\n    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)\n\n    if fabric.global_rank == 0:\n        os.makedirs(out_dir, exist_ok=True)\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    with fabric.init_module(empty_init=(fabric.world_size > 1)):\n        model = GPT(config)\n\n    fabric.print(f\"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}\")\n\n    model = fabric.setup(model)\n\n    optimizer = instantiate_torch_optimizer(optimizer, model.parameters())\n    optimizer = fabric.setup_optimizers(optimizer)\n    scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)\n    state = {\"model\": model, \"optimizer\": optimizer, \"scheduler\": scheduler, \"iter_num\": 0, \"step_count\": 0}\n\n    resume = find_resume_path(resume, out_dir)\n    if resume:\n        fabric.print(f\"Resuming training from {resume}\")\n        fabric.load(resume, state)\n    else:\n        load_checkpoint(fabric, state[\"model\"], checkpoint_path)\n\n    train_time = time.perf_counter()\n    token_counts = fit(\n        fabric=fabric,\n        state=state,\n        train_dataloader=train_dataloader,\n        val_dataloader=val_dataloader,\n        devices=devices,\n        num_nodes=num_nodes,\n        resume=resume,\n        checkpoint_dir=checkpoint_dir,\n        out_dir=out_dir,\n        train=train,\n        eval=eval,\n        data=data,\n    )\n    training_time = time.perf_counter() - train_time\n    output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)\n    fabric.print(output)\n\n    # Final evaluation\n    if eval.final_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n        fabric.log_dict(metrics, step=state[\"iter_num\"])\n        fabric.print(f\"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}\")\n\n    # Save the final checkpoint at the end of training\n    save_path = out_dir / \"final\" / \"lit_model.pth\"\n    save_path.parent.mkdir(parents=True, exist_ok=True)\n    fabric.save(save_path, {\"model\": state[\"model\"]})\n    if fabric.global_rank == 0:\n        # Copy checkpoint files from original checkpoint dir\n        copy_config_files(checkpoint_dir, save_path.parent)\n        save_hyperparameters(setup, save_path.parent)\n        save_prompt_style(data.prompt_style, save_path.parent)\n\n\ndef fit(\n    fabric: L.Fabric,\n    state: Dict,\n    train_dataloader: DataLoader,\n    val_dataloader: DataLoader,\n    devices: int,\n    resume: Union[bool, Literal[\"auto\"], Path],\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    data: DataModule,\n    num_nodes: int = 1,\n) -> None:\n    model = state[\"model\"]\n    optimizer = state[\"optimizer\"]\n    scheduler = state[\"scheduler\"]\n    tokenizer = Tokenizer(checkpoint_dir)\n    longest_seq_length, longest_seq_ix = get_longest_seq_length(\n        ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])\n    )\n    model.max_seq_length = min(longest_seq_length, train.max_seq_length or float(\"inf\"))\n    fabric.print(\n        f\"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is\"\n        f\" {model.max_seq_length} and context length is {model.config.block_size}\"\n    )\n\n    token_counts = {\n        \"raw_tokens\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template_and_padding\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n    }\n\n    if eval.initial_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        val_loss = f\"{val_loss:.3f}\"\n    else:\n        fabric.print(\"Verifying settings ...\")\n        validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False)  # sanity check\n        val_loss = \"n/a\"\n\n    initial_iter = state[\"iter_num\"]\n    max_steps = train.max_steps or float(\"inf\")\n    train_iterator = CycleIterator(train_dataloader)\n\n    # resume data loader state by fast-forwarding through all seen batches\n    if resume:\n        resume_t0 = time.perf_counter()\n        for resume_iter in range(initial_iter):\n            next(train_iterator)\n            if resume_iter % 1000 == 0:\n                fabric.print(f\"Resuming dataset: {resume_iter} / {initial_iter}\")\n        fabric.barrier()\n        fabric.print(\n            f\"Resuming data loader finished. Took {time.perf_counter() - resume_t0:.1f} seconds to reach iteration\"\n            f\" {initial_iter}.\"\n        )\n\n    running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(\n        fabric.device\n    )\n    fabric.barrier()\n\n    while state[\"step_count\"] < max_steps:\n        state[\"iter_num\"] += 1\n        iter_t0 = time.perf_counter()\n        batch = next(train_iterator)\n        if train_iterator.epoch >= train.epochs:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n\n        is_accumulating = state[\"iter_num\"] % train.gradient_accumulation_iters(devices, num_nodes) != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids)\n            # shift the targets such that output n predicts token n+1\n            loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])\n            fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))\n\n        running_loss.update(loss.detach())\n\n        if not is_accumulating:\n            optimizer.step()\n            optimizer.zero_grad()\n            scheduler.step()\n            state[\"step_count\"] += 1\n\n        token_counts[\"raw_tokens\"] += batch[\"token_counts\"][\"raw\"].sum().item()\n        token_counts[\"raw_tokens_plus_prompt_template\"] += (\n            batch[\"token_counts\"][\"raw_plus_prompt_template\"].sum().item()\n        )\n        token_counts[\"raw_tokens_plus_prompt_template_and_padding\"] += input_ids.numel()\n\n        if state[\"iter_num\"] % train.log_interval == 0:\n            loss = running_loss.compute().item()  # expensive device-to-host synchronization\n            t1 = time.perf_counter()\n            metrics = {\n                \"loss\": loss,\n                \"iter\": state[\"iter_num\"],\n                \"step\": state[\"step_count\"],\n                \"epoch\": train_iterator.epoch,\n                \"iter_time\": t1 - iter_t0,\n                \"tokens\": token_counts[\"raw_tokens_plus_prompt_template\"],\n                \"total_tokens\": token_counts[\"raw_tokens_plus_prompt_template\"] * fabric.world_size,\n                \"learning_rate\": scheduler.get_last_lr()[0],\n            }\n            if isinstance(val_loss, torch.Tensor):\n                val_loss = f\"{val_loss:.3f}\"\n            fabric.print(\n                f\"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |\"\n                f\" loss train: {metrics['loss']:.3f},\"\n                f\" val: {val_loss} |\"\n                f\" iter time: {metrics['iter_time'] * 1000:.2f} ms\"\n                f\"{' (step)' if not is_accumulating else ''}\"\n            )\n            fabric.log_dict(metrics, step=state[\"iter_num\"])\n\n        if not is_accumulating and state[\"step_count\"] % eval.interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader, eval)\n            generate_example(fabric, model, tokenizer, eval, data)\n            t1 = time.perf_counter() - t0\n\n            val_loss_tensor = val_loss.detach().clone().to(fabric.device)\n            val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)\n\n            fabric.all_reduce(val_loss_tensor, reduce_op=\"mean\")\n            fabric.all_reduce(val_time_tensor, reduce_op=\"mean\")\n\n            fabric.print(\n                f\"iter {state['iter_num']}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms\"\n            )\n            metrics = {\"val_loss\": val_loss_tensor, \"val_ppl\": math.exp(val_loss_tensor)}\n            fabric.log_dict(metrics, step=state[\"iter_num\"])\n            fabric.barrier()\n        if train.save_interval is not None and not is_accumulating and state[\"step_count\"] % train.save_interval == 0:\n            checkpoint_file = out_dir / f\"step-{state['step_count']:06d}\" / \"lit_model.pth\"\n            checkpoint_file.parent.mkdir(parents=True, exist_ok=True)\n            fabric.print(f\"Saving checkpoint to {str(checkpoint_file.parent)!r}\")\n            fabric.save(checkpoint_file, state)\n            if fabric.global_rank == 0:\n                copy_config_files(checkpoint_dir, checkpoint_file.parent)\n                save_hyperparameters(setup, checkpoint_file.parent)\n                save_prompt_style(data.prompt_style, checkpoint_file.parent)\n\n    total_token_counts = {}\n    for key in token_counts:\n        total = fabric.all_reduce(token_counts[key], reduce_op=\"sum\")\n        total_token_counts[key] = total.item()\n\n    return total_token_counts\n\n\n# FSDP has issues with `inference_mode`\n@torch.no_grad()\ndef validate(\n    fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True\n) -> torch.Tensor:\n    if verbose:\n        fabric.print(\"Validating ...\")\n    model.eval()\n    losses = torch.zeros(min(len(val_dataloader), eval.max_iters))\n    for k, batch in enumerate(val_dataloader):\n        if k >= eval.max_iters:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n        logits = model(input_ids)\n        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)\n\n    val_loss = losses.mean()\n    model.train()\n    return val_loss\n\n\n@torch.no_grad()\ndef generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):\n    instruction = select_sft_generate_example(eval, data)\n    fabric.print(instruction)\n    prompt = data.prompt_style.apply(instruction)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    model.eval()\n\n    with fabric.init_tensor():\n        # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n        model.set_kv_cache(batch_size=1)\n\n    max_returned_tokens = len(encoded) + eval.max_new_tokens\n\n    if max_returned_tokens < model.max_seq_length:\n        with fabric.init_tensor():\n            # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n            model.set_kv_cache(batch_size=1)\n        output = generate(\n            model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id\n        )\n        model.clear_kv_cache()\n        model.train()\n        output = tokenizer.decode(output)\n        fabric.print(f\"{output}\\n\")\n    else:\n        print(\n            f\"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) \"\n            f\"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. \"\n            f\"The model's supported context size (post-training) is {model.config.block_size}.\"\n        )\n\n\ndef get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):\n    # linear warmup followed by cosine annealing\n    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))\n    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])\n\n\ndef get_dataloaders(\n    fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs\n) -> Tuple[DataLoader, DataLoader]:\n    data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)\n    with fabric.rank_zero_first():\n        data.prepare_data()\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n    return train_dataloader, val_dataloader\n\n\ndef get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:\n    # find out the minimum max_seq_length required during fine-tuning (saves memory!)\n    lengths = [len(d[\"input_ids\"]) for d in data]\n    longest_seq_length = max(lengths)\n    longest_seq_ix = lengths.index(longest_seq_length)\n    return longest_seq_length, longest_seq_ix\n\n\ndef validate_args(train: TrainArgs, eval: EvalArgs) -> None:\n    issues = []\n    unsupported = [(train, [\"max_tokens\", \"max_norm\", \"tie_embeddings\", \"lr_warmup_fraction\"])]\n    for args, names in unsupported:\n        for name in names:\n            if getattr(args, name) is not None:\n                issues.append(f\"{__file__} doesn't support the {name!r} argument. This is set in {args}\")\n    required = [(train, [\"epochs\"]), (eval, [\"max_new_tokens\"])]\n    for args, names in required:\n        for name in names:\n            if getattr(args, name) is None:\n                issues.append(f\"{__file__} requires the {name!r} argument. This is set in {args}\")\n    if not train.epochs and not train.max_steps:\n        issues.append(f\"{__file__} requires either epochs or max_steps to be set. This is set in {train}\")\n    if issues:\n        raise ValueError(\"\\n\".join(issues))\n"
  },
  {
    "path": "litgpt/finetune/lora.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport dataclasses\nimport math\nimport os\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom lightning.fabric.strategies import ModelParallelStrategy\nfrom lightning.fabric.utilities import ThroughputMonitor\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torchmetrics import RunningMean\n\nfrom litgpt.args import EvalArgs, LogArgs, TrainArgs\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.data import Alpaca, DataModule\nfrom litgpt.generate.base import generate\nfrom litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.prompts import save_prompt_style\nfrom litgpt.scripts.merge_lora import merge_lora\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.types import LoggerChoice\nfrom litgpt.utils import (\n    CycleIterator,\n    auto_download_checkpoint,\n    check_nvlink_connectivity,\n    check_valid_checkpoint_dir,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    create_finetuning_performance_report,\n    get_default_supported_precision,\n    init_out_dir,\n    instantiate_bnb_optimizer,\n    instantiate_torch_optimizer,\n    load_checkpoint,\n    num_parameters,\n    parse_devices,\n    select_sft_generate_example,\n)\n\n\ndef setup(\n    checkpoint_dir: Path,\n    out_dir: Path = Path(\"out/finetune/lora\"),\n    precision: Optional[str] = None,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8-training\"]] = None,\n    devices: Union[int, str] = 1,\n    num_nodes: int = 1,\n    lora_r: int = 8,\n    lora_alpha: int = 16,\n    lora_dropout: float = 0.05,\n    lora_query: bool = True,\n    lora_key: bool = False,\n    lora_value: bool = True,\n    lora_projection: bool = False,\n    lora_mlp: bool = False,\n    lora_head: bool = False,\n    data: Optional[DataModule] = None,\n    train: TrainArgs = TrainArgs(\n        save_interval=1000,\n        log_interval=1,\n        global_batch_size=16,\n        micro_batch_size=1,\n        lr_warmup_steps=100,\n        epochs=5,\n        max_seq_length=None,\n        max_time=None,\n    ),\n    log: LogArgs = LogArgs(),\n    eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),\n    optimizer: Union[str, Dict] = \"AdamW\",\n    logger_name: LoggerChoice = \"csv\",\n    seed: int = 1337,\n    access_token: Optional[str] = None,\n) -> None:\n    \"\"\"Finetune a model using the LoRA method.\n\n    Arguments:\n        checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.\n        out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n            /teamspace/jobs/<job-name>/share.\n        precision: The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\".\n        quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.\n        devices: How many devices/GPUs to use.\n        num_nodes: How many nodes the code is being run on.\n        lora_r: The LoRA rank.\n        lora_alpha: The LoRA alpha.\n        lora_dropout: The LoRA dropout value.\n        lora_query: Whether to apply LoRA to the query weights in attention.\n        lora_key: Whether to apply LoRA to the key weights in attention.\n        lora_value: Whether to apply LoRA to the value weights in attention.\n        lora_projection: Whether to apply LoRA to the output projection in the attention block.\n        lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block.\n        lora_head: Whether to apply LoRA to output head in GPT.\n        data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\n        train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.\n        eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.\n        optimizer: An optimizer name (such as \"AdamW\") or config.\n        logger_name: The name of the logger to send metrics to.\n        seed: The random seed to use for reproducibility.\n        access_token: Optional API token to access models with restrictions.\n    \"\"\"\n\n    checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n    pprint(locals())\n    data = Alpaca() if data is None else data\n    devices = parse_devices(devices)\n    out_dir = init_out_dir(out_dir)\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(\n        checkpoint_dir / \"model_config.yaml\",\n        lora_r=lora_r,\n        lora_alpha=lora_alpha,\n        lora_dropout=lora_dropout,\n        lora_query=lora_query,\n        lora_key=lora_key,\n        lora_value=lora_value,\n        lora_projection=lora_projection,\n        lora_mlp=lora_mlp,\n        lora_head=lora_head,\n    )\n\n    precision = precision or get_default_supported_precision(training=True)\n    logger = choose_logger(\n        logger_name,\n        out_dir,\n        name=f\"finetune-{config.name}\",\n        log_interval=train.log_interval,\n        log_args=dataclasses.asdict(log),\n    )\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    if devices * num_nodes > 1:\n        if quantize:\n            raise NotImplementedError(\n                \"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1\"\n                \" when using the --quantize flag.\"\n            )\n        strategy = ModelParallelStrategy(\n            parallelize_fn=parallelize_fn,\n            data_parallel_size=devices * num_nodes,\n            tensor_parallel_size=1,\n        )\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(\n        devices=devices,\n        num_nodes=num_nodes,\n        strategy=strategy,\n        precision=precision,\n        loggers=logger,\n        plugins=plugins,\n    )\n\n    if torch.cuda.is_available() and devices > 1:\n        check_nvlink_connectivity(fabric)\n\n    fabric.launch(\n        main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes, precision\n    )\n\n\ndef main(\n    fabric: L.Fabric,\n    devices: int,\n    seed: int,\n    config: Config,\n    data: DataModule,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    num_nodes: int = 1,\n    precision: Optional[str] = None,\n) -> None:\n    validate_args(train, eval)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)\n    steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)\n    lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float(\"inf\")))\n\n    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)\n\n    if fabric.global_rank == 0:\n        os.makedirs(out_dir, exist_ok=True)\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    with fabric.init_module(empty_init=(fabric.world_size > 1)):\n        model = GPT(config)\n    mark_only_lora_as_trainable(model)\n\n    fabric.print(f\"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}\")\n    fabric.print(f\"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}\")\n\n    model = fabric.setup_module(model)\n    if isinstance(fabric.strategy.precision, BitsandbytesPrecision):\n        optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())\n\n        from bitsandbytes.nn import StableEmbedding\n\n        old_embedding = model.transformer.wte\n        model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)\n        with torch.no_grad():\n            model.transformer.wte.weight.copy_(old_embedding.weight)\n        model.transformer.wte = model.transformer.wte.to(\n            device=old_embedding.weight.device, dtype=old_embedding.weight.dtype\n        )\n    else:\n        optimizer = instantiate_torch_optimizer(optimizer, model.parameters())\n\n    optimizer = fabric.setup_optimizers(optimizer)\n    scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)\n\n    load_checkpoint(fabric, model, checkpoint_path, strict=False)\n\n    train_time = time.perf_counter()\n    token_counts = fit(\n        fabric=fabric,\n        model=model,\n        optimizer=optimizer,\n        scheduler=scheduler,\n        train_dataloader=train_dataloader,\n        val_dataloader=val_dataloader,\n        devices=devices,\n        num_nodes=num_nodes,\n        checkpoint_dir=checkpoint_dir,\n        out_dir=out_dir,\n        train=train,\n        eval=eval,\n        data=data,\n    )\n\n    training_time = time.perf_counter() - train_time\n    output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)\n    fabric.print(output)\n\n    # Final evaluation\n    if eval.final_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n        fabric.log_dict(metrics)\n        fabric.print(f\"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}\")\n\n    # Save the final LoRA checkpoint at the end of training\n    save_path = out_dir / \"final\" / \"lit_model.pth.lora\"\n    save_path.parent.mkdir(parents=True, exist_ok=True)\n    save_lora_checkpoint(fabric, model, save_path)\n\n    fabric.barrier()\n    if fabric.global_rank == 0:\n        # Copy checkpoint files from original checkpoint dir\n        copy_config_files(checkpoint_dir, save_path.parent)\n        save_hyperparameters(setup, save_path.parent)\n        save_prompt_style(data.prompt_style, save_path.parent)\n        merge_lora(\n            checkpoint_dir=save_path.parent,\n            pretrained_checkpoint_dir=checkpoint_dir,\n            precision=precision,\n        )\n    fabric.barrier()\n\n\ndef fit(\n    fabric: L.Fabric,\n    model: GPT,\n    optimizer: torch.optim.Optimizer,\n    scheduler: torch.optim.lr_scheduler,\n    train_dataloader: DataLoader,\n    val_dataloader: DataLoader,\n    devices: int,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    data: DataModule,\n    num_nodes: int = 1,\n) -> dict:\n    tokenizer = Tokenizer(checkpoint_dir)\n    longest_seq_length, longest_seq_ix = get_longest_seq_length(\n        ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])\n    )\n    model.max_seq_length = min(longest_seq_length, train.max_seq_length or float(\"inf\"))\n    fabric.print(\n        f\"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is\"\n        f\" {model.max_seq_length} and context length is {model.config.block_size}\"\n    )\n\n    if eval.initial_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        val_loss = f\"{val_loss:.3f}\"\n    else:\n        fabric.print(\"Verifying settings ...\")\n        validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False)  # sanity check\n        val_loss = \"n/a\"\n\n    train_iterator = CycleIterator(train_dataloader)\n    throughput = ThroughputMonitor(fabric, window_size=50)\n    running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(\n        fabric.device\n    )\n    max_steps = train.max_steps or float(\"inf\")\n    step_count = 0\n    iter_num = 0\n    total_lengths = 0\n    total_t0 = time.perf_counter()\n\n    max_time = train.max_time or float(\"inf\")\n\n    token_counts = {\n        \"raw_tokens\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template_and_padding\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n    }\n\n    while step_count < max_steps:\n        iter_num += 1\n        iter_t0 = time.perf_counter()\n        batch = next(train_iterator)\n        if train_iterator.epoch >= train.epochs:\n            generate_example(fabric, model, tokenizer, eval, data)\n            fabric.print(f\"Number of epochs {train.epochs} reached, stopping training...\")\n            break\n        if iter_t0 - total_t0 > max_time:\n            generate_example(fabric, model, tokenizer, eval, data)\n            fabric.print(f\"Max time ({max_time / 60.0:.2f}m) reached, stopping training...\")\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n\n        is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids, lm_head_chunk_size=128)\n            # shift the targets such that output n predicts token n+1\n            logits[-1] = logits[-1][..., :-1, :]\n            loss = chunked_cross_entropy(logits, targets[..., 1:])\n            fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))\n\n        running_loss.update(loss.detach())\n\n        if not is_accumulating:\n            optimizer.step()\n            optimizer.zero_grad()\n            scheduler.step()\n            step_count += 1\n\n        token_counts[\"raw_tokens\"] += batch[\"token_counts\"][\"raw\"].sum().item()\n        token_counts[\"raw_tokens_plus_prompt_template\"] += (\n            batch[\"token_counts\"][\"raw_plus_prompt_template\"].sum().item()\n        )\n        token_counts[\"raw_tokens_plus_prompt_template_and_padding\"] += input_ids.numel()\n\n        total_lengths += input_ids.numel()\n        if iter_num % train.log_interval == 0:\n            loss = running_loss.compute().item()  # expensive device-to-host synchronization\n            t1 = time.perf_counter()\n            throughput.update(\n                time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths\n            )\n            throughput.compute_and_log(step=iter_num)\n            metrics = {\n                \"loss\": loss,\n                \"iter\": iter_num,\n                \"step\": step_count,\n                \"epoch\": train_iterator.epoch,\n                \"iter_time\": t1 - iter_t0,\n                \"tokens\": token_counts[\"raw_tokens_plus_prompt_template\"],\n                \"total_tokens\": token_counts[\"raw_tokens_plus_prompt_template\"] * fabric.world_size,\n                \"learning_rate\": scheduler.get_last_lr()[0],\n            }\n            if isinstance(val_loss, torch.Tensor):\n                val_loss = f\"{val_loss:.3f}\"\n            fabric.print(\n                f\"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |\"\n                f\" loss train: {metrics['loss']:.3f},\"\n                f\" val: {val_loss} |\"\n                f\" iter time: {metrics['iter_time'] * 1000:.2f} ms\"\n                f\"{' (step)' if not is_accumulating else ''}\"\n            )\n            fabric.log_dict(metrics, step=iter_num)\n\n        if not is_accumulating and step_count % eval.interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader, eval)\n            generate_example(fabric, model, tokenizer, eval, data)\n            t1 = time.perf_counter() - t0\n\n            val_loss_tensor = val_loss.detach().clone().to(fabric.device)\n            val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)\n\n            fabric.all_reduce(val_loss_tensor, reduce_op=\"mean\")\n            fabric.all_reduce(val_time_tensor, reduce_op=\"mean\")\n\n            fabric.print(\n                f\"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms\"\n            )\n            metrics = {\"val_loss\": val_loss_tensor, \"val_ppl\": math.exp(val_loss_tensor)}\n            fabric.log_dict(metrics, step=iter_num)\n            fabric.barrier()\n\n        if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:\n            checkpoint_file = out_dir / f\"step-{step_count:06d}\" / \"lit_model.pth.lora\"\n            checkpoint_file.parent.mkdir(parents=True, exist_ok=True)\n            save_lora_checkpoint(fabric, model, checkpoint_file)\n            if fabric.global_rank == 0:\n                copy_config_files(checkpoint_dir, checkpoint_file.parent)\n                save_hyperparameters(setup, checkpoint_file.parent)\n                save_prompt_style(data.prompt_style, checkpoint_file.parent)\n\n    total_token_counts = {}\n    for key in token_counts:\n        total = fabric.all_reduce(token_counts[key], reduce_op=\"sum\")\n        total_token_counts[key] = total.item()\n\n    return total_token_counts\n\n\n# FSDP has issues with `inference_mode`\n@torch.no_grad()\ndef validate(\n    fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True\n) -> torch.Tensor:\n    if verbose:\n        fabric.print(\"Validating ...\")\n    model.eval()\n    losses = torch.zeros(min(len(val_dataloader), eval.max_iters))\n    for k, batch in enumerate(val_dataloader):\n        if k >= eval.max_iters:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n        logits = model(input_ids)\n        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)\n\n    val_loss = losses.mean()\n\n    model.train()\n    return val_loss\n\n\n@torch.no_grad()\ndef generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):\n    instruction = select_sft_generate_example(eval, data)\n\n    fabric.print(instruction)\n    prompt = data.prompt_style.apply(instruction)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    model.eval()\n\n    max_returned_tokens = len(encoded) + eval.max_new_tokens\n\n    if max_returned_tokens < model.max_seq_length:\n        with fabric.init_tensor():\n            # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n            model.set_kv_cache(batch_size=1)\n        output = generate(\n            model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id\n        )\n        model.clear_kv_cache()\n        model.train()\n        output = tokenizer.decode(output)\n        fabric.print(f\"{output}\\n\")\n    else:\n        print(\n            f\"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) \"\n            f\"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. \"\n            f\"The model's supported context size (post-training) is {model.config.block_size}.\"\n        )\n\n\ndef get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):\n    # linear warmup followed by cosine annealing\n    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))\n    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])\n\n\ndef get_dataloaders(\n    fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs\n) -> Tuple[DataLoader, DataLoader]:\n    data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)\n    with fabric.rank_zero_first():\n        data.prepare_data()\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n    return train_dataloader, val_dataloader\n\n\ndef get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:\n    # find out the minimum max_seq_length required during fine-tuning (saves memory!)\n    lengths = [len(d[\"input_ids\"]) for d in data]\n    longest_seq_length = max(lengths)\n    longest_seq_ix = lengths.index(longest_seq_length)\n    return longest_seq_length, longest_seq_ix\n\n\ndef parallelize_fn(model, device_mesh, activation_checkpointing=True):\n    from torch.distributed._composable.fsdp.fully_shard import fully_shard\n    from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper, checkpoint_wrapper\n\n    if activation_checkpointing:\n        model.transformer.h = torch.nn.ModuleList(\n            [checkpoint_wrapper(el, preserve_rng_state=False) for el in model.transformer.h]\n        )\n\n    dp_mesh = device_mesh[\"data_parallel\"]\n\n    for m in reversed(list(model.modules())):\n        if (\n            (isinstance(m, torch.nn.Linear) and m.weight.requires_grad)\n            or isinstance(m, CheckpointWrapper)\n            or isinstance(m, Block)\n        ):\n            fully_shard(m, mesh=dp_mesh)\n\n    fully_shard(model, mesh=dp_mesh)\n\n    return model\n\n\ndef save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:\n    cpu_state_dict = {}\n    sharded_sd = model.state_dict()\n    for param_name, param in sharded_sd.items():\n        if \"lora_\" not in param_name:\n            continue\n        if param.is_cpu:\n            param = param.to(fabric.device)\n        if hasattr(param, \"_local_tensor\"):\n            param = param.full_tensor()\n        if fabric.is_global_zero:\n            cpu_state_dict[param_name] = param.cpu()\n        fabric.barrier()\n    if fabric.is_global_zero:\n        torch.save({\"model\": cpu_state_dict}, file_path)\n\n\ndef validate_args(train: TrainArgs, eval: EvalArgs) -> None:\n    issues = []\n    unsupported = [(train, [\"max_tokens\", \"max_norm\", \"tie_embeddings\", \"lr_warmup_fraction\"])]\n    for args, names in unsupported:\n        for name in names:\n            if getattr(args, name) is not None:\n                issues.append(f\"{__file__} doesn't support the {name!r} argument. This is set in {args}\")\n    required = [(train, [\"epochs\"]), (eval, [\"max_new_tokens\"])]\n    for args, names in required:\n        for name in names:\n            if getattr(args, name) is None:\n                issues.append(f\"{__file__} requires the {name!r} argument. This is set in {args}\")\n    if not train.epochs and not train.max_steps:\n        issues.append(f\"{__file__} requires either epochs or max_steps to be set. This is set in {train}\")\n    if issues:\n        raise ValueError(\"\\n\".join(issues))\n"
  },
  {
    "path": "litgpt/finetune/lora_legacy.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport dataclasses\nimport math\nimport os\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Dict, List, Literal, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom lightning.fabric.strategies import FSDPStrategy\nfrom lightning.fabric.utilities import ThroughputMonitor\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torchmetrics import RunningMean\n\nfrom litgpt.args import EvalArgs, LogArgs, TrainArgs\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.data import Alpaca, DataModule\nfrom litgpt.generate.base import generate\nfrom litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.prompts import save_prompt_style\nfrom litgpt.scripts.merge_lora import merge_lora\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.types import LoggerChoice\nfrom litgpt.utils import (\n    CycleIterator,\n    auto_download_checkpoint,\n    check_nvlink_connectivity,\n    check_valid_checkpoint_dir,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    create_finetuning_performance_report,\n    get_default_supported_precision,\n    init_out_dir,\n    instantiate_bnb_optimizer,\n    instantiate_torch_optimizer,\n    load_checkpoint,\n    num_parameters,\n    parse_devices,\n    select_sft_generate_example,\n)\n\n\ndef setup(\n    checkpoint_dir: Path,\n    out_dir: Path = Path(\"out/finetune/lora\"),\n    precision: Optional[str] = None,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8-training\"]] = None,\n    devices: Union[int, str] = 1,\n    num_nodes: int = 1,\n    lora_r: int = 8,\n    lora_alpha: int = 16,\n    lora_dropout: float = 0.05,\n    lora_query: bool = True,\n    lora_key: bool = False,\n    lora_value: bool = True,\n    lora_projection: bool = False,\n    lora_mlp: bool = False,\n    lora_head: bool = False,\n    data: Optional[DataModule] = None,\n    train: TrainArgs = TrainArgs(\n        save_interval=1000,\n        log_interval=1,\n        global_batch_size=16,\n        micro_batch_size=1,\n        lr_warmup_steps=100,\n        epochs=5,\n        max_seq_length=None,\n    ),\n    log: LogArgs = LogArgs(),\n    eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),\n    optimizer: Union[str, Dict] = \"AdamW\",\n    logger_name: LoggerChoice = \"csv\",\n    seed: int = 1337,\n    access_token: Optional[str] = None,\n) -> None:\n    \"\"\"Finetune a model using the LoRA method.\n\n    Arguments:\n        checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.\n        out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n            /teamspace/jobs/<job-name>/share.\n        precision: The precision to use for finetuning. Possible choices: \"bf16-true\", \"bf16-mixed\", \"32-true\".\n        quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.\n        devices: How many devices/GPUs to use.\n        num_nodes: How many nodes the code is being run on.\n        lora_r: The LoRA rank.\n        lora_alpha: The LoRA alpha.\n        lora_dropout: The LoRA dropout value.\n        lora_query: Whether to apply LoRA to the query weights in attention.\n        lora_key: Whether to apply LoRA to the key weights in attention.\n        lora_value: Whether to apply LoRA to the value weights in attention.\n        lora_projection: Whether to apply LoRA to the output projection in the attention block.\n        lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block.\n        lora_head: Whether to apply LoRA to output head in GPT.\n        data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.\n        train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.\n        eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.\n        optimizer: An optimizer name (such as \"AdamW\") or config.\n        logger_name: The name of the logger to send metrics to.\n        seed: The random seed to use for reproducibility.\n        access_token: Optional API token to access models with restrictions.\n    \"\"\"\n\n    checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)\n    pprint(locals())\n    data = Alpaca() if data is None else data\n    devices = parse_devices(devices)\n    out_dir = init_out_dir(out_dir)\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(\n        checkpoint_dir / \"model_config.yaml\",\n        lora_r=lora_r,\n        lora_alpha=lora_alpha,\n        lora_dropout=lora_dropout,\n        lora_query=lora_query,\n        lora_key=lora_key,\n        lora_value=lora_value,\n        lora_projection=lora_projection,\n        lora_mlp=lora_mlp,\n        lora_head=lora_head,\n    )\n\n    precision = precision or get_default_supported_precision(training=True)\n    logger = choose_logger(\n        logger_name,\n        out_dir,\n        name=f\"finetune-{config.name}\",\n        log_interval=train.log_interval,\n        log_args=dataclasses.asdict(log),\n    )\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    if devices * num_nodes > 1:\n        if quantize:\n            raise NotImplementedError(\n                \"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1\"\n                \" when using the --quantize flag.\"\n            )\n\n        strategy = FSDPStrategy(\n            auto_wrap_policy={torch.nn.Linear},\n            activation_checkpointing_policy={Block},\n            state_dict_type=\"full\",\n            limit_all_gathers=True,\n            cpu_offload=False,\n        )\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(\n        devices=devices,\n        num_nodes=num_nodes,\n        strategy=strategy,\n        precision=precision,\n        loggers=logger,\n        plugins=plugins,\n    )\n\n    if torch.cuda.is_available() and devices > 1:\n        check_nvlink_connectivity(fabric)\n\n    fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)\n\n\ndef main(\n    fabric: L.Fabric,\n    devices: int,\n    seed: int,\n    config: Config,\n    data: DataModule,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    num_nodes: int = 1,\n) -> None:\n    validate_args(train, eval)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train)\n    steps_per_epoch = len(train_dataloader) // train.gradient_accumulation_iters(devices, num_nodes)\n    lr_max_steps = min(train.epochs * steps_per_epoch, (train.max_steps or float(\"inf\")))\n\n    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)\n\n    if fabric.global_rank == 0:\n        os.makedirs(out_dir, exist_ok=True)\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    with fabric.init_module(empty_init=(fabric.world_size > 1)):\n        model = GPT(config)\n    mark_only_lora_as_trainable(model)\n\n    fabric.print(f\"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}\")\n    fabric.print(f\"Number of non-trainable parameters: {num_parameters(model, requires_grad=False):,}\")\n\n    model = fabric.setup_module(model)\n    if isinstance(fabric.strategy.precision, BitsandbytesPrecision):\n        optimizer = instantiate_bnb_optimizer(optimizer, model.parameters())\n\n        from bitsandbytes.nn import StableEmbedding\n\n        old_embedding = model.transformer.wte\n        model.transformer.wte = StableEmbedding(old_embedding.num_embeddings, old_embedding.embedding_dim)\n        with torch.no_grad():\n            model.transformer.wte.weight.copy_(old_embedding.weight)\n        model.transformer.wte = model.transformer.wte.to(\n            device=old_embedding.weight.device, dtype=old_embedding.weight.dtype\n        )\n    else:\n        optimizer = instantiate_torch_optimizer(optimizer, model.parameters())\n\n    optimizer = fabric.setup_optimizers(optimizer)\n    scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)\n\n    # strict=False because missing keys due to LoRA weights not contained in state dict\n    load_checkpoint(fabric, model, checkpoint_path, strict=False)\n\n    train_time = time.perf_counter()\n    token_counts = fit(\n        fabric=fabric,\n        model=model,\n        optimizer=optimizer,\n        scheduler=scheduler,\n        train_dataloader=train_dataloader,\n        val_dataloader=val_dataloader,\n        devices=devices,\n        num_nodes=num_nodes,\n        checkpoint_dir=checkpoint_dir,\n        out_dir=out_dir,\n        train=train,\n        eval=eval,\n        data=data,\n    )\n\n    training_time = time.perf_counter() - train_time\n    output = create_finetuning_performance_report(training_time, token_counts, fabric.device.type)\n    fabric.print(output)\n\n    # Final evaluation\n    if eval.final_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n        fabric.log_dict(metrics)\n        fabric.print(f\"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}\")\n\n    # Save the final LoRA checkpoint at the end of training\n    save_path = out_dir / \"final\" / \"lit_model.pth.lora\"\n    save_path.parent.mkdir(parents=True, exist_ok=True)\n    save_lora_checkpoint(fabric, model, save_path)\n    if fabric.global_rank == 0:\n        # Copy checkpoint files from original checkpoint dir\n        copy_config_files(checkpoint_dir, save_path.parent)\n        save_hyperparameters(setup, save_path.parent)\n        save_prompt_style(data.prompt_style, save_path.parent)\n        merge_lora(checkpoint_dir=save_path.parent)\n\n\ndef fit(\n    fabric: L.Fabric,\n    model: GPT,\n    optimizer: torch.optim.Optimizer,\n    scheduler: torch.optim.lr_scheduler,\n    train_dataloader: DataLoader,\n    val_dataloader: DataLoader,\n    devices: int,\n    checkpoint_dir: Path,\n    out_dir: Path,\n    train: TrainArgs,\n    eval: EvalArgs,\n    data: DataModule,\n    num_nodes: int = 1,\n) -> dict:\n    tokenizer = Tokenizer(checkpoint_dir)\n    longest_seq_length, longest_seq_ix = get_longest_seq_length(\n        ConcatDataset([train_dataloader.dataset, val_dataloader.dataset])\n    )\n    model.max_seq_length = min(longest_seq_length, train.max_seq_length or float(\"inf\"))\n    fabric.print(\n        f\"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is\"\n        f\" {model.max_seq_length} and context length is {model.config.block_size}\"\n    )\n\n    if eval.initial_validation:\n        val_loss = validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=len(val_dataloader)))\n        val_loss = f\"{val_loss:.3f}\"\n    else:\n        fabric.print(\"Verifying settings ...\")\n        validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2), verbose=False)  # sanity check\n        val_loss = \"n/a\"\n\n    train_iterator = CycleIterator(train_dataloader)\n    throughput = ThroughputMonitor(fabric, window_size=50)\n    running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(\n        fabric.device\n    )\n    max_steps = train.max_steps or float(\"inf\")\n    step_count = 0\n    iter_num = 0\n    total_lengths = 0\n    total_t0 = time.perf_counter()\n\n    token_counts = {\n        \"raw_tokens\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n        \"raw_tokens_plus_prompt_template_and_padding\": torch.tensor(0, device=fabric.device, dtype=torch.long),\n    }\n\n    while step_count < max_steps:\n        iter_num += 1\n        iter_t0 = time.perf_counter()\n        batch = next(train_iterator)\n        if train_iterator.epoch >= train.epochs:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n\n        is_accumulating = iter_num % train.gradient_accumulation_iters(devices, num_nodes) != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids, lm_head_chunk_size=128)\n            # shift the targets such that output n predicts token n+1\n            logits[-1] = logits[-1][..., :-1, :]\n            loss = chunked_cross_entropy(logits, targets[..., 1:])\n            fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))\n\n        running_loss.update(loss.detach())\n\n        if not is_accumulating:\n            optimizer.step()\n            optimizer.zero_grad()\n            scheduler.step()\n            step_count += 1\n\n        token_counts[\"raw_tokens\"] += batch[\"token_counts\"][\"raw\"].sum().item()\n        token_counts[\"raw_tokens_plus_prompt_template\"] += (\n            batch[\"token_counts\"][\"raw_plus_prompt_template\"].sum().item()\n        )\n        token_counts[\"raw_tokens_plus_prompt_template_and_padding\"] += input_ids.numel()\n\n        total_lengths += input_ids.numel()\n        if iter_num % train.log_interval == 0:\n            loss = running_loss.compute().item()  # expensive device-to-host synchronization\n            t1 = time.perf_counter()\n            throughput.update(\n                time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths\n            )\n            throughput.compute_and_log(step=iter_num)\n            metrics = {\n                \"loss\": loss,\n                \"iter\": iter_num,\n                \"step\": step_count,\n                \"epoch\": train_iterator.epoch,\n                \"iter_time\": t1 - iter_t0,\n                \"tokens\": token_counts[\"raw_tokens_plus_prompt_template\"],\n                \"total_tokens\": token_counts[\"raw_tokens_plus_prompt_template\"] * fabric.world_size,\n                \"learning_rate\": scheduler.get_last_lr()[0],\n            }\n            if isinstance(val_loss, torch.Tensor):\n                val_loss = f\"{val_loss:.3f}\"\n            fabric.print(\n                f\"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |\"\n                f\" loss train: {metrics['loss']:.3f},\"\n                f\" val: {val_loss} |\"\n                f\" iter time: {metrics['iter_time'] * 1000:.2f} ms\"\n                f\"{' (step)' if not is_accumulating else ''}\"\n            )\n            fabric.log_dict(metrics, step=iter_num)\n\n        if not is_accumulating and step_count % eval.interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader, eval)\n            generate_example(fabric, model, tokenizer, eval, data)\n            t1 = time.perf_counter() - t0\n\n            val_loss_tensor = val_loss.detach().clone().to(fabric.device)\n            val_time_tensor = torch.tensor(t1, device=fabric.device, dtype=torch.float32)\n\n            fabric.all_reduce(val_loss_tensor, reduce_op=\"mean\")\n            fabric.all_reduce(val_time_tensor, reduce_op=\"mean\")\n\n            fabric.print(\n                f\"iter {iter_num}: val loss {val_loss_tensor.item():.4f}, val time: {val_time_tensor.item() * 1000:.2f} ms\"\n            )\n            metrics = {\"val_loss\": val_loss_tensor, \"val_ppl\": math.exp(val_loss_tensor)}\n            fabric.log_dict(metrics, step=iter_num)\n            fabric.barrier()\n\n        if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:\n            checkpoint_file = out_dir / f\"step-{step_count:06d}\" / \"lit_model.pth.lora\"\n            checkpoint_file.parent.mkdir(parents=True, exist_ok=True)\n            save_lora_checkpoint(fabric, model, checkpoint_file)\n            if fabric.global_rank == 0:\n                copy_config_files(checkpoint_dir, checkpoint_file.parent)\n                save_hyperparameters(setup, checkpoint_file.parent)\n                save_prompt_style(data.prompt_style, checkpoint_file.parent)\n\n    total_token_counts = {}\n    for key in token_counts:\n        total = fabric.all_reduce(token_counts[key], reduce_op=\"sum\")\n        total_token_counts[key] = total.item()\n\n    return total_token_counts\n\n\n# FSDP has issues with `inference_mode`\n@torch.no_grad()\ndef validate(\n    fabric: L.Fabric, model: GPT, val_dataloader: DataLoader, eval: EvalArgs, verbose: bool = True\n) -> torch.Tensor:\n    if verbose:\n        fabric.print(\"Validating ...\")\n    model.eval()\n    losses = torch.zeros(min(len(val_dataloader), eval.max_iters))\n    for k, batch in enumerate(val_dataloader):\n        if k >= eval.max_iters:\n            break\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n        logits = model(input_ids)\n        losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)\n\n    val_loss = losses.mean()\n\n    model.train()\n    return val_loss\n\n\n@torch.no_grad()\ndef generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: EvalArgs, data: DataModule):\n    instruction = select_sft_generate_example(eval, data)\n\n    fabric.print(instruction)\n    prompt = data.prompt_style.apply(instruction)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    model.eval()\n\n    max_returned_tokens = len(encoded) + eval.max_new_tokens\n\n    if max_returned_tokens < model.max_seq_length:\n        with fabric.init_tensor():\n            # do not set `max_seq_length=max_returned_token` because memory is not a concern here\n            model.set_kv_cache(batch_size=1)\n        output = generate(\n            model, encoded, max_returned_tokens=max_returned_tokens, temperature=0.8, eos_id=tokenizer.eos_id\n        )\n        model.clear_kv_cache()\n        model.train()\n        output = tokenizer.decode(output)\n        fabric.print(f\"{output}\\n\")\n    else:\n        print(\n            f\"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) \"\n            f\"exceeds model.max_seq_length ({model.max_seq_length}) used for training. Skipping example generation for efficiency. \"\n            f\"The model's supported context size (post-training) is {model.config.block_size}.\"\n        )\n\n\ndef get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int):\n    # linear warmup followed by cosine annealing\n    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(max_steps - warmup_steps))\n    return torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[warmup_steps])\n\n\ndef get_dataloaders(\n    fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs\n) -> Tuple[DataLoader, DataLoader]:\n    data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=train.max_seq_length)\n    with fabric.rank_zero_first():\n        data.prepare_data()\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n    return train_dataloader, val_dataloader\n\n\ndef get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:\n    # find out the minimum max_seq_length required during fine-tuning (saves memory!)\n    lengths = [len(d[\"input_ids\"]) for d in data]\n    longest_seq_length = max(lengths)\n    longest_seq_ix = lengths.index(longest_seq_length)\n    return longest_seq_length, longest_seq_ix\n\n\ndef save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:\n    fabric.print(f\"Saving LoRA weights to {str(file_path)!r}\")\n    fabric.save(file_path, {\"model\": model}, filter={\"model\": lora_filter})\n\n\ndef validate_args(train: TrainArgs, eval: EvalArgs) -> None:\n    issues = []\n    unsupported = [(train, [\"max_tokens\", \"max_norm\", \"tie_embeddings\", \"lr_warmup_fraction\"])]\n    for args, names in unsupported:\n        for name in names:\n            if getattr(args, name) is not None:\n                issues.append(f\"{__file__} doesn't support the {name!r} argument. This is set in {args}\")\n    required = [(train, [\"epochs\"]), (eval, [\"max_new_tokens\"])]\n    for args, names in required:\n        for name in names:\n            if getattr(args, name) is None:\n                issues.append(f\"{__file__} requires the {name!r} argument. This is set in {args}\")\n    if not train.epochs and not train.max_steps:\n        issues.append(f\"{__file__} requires either epochs or max_steps to be set. This is set in {train}\")\n    if issues:\n        raise ValueError(\"\\n\".join(issues))\n"
  },
  {
    "path": "litgpt/generate/__init__.py",
    "content": ""
  },
  {
    "path": "litgpt/generate/adapter.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Literal, Optional\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\n\nfrom litgpt import PromptStyle, Tokenizer\nfrom litgpt.adapter import GPT, Config\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.generate.base import generate\nfrom litgpt.prompts import has_prompt_style, load_prompt_style\nfrom litgpt.utils import (\n    check_file_size_on_cpu_and_warn,\n    check_valid_checkpoint_dir,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n    lazy_load,\n)\n\n\ndef main(\n    checkpoint_dir: Path,\n    prompt: str = \"What food do llamas eat?\",\n    input: str = \"\",\n    sys_prompt: Optional[str] = None,\n    adapter_path: Path = Path(\"out/finetune/adapter/final/lit_model.pth.adapter\"),\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n    max_new_tokens: int = 100,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    precision: Optional[str] = None,\n) -> None:\n    \"\"\"For models finetuned with `litgpt finetune_adapter`.\n\n    Generates a response based on a given instruction and an optional input. This script will only work with\n    checkpoints from the instruction-tuned adapter model. See ``litgpt.finetune.adapter``.\n\n    Args:\n        checkpoint_dir: The path to the checkpoint folder with pretrained model weights.\n        prompt: The prompt/instruction (Alpaca style).\n        input: Optional input (Alpaca style).\n        sys_prompt: Optional system prompt.\n        adapter_path: Path to the checkpoint with trained adapter weights, which are the output of\n            ``litgpt.finetune.adapter``.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        precision: Indicates the Fabric precision setting to use.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    precision = precision or get_default_supported_precision(training=False)\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)\n    fabric.launch()\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    prompt_style = (\n        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n    )\n\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n    with fabric.init_tensor():\n        # set the max_seq_length to limit the memory usage to what we need\n        model.max_seq_length = max_returned_tokens\n        # enable the kv cache\n        model.set_kv_cache(batch_size=1)\n    model.eval()\n\n    t0 = time.perf_counter()\n    checkpoint = lazy_load(checkpoint_path)\n    adapter_checkpoint = lazy_load(adapter_path)\n    checkpoint.update(adapter_checkpoint.get(\"model\", adapter_checkpoint))\n    model.load_state_dict(checkpoint)\n    fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    model = fabric.setup(model)\n\n    L.seed_everything(1234)\n    t0 = time.perf_counter()\n    y = generate(\n        model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id\n    )\n    t = time.perf_counter() - t0\n\n    output = tokenizer.decode(y)\n    output = output.split(\"### Response:\")[1].strip()\n    fabric.print(output)\n\n    tokens_generated = y.size(0) - prompt_length\n    fabric.print(f\"\\n\\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr)\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/generate/adapter_v2.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Literal, Optional\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\n\nfrom litgpt import PromptStyle, Tokenizer\nfrom litgpt.adapter_v2 import GPT, Config\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.generate.base import generate\nfrom litgpt.prompts import has_prompt_style, load_prompt_style\nfrom litgpt.utils import (\n    check_file_size_on_cpu_and_warn,\n    check_valid_checkpoint_dir,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n    lazy_load,\n)\n\n\ndef main(\n    checkpoint_dir: Path,\n    prompt: str = \"What food do llamas eat?\",\n    input: str = \"\",\n    sys_prompt: Optional[str] = None,\n    adapter_path: Path = Path(\"out/finetune/adapter-v2/final/lit_model.pth.adapter_v2\"),\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n    max_new_tokens: int = 100,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    precision: Optional[str] = None,\n) -> None:\n    \"\"\"For models finetuned with `litgpt finetune adapter_v2`.\n\n    Generates a response based on a given instruction and an optional input. This script will only work with\n    checkpoints from the instruction-tuned adapter v2 model. See ``litgpt.finetune.adapter_v2``.\n\n    Args:\n        checkpoint_dir: The path to the checkpoint folder with pretrained model weights.\n        prompt: The prompt/instruction (Alpaca style).\n        input: Optional input (Alpaca style).\n        sys_prompt: Optional system prompt.\n        adapter_path: Path to the checkpoint with trained adapter weights, which are the output of\n            ``litgpt.finetune.adapter_v2``.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        precision: Indicates the Fabric precision setting to use.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    precision = precision or get_default_supported_precision(training=False)\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)\n    fabric.launch()\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    prompt_style = (\n        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n    )\n\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n    with fabric.init_tensor():\n        # set the max_seq_length to limit the memory usage to what we need\n        model.max_seq_length = max_returned_tokens\n        # enable the kv cache\n        model.set_kv_cache(batch_size=1)\n    model.eval()\n\n    t0 = time.perf_counter()\n    checkpoint = lazy_load(checkpoint_path)\n    adapter_checkpoint = lazy_load(adapter_path)\n    checkpoint.update(adapter_checkpoint.get(\"model\", adapter_checkpoint))\n    model.load_state_dict(checkpoint)\n    fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    model = fabric.setup(model)\n\n    L.seed_everything(1234)\n    t0 = time.perf_counter()\n    y = generate(\n        model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id\n    )\n    t = time.perf_counter() - t0\n\n    output = tokenizer.decode(y)\n    output = output.split(\"### Response:\")[1].strip()\n    fabric.print(output)\n\n    tokens_generated = y.size(0) - prompt_length\n    fabric.print(f\"\\n\\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr)\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/generate/base.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nimport torch._dynamo.config\nimport torch._inductor.config\nfrom lightning.fabric.plugins import BitsandbytesPrecision\n\nfrom litgpt.config import Config\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.model import GPT\nfrom litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import (\n    check_file_size_on_cpu_and_warn,\n    check_valid_checkpoint_dir,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n    load_checkpoint,\n)\n\n\ndef multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:\n    if torch._dynamo.is_compiling():\n        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly\n        distribution = torch.empty_like(probs).exponential_(1)\n        return torch.argmax(probs / distribution, dim=-1, keepdim=True)\n    return torch.multinomial(probs, num_samples=1)\n\n\ndef sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:\n    sorted_logits, sorted_indices = torch.sort(logits, descending=False)\n    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n    # Example:\n    # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]\n    # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7\n    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n    # Keep at least 1 token always to prevent the case where no token is selected\n    # In this case the most probable one is always kept\n    sorted_indices_to_remove[-1:] = 0\n    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)\n    logits = logits.masked_fill(indices_to_remove, float(\"-inf\"))\n    return logits\n\n\ndef sample(\n    logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0\n) -> torch.Tensor:\n    if top_p < 0.0 or top_p > 1.0:\n        raise ValueError(f\"top_p must be in [0, 1], got {top_p}\")\n    logits = logits[0, -1]\n    # optionally crop the logits to only the top k options\n    if top_k is not None:\n        v, i = torch.topk(logits, min(top_k, logits.size(-1)))\n        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions\n        logits = torch.full_like(logits, float(\"-inf\")).scatter_(-1, i, v)\n    # optionally scale the logits and sample from a probability distribution\n    if temperature > 0.0 or top_p > 0.0:\n        if temperature > 0.0:\n            logits = logits / temperature\n        # optionally crop the logits to smallest set of logits with a cumulative probability above top_p\n        if top_p < 1.0:\n            logits = sample_top_p(logits, top_p)\n        probs = torch.nn.functional.softmax(logits, dim=-1)\n        return multinomial_num_samples_1(probs)\n    return torch.argmax(logits, dim=-1, keepdim=True)\n\n\ndef next_token(\n    model: GPT,\n    input_pos: torch.Tensor,\n    x: torch.Tensor,\n    input_pos_maxp1: Optional[int] = None,\n    **sample_kwargs: Dict[str, Any],\n) -> torch.Tensor:\n    logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1)\n    _next = sample(logits, **sample_kwargs).to(dtype=torch.int64)\n    return _next\n\n\ndef batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor:\n    assert len(logits) == len(kwargs), \"logits and kwargs must have the same length.\"\n    return torch.stack(\n        [sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0\n    )\n\n\ndef batched_next_token(\n    model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]]\n) -> torch.Tensor:\n    # Where:\n    # input_pos is a 1d tensor of shape [seq_length...]\n    # x is context tokens to add to the kvcache.\n    # For prefill, x is a 2d tensor of shape [batch_size, prompt_length].\n    # For subsequent tokens, x is a 2d tensor of shape [batch_size, 1].\n    # kwargs is a list of dictionaries, each containing the keyword arguments for the sample function.\n    # If one dictionary is passed, it's repeated for each sample in the batch.\n\n    # In the future, we would like input_pos to be a 2d tensor of shape [batch_size, seq_length].\n    # That way, we can support prompts of different sizes.\n    # This means making the rope cache and kvcache forward() work with batches. Currently, they do not.\n    # This is relatively complicated, given the current implementation. It will require some rewriting.\n    # Relevant thread: https://discuss.pytorch.org/t/batched-index-select/9115\n    # We will also need the same with tensor.index_copy_(). These do not work for batches, and the replacement\n    # is somewhat nontrivial. Until then, we can only accept prompts that are all the same length.\n    # After this problem is resolved, there will be another problem. That being, continuous batched prefill.\n    # If you have any ideas on this, let me know. I don't think that padding input_pos is viable.\n\n    _kwargs = kwargs if isinstance(kwargs, list) else [kwargs] * x.size(0)\n\n    # Run the model on the batch.\n    logits_stack = model(x, input_pos)\n\n    # Unbind the logits stack into a list of logits.\n    logits_list = [logits_stack] if logits_stack.ndim == 1 else logits_stack.unbind(0)\n    logits_list = [l.unsqueeze(0) for l in logits_list]\n\n    # Return the next token for each sample in the batch.\n    return batched_sample(logits_list, kwargs=_kwargs)\n\n\n@torch.inference_mode()\ndef generate_fn(\n    model: GPT,\n    prompt: torch.Tensor,\n    max_returned_tokens: int,\n    *,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: float = 1.0,\n    stop_tokens: Tuple[List[int], ...] = (),\n    include_prompt: bool,\n    include_eos: bool,\n) -> Iterator[torch.Tensor]:\n    \"\"\"\n    Generates tokens for a single prompt.\n\n    Args:\n        model: The model to use.\n        prompt: The tokenized prompt to generate from.\n        max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens.\n        temperature: The temp to pass to sample().\n        top_k: The top_k to pass to sample().\n        top_p: The top_p to pass to sample().\n        stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens.\n        include_prompt: Whether to output the prompt tokens.\n        include_eos: Whether to output the stop tokens if generation stops early.\n    \"\"\"\n\n    prompt_size = prompt.size(0)\n    device = prompt.device\n\n    assert max_returned_tokens > prompt_size, (\n        f\"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}.\"\n    )\n    if model.max_seq_length < max_returned_tokens - 1:\n        raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n\n    # Yield the prompt if include_prompt is True\n    if include_prompt:\n        yield prompt\n\n    stop_progress = [0] * len(stop_tokens)\n    yielded_idx = 0\n\n    # Generate output tokens.\n    # The first token generated is the prefill token.\n    # The input_pos for this token is the width of the entire prompt.\n    # For subsequent iterations, it's the index in the context for the token that we're generating.\n    tokens = []\n    token = prompt\n    prefill_token = True\n    input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)\n    # input_pos_maxp1 introduces data-dependent shapes and control flow.\n    # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.\n    input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != \"ThunderModule\" for m in model.modules()) else None\n    for current_idx in range(max_returned_tokens - prompt_size):\n        # Generate the token\n        token = next_token(\n            model,\n            input_pos,\n            token.view(1, -1),\n            input_pos_maxp1=input_pos_maxp1,\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n        )\n        tokens.append(token)\n        int_token = token.item()\n\n        # Check for stop sequences\n        # For each stop sequence, we keep a running total of how many are matched in stop_progress.\n        # If the current token matches the next token in the stop sequence, we increment the\n        # running total and hold off on yielding the token.\n        for i, seq in enumerate(stop_tokens):\n            if int_token == seq[stop_progress[i]]:\n                stop_progress[i] += 1\n                if stop_progress[i] == len(seq):\n                    if include_eos:\n                        yield from tokens[yielded_idx:]\n                    return\n            else:\n                stop_progress[i] = 0\n\n        # Yield tokens that are not part of a stop sequence in progress.\n        # If there are no stop sequences, then that's all of them.\n        if stop_tokens:\n            safe_idx = len(tokens) - max(stop_progress)\n        else:\n            safe_idx = current_idx + 1  # include the token just generated\n\n        if yielded_idx < safe_idx:\n            y_tokens = tokens[yielded_idx:safe_idx]\n            yield from y_tokens\n            yielded_idx = safe_idx\n\n        # Update input_pos for the next iteration.\n        if prefill_token:\n            prefill_token = False\n            input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)\n        else:\n            input_pos.add_(1)\n        if input_pos_maxp1 is not None:\n            input_pos_maxp1 += 1\n\n    # Yield any remaining tokens\n    if yielded_idx < len(tokens):\n        yield from tokens[yielded_idx:]\n\n\n# TODO: Make include_eos work.\n# TODO: Rewrite unbatched generate_fn to use batched_generate_fn.\n@torch.inference_mode()\ndef batched_generate_fn(\n    model: GPT,\n    prompts: torch.Tensor,\n    max_returned_tokens: int,\n    *,\n    sample_args: Union[list[dict], dict],\n    stop_tokens: Tuple[List[int], ...] = (),\n    include_prompt: bool,\n    include_eos: bool,\n) -> Iterator[list[Union[torch.Tensor, None]]]:\n    \"\"\"\n    Generates tokens for a batch of prompts.\n\n    Args:\n        model: The model to use.\n        prompts: A 2D tensor of shape [batch_size, prompt_length].\n        max_returned_tokens: The maximum number of tokens to return, including the prompt tokens.\n        sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch.\n        stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens.\n        include_prompt: Whether to output the prompt tokens.\n        include_eos: Whether to output the stop tokens if generation stops early.\n\n    Yields:\n        A list of tokens for each prompt in the batch, or None if a stop sequence has already been encountered for that index in the batch.\n    \"\"\"\n\n    if prompts.ndim == 1:\n        prompts = prompts.unsqueeze(0)\n    assert prompts.ndim == 2, \"Prompts must be a 2D tensor.\"\n\n    batch_size = prompts.size(0)\n    max_prompt_size = prompts.size(1)\n    device = prompts.device\n\n    if isinstance(sample_args, dict):\n        sample_args = [sample_args] * len(prompts)\n    else:\n        assert len(sample_args) == batch_size, \"sample_args must have the length as the batch size.\"\n\n    # TODO: This check (and the one in generate_fn) is not sufficient. We do the proper checks in LLM.generate().\n    assert max_returned_tokens > max_prompt_size, (\n        f\"Not enough space for {max_prompt_size} prompt tokens in a context length of {max_returned_tokens}.\"\n    )\n    if model.max_seq_length < max_returned_tokens - 1:\n        raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n\n    # Yield the prompts if include_prompt is True\n    if include_prompt:\n        # TODO: Prompt length is padded, but they shouldn't all be the same length.\n        for i in range(max_prompt_size):\n            yield [prompt[i].view(-1) for prompt in prompts]\n\n    stop_progresses = [[0] * len(stop_tokens) for _ in range(batch_size)]  # [batch_size, ~len(stop_tokens)]\n    stop_idxes = [-1] * batch_size\n    yielded_idx = 0\n\n    # Generate output tokens.\n    # The first token generated is the prefill token.\n    # The input_pos for this token is the width of the entire prompt.\n    # For subsequent iterations, it's the index in the context for the token that we're generating.\n    token_lists = [[] for _ in range(batch_size)]\n    tokens: torch.Tensor = prompts\n    prefill_token = True\n    input_pos = torch.arange(0, max_prompt_size, device=device, dtype=torch.int64)\n    for current_idx in range(max_returned_tokens - max_prompt_size):\n        # Generate the next token for each prompt in the batch.\n        # This is of shape [batch_size, 1].\n        tokens = batched_next_token(model, input_pos, tokens, sample_args)\n        for i in range(batch_size):\n            token_lists[i].append(tokens[i])\n        int_tokens = [token.item() for token in tokens]\n\n        # Check for stop sequences\n        # For each stop sequence, we keep a running total of how many are matched in stop_progress.\n        # If the current token matches the next token in the stop sequence, we increment the\n        # running total and hold off on yielding the token.\n        for batch_idx, int_token in enumerate(int_tokens):\n            if stop_idxes[batch_idx] != -1:\n                continue\n            for seq_idx, seq in enumerate(stop_tokens):\n                seq_pos = stop_progresses[batch_idx][seq_idx]\n                if seq_pos >= len(seq):\n                    continue\n                if int_token == seq[seq_pos]:\n                    stop_progresses[batch_idx][seq_idx] += 1\n                    if stop_progresses[batch_idx][seq_idx] == len(seq):\n                        stop_idxes[batch_idx] = current_idx\n                else:\n                    stop_progresses[batch_idx][seq_idx] = 0\n\n        # Yield tokens that are not part of a stop sequence in progress.\n        # If there are no stop sequences, then that's all of them.\n        if len(stop_tokens) != 0:\n            safe_idxes = [len(token_lists[i]) - max(stop_progresses[i]) for i in range(batch_size)]\n        else:\n            safe_idxes = [current_idx + 1]  # include the token just generated\n        safe_idx = min(safe_idxes)\n\n        if yielded_idx < safe_idx:\n            for idx in range(yielded_idx, safe_idx):\n                y_tokens = [\n                    token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None\n                    for i in range(batch_size)\n                ]\n                if all(y is None for y in y_tokens):\n                    return\n                yield y_tokens\n            yielded_idx = safe_idx\n\n        # Update input_pos for the next iteration.\n        if prefill_token:\n            prefill_token = False\n\n            # TODO: Make the model support a batched input_pos of shape [batch_size, 1].\n            # The kvcache has been fixed, but the rope cache is still broken.\n            input_pos = torch.tensor([max_prompt_size], device=device, dtype=torch.int64)\n        else:\n            input_pos.add_(1)\n\n    # Yield any remaining tokens\n    max_token_lists = max(len(l) for l in token_lists)\n    if yielded_idx < max_token_lists:\n        for idx in range(yielded_idx, max_token_lists):\n            y_tokens = [\n                token_lists[i][idx] if (stop_idxes[i] == -1 or idx < stop_idxes[i]) else None for i in range(batch_size)\n            ]\n            if all(y is None for y in y_tokens):\n                return\n            yield y_tokens\n    return\n\n\n@torch.inference_mode()\ndef generate(\n    model: GPT,\n    prompt: torch.Tensor,\n    max_returned_tokens: int,\n    *,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: float = 1.0,\n    eos_id: Optional[int] = None,\n    include_prompt: bool = True,\n) -> torch.Tensor:\n    \"\"\"\n    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n    The implementation of this function is modified from A. Karpathy's nanoGPT.\n\n    Args:\n        model: The model to use.\n        prompt: Tensor of shape (T) with indices of the prompt sequence.\n        max_returned_tokens: The maximum number of tokens to return (given plus generated).\n        temperature: Scales the predicted logits by 1 / temperature.\n        top_k: If specified, only sample among the tokens with the k highest probabilities.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        eos_id: If specified, stop generating any more token once the <eos> token is triggered.\n        include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.\n    \"\"\"\n\n    token_list = list(\n        generate_fn(\n            include_prompt=include_prompt,\n            include_eos=True,\n            model=model,\n            prompt=prompt,\n            max_returned_tokens=max_returned_tokens,\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            stop_tokens=(([eos_id],) if eos_id is not None else ()),\n        )\n    )\n\n    return torch.cat(token_list) if not len(token_list) == 0 else torch.Tensor()\n\n\n@torch.inference_mode()\ndef main(\n    checkpoint_dir: Path,\n    prompt: str = \"What food do llamas eat?\",\n    *,\n    sys_prompt: Optional[str] = None,\n    num_samples: int = 1,\n    max_new_tokens: int = 50,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n    precision: Optional[str] = None,\n    compile: bool = False,\n) -> None:\n    \"\"\"Default generation option.\n\n    Generates text samples based on a pre-trained model and tokenizer.\n\n    Args:\n        checkpoint_dir: The checkpoint directory to load.\n        prompt: The prompt string to use for generating the samples.\n        sys_prompt: The system prompt to use for generating the samples.\n        num_samples: The number of text samples to generate.\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        precision: Indicates the Fabric precision setting to use.\n        compile: Whether to compile the model.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    precision = precision or get_default_supported_precision(training=False)\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    prompt_style = (\n        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n    )\n\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n    with fabric.init_tensor():\n        # set the max_seq_length to limit the memory usage to what we need\n        model.max_seq_length = max_returned_tokens\n        # enable the kv cache\n        model.set_kv_cache(batch_size=1)\n    model.eval()\n\n    if compile:\n        torch._dynamo.config.automatic_dynamic_shapes = True\n        torch._inductor.config.triton.unique_kernel_names = True\n        torch._inductor.config.coordinate_descent_tuning = True\n        global next_token\n        next_token = torch.compile(next_token, mode=\"reduce-overhead\")\n\n    model = fabric.setup_module(model)\n\n    t0 = time.perf_counter()\n    load_checkpoint(fabric, model, checkpoint_path)\n    fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    L.seed_everything(1234)\n    for i in range(num_samples):\n        t0 = time.perf_counter()\n        y = generate(\n            model,\n            encoded,\n            max_returned_tokens,\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            eos_id=tokenizer.eos_id,\n        )\n        t = time.perf_counter() - t0\n        for block in model.transformer.h:\n            block.attn.kv_cache.reset_parameters()\n        fabric.print(tokenizer.decode(y))\n        tokens_generated = y.size(0) - prompt_length\n        fabric.print(\n            f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n        )\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/generate/full.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Literal, Optional\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.plugins import BitsandbytesPrecision\n\nfrom litgpt import GPT, Config, PromptStyle, Tokenizer\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.generate.base import generate\nfrom litgpt.prompts import has_prompt_style, load_prompt_style\nfrom litgpt.utils import (\n    check_file_size_on_cpu_and_warn,\n    check_valid_checkpoint_dir,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n    load_checkpoint,\n)\n\n\ndef main(\n    checkpoint_dir: Path,\n    prompt: str = \"What food do llamas eat?\",\n    input: str = \"\",\n    sys_prompt: Optional[str] = None,\n    finetuned_path: Path = Path(\"out/full/alpaca/lit_model_finetuned.pth\"),\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n    max_new_tokens: int = 100,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    precision: Optional[str] = None,\n) -> None:\n    \"\"\"For models finetuned with `litgpt finetune_full`.\n\n    Generates a response based on a given instruction and an optional input. This script will only work with\n    checkpoints from the instruction-tuned model. See ``litgpt.finetune.full``.\n\n    Args:\n        checkpoint_dir: The path to the checkpoint folder with pretrained model weights.\n        prompt: The prompt/instruction (Alpaca style).\n        input: Optional input (Alpaca style).\n        sys_prompt: Optional system prompt.\n        finetuned_path: Path to the checkpoint with trained weights, which are the output of\n            ``litgpt.finetune.full``.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        precision: Indicates the Fabric precision setting to use.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    precision = precision or get_default_supported_precision(training=False)\n\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)\n    fabric.launch()\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    checkpoint_path = finetuned_path\n    check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)\n    tokenizer = Tokenizer(checkpoint_dir)\n    prompt_style = (\n        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n    )\n\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt, input=input)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n    with fabric.init_tensor():\n        # set the max_seq_length to limit the memory usage to what we need\n        model.max_seq_length = max_returned_tokens\n        # enable the kv cache\n        model.set_kv_cache(batch_size=1)\n    model.eval()\n\n    model = fabric.setup(model)\n\n    t0 = time.perf_counter()\n    load_checkpoint(fabric, model, checkpoint_path)\n    fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    L.seed_everything(1234)\n    t0 = time.perf_counter()\n    y = generate(\n        model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id\n    )\n    t = time.perf_counter() - t0\n\n    output = tokenizer.decode(y)\n    output = output.split(\"### Response:\")[1].strip()\n    fabric.print(output)\n\n    tokens_generated = y.size(0) - prompt_length\n    fabric.print(f\"\\n\\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr)\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/generate/sequentially.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport itertools\nimport logging\nimport re\nimport sys\nimport time\nimport warnings\nfrom collections import OrderedDict\nfrom functools import partial\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import List, Literal, Optional, Type\n\nimport lightning as L\nimport torch\nfrom lightning.fabric.accelerators import CUDAAccelerator\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom lightning.fabric.utilities.init import _materialize_meta_tensors\nfrom tqdm import tqdm\n\nimport litgpt.generate.base as generate_base\nfrom litgpt.config import Config\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.model import GPT, Block, build_mask_cache\nfrom litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import (\n    check_valid_checkpoint_dir,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n)\n\n\n@torch.inference_mode()\ndef sequential(model: GPT, root: torch.device, max_seq_length: int, devices: int):\n    if model.config.n_layer < devices:\n        raise ValueError(\n            f\"The number of layers in the model must be larger than the number of devices, but got\"\n            f\" n_layer={model.config.n_layer} and devices={devices}.\"\n        )\n\n    # Dictates where each block should be instantiated\n    mapping = layer_to_device(\n        model,\n        chunk_on=Block,\n        chunk_sizes=chunk_sizes(model.config.n_layer, devices),\n    )\n    num_layers_per_device = {i: sum(1 for v in mapping.values() if v == i) for i in range(devices)}\n\n    # materialize each block on the appropriate device\n    with tqdm(total=len(mapping), desc=\"Moving submodules\") as pbar:\n        for path, target_index in mapping.items():\n            submodule = model.get_submodule(path)\n            target_device = torch.device(root.type, target_index)\n\n            pbar.set_description(f\"Moving {path!r} to {target_device}\")\n            pbar.update(1)\n\n            # submodules loaded by the checkpoint will be on CPU (if no quantization). move them\n            replace_device(submodule, replace=torch.device(\"cpu\"), by=target_device)\n            # in case the checkpoint was partial, materialize leftover metas\n            _materialize_meta_tensors(submodule, target_device)\n            # and build the kv cache\n            submodule.attn.kv_cache = submodule.attn.build_kv_cache(\n                1, max_seq_length, model.rope_cache_length(), target_device\n            )\n    # rebuild odd ends\n    with root:\n        model.max_seq_length = max_seq_length\n        # the rope cache which is on meta device\n        model.cos, model.sin = model.rope_cache()\n        # the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers\n        model.mask_cache = build_mask_cache(max_seq_length)\n    # and everything that is not a block in the root\n    _materialize_meta_tensors(model, root)\n    replace_device(model, replace=torch.device(\"cpu\"), by=root)\n\n    if devices > 1:\n        # install hooks to move layer inputs/output between devices\n        for layer_num, (path, target_index) in enumerate(mapping.items()):\n            submodule = model.get_submodule(path)\n            if layer_num >= num_layers_per_device[target_index]:\n                # we need to move the block input on the boundaries between devices\n                # and also on every non-root device because the RoPE and mask cache is shared\n                # TODO: the second case could be optimized and then we would only need this hook for\n                # `layer_num in [layers_per_rank * i - 1 for i in range(1, devices + 1)]`\n                target_device = torch.device(root.type, target_index)\n                submodule.register_forward_pre_hook(partial(move_block_input, target_device))\n            if layer_num == model.config.n_layer - 1:\n                submodule.register_forward_hook(partial(move_block_output, root))\n\n    return model\n\n\ndef chunk_sizes(num_units: int, devices: int) -> List[int]:\n    cs = num_units // devices\n    k = devices * (cs + 1) - num_units\n    return [cs] * k + [cs + 1] * (devices - k)\n\n\ndef layer_to_device(\n    module: torch.nn.Module,\n    chunk_on: Type[torch.nn.Module],\n    chunk_sizes: List[int],\n) -> \"OrderedDict[str, int]\":\n    \"\"\"Create a mapping from layer (block) to device.\"\"\"\n    # this assumes that the definition order is the same as the execution order\n    hits = [name for name, submodule in module.named_modules() if isinstance(submodule, chunk_on)]\n    if sum(chunk_sizes) != len(hits):\n        raise ValueError(f\"Found {len(hits)} for chunk_on={chunk_on}, not covered by chunk_sizes={chunk_sizes}\")\n    _devices = [[d] * cs for d, cs in enumerate(chunk_sizes)]\n    devices = [d for lst in _devices for d in lst]\n    return OrderedDict(zip(hits, devices))\n\n\ndef move_block_input(device: torch.device, module: torch.nn.Module, ins):\n    \"\"\"``forward_pre_hook`` to move a Block's input before forward.\"\"\"\n    # during inference, none of the inputs are None: x, cos, sin, mask, input_pos\n    return tuple(t.to(device) if torch.is_tensor(t) else t for t in ins)\n\n\ndef move_block_output(device: torch.device, module: torch.nn.Module, ins, outs) -> torch.Tensor:\n    \"\"\"``forward_hook`` to move a Block's output after forward.\"\"\"\n    return outs.to(device)\n\n\ndef replace_device(module: torch.nn.Module, replace: torch.device, by: torch.device) -> torch.nn.Module:\n    for name, submodule in module.named_modules():\n        tensors = dict(\n            itertools.chain(submodule.named_parameters(recurse=False), submodule.named_buffers(recurse=False))\n        )\n        if not tensors:\n            continue\n        devices = {t.device for t in tensors.values()}\n        if len(devices) != 1:\n            # since this is using `submodule.to`, different devices in the same submodule is a problem\n            path_to_device = {f\"{name}.{p}\": t.device for p, t in tensors.items()}\n            raise ValueError(f\"Found multiple devices: {path_to_device}\")\n        if devices.pop() == replace:\n            submodule.to(by)\n    return module\n\n\n@torch.inference_mode()\ndef main(\n    checkpoint_dir: Path,\n    prompt: str = \"What food do llamas eat?\",\n    *,\n    sys_prompt: Optional[str] = None,\n    num_samples: int = 1,\n    max_new_tokens: int = 50,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\"]] = None,\n    precision: Optional[str] = None,\n    compile: bool = False,\n) -> None:\n    \"\"\"Generation script that partitions layers across devices to be run sequentially.\n\n    Generates text samples based on a pre-trained model and tokenizer.\n\n    Args:\n        checkpoint_dir: The checkpoint directory to load.\n        prompt: The prompt string to use for generating the samples.\n        sys_prompt: The system prompt to use for generating the samples.\n        num_samples: The number of text samples to generate.\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        precision: Indicates the Fabric precision setting to use.\n        compile: Whether to compile the model.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    precision = precision or get_default_supported_precision(training=False)\n\n    plugins = None\n    if quantize is not None:\n        if compile:\n            raise NotImplementedError  # untested\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        logging.getLogger(\"lightning.fabric.plugins.precision.bitsandbytes\").setLevel(logging.DEBUG)\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    fabric = L.Fabric(devices=1, precision=precision, accelerator=\"cuda\", plugins=plugins)\n\n    total_devices = CUDAAccelerator.auto_device_count()\n    print(f\"Using {total_devices} devices\", file=sys.stderr)\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    prompt_style = (\n        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n    )\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced\n    # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert\n    # still, use init_tensor for the precision\n    with fabric.init_tensor(), torch.device(\"meta\"):\n        model = GPT(config)\n    print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    t0 = time.perf_counter()\n    state_dict = torch.load(str(checkpoint_path), mmap=True, map_location=\"cpu\")\n    # TODO: this assumes that the model fits on CPU. Use lazy_load and make the materialization checkpoint aware\n    model.load_state_dict(state_dict, assign=True)\n    print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    model = fabric.setup_module(model, move_to_device=False)\n\n    t0 = time.perf_counter()\n    model = sequential(model, fabric.device, max_returned_tokens, total_devices)\n    print(f\"Time to sequential-ize the model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    if compile:\n        # TODO: raises an internal compile AssertionError caused by fabric.strategy.precision.forward_context\n        raise NotImplementedError\n        # silence developer warning on nightly builds\n        # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/ir.py#L4166\n        pattern = re.compile(\".*DeviceCopy in input program.*\")\n        logging.getLogger(\"torch._inductor.utils\").addFilter(lambda record: not pattern.search(record.getMessage()))\n        torch._dynamo.config.automatic_dynamic_shapes = True\n        torch._inductor.config.triton.unique_kernel_names = True\n        torch._inductor.config.coordinate_descent_tuning = True\n        # cannot use cudagraphs because it doesn't support multiple device indices\n        # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/compile_fx.py#L371-L375\n        generate_base.next_token = torch.compile(generate_base.next_token)\n\n    L.seed_everything(1234)\n    for i in range(num_samples):\n        t0 = time.perf_counter()\n        y = generate_base.generate(\n            model=model,\n            prompt=encoded,\n            max_returned_tokens=max_returned_tokens,\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            eos_id=tokenizer.eos_id,\n        )\n        t = time.perf_counter() - t0\n        for block in model.transformer.h:\n            block.attn.kv_cache.reset_parameters()\n        print(tokenizer.decode(y))\n        tokens_generated = y.size(0) - prompt_length\n        print(\n            f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n        )\n    print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/generate/speculative_decoding.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport sys\nimport time\nimport warnings\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Any, Dict, Iterator, List, Literal, Optional, Tuple\n\nimport lightning as L\nimport torch\nimport torch._dynamo.config\nimport torch._inductor.config\nimport torch.nn.functional as F\nfrom lightning.fabric.plugins import BitsandbytesPrecision\n\nfrom litgpt.config import Config\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.generate.base import multinomial_num_samples_1, next_token, sample_top_p\nfrom litgpt.model import GPT\nfrom litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import (\n    check_file_size_on_cpu_and_warn,\n    check_valid_checkpoint_dir,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n    load_checkpoint,\n)\n\n\ndef sample(\n    logits: torch.Tensor,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: float = 1.0,\n    apply_softmax: bool = True,\n) -> torch.Tensor:\n    if top_p < 0.0 or top_p > 1.0:\n        raise ValueError(f\"top_p must be in [0, 1], got {top_p}\")\n    logits = logits[0, -1]\n    # optionally crop the logits to only the top k options\n    if top_k is not None:\n        v, i = torch.topk(logits, min(top_k, logits.size(-1)))\n        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions\n        fill_value = float(\"-inf\") if apply_softmax else float(0)\n        logits = torch.full_like(logits, fill_value).scatter_(-1, i, v)\n    # optionally scale the logits and sample from a probability distribution\n    if temperature > 0.0 or top_p > 0.0:\n        if temperature > 0.0:\n            logits = logits / temperature\n        # optionally crop the logits to smallest set of logits with a cumulative probability above top_p\n        if top_p < 1.0:\n            logits = sample_top_p(logits, top_p)\n        probs = F.softmax(logits, dim=-1) if apply_softmax else logits\n        return multinomial_num_samples_1(probs), probs\n    return torch.argmax(logits, dim=-1, keepdim=True), F.softmax(logits, dim=-1)\n\n\ndef speculative_decoding(\n    draft_model: GPT,\n    target_model: GPT,\n    token: torch.Tensor,\n    input_pos: torch.Tensor,\n    input_pos_maxp1: int,\n    speculative_k: int,\n    **sample_kwargs: Dict[str, Any],\n) -> torch.Tensor:\n    \"\"\"Performs speculative decoding using a draft and a target model.\n\n    This implements the speculative decoding algorithm from \"Fast Inference from Transformers via Speculative Decoding\"\n    (https://arxiv.org/pdf/2211.17192).\n\n    The core idea is to:\n    1. Use a faster draft model to predict multiple tokens ahead\n    2. Verify those predictions with the slower but more accurate target model\n    3. Accept tokens where the target model agrees with high probability\n    4. Reject and resample tokens where there is a disagreement\n\n    This allows leveraging a smaller/faster model to speed up generation while maintaining\n    the quality of the larger target model.\n\n    Args:\n        draft_model: Smaller/faster model used for initial token predictions\n        target_model: Larger/slower model used for verification\n        token: Current input token tensor of shape [1]\n        input_pos: Position index of the token tensor for KV-cache\n        input_pos_maxp1: Maximum position + 1 for managing KV-cache buffer\n        speculative_k: Number of tokens to speculatively generate at once\n        sample_kwargs: Additional sampling parameters (temperature, top_k, top_p)\n\n    Returns:\n        torch.Tensor: Generated tokens that were either accepted from draft model\n                      or resampled from target model\n    \"\"\"\n\n    if speculative_k < 1:\n        raise ValueError(f\"speculative_k must be >= 1, got {speculative_k}\")\n\n    # Step 1: Generate candidate tokens using draft model\n    # The draft model autoregressively generates k tokens, keeping track of probabilities\n    draft_input_pos = input_pos.clone()\n    draft_input_pos_maxp1 = input_pos_maxp1\n    draft_tokens, draft_probs = [], []\n    draft_token = token\n    for idx in range(speculative_k):\n        logits = draft_model(\n            idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1\n        )\n        draft_token, draft_prob = sample(logits, **sample_kwargs)\n        draft_input_pos.add_(1)\n        draft_input_pos_maxp1 += 1\n        draft_tokens.append(draft_token)\n        draft_probs.append(draft_prob)\n    draft_tokens = torch.cat(draft_tokens)\n\n    # Step 2: Get target model predictions for comparison\n    # Feed both original token and draft tokens to get target probabilities\n    candidate_tokens = torch.cat((token, draft_tokens))\n    candidate_input_pos = input_pos + torch.arange(0, speculative_k + 1, device=input_pos.device)\n    candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k\n    target_logits = target_model(\n        idx=candidate_tokens.unsqueeze(0), input_pos=candidate_input_pos, input_pos_maxp1=candidate_input_pos_maxp1\n    )\n\n    # Step 3: Convert target logits to probabilities using same sampling params\n    target_probs = []\n    for target_logit in target_logits.split(1, dim=1):\n        _, target_prob = sample(target_logit, **sample_kwargs)\n        target_probs.append(target_prob)\n\n    # Step 4: Accept/reject draft tokens based on probability comparison\n    # Using rejection sampling: keep token if target_prob >= draft_prob.\n    # Otherwise reject with probability 1 - target_prob / draft_prob.\n    # If rejected, sample from an adjusted distribution: norm(max(0, target_prob_distribution - draft_prob_distribution) instead.\n    accepted_tokens = []\n    for idx in range(len(draft_tokens)):\n        draft_token = draft_tokens[idx].unsqueeze(0)\n        draft_prob = draft_probs[idx][draft_token]\n        target_prob = target_probs[idx][draft_token]\n\n        # Accept the draft token if the target model is \"confident\" in it\n        if target_prob >= draft_prob:\n            accepted_tokens.append(draft_token)\n            continue\n\n        # If not accepted, probabilistically reject it\n        discard_prob = 1 - target_prob / draft_prob\n        should_discard_token = torch.rand(1, device=discard_prob.device) <= discard_prob\n\n        if not should_discard_token:\n            accepted_tokens.append(draft_token)\n            continue\n\n        # On rejection: sample new token from adjusted distribution\n        # p'(x) = normalize(max(0, p_target(x) - p_draft(x)))\n        adjusted_distribution = target_probs[idx] - draft_probs[idx]\n        adjusted_distribution = torch.clamp(adjusted_distribution, 0.0)\n        adjusted_distribution = adjusted_distribution / adjusted_distribution.sum()\n        new_token, _ = sample(adjusted_distribution[None, None, ...], apply_softmax=False, **sample_kwargs)\n        return torch.cat((*accepted_tokens, new_token))\n\n    # If all draft tokens were accepted:\n    # 1. Update draft model's key-value cache\n    # 2. Sample one more token from target model\n    draft_model(idx=draft_token.unsqueeze(0), input_pos=draft_input_pos, input_pos_maxp1=draft_input_pos_maxp1)\n    new_token, _ = sample(target_logits, **sample_kwargs)\n    return torch.cat((*accepted_tokens, new_token))\n\n\n@torch.inference_mode()\ndef generate(\n    draft_model: GPT,\n    target_model: GPT,\n    prompt: torch.Tensor,\n    max_returned_tokens: int,\n    *,\n    temperature: float = 1.0,\n    top_k: Optional[int] = None,\n    top_p: float = 1.0,\n    stop_tokens: Tuple[List[int], ...] = (),\n    include_prompt: bool = True,\n    speculative_k: int,\n) -> Iterator[torch.Tensor]:\n    \"\"\"Generates tokens using speculative decoding with a draft and a target model.\n\n    This function implements token generation using speculative decoding, where a faster draft model\n    makes initial token predictions that are verified by a slower but more accurate target model.\n\n    Args:\n        draft_model: Smaller/faster model used for initial token predictions\n        target_model: Larger/more accurate model used to verify draft predictions\n        prompt: Input tensor of token ids to generate from, shape [sequence_length]\n        max_returned_tokens: Maximum total tokens (prompt + generated) to return\n        temperature: Sampling temperature (higher = more random, lower = more deterministic)\n        top_k: If set, only sample from the top k most likely next tokens\n        top_p: If <1.0, only sample from tokens whose cumulative probability exceeds top_p\n        stop_tokens: List of token sequences that will stop generation if produced\n        include_prompt: Whether to include prompt tokens in the returned sequence\n        speculative_k: Number of tokens to speculatively generate at each step\n\n    Returns:\n        - tokens: Tensor of generated token ids\n        - acceptance_rate: Ratio of accepted draft model predictions\n\n    This implements an optimized decoding process:\n    1. Both models process the initial prompt\n    2. Draft model speculatively generates k tokens ahead\n    3. Target model verifies the draft predictions\n    4. Accepted tokens are kept, rejected ones trigger resampling\n    5. Process repeats until max tokens or stop sequence reached\n    \"\"\"\n\n    prompt_size = prompt.size(0)\n    device = prompt.device\n\n    assert max_returned_tokens > prompt_size, (\n        f\"Not enough space for {prompt_size} prompt tokens in a context length of {max_returned_tokens}.\"\n    )\n    if draft_model.max_seq_length < max_returned_tokens - 1:\n        raise NotImplementedError(\n            f\"max_seq_length {draft_model.max_seq_length} needs to be >= {max_returned_tokens - 1}\"\n        )\n    if target_model.max_seq_length < max_returned_tokens - 1:\n        raise NotImplementedError(\n            f\"max_seq_length {target_model.max_seq_length} needs to be >= {max_returned_tokens - 1}\"\n        )\n\n    # Step 1: Prefill draft and target models with the prompt.\n    input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)\n    # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.\n    input_pos_maxp1 = (\n        prompt_size if all(m.__class__.__name__ != \"ThunderModule\" for m in target_model.modules()) else None\n    )\n    next_token(\n        draft_model,\n        input_pos,\n        prompt.view(1, -1),\n        input_pos_maxp1=input_pos_maxp1,\n        temperature=temperature,\n        top_k=top_k,\n        top_p=top_p,\n    )\n    token = next_token(\n        target_model,\n        input_pos,\n        prompt.view(1, -1),\n        input_pos_maxp1=input_pos_maxp1,\n        temperature=temperature,\n        top_k=top_k,\n        top_p=top_p,\n    )\n    # Update position trackers after prompt\n    input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)\n    input_pos_maxp1 += 1\n\n    # Step 2: Main generation loop.\n    tokens = []\n    total_generated, total_accepted = 0, 0  # Track acceptance statistics\n    while input_pos < max_returned_tokens - 1:\n        # Calculate speculative tokens to generate\n        _speculative_k = min(speculative_k, (max_returned_tokens - input_pos - 1).item())\n\n        # Get new tokens via speculative decoding\n        new_tokens = speculative_decoding(\n            draft_model=draft_model,\n            target_model=target_model,\n            token=token,\n            input_pos=input_pos,\n            input_pos_maxp1=input_pos_maxp1,\n            speculative_k=_speculative_k,\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n        )\n\n        # Update statistics\n        accepted_tokens_len = len(new_tokens)\n        total_generated += _speculative_k\n        total_accepted += accepted_tokens_len - 1  # accepted +1 sampled from a target model\n\n        # Process tokens and check for stop condition\n        should_break = False\n        for new_token in new_tokens:\n            if new_token in stop_tokens:\n                should_break = True\n                break\n            tokens.append(new_token)\n\n        if should_break:\n            break\n\n        # Update positions for next iteration\n        input_pos.add_(accepted_tokens_len)\n        input_pos_maxp1 += accepted_tokens_len\n        token = new_tokens[-1].unsqueeze(0)\n\n    # Finalize generated sequence\n    tokens = torch.stack(tokens)\n    if include_prompt:\n        tokens = torch.cat([prompt, tokens])\n    acceptance_rate = total_accepted / total_generated if total_generated > 0 else 0.0\n    return tokens, acceptance_rate\n\n\ndef setup_model(config: Config, max_returned_tokens: int, fabric: L.Fabric) -> GPT:\n    \"\"\"Helper function to setup a model with common configuration.\"\"\"\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n    with fabric.init_tensor():\n        # set the max_seq_length to limit the memory usage to what we need\n        model.max_seq_length = max_returned_tokens\n        # enable the kv cache\n        model.set_kv_cache(batch_size=1)\n    model.eval()\n    return fabric.setup_module(model)\n\n\ndef load_model(checkpoint_dir: Path, fabric: L.Fabric) -> Tuple[Config, Path]:\n    \"\"\"Helper function to validate and load model configuration.\"\"\"\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n    checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n    check_file_size_on_cpu_and_warn(checkpoint_path, fabric.device)\n    return config, checkpoint_path\n\n\n@torch.inference_mode()\ndef main(\n    draft_model_checkpoint_dir: Path,\n    target_model_checkpoint_dir: Path,\n    prompt: str = \"What food do llamas eat?\",\n    *,\n    sys_prompt: Optional[str] = None,\n    num_samples: int = 1,\n    max_new_tokens: int = 50,\n    speculative_k: int = 3,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = None,\n    precision: Optional[str] = None,\n    compile: bool = False,\n) -> None:\n    \"\"\"Default generation option.\n\n    Generates text samples based on pre-trained models and a tokenizer.\n\n    Args:\n        draft_model: Smaller/faster model used for initial token predictions\n        target_model: Larger/more accurate model used to verify draft predictions\n        prompt: The prompt string to use for generating the samples.\n        sys_prompt: The system prompt to use for generating the samples.\n        num_samples: The number of text samples to generate.\n        max_new_tokens: The number of generation steps to take.\n        speculative_k: Number of tokens to speculatively generate at each step\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            - bnb.int8: 8-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        precision: Indicates the Fabric precision setting to use.\n        compile: Whether to compile the model.\n    \"\"\"\n    draft_model_checkpoint_dir = extend_checkpoint_dir(draft_model_checkpoint_dir)\n    target_model_checkpoint_dir = extend_checkpoint_dir(target_model_checkpoint_dir)\n    pprint(locals())\n\n    # Setup Fabric\n    precision = precision or get_default_supported_precision(training=False)\n    plugins = None\n    if quantize is not None and quantize.startswith(\"bnb.\"):\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)\n\n    # Load model configs and checkpoints\n    draft_config, draft_checkpoint_path = load_model(draft_model_checkpoint_dir, fabric)\n    target_config, target_checkpoint_path = load_model(target_model_checkpoint_dir, fabric)\n\n    # Setup tokenizer and validate\n    draft_tokenizer = Tokenizer(draft_model_checkpoint_dir)\n    target_tokenizer = Tokenizer(target_model_checkpoint_dir)\n    if draft_tokenizer.vocab_size != target_tokenizer.vocab_size:\n        raise ValueError(\"Draft and target models have different vocab sizes.\")\n    tokenizer = target_tokenizer\n\n    # Setup prompt\n    prompt_style = (\n        load_prompt_style(target_model_checkpoint_dir)\n        if has_prompt_style(target_model_checkpoint_dir)\n        else PromptStyle.from_config(target_config)\n    )\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    # Initialize models\n    fabric.print(f\"Loading draft model {str(draft_checkpoint_path)!r} with {draft_config.__dict__}\", file=sys.stderr)\n    fabric.print(f\"Loading target model {str(target_checkpoint_path)!r} with {target_config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    draft_model = setup_model(draft_config, max_returned_tokens, fabric)\n    target_model = setup_model(target_config, max_returned_tokens, fabric)\n    fabric.print(f\"Time to instantiate models: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    # Setup compilation if needed\n    if compile:\n        torch._dynamo.config.automatic_dynamic_shapes = True\n        torch._inductor.config.triton.unique_kernel_names = True\n        torch._inductor.config.coordinate_descent_tuning = True\n        global next_token\n        next_token = torch.compile(next_token, mode=\"reduce-overhead\")\n\n    # Load model weights\n    t0 = time.perf_counter()\n    load_checkpoint(fabric, draft_model, draft_checkpoint_path)\n    load_checkpoint(fabric, target_model, target_checkpoint_path)\n    fabric.print(f\"Time to load the models weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    # Generate samples\n    L.seed_everything(1234)\n    for i in range(num_samples):\n        t0 = time.perf_counter()\n        y, acceptance_rate = generate(\n            draft_model,\n            target_model,\n            encoded,\n            max_returned_tokens,\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            stop_tokens=([tokenizer.eos_id] if tokenizer.eos_id is not None else []),\n            speculative_k=speculative_k,\n        )\n        t = time.perf_counter() - t0\n\n        # Reset KV cache\n        for model in (draft_model, target_model):\n            for block in model.transformer.h:\n                block.attn.kv_cache.reset_parameters()\n\n        # Print results\n        fabric.print(tokenizer.decode(y))\n        tokens_generated = y.size(0) - prompt_length\n        print(f\"Acceptance rate: {acceptance_rate * 100:.2f}%\")\n        fabric.print(\n            f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n        )\n\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/generate/tp.py",
    "content": "\"\"\"Tensor-parallel implementation adapted from https://github.com/pytorch-labs/gpt-fast/blob/14df27/tp.py\"\"\"\n\nimport logging\nimport sys\nimport time\nimport warnings\nfrom functools import partial\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Literal, Optional, Union\n\nimport lightning as L\nimport torch\nimport torch._dynamo.config\nimport torch._inductor.config\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom lightning.fabric.utilities import rank_zero_only\n\nimport litgpt.generate.base as generate_base\nfrom litgpt.config import Config\nfrom litgpt.constants import _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0\nfrom litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE\nfrom litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import (\n    check_nvlink_connectivity,\n    check_valid_checkpoint_dir,\n    extend_checkpoint_dir,\n    get_default_supported_precision,\n)\n\n\ndef tensor_parallel_linear(fabric: L.Fabric, linear: torch.nn.Linear, style: str) -> None:\n    world_size = fabric.world_size\n    dim, attr = {\"colwise\": (0, \"out_features\"), \"rowwise\": (1, \"in_features\")}[style]\n    size = getattr(linear, attr)\n    if size % world_size != 0:\n        raise ValueError(\n            f\"This linear's {attr} value ({size}) is not evenly divisible by the world size ({world_size})\"\n        )\n\n    shard = torch.tensor_split(linear.weight, world_size, dim=dim)[fabric.global_rank]\n    # overwrite `.data` instead of recreating the parameter for quantization (bitsandbytes) support.\n    # the bitsandbytes linear classes use custom `torch.nn.Parameter` subclasses\n    linear.weight.data = shard\n    setattr(linear, attr, shard.size(dim))\n\n    if linear.bias is not None and dim == 0:\n        shard = torch.tensor_split(linear.bias, world_size)[fabric.global_rank]\n        linear.bias = torch.nn.Parameter(shard, requires_grad=linear.bias.requires_grad)\n\n\ndef tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE]) -> None:\n    if isinstance(mlp, LLaMAMLP):\n        tensor_parallel_linear(fabric, mlp.fc_1, \"colwise\")\n        tensor_parallel_linear(fabric, mlp.fc_2, \"colwise\")\n        tensor_parallel_linear(fabric, mlp.proj, \"rowwise\")\n        mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size))\n    elif isinstance(mlp, GptNeoxMLP):\n        tensor_parallel_linear(fabric, mlp.fc, \"colwise\")\n        tensor_parallel_linear(fabric, mlp.proj, \"rowwise\")\n        mlp.register_forward_hook(partial(all_reduce_output, fabric.world_size))\n    elif isinstance(mlp, LLaMAMoE):\n        # we use expert slicing across ranks, alternatively, we could create a expert parallelism group\n        # when the number of experts is a multiple of the world size\n        for expert in mlp.experts:\n            tensor_parallel_mlp(fabric, expert)\n    else:\n        raise NotImplementedError\n\n\ndef tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None:\n    tensor_parallel_linear(fabric, attn.qkv, \"colwise\")\n    tensor_parallel_linear(fabric, attn.proj, \"rowwise\")\n    attn.register_forward_hook(partial(all_reduce_output, fabric.world_size))\n\n\ndef all_reduce_output(world_size: int, module: torch.nn.Module, ins, outs) -> torch.Tensor:\n    from torch.distributed._functional_collectives import all_reduce\n\n    return all_reduce(outs, \"sum\", list(range(world_size)))\n\n\ndef tensor_parallel(fabric: L.Fabric, model: GPT) -> GPT:\n    for block in model.transformer.h:\n        tensor_parallel_mlp(fabric, block.mlp)\n        tensor_parallel_attn(fabric, block.attn)\n\n    # update the config values to the shard sizes\n    # this is only relevant for `tensor_parallel_attn`, but it needs to run only once\n    world_size = fabric.world_size\n    attrs = [\"n_head\", \"n_embd\", \"n_query_groups\"]\n    for attr in attrs:\n        size = getattr(model.config, attr)\n        if size % world_size != 0:\n            raise ValueError(f\"This {attr} value ({size}) is not evenly divisible by the world size ({world_size})\")\n        setattr(model.config, attr, size // world_size)\n\n    return model\n\n\n@torch.inference_mode()\ndef main(\n    checkpoint_dir: Path,\n    prompt: str = \"What food do llamas eat?\",\n    *,\n    sys_prompt: Optional[str] = None,\n    num_samples: int = 1,\n    max_new_tokens: int = 50,\n    top_k: Optional[int] = 50,\n    top_p: float = 1.0,\n    temperature: float = 0.8,\n    quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\"]] = None,\n    precision: Optional[str] = None,\n    compile: bool = False,\n) -> None:\n    \"\"\"Generation script that uses tensor parallelism to run across devices.\n\n    Generates text samples based on a pre-trained model and tokenizer.\n\n    Args:\n        checkpoint_dir: The checkpoint directory to load.\n        prompt: The prompt string to use for generating the samples.\n        sys_prompt: The system prompt to use for generating the samples.\n        num_samples: The number of text samples to generate.\n        max_new_tokens: The number of generation steps to take.\n        top_k: The number of top most probable tokens to consider in the sampling process.\n        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n            In top-p sampling, the next token is sampled from the highest probability tokens\n            whose cumulative probability exceeds the threshold `top_p`. When specified,\n            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n            to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n            It can be used in conjunction with `top_k` and `temperature` with the following order\n            of application:\n\n            1. `top_k` sampling\n            2. `temperature` scaling\n            3. `top_p` sampling\n\n            For more details, see https://arxiv.org/abs/1904.09751\n            or https://huyenchip.com/2024/01/16/sampling.html#top_p\n        temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n            samples.\n        quantize: Whether to quantize the model and using which method:\n            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n        precision: Indicates the Fabric precision setting to use.\n        compile: Whether to compile the model.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    precision = precision or get_default_supported_precision(training=False)\n\n    plugins = None\n    if quantize is not None:\n        if compile:\n            raise NotImplementedError  # untested\n        if \"mixed\" in precision:\n            raise ValueError(\"Quantization and mixed precision is not supported.\")\n        if _BITANDBYTES_AVAILABLE_NOT_EQUAL_0_42_0:\n            warnings.warn(\n                \"LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.\"\n            )\n        dtype = {\"16-true\": torch.float16, \"bf16-true\": torch.bfloat16, \"32-true\": torch.float32}[precision]\n        bnb_logger = logging.getLogger(\"lightning.fabric.plugins.precision.bitsandbytes\")\n        bnb_logger.setLevel(logging.DEBUG)\n        bnb_logger.debug = rank_zero_only(bnb_logger.debug)\n        plugins = BitsandbytesPrecision(quantize[4:], dtype)\n        precision = None\n\n    # set \"ddp\" as the strategy for the launching functionality, but there's no data-parallelism\n    fabric = L.Fabric(devices=\"auto\", strategy=\"ddp\", precision=precision, plugins=plugins)\n    if torch.cuda.is_available() and fabric.accelerator.auto_device_count() > 1:\n        check_nvlink_connectivity(fabric)\n    fabric.launch()\n\n    check_valid_checkpoint_dir(checkpoint_dir)\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    model_file = \"lit_model.pth\"\n    checkpoint_path = checkpoint_dir / model_file\n\n    tokenizer = Tokenizer(checkpoint_dir)\n    prompt_style = (\n        load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n    )\n    prompt = prompt_style.apply(prompt, sys_prompt=sys_prompt)\n    encoded = tokenizer.encode(prompt, device=fabric.device)\n    prompt_length = encoded.size(0)\n    max_returned_tokens = prompt_length + max_new_tokens\n\n    fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n    t0 = time.perf_counter()\n    # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced\n    # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert\n    # still, use init_tensor for the precision\n    with fabric.init_tensor(), torch.device(\"meta\"):\n        model = GPT(config)\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n    # sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device\n    # so that the CPU RAM doesn't OOM with larger models\n    for rank in range(fabric.world_size):\n        if fabric.global_rank == rank:\n            t0 = time.perf_counter()\n            state_dict = torch.load(str(checkpoint_path), mmap=True, map_location=\"cpu\")\n            model.load_state_dict(state_dict, assign=True)\n            print(f\"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n\n            # cannot use `.setup_module` because it will wrap with DDP\n            model = fabric._precision.convert_module(model)\n\n            t0 = time.perf_counter()\n            model = tensor_parallel(fabric, model)\n            print(\n                f\"[{rank}] Time to tensor-parallelize the model: {time.perf_counter() - t0:.02f} seconds.\",\n                file=sys.stderr,\n            )\n\n            with fabric.init_tensor():\n                # set the max_seq_length to limit the memory usage to what we need\n                model.max_seq_length = max_returned_tokens\n                # the rope cache which is on meta device\n                model.cos, model.sin = model.rope_cache()\n                # enable the kv cache\n                model.set_kv_cache(batch_size=1)\n            model.eval()\n\n            t0 = time.perf_counter()\n            model = fabric.to_device(model)\n            print(f\"[{rank}] Time to move the model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n        fabric.barrier()\n\n    if compile:\n        torch._dynamo.config.automatic_dynamic_shapes = True\n        torch._inductor.config.triton.unique_kernel_names = True\n        torch._inductor.config.coordinate_descent_tuning = True\n        generate_base.next_token = torch.compile(generate_base.next_token, mode=\"reduce-overhead\")\n\n    L.seed_everything(1234)\n    for i in range(num_samples):\n        t0 = time.perf_counter()\n        y = generate_base.generate(\n            model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id\n        )\n        t = time.perf_counter() - t0\n        for block in model.transformer.h:\n            block.attn.kv_cache.reset_parameters()\n        fabric.print(tokenizer.decode(y))\n        tokens_generated = y.size(0) - prompt_length\n        fabric.print(\n            f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n        )\n    if fabric.device.type == \"cuda\":\n        fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)\n"
  },
  {
    "path": "litgpt/lora.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n# Derived from https://github.com/microsoft/LoRA\n#  ------------------------------------------------------------------------------------------\n#  Copyright (c) Microsoft Corporation. All rights reserved.\n#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.\n#  ------------------------------------------------------------------------------------------\n\nr\"\"\"\n    Low Ranking Adaptation for LLMs scheme.\n\n             ┌───────────────────┐\n             ┆         h         ┆\n             └───────────────────┘\n                       ▲\n                       |\n                       +\n                    /     \\\n    ┌─────────────────┐    ╭───────────────╮     Matrix initialization:\n    ┆                 ┆     \\      B      /      B = 0\n    ┆   pretrained    ┆      \\    r*d    /       A = N(0, sigma^2)\n    ┆    weights      ┆       ╰─────────╯\n    ┆                 ┆       |    r    |        r - rank\n    ┆   W e R^(d*d)   ┆       | ◀─────▶ |\n    ┆                 ┆       ╭─────────╮\n    └─────────────────┘      /     A     \\\n              ▲             /     d*r     \\\n               \\           ╰───────────────╯\n                \\                ▲\n                 \\              /\n                  \\            /\n             ┌───────────────────┐\n             ┆         x         ┆\n             └───────────────────┘\n\nWith LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,\nwe can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates\nfor the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of\ncourse) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen\npretrained weights and thus fine-tune the model.\n\nThe goal of this approach is to move weight updates into a separate matrix which is decomposed with\ntwo matrices of a lower rank.\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tuple, Type, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nfrom typing_extensions import Self\n\nimport litgpt\nfrom litgpt.config import Config as BaseConfig\nfrom litgpt.model import GPT as BaseModel\nfrom litgpt.model import Block as BaseBlock\nfrom litgpt.model import CausalSelfAttention as BaseCausalSelfAttention\nfrom litgpt.scripts.convert_hf_checkpoint import qkv_reassemble\nfrom litgpt.utils import map_old_state_dict_weights\n\n\nclass LoRALayer(nn.Module):\n    def __init__(self, r: int, lora_alpha: int, lora_dropout: float):\n        \"\"\"Store LoRA specific attributes in a class.\n\n        Args:\n            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n            lora_alpha: alpha is needed for scaling updates as alpha/r\n                \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n        \"\"\"\n        super().__init__()\n        assert r >= 0\n        self.r = r\n        self.lora_alpha = lora_alpha\n        # Optional dropout\n        if lora_dropout > 0.0:\n            self.lora_dropout = nn.Dropout(p=lora_dropout)\n        else:\n            self.lora_dropout = lambda x: x\n        # Mark the weight as unmerged\n        self.merged = False\n\n\nclass LoRALinear(LoRALayer):\n    # LoRA implemented in a dense layer\n    def __init__(\n        self,\n        # ↓ this part is for pretrained weights\n        in_features: int,\n        out_features: int,\n        # ↓ the remaining part is for LoRA\n        r: int = 0,\n        lora_alpha: int = 1,\n        lora_dropout: float = 0.0,\n        **kwargs: Any,\n    ):\n        \"\"\"LoRA wrapper around linear class.\n\n        This class has three weight matrices:\n            1. Pretrained weights are stored as `self.linear.weight`\n            2. LoRA A matrix as `self.lora_A`\n            3. LoRA B matrix as `self.lora_B`\n        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.\n\n        Args:\n            in_features: number of input features of the pretrained weights\n            out_features: number of output features of the pretrained weights\n            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n            lora_alpha: alpha is needed for scaling updates as alpha/r\n                \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n        \"\"\"\n        super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)\n        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)\n\n        # Actual trainable parameters\n        if r > 0:\n            self.lora_A = nn.Parameter(torch.empty((r, in_features)))\n            self.lora_B = nn.Parameter(torch.empty((out_features, r)))\n            self.scaling = self.lora_alpha / self.r\n            self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        \"\"\"Reset all the weights, even including pretrained ones.\"\"\"\n        if hasattr(self, \"lora_A\"):\n            # initialize A the same way as the default for nn.Linear and B to zero\n            # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314\n            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))\n            nn.init.zeros_(self.lora_B)\n\n    def get_lora_AB(self) -> torch.Tensor:\n        \"\"\"Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.\"\"\"\n        return (self.lora_B @ self.lora_A) * self.scaling\n\n    def merge(self) -> None:\n        \"\"\"Merges the LoRA weights into the full-rank weights (W = W + delta_W).\"\"\"\n        if self.r > 0 and not self.merged:\n            pretrained_dtype = self.linear.weight.data.dtype\n            lora_data = self.get_lora_AB()\n            # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result\n            if pretrained_dtype == torch.uint8:\n                import bitsandbytes as bnb\n\n                weight = self.linear.weight\n                # dequantize the pretrained weights\n                weight_data = bnb.functional.dequantize_4bit(weight.data, weight.quant_state).to(lora_data.dtype)\n                # add pretrained and LoRA weights\n                weight_data += lora_data\n                # assign updated weights and quantize by moving to CUDA device\n                self.linear.weight = bnb.nn.Params4bit(weight_data, requires_grad=False, **weight.__dict__)\n                self.linear.weight.cuda(weight.device)\n            else:\n                # self.linear might be on CPU and lora_data on CUDA\n                # the inplace add will preserve the dtype of linear.weight\n                self.linear.weight.data += lora_data.to(device=self.linear.weight.data.device)\n            self.merged = True\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;\n        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights\n        pretrained = self.linear(x)\n        if self.r == 0 or self.merged:\n            return pretrained\n        lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling\n        return pretrained + lora\n\n\nclass LoRAQKVLinear(LoRALinear):\n    # LoRA implemented in a dense layer\n    def __init__(\n        self,\n        # ↓ this part is for pretrained weights\n        in_features: int,\n        out_features: int,\n        # ↓ the remaining part is for LoRA\n        head_size: int,\n        n_head: int,\n        n_query_groups: int,\n        r: int = 0,\n        lora_alpha: int = 1,\n        lora_dropout: float = 0.0,\n        enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,\n        **kwargs: Any,\n    ):\n        \"\"\"LoRA wrapper around linear class that is used for calculation of q, k and v matrices.\n\n        This class has three weight matrices:\n            1. Pretrained weights are stored as `self.linear.weight`\n            2. LoRA A matrix as `self.lora_A`\n            3. LoRA B matrix as `self.lora_B`\n        Only LoRA's A and B matrices are updated, pretrained weights stay frozen.\n\n        Args:\n            in_features: number of input features of the pretrained weights\n            out_features: number of output features of the pretrained weights\n            head_size: size of a single attention head\n            n_head: number of attention heads\n            n_query_groups: number of query groups (see diagram in `litgpt/config.py`)\n            r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n                the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n            lora_alpha: alpha is needed for scaling updates as alpha/r\n                \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n                https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n            lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n            enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we\n                don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`\n                and `value` but keep `key` without weight updates we should pass `[True, False, True]`\n        \"\"\"\n        super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)\n        self.linear = torch.nn.Linear(in_features, out_features, **kwargs)\n        self.head_size = head_size\n        self.n_head = n_head\n        self.n_query_groups = n_query_groups\n        if isinstance(enable_lora, bool):\n            enable_lora = [enable_lora] * 3\n        assert len(enable_lora) == 3\n        self.enable_lora = enable_lora\n\n        # Actual trainable parameters\n        # To better understand initialization let's imagine that we have such parameters:\n        # ⚬ in_features: 128 (embeddings_size)\n        # ⚬ out_features: 384 (3 * embedding_size)\n        # ⚬ r: 2\n        # ⚬ enable_lora: [True, False, True]\n        if r > 0 and any(enable_lora):\n            self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features)))  # (4, 128)\n            enable_q, enable_k, enable_v = enable_lora\n            # qkv_shapes will be used to split a tensor with weights correctly\n            qkv_shapes = (\n                # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)\n                # might not be equal to `head_size * n_head`, thus we use it directly here\n                head_size * n_head * enable_q,\n                head_size * n_query_groups * enable_k,\n                head_size * n_query_groups * enable_v,\n            )\n            self.qkv_shapes = [s for s in qkv_shapes if s]\n            self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r))  # (256, 2))\n            # Notes about shapes above\n            # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;\n            # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in\n            # F.linear function weights are automatically transposed. In addition conv1d requires channels to\n            # be before seq length\n            # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is\n            # 128*2; 2 tells to have two channels per group for group convolution\n\n            # Scaling:\n            # This balances the pretrained model`s knowledge and the new task-specific adaptation\n            # https://lightning.ai/pages/community/tutorial/lora-llm/\n            # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set\n            # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can\n            # tune these values to your needs. This value can be even slightly greater than 1.0!\n            # https://github.com/cloneofsimo/lora\n            self.scaling = self.lora_alpha / self.r\n\n            self.reset_parameters()\n\n    @property\n    def lora_ind(self) -> torch.Tensor:\n        \"\"\"Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used.\"\"\"\n        # Indices are needed to properly pad weight updates with zeros.\n        if not hasattr(self, \"_lora_ind\"):\n            enable_q, enable_k, enable_v = self.enable_lora\n            q_embd_size = self.head_size * self.n_head\n            kv_embd_size = self.head_size * self.n_query_groups\n            lora_ind = []\n            if enable_q:\n                lora_ind.extend(range(0, q_embd_size))\n            if enable_k:\n                lora_ind.extend(range(q_embd_size, q_embd_size + kv_embd_size))\n            if enable_v:\n                lora_ind.extend(range(q_embd_size + kv_embd_size, self.linear.out_features))\n            self.register_buffer(\n                \"_lora_ind\", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False\n            )\n\n        return self._lora_ind\n\n    def zero_pad(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Properly pad the last dimension of weight updates with zeros.\n\n        If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,\n        then the weights update should be:\n\n        [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],\n         [....................................],\n         [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]\n            ↑              ↑            ↑\n        ________________________________________\n        | query         | key       | value    |\n        ----------------------------------------\n\n        Args:\n            x: tensor with weights update that will be padded with zeros if necessary\n\n        Returns:\n            A tensor with weight updates and zeros for deselected q, k or v\n        \"\"\"\n        # we need to do zero padding only if LoRA is disabled for one of QKV matrices\n        if all(self.enable_lora):\n            return x\n\n        # Let's image that:\n        # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)\n        # ⚬ embeddings_size: 128\n        # ⚬ self.linear.out_features: 384 (3 * embeddings_size)\n        # ⚬ enable_lora: [True, False, True]\n        # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected\n        # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but\n        # only for key updates (this is where self.lora_ind comes in handy)\n\n        result = x.new_zeros(*x.shape[:-1], self.linear.out_features)  # (64, 64, 384)\n        if result.device.type == \"mps\":\n            result[..., self.lora_ind] = x\n            return result\n        else:\n            return result.index_copy_(dim=-1, index=self.lora_ind, source=x)  # (64, 64, 384)\n\n    def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:\n        \"\"\"An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.\n\n        If the number of heads is equal to the number of query groups - grouped queries are disabled\n        (see scheme in `litgpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized\n        query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the\n        input and weight matrices will be split in equally sized parts and applied separately (like having multiple\n        conv layers side by side).\n\n        Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,\n        apply each part of the weight matrix to the corresponding input's part and concatenate the result.\n\n        Args:\n            input: input matrix of shape (B, C, T)\n            weight: weight matrix of shape (C_output, rank, 1).\n                \"C_output\" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).\n\n        Returns:\n            A tensor with a shape (B, C_output, T)\n\n        \"\"\"\n        if self.n_head == self.n_query_groups:\n            return F.conv1d(input, weight, groups=sum(self.enable_lora))  # (B, C_output, T)\n\n        # Notation:\n        # ⚬ N: number of enabled LoRA layers (self.enable_lora)\n        # ⚬ C_output': embeddings size for each LoRA layer (not equal in size)\n        # ⚬ r: rank of all LoRA layers (equal in size)\n\n        input_splitted = input.chunk(sum(self.enable_lora), dim=1)  # N * (B, C // N, T)\n        weight_splitted = weight.split(self.qkv_shapes)  # N * (C_output', r, 1)\n        return torch.cat(\n            [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)],\n            dim=1,  # (B, C_output', T)\n        )  # (B, C_output, T)\n\n    def get_lora_AB(self) -> torch.Tensor:\n        \"\"\"Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.\"\"\"\n        # Let's assume that:\n        # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)\n        # ⚬ self.lora_A.data: (4, 128)\n        # ⚬ self.lora_B.data: (256, 2)\n        lora = self.conv1d(\n            self.lora_A.data.unsqueeze(0),  # (4, 128) -> (1, 4, 128)\n            self.lora_B.data.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)\n        ).squeeze(0)  # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)\n        return self.zero_pad(lora.T * self.scaling).T  # (256, 128) after zero_pad (384, 128)\n\n    def merge(self) -> None:\n        \"\"\"Merges the LoRA weights into the full-rank weights (W = W + delta_W).\"\"\"\n        if self.r > 0 and any(self.enable_lora) and not self.merged:\n            super().merge()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"Do the forward pass.\n\n        If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.\n        If not, then multiply pretrained weights with input, apply LoRA on input and do summation.\n\n        Args:\n            x: input tensor of shape (batch_size, context_length, embedding_size)\n\n        Returns:\n            Output tensor of shape (batch_size, context_length, 3 * embedding_size)\n        \"\"\"\n\n        # Let's assume that:\n        # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)\n        # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)\n        # ⚬ self.lora_A.data: (4, 128)\n        # ⚬ self.lora_B.data: (256, 2)\n\n        # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;\n        # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights\n        pretrained = self.linear(x)\n        if self.r == 0 or not any(self.enable_lora) or self.merged:\n            return pretrained\n        after_A = F.linear(self.lora_dropout(x), self.lora_A)  # (64, 64, 128) @ (4, 128) -> (64, 64, 4)\n        # For F.conv1d:\n        # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)\n        # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)\n        after_B = self.conv1d(\n            after_A.transpose(-2, -1),  # (64, 64, 4) -> (64, 4, 64)\n            self.lora_B.unsqueeze(-1),  # (256, 2) -> (256, 2, 1)\n        ).transpose(-2, -1)  # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)\n        lora = self.zero_pad(after_B) * self.scaling  # (64, 64, 256) after zero_pad (64, 64, 384)\n        return pretrained + lora\n\n\ndef mark_only_lora_as_trainable(model: nn.Module, bias: str = \"none\") -> None:\n    \"\"\"Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.\n\n    Args:\n        model: model with LoRA layers\n        bias:\n            ``\"none\"``: all bias weights will be frozen,\n            ``\"lora_only\"``: only bias weight for LoRA layers will be unfrozen,\n            ``\"all\"``: all bias weights will be unfrozen.\n\n    Raises:\n        NotImplementedError: if `bias` not in [\"none\", \"lora_only\", \"all\"]\n    \"\"\"\n    # freeze all layers except LoRA's\n    for n, p in model.named_parameters():\n        if \"lora_\" not in n:\n            p.requires_grad = False\n\n    # depending on the `bias` value unfreeze bias weights\n    if bias == \"none\":\n        return\n    if bias == \"all\":\n        for n, p in model.named_parameters():\n            if \"bias\" in n:\n                p.requires_grad = True\n    elif bias == \"lora_only\":\n        for m in model.modules():\n            if isinstance(m, LoRALayer) and hasattr(m, \"bias\") and m.bias is not None:\n                m.bias.requires_grad = True\n    else:\n        raise NotImplementedError\n\n\ndef lora_filter(key: str, value: Any) -> bool:\n    return \"lora_\" in key\n\n\n@dataclass\nclass Config(BaseConfig):\n    \"\"\"\n    Args:\n        lora_r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of\n            the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)\n        lora_alpha: alpha is needed for scaling updates as alpha/r\n            \"This scaling helps to reduce the need to retune hyperparameters when we vary r\"\n            https://arxiv.org/pdf/2106.09685.pdf (section 4.1)\n        lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)\n        lora_*: whether to apply LoRA to the specified weights or not\n    \"\"\"\n\n    lora_r: int = 0\n    lora_alpha: int = 1\n    lora_dropout: float = 0.0\n    lora_query: bool = False\n    lora_key: bool = False\n    lora_value: bool = False\n    lora_projection: bool = False\n    lora_mlp: bool = False\n    lora_head: bool = False\n\n    @property\n    def mlp_class(self) -> Type:\n        return getattr(litgpt.lora, self.mlp_class_name)\n\n\nclass GPT(BaseModel):\n    # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = create_lora_linear(\n            config,\n            config.n_embd,\n            config.padded_vocab_size,\n            bias=config.lm_head_bias,\n            use_r=config.lora_head,\n        )\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.max_seq_length = self.config.block_size\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def _init_weights(self, module: nn.Module) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.\"\"\"\n        super()._init_weights(module)\n        if isinstance(module, LoRALinear):\n            module.reset_parameters()\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\"lm_head.weight\": \"lm_head.linear.weight\", \"lm_head.bias\": \"lm_head.linear.bias\"}\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass Block(BaseBlock):\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__(config, block_idx)\n        self.attn = CausalSelfAttention(config, block_idx)\n        self.mlp = config.mlp_class(config)\n\n\nclass CausalSelfAttention(BaseCausalSelfAttention):\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__(config, block_idx)\n        # key, query, value projections for all heads, but in a batch\n        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size\n        self.qkv = LoRAQKVLinear(\n            in_features=config.n_embd,\n            out_features=shape,\n            r=config.lora_r,\n            lora_alpha=config.lora_alpha,\n            lora_dropout=config.lora_dropout,\n            enable_lora=(config.lora_query, config.lora_key, config.lora_value),\n            bias=config.bias or config.attn_bias,\n            # for MQA/GQA support\n            head_size=config.head_size,\n            n_head=config.n_head,\n            n_query_groups=config.n_query_groups,\n        )\n        # output projection\n        self.proj = create_lora_linear(\n            config,\n            config.head_size * config.n_head,\n            config.n_embd,\n            use_r=config.lora_projection,\n        )\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base and/or legacy checkpoints.\"\"\"\n        mapping = {\n            \"qkv.weight\": \"qkv.linear.weight\",\n            \"qkv.bias\": \"qkv.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n\n        for attr in (\"weight\", \"bias\"):\n            legacy_key = f\"{prefix}attn.linear.{attr}\"\n            current_key = f\"{prefix}qkv.linear.{attr}\"\n            if legacy_key in state_dict:\n                state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)\n\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\ndef create_lora_linear(\n    config: Config,\n    in_size: int,\n    out_size: int,\n    bias: Optional[Union[float, bool]] = None,\n    use_r: Optional[bool] = None,\n) -> LoRALinear:\n    if bias is None:\n        bias = config.bias\n    if use_r is None:\n        use_r = config.lora_mlp\n    return LoRALinear(\n        in_size,\n        out_size,\n        bias=bias,\n        r=(config.lora_r if use_r else 0),\n        lora_alpha=config.lora_alpha,\n        lora_dropout=config.lora_dropout,\n    )\n\n\nclass GptNeoxMLP(litgpt.model.GptNeoxMLP):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.fc = create_lora_linear(config, config.n_embd, config.intermediate_size)\n        self.proj = create_lora_linear(config, config.intermediate_size, config.n_embd)\n        self.config = config\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc.weight\": \"fc.linear.weight\",\n            \"fc.bias\": \"fc.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass LLaMAMLP(litgpt.model.LLaMAMLP):\n    def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:\n        nn.Module.__init__(self)\n        self.intermediate_size = intermediate_size or config.intermediate_size\n        self.fc_1 = create_lora_linear(config, config.n_embd, self.intermediate_size)\n        self.fc_2 = create_lora_linear(config, config.n_embd, self.intermediate_size)\n        self.proj = create_lora_linear(config, self.intermediate_size, config.n_embd)\n        self.config = config\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\n            \"fc_1.weight\": \"fc_1.linear.weight\",\n            \"fc_1.bias\": \"fc_1.linear.bias\",\n            \"fc_2.weight\": \"fc_2.linear.weight\",\n            \"fc_2.bias\": \"fc_2.linear.bias\",\n            \"proj.weight\": \"proj.linear.weight\",\n            \"proj.bias\": \"proj.linear.bias\",\n        }\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass GemmaMLP(LLaMAMLP):\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_fc_1 = self.fc_1(x)\n        x_fc_2 = self.fc_2(x)\n        x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2\n        return self.proj(x)\n\n\nclass LLaMAMoE(litgpt.model.LLaMAMoE):\n    def __init__(self, config: Config) -> None:\n        nn.Module.__init__(self)\n        self.gate = create_lora_linear(config, config.n_embd, config.n_expert, bias=False)\n        self.experts = nn.ModuleList(\n            LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)\n        )\n        self.config = config\n\n    def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with base checkpoints.\"\"\"\n        mapping = {\"gate.weight\": \"gate.linear.weight\"}\n        state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\ndef merge_lora_weights(model: GPT) -> None:\n    \"\"\"Merge LoRA weights into the full-rank weights to speed up inference.\"\"\"\n    for module in model.modules():\n        if isinstance(module, LoRALinear):\n            module.merge()\n"
  },
  {
    "path": "litgpt/model.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\"\"\"Full definition of a decoder-only transformer-based language model, all of it in this single file.\n\nBased on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and\nhttps://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.\n\"\"\"\n\nimport math\nfrom functools import partial\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing_extensions import Self\n\nfrom litgpt.config import Config\nfrom litgpt.scripts.convert_hf_checkpoint import qkv_reassemble\n\n\nclass GPT(nn.Module):\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        assert config.padded_vocab_size is not None\n        self.config = config\n\n        self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)\n        self.transformer = nn.ModuleDict(\n            dict(\n                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),\n                h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)),\n                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),\n            )\n        )\n        self.mask_cache: Optional[torch.Tensor] = None\n        self.max_seq_length = self.config.block_size\n\n    @property\n    def max_seq_length(self) -> int:\n        return self._max_seq_length\n\n    @max_seq_length.setter\n    def max_seq_length(self, value: int) -> None:\n        \"\"\"\n        When doing inference, the sequences used might be shorter than the model's context length.\n        This allows setting a smaller number to avoid allocating unused memory\n        \"\"\"\n        if value > self.config.block_size:\n            raise ValueError(\n                f\"Cannot attend to {value}, block size is only {self.config.block_size}.\"\n                \" This is likely because the input text exceeds the supported context length of this model.\"\n            )\n        self._max_seq_length = value\n        if not hasattr(self, \"cos\"):\n            # first call\n            cos, sin = self.rope_cache()\n            self.register_buffer(\"cos\", cos, persistent=False)\n            self.register_buffer(\"sin\", sin, persistent=False)\n        # override\n        elif value != self.cos.size(0):\n            self.cos, self.sin = self.rope_cache(device=self.cos.device)\n        # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know\n        # if the kv cache is expected\n        if self.mask_cache is not None and self.mask_cache.shape[-1] < value:\n            print(\n                f\"Warning: KV cache has length {self.mask_cache.shape[-1]} < {value} = max_seq_length. Call 'set_kv_cache' before doing any forwards!\"\n            )\n\n    def reset_parameters(self) -> None:\n        # Trigger resetting the rope-cache\n        self.cos, self.sin = self.rope_cache(device=self.cos.device)\n\n    def _init_weights(self, module: nn.Module) -> None:\n        \"\"\"Meant to be used with `gpt.apply(gpt._init_weights)`.\"\"\"\n        if isinstance(module, GroupedTopkRouter):\n            torch.nn.init.normal_(module.weight.data, mean=0.0, std=0.02)\n        elif isinstance(module, nn.Linear):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n            if module.bias is not None:\n                torch.nn.init.zeros_(module.bias)\n        elif isinstance(module, nn.Embedding):\n            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n\n    def forward(\n        self,\n        idx: torch.Tensor,\n        input_pos: Optional[torch.Tensor] = None,\n        input_pos_maxp1: Optional[int] = None,\n        lm_head_chunk_size: int = 0,\n    ) -> Union[torch.Tensor, List[torch.Tensor]]:\n        \"\"\"\n        If `input_pos` is provided, the KV cache uses K and V vectors for\n        positions smaller than entries in `input_pos`. For efficiency, pass\n        `input_pos_maxp1` as `max(input_pos) + 1` if already available from\n        your forward algorithm. This slices the KV cache buffers and speeds\n        up multi-head attention.\n\n        Without `input_pos_maxp1`, the computation uses the full KV cache\n        (`max_seq_length`) with masking applied. Note that inferring\n        `input_pos_maxp1` from `input_pos` causes graph breaks and prevents\n        compilation.\n\n        Args:\n            idx: Token indices of input sequences, shape `(B, T)`, where `B`\n                is batch size.\n            input_pos: Optional. Positions of input tokens. The default is\n                `arange(T)`. Can have shape `(T,)` or `(B, T)` (batched index).\n            input_pos_maxp1: Optional. See above.\n            lm_head_chunk_size: Optional. If `lm_head_chunk_size > 0`, the final\n                `lm_head` computation is done in chunks of this size.\n\n        Returns:\n            Logit outputs, shape `(B, T, config.padded_vocab_size)`. If\n            `lm_head_chunk_size > 0`, this is a list of chunks of shape\n            `(B, lm_head_chunk_size, config.padded_vocab_size)`, the final\n            entry can be shorter.\n\n        \"\"\"\n        T = idx.size(1)\n        if self.max_seq_length < T:\n            raise ValueError(f\"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.\")\n\n        if input_pos is not None:  # use the kv cache\n            if input_pos.dim() > 2:\n                # otherwise, things go wrong in `apply_rope`\n                raise ValueError(f\"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}\")\n            if input_pos.shape[-1] != T:\n                raise ValueError(f\"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same\")\n            cos = batched_index_select(self.cos, 0, input_pos)\n            sin = batched_index_select(self.sin, 0, input_pos)\n            if input_pos.dim() == 1:\n                cos = cos.unsqueeze(0)\n                sin = sin.unsqueeze(0)\n            if self.mask_cache is None:\n                raise TypeError(\"You need to call `gpt.set_kv_cache()`\")\n            mask = batched_index_select(self.mask_cache, 2, input_pos)\n            if mask.dim() > 4:\n                # the mask cache has a batch dim of 1 in addition to the one\n                # we get if input_pos has a batch dimension\n                mask = mask.view(*(mask.shape[0:1] + mask.shape[2:]))\n            if input_pos_maxp1 is not None:\n                # Shorten final dimension so it just covers all `input_pos` entries\n                if input_pos_maxp1 > self.max_seq_length:\n                    raise ValueError(f\"Positions in 'input_pos' must be in [0,{self.max_seq_length})\")\n                mask = mask[..., :input_pos_maxp1]\n        else:\n            # unsqueeze to have a batch dimension\n            cos = self.cos[:T].unsqueeze(0)\n            sin = self.sin[:T].unsqueeze(0)\n            # `cos`, `sin` have shape (1, T, config.rope_n_elem)\n            mask = None  # defaults to causal mask\n            input_pos_maxp1 = None\n\n        x = self.transformer.wte(idx)  # token embeddings of shape (B, T, n_embd)\n        if self.config.scale_embeddings:\n            x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)\n\n        for block_idx, block in enumerate(self.transformer.h):\n            if self.config.rope_indices is not None:\n                x = block(\n                    x,\n                    cos[..., self.config.rope_indices[block_idx]],\n                    sin[..., self.config.rope_indices[block_idx]],\n                    mask,\n                    input_pos,\n                    input_pos_maxp1,\n                )\n            else:\n                x = block(x, cos, sin, mask, input_pos, input_pos_maxp1)\n        x = self.transformer.ln_f(x)\n        clamp_head = (\n            partial(do_softcapping, thresh=self.config.final_logit_softcapping)\n            if self.config.final_logit_softcapping is not None\n            else nn.Identity()\n        )\n        if lm_head_chunk_size > 0:\n            # chunk the lm head logits to reduce the peak memory used by autograd\n            return [clamp_head(self.lm_head(x_i)) for x_i in x.split(lm_head_chunk_size, dim=1)]\n        else:\n            return clamp_head(self.lm_head(x))  # (B, T, padded_vocab_size)\n\n    @classmethod\n    def from_name(cls, name: str, **kwargs: Any) -> Self:\n        return cls(Config.from_name(name, **kwargs))\n\n    def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:\n        if self.config.rope_adjustments is None:\n            extra_config = None\n\n        else:\n            # Check for mutually exclusive parameter sets\n            llama3_params = [\"low_freq_factor\", \"high_freq_factor\"]\n            yarn_params = [\"beta_fast\", \"beta_slow\"]\n\n            has_llama3 = any(param in self.config.rope_adjustments for param in llama3_params)\n            has_yarn = any(param in self.config.rope_adjustments for param in yarn_params)\n\n            if has_llama3 and has_yarn:\n                raise ValueError(\n                    \"RoPE adjustments cannot contain both Llama3 parameters (low_freq_factor, high_freq_factor) \"\n                    \"and YaRN parameters (beta_fast, beta_slow). These are mutually exclusive.\"\n                )\n\n            # Llama3-style RoPE\n            if has_llama3:\n                adjusted_params_required = [\"factor\", \"low_freq_factor\", \"high_freq_factor\", \"original_max_seq_len\"]\n                params_present = [param in self.config.rope_adjustments for param in adjusted_params_required]\n                if all(params_present):\n                    extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required}\n                else:\n                    missing_params = [\n                        param for param, present in zip(adjusted_params_required, params_present) if not present\n                    ]\n                    raise ValueError(\n                        f\"The following Llama3 RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. \"\n                        \"All Llama3 parameters must be specified together.\"\n                    )\n\n            # YaRN-style RoPE\n            elif has_yarn:\n                # Required: factor, beta_fast, beta_slow, original_max_seq_len\n                # Optional: mscale, mscale_all_dim\n                yarn_required_params = [\"factor\", \"beta_fast\", \"beta_slow\", \"original_max_seq_len\"]\n                params_present = [param in self.config.rope_adjustments for param in yarn_required_params]\n\n                if not all(params_present):\n                    missing_params = [\n                        param for param, present in zip(yarn_required_params, params_present) if not present\n                    ]\n                    raise ValueError(\n                        f\"The following YaRN RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. \"\n                        \"All YaRN required parameters must be specified together.\"\n                    )\n\n                extra_config = {name: self.config.rope_adjustments[name] for name in yarn_required_params}\n\n                # Add optional YaRN parameters\n                for param in [\"mscale\", \"mscale_all_dim\"]:\n                    if param in self.config.rope_adjustments:\n                        extra_config[param] = self.config.rope_adjustments[param]\n\n            # Linear or standard RoPE\n            elif \"factor\" in self.config.rope_adjustments:\n                # linear RoPE\n                adjusted_params_required = [\"factor\"]\n                extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required}\n            else:\n                extra_config = None  # uses standard RoPE\n\n        return build_rope_cache(\n            seq_len=self.max_seq_length,\n            n_elem=self.config.rope_n_elem,\n            device=device,\n            condense_ratio=self.config.rope_condense_ratio,\n            base=self.config.rope_base,\n            extra_config=extra_config,\n            rope_local_base_freq=self.config.rope_local_base_freq,\n        )\n\n    def rope_cache_length(self) -> int:\n        \"\"\"\n        Extract the head dimension (n_elem) from RoPE cache regardless of shape.\n\n        The RoPE cache can have different shapes depending on model configuration:\n        - Standard RoPE: (seq_len, n_elem) - 2D tensor\n        - Dual RoPE (local/global): (seq_len, n_elem, 2) - 3D tensor\n\n        Returns:\n            int: n_elem (head dimension for RoPE)\n        \"\"\"\n        return self.cos.size(1)\n\n    def set_kv_cache(\n        self,\n        batch_size: int,\n        max_seq_length: Optional[int] = None,\n        rope_cache_length: Optional[int] = None,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ) -> None:\n        if rope_cache_length is None:\n            rope_cache_length = self.rope_cache_length()\n\n        if max_seq_length is None:\n            max_seq_length = self.max_seq_length\n\n        # initialize the kv cache for all blocks\n        for block in self.transformer.h:\n            block.attn.kv_cache = block.attn.build_kv_cache(\n                batch_size,\n                max_seq_length,\n                rope_cache_length,\n                device,\n                dtype,\n            )\n\n        if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:\n            # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask\n            # for the kv-cache support (only during inference), we only create it in that situation\n            self.mask_cache = build_mask_cache(max_seq_length, device)\n\n    def clear_kv_cache(self) -> None:\n        self.mask_cache = None\n        for block in self.transformer.h:\n            block.attn.kv_cache = None\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        config: Config,\n        block_idx: int,\n    ) -> None:\n        super().__init__()\n        if not config.parallel_residual and config.shared_attention_norm:\n            raise NotImplementedError(\n                \"No checkpoint amongst the ones we support uses this configuration\"\n                \" (non-parallel residual and shared attention norm).\"\n            )\n\n        self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps)\n        self.attn = (\n            CausalSelfAttention(config, block_idx)\n            if not config.latent_attention\n            else MultiheadLatentAttention(config, block_idx)\n        )\n        self.post_attention_norm = (\n            config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity()\n        )\n        self.norm_2 = (\n            nn.Identity()\n            if not config.norm_2\n            else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps))\n        )\n        self.mlp = config.mlp_class(config)\n        if config.first_k_dense_replace is not None and block_idx < config.first_k_dense_replace:\n            self.mlp = LLaMAMLP(config)\n        self.post_mlp_norm = (\n            config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity()\n        )\n\n        self.config = config\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        input_pos_maxp1: Optional[int] = None,\n    ) -> torch.Tensor:\n        \"\"\"\n        Non-parallel residual       Parallel residual\n           ┌─ x                     ┌─ x ──────────────────┐             Note: if `shared_attention_norm` is True,\n           │  ↓                     │  ↓                   ↓                   the output from `norm_1` is reused\n           │  norm_1                │  norm_1  ───────►    norm_2\n           │  ↓                     │  ↓                   ↓\n           │  attn                  │  attn                MLP\n           │  ↓                     │  ↓                   ↓\n           |  post_attn_norm        |  post_attn_norm      post_mlp_norm\n           |  ↓                     |  ↓                   ↓\n        ┌─ └► +                     └► + ◄─────────────────┘\n        |     ↓\n        │     norm_2\n        │     ↓\n        │     MLP\n        │     ↓\n        |     post_mlp_norm\n        |     ↓\n        └───► +\n        \"\"\"\n\n        x_normed = self.norm_1(x)\n        attention_output = self.attn(x_normed, cos, sin, mask, input_pos, input_pos_maxp1)\n        attention_output = self.post_attention_norm(attention_output)\n\n        if self.config.parallel_residual:\n            if not self.config.shared_attention_norm:\n                x_normed = self.norm_2(x)\n            x = attention_output + x\n        else:\n            x = attention_output + x\n            x_normed = self.norm_2(x)\n\n        return self.post_mlp_norm(self.mlp(x_normed)) + x\n\n\nclass CausalSelfAttention(nn.Module):\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__()\n        # key, query and value projections for all heads, but in a batch\n        self.qkv = nn.Linear(\n            config.n_embd,\n            (config.n_head + 2 * config.n_query_groups) * config.head_size,  # support for grouped/multi queries\n            bias=config.bias or config.attn_bias,\n        )\n        # output projection\n        self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)\n        # disabled by default\n        self.kv_cache: Optional[KVCache] = None\n        self.apply_sliding_window_attention = False\n        if config.sliding_window_size is not None and config.sliding_window_indices is not None:\n            self.apply_sliding_window_attention = config.sliding_window_indices[block_idx]\n\n        if config.norm_qk:\n            norm_q_size = config.n_head * config.head_size if config.norm_qk_type == \"olmo2\" else config.head_size\n            norm_k_size = (\n                config.n_query_groups * config.head_size if config.norm_qk_type == \"olmo2\" else config.head_size\n            )\n            self.norm_q = config.norm_class(norm_q_size, eps=config.norm_eps)\n            self.norm_k = config.norm_class(norm_k_size, eps=config.norm_eps)\n        else:\n            self.norm_q = self.norm_k = None\n\n        if config.rope_adjustments is not None:\n            mscale_all_dim = config.rope_adjustments.get(\"mscale_all_dim\", None)\n            scaling_factor = config.rope_adjustments.get(\"factor\", None)\n            if mscale_all_dim and scaling_factor:  # YaRN\n                self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n            else:\n                self.mscale = 1.0\n        else:\n            self.mscale = 1.0\n\n        self.config = config\n        self.block_idx = block_idx\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        input_pos_maxp1: Optional[int] = None,\n    ) -> torch.Tensor:\n        # Notation:\n        # - B          | batch size\n        # - T          | time-step (sequence length)\n        # - C          | model's embeddings size (n_embd)\n        # - C*         | attentions's embeddings size\n        # - hs         | head size\n        # - nh_(q,k,v) | number of heads for query, key and value\n        # - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads\n        # alternative notation: num_kv_groups = n_query_groups\n        # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐\n        # │ v ││ v ││ v ││ v │     │ v │    │ v │             │ v │\n        # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘\n        #   │    │    │    │         │        │                 │\n        # ┌───┐┌───┐┌───┐┌───┐     ┌───┐    ┌───┐             ┌───┐\n        # │ k ││ k ││ k ││ k │     │ k │    │ k │             │ k │\n        # └───┘└───┘└───┘└───┘     └───┘    └───┘             └───┘\n        #   │    │    │    │      ┌──┴──┐  ┌──┴──┐      ┌────┬──┴─┬────┐\n        # ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐  ┌───┐┌───┐┌───┐┌───┐\n        # │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │  │ q ││ q ││ q ││ q │\n        # └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘  └───┘└───┘└───┘└───┘\n        # ◀──────────────────▶  ◀──────────────────▶  ◀──────────────────▶\n        #         MHA                    GQA                   MQA\n        #   n_query_groups=4       n_query_groups=2      n_query_groups=1\n        #\n        # credit https://arxiv.org/pdf/2305.13245.pdf\n        head_size = self.config.head_size\n        n_head = self.config.n_head\n        n_query_groups = self.config.n_query_groups\n        rope_n_elem = self.config.rope_n_elem\n        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)\n\n        # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value`\n        # instead of individually multiplying the input `x` with the respective weight matrices.\n        qkv = self.qkv(x)  # (B, T, 3xC*)\n\n        # Define query, key and value sizes.\n        # If grouped/multi query is enabled, these sizes are not equal (see the diagram above).\n        query_size = n_head * head_size\n        key_size = value_size = n_query_groups * head_size\n        # Split qkv into query, key and value matrices.\n        q, k, v = qkv.split((query_size, key_size, value_size), dim=-1)  # 3x(B, T, C*)\n\n        if self.config.norm_qk and self.config.norm_qk_type == \"olmo2\":\n            q = self.norm_q(q)\n            k = self.norm_k(k)\n\n        # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the\n        # embedding size (C) into num_heads (nh) and head_size (hs).\n\n        # The original GQA paper is followed here and the term query groups is used.\n        # alternative notation: Query groups are also referred to as KV groups.\n        q = q.view(B, T, n_head, head_size)  # (B, T, nh_q, hs)\n        k = k.view(B, T, n_query_groups, head_size)  # (B, T, n_query_groups, hs)\n        v = v.view(B, T, n_query_groups, head_size)  # (B, T, n_query_groups, hs)\n\n        # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are\n        # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector\n        # of size `hs`.\n        q = q.transpose(1, 2)  # (B, nh_q, T, hs)\n        k = k.transpose(1, 2)  # (B, nh_k, T, hs)\n        v = v.transpose(1, 2)  # (B, nh_v, T, hs)\n\n        if self.config.norm_qk and self.config.norm_qk_type == \"default\":\n            q = self.norm_q(q)\n            k = self.norm_k(k)\n\n        # Unlike standard positional embeddings rotary embeddings must be applied at every layer.\n        if self.config.rope_interleave:\n            q_roped = apply_rope_interleave(q[..., :rope_n_elem], cos, sin)\n            k_roped = apply_rope_interleave(k[..., :rope_n_elem], cos, sin)\n        else:\n            q_roped = apply_rope(q[..., :rope_n_elem], cos, sin)\n            k_roped = apply_rope(k[..., :rope_n_elem], cos, sin)\n        q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1)  # (B, nh_q, T, hs)\n        k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1)  # (B, nh_k, T, hs)\n\n        # Apply kv-cache during inference.\n        if input_pos is not None:\n            if not isinstance(self.kv_cache, KVCache):\n                raise TypeError(\"You need to call `gpt.set_kv_cache()`\")\n            k, v = self.kv_cache(input_pos, k, v)\n\n            if self.apply_sliding_window_attention:\n                actual_kv_len = k.size(2)\n                if mask is not None and mask.size(-1) != actual_kv_len:\n                    mask = mask[..., :actual_kv_len]\n\n            if input_pos_maxp1 is not None:\n                # Subselect along sequence dimension\n                k = k[..., :input_pos_maxp1, :]\n                v = v[..., :input_pos_maxp1, :]\n            # k, v: (B, nh_k, input_pos_maxp1, hs)\n            # If input_pos_maxp1 is None -> max_seq_length\n\n        # Grouped queries: balance the number of heads across all three matrices.\n        # NOTE: flash attention requires it in training mode.\n        # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.\n        if n_query_groups != n_head and (input_pos is None or n_query_groups != 1):\n            q_per_kv = n_head // n_query_groups\n            k = k.repeat_interleave(q_per_kv, dim=1)  # (B, nh_q, T, hs)\n            v = v.repeat_interleave(q_per_kv, dim=1)  # (B, nh_q, T, hs)\n\n        if self.apply_sliding_window_attention:\n            \"\"\"\n                  Global Window              Sliding window             Sliding window\n                  attention mask      +            bias          =      attention mask\n            ┌────────────────────────┐  ┌───────────────────────┐  ┌─────────────────────────┐\n            │ True False False False │  │ True  True  True True │  │ True  False False False │\n            │ True True  False False │  │ True  True  True True │  │ True  True  False False │\n            │ True True  True  False │  │ False True  True True │  │ False True  True  False │\n            │ True True  True  True  │  │ False False True True │  │ False False True  True  │\n            └────────────────────────┘  └───────────────────────┘  └─────────────────────────┘\n            \"\"\"\n            if input_pos is None:\n                if mask is None:\n                    mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1)\n                    mask.masked_fill_(mask.bool(), float(\"-inf\"))\n                    mask = mask.view(1, 1, *mask.shape)\n\n                sliding_window_mask = torch.full((T, T), float(\"-inf\"), dtype=q.dtype, device=q.device)\n                for i in range(T):\n                    window_start = max(0, i - self.config.sliding_window_size + 1)\n                    sliding_window_mask[i, window_start : i + 1] = 0.0\n                sliding_window_mask = sliding_window_mask.view(1, 1, T, T)\n                mask = sliding_window_mask\n\n        # Efficient attention using Flash Attention CUDA kernels.\n        # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.\n        # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)\n        y = self.scaled_dot_product_attention(q, k, v, mask)\n\n        # Re-assemble all head outputs side by side.\n        y = y.reshape(B, T, head_size * n_head)\n\n        # Output projection.\n        return self.proj(y)  # (B, T, C)\n\n    def scaled_dot_product_attention(\n        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size)\n        scale = scale * self.mscale * self.mscale\n\n        # with softcapping we cannot use SDPA\n        if self.config.attention_logit_softcapping is not None:\n            scores = q @ k.mT * scale\n            scores = do_softcapping(scores, self.config.attention_logit_softcapping)\n            if mask is None:\n                mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)\n                mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)\n            scores = scores + mask\n            scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype)\n            y = scores @ v\n        else:\n            y = F.scaled_dot_product_attention(\n                q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None\n            )\n        return y.transpose(1, 2)\n\n    def build_kv_cache(\n        self,\n        batch_size: int,\n        max_seq_length: int,\n        rope_cache_length: Optional[int] = None,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ) -> \"KVCache\":\n        if self.apply_sliding_window_attention and self.config.sliding_window_size is not None:\n            effective_cache_size = min(max_seq_length, self.config.sliding_window_size)\n        else:\n            effective_cache_size = max_seq_length\n\n        v_shape = (batch_size, self.config.n_query_groups, effective_cache_size, self.config.head_size)\n\n        if rope_cache_length is None:\n            if self.config.rotary_percentage != 1.0:\n                raise TypeError(\n                    \"Please pass the `rope_cache_length` parameter. \"\n                    \"Use `rope_cache_length=model.rope_cache_length()` to extract it automatically.\"\n                )\n            k_shape = v_shape\n        else:\n            k_shape = (\n                batch_size,\n                self.config.n_query_groups,\n                effective_cache_size,\n                rope_cache_length + self.config.head_size - self.config.rope_n_elem,\n            )\n\n        return KVCache(\n            k_shape,\n            v_shape,\n            device=device,\n            dtype=dtype,\n            is_sliding_window=self.apply_sliding_window_attention,\n            sliding_window_size=self.config.sliding_window_size if self.apply_sliding_window_attention else None,\n        )\n\n    def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None:\n        \"\"\"For compatibility with legacy checkpoints.\"\"\"\n\n        for attr in (\"weight\", \"bias\"):\n            legacy_key = f\"{prefix}attn.{attr}\"\n            current_key = f\"{prefix}qkv.{attr}\"\n            if legacy_key in state_dict:\n                state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)\n\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n\nclass MultiheadLatentAttention(nn.Module):\n    def __init__(self, config: Config, block_idx: int) -> None:\n        super().__init__()\n\n        self.q_a_proj = nn.Linear(config.n_embd, config.q_lora_rank, bias=config.attn_bias)\n        self.q_a_norm = RMSNorm(config.q_lora_rank, eps=config.norm_eps)\n        self.q_b_proj = nn.Linear(config.q_lora_rank, config.n_head * config.qk_head_dim, bias=config.bias)\n\n        self.kv_a_proj_with_mqa = nn.Linear(\n            config.n_embd, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attn_bias\n        )\n        self.kv_a_norm = RMSNorm(config.kv_lora_rank, eps=config.norm_eps)\n        self.kv_b_proj = nn.Linear(\n            config.kv_lora_rank,\n            config.n_query_groups * (config.qk_nope_head_dim + config.v_head_dim),\n            bias=config.bias,\n        )\n\n        # output projection\n        self.proj = nn.Linear(config.n_head * config.v_head_dim, config.n_embd, bias=config.bias)\n        # disabled by default\n        self.kv_cache: Optional[KVCache] = None\n\n        if config.rope_adjustments is not None:\n            mscale_all_dim = config.rope_adjustments.get(\"mscale_all_dim\", None)\n            scaling_factor = config.rope_adjustments.get(\"factor\", None)\n            if mscale_all_dim and scaling_factor:  # YaRN\n                self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n            else:\n                self.mscale = 1.0\n        else:\n            self.mscale = 1.0\n\n        self.config = config\n        self.block_idx = block_idx\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        mask: Optional[torch.Tensor] = None,\n        input_pos: Optional[torch.Tensor] = None,\n        input_pos_maxp1: Optional[int] = None,\n    ) -> torch.Tensor:\n        # Notation:\n        # - B          | batch size\n        # - T          | time-step (sequence length)\n        # - C          | model's embeddings size (n_embd)\n        # - C*         | attentions's embeddings size\n        # - hs         | head size\n        # - nh_(q,k,v) | number of heads for query, key and value\n        # - n_query_groups = nh_k = nh_v | number of query groups sharing key and value heads\n        # alternative notation: num_kv_groups = n_query_groups\n        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)\n\n        q = self.q_b_proj(self.q_a_norm(self.q_a_proj(x)))  # (B, T, n_head * qk_head_dim)\n        q = q.view(B, T, -1, self.config.qk_head_dim)  # (B, T, n_head, qk_head_dim)\n        q = q.transpose(1, 2)  # (B, n_head, T, qk_head_dim)\n        q_pass, q_rot = torch.split(q, [self.config.qk_nope_head_dim, self.config.qk_rope_head_dim], dim=-1)\n\n        compressed_kv = self.kv_a_proj_with_mqa(x)  # (B, T, kv_lora_rank + qk_rope_head_dim)\n        k_pass, k_rot = torch.split(compressed_kv, [self.config.kv_lora_rank, self.config.qk_rope_head_dim], dim=-1)\n\n        k_pass = self.kv_b_proj(self.kv_a_norm(k_pass))\n        k_pass = k_pass.view(B, T, self.config.n_query_groups, -1)\n        k_pass = k_pass.transpose(1, 2)\n\n        k_pass, v = torch.split(k_pass, [self.config.qk_nope_head_dim, self.config.v_head_dim], dim=-1)\n        k_rot = k_rot.view(B, 1, T, self.config.qk_rope_head_dim)  # (B, 1, T, qk_rope_head_dim)\n\n        # Unlike standard positional embeddings rotary embeddings must be applied at every layer.\n        if self.config.rope_interleave:\n            q_roped = apply_rope_interleave(q_rot, cos, sin)\n            k_roped = apply_rope_interleave(k_rot, cos, sin)\n        else:\n            q_roped = apply_rope(q_rot, cos, sin)\n            k_roped = apply_rope(k_rot, cos, sin)\n        k_roped = k_roped.expand(*k_pass.shape[:-1], -1)  # (B, n_head, T, qk_rope_head_dim)\n\n        q = torch.cat((q_pass, q_roped), dim=-1)\n        k = torch.cat((k_pass, k_roped), dim=-1)\n\n        # Apply kv-cache during inference.\n        if input_pos is not None:\n            if not isinstance(self.kv_cache, KVCache):\n                raise TypeError(\"You need to call `gpt.set_kv_cache()`\")\n            k, v = self.kv_cache(input_pos, k, v)\n            if input_pos_maxp1 is not None:\n                # Subselect along sequence dimension\n                k = k[..., :input_pos_maxp1, :]\n                v = v[..., :input_pos_maxp1, :]\n            # k, v: (B, nh_k, input_pos_maxp1, hs)\n            # If input_pos_maxp1 is None -> max_seq_length\n\n        # Grouped queries: balance the number of heads across all three matrices.\n        # NOTE: flash attention requires it in training mode.\n        # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting.\n        if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):\n            q_per_kv = self.config.n_head // self.config.n_query_groups\n            k = k.repeat_interleave(q_per_kv, dim=1)  # (B, nh_q, T, hs)\n            v = v.repeat_interleave(q_per_kv, dim=1)  # (B, nh_q, T, hs)\n\n        # Efficient attention using Flash Attention CUDA kernels.\n        # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.\n        # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs)\n        y = self.scaled_dot_product_attention(q, k, v, mask)\n\n        # Re-assemble all head outputs side by side.\n        y = y.reshape(B, T, self.config.n_head * self.config.v_head_dim)\n\n        # Output projection.\n        return self.proj(y)  # (B, T, C)\n\n    def scaled_dot_product_attention(\n        self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None\n    ) -> torch.Tensor:\n        scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.qk_head_dim)\n        scale = scale * self.mscale * self.mscale\n\n        # with softcapping we cannot use SDPA\n        if self.config.attention_logit_softcapping is not None:\n            scores = q @ k.mT * scale\n            scores = do_softcapping(scores, self.config.attention_logit_softcapping)\n            if mask is None:\n                mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)\n                mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)\n            scores = scores + mask\n            scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype)\n            y = scores @ v\n        else:\n            y = F.scaled_dot_product_attention(\n                q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None\n            )\n        return y.transpose(1, 2)\n\n    def build_kv_cache(\n        self,\n        batch_size: int,\n        max_seq_length: int,\n        rope_cache_length: Optional[int] = None,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ) -> \"KVCache\":\n        v_shape = (batch_size, self.config.n_head, max_seq_length, self.config.v_head_dim)\n        k_shape = (batch_size, self.config.n_head, max_seq_length, self.config.qk_head_dim)\n\n        if rope_cache_length is not None:\n            print(\"Warning: `rope_cache_length` has no effect on MultiheadLatentAttention!\")\n        if self.config.rotary_percentage != 1.0:\n            print(\"Warning: `rotary_percentage` has no effect on MultiheadLatentAttention!\")\n\n        return KVCache(k_shape, v_shape, device=device, dtype=dtype)\n\n\nclass GptNeoxMLP(nn.Module):\n    def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:\n        super().__init__()\n        self.intermediate_size = intermediate_size or config.intermediate_size\n        self.fc = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)\n        self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias)\n        self.config = config\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.fc(x)\n        x = F.gelu(x, approximate=self.config.gelu_approximate)\n        return self.proj(x)\n\n\nclass LLaMAMLP(nn.Module):\n    def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None:\n        super().__init__()\n        self.intermediate_size = intermediate_size or config.intermediate_size\n        self.fc_1 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)\n        self.fc_2 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias)\n        self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias)\n        self.config = config\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_fc_1 = self.fc_1(x)\n        x_fc_2 = self.fc_2(x)\n        x = F.silu(x_fc_1) * x_fc_2\n        return self.proj(x)\n\n\nclass GemmaMLP(LLaMAMLP):\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x_fc_1 = self.fc_1(x)\n        x_fc_2 = self.fc_2(x)\n        x = F.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2\n        return self.proj(x)\n\n\nclass LLaMAMoE(nn.Module):\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        self.gate = (\n            nn.Linear(config.n_embd, config.n_expert, bias=False)\n            if not config.n_expert_groups\n            else GroupedTopkRouter(config)\n        )\n        self.experts = nn.ModuleList(\n            LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert)\n        )\n        if config.n_shared_expert:\n            self.shared_experts = LLaMAMLP(\n                config, intermediate_size=config.moe_intermediate_size * config.n_shared_expert\n            )\n        self.config = config\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219\n        See also figure 1 in https://arxiv.org/abs/2211.15841\n        \"\"\"\n        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)\n        residual_x = x.clone()\n        x = x.view(-1, C)  # (B*T, C)\n        if not self.config.n_expert_groups:\n            router = self.gate(x)  # (B*T, n_expert)\n            probs, indices = torch.topk(router, self.config.n_expert_per_token)  # (B*T, n_expert_per_token)\n            probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)\n        else:\n            probs, indices = self.gate(x)\n        if self.config.routed_scaling_factor != 1.0:\n            probs = probs * self.config.routed_scaling_factor\n        masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device)\n        masks = masks.permute(2, 0, 1)  # (n_expert, B*T, n_expert_per_token)\n        y = torch.zeros_like(x)  # (B*T, C)\n        for mask, expert in zip(masks, self.experts):\n            token_idx, expert_idx = torch.where(mask)\n            y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])\n\n        y = y.view(B, T, C)\n        if self.config.n_shared_expert:\n            y = y + self.shared_experts(residual_x)\n        return y\n\n\nclass GroupedTopkRouter(nn.Module):\n    \"\"\"\n    Derived from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py.\n    DeepseekV3TopkRouter class.\n    \"\"\"\n\n    def __init__(self, config: Config) -> None:\n        super().__init__()\n        self.config = config\n        self.weight = nn.Parameter(torch.empty(config.n_expert, config.n_embd))\n        self.register_buffer(\"e_score_correction_bias\", torch.zeros(config.n_expert))\n\n    @torch.no_grad()\n    def get_topk_indices(self, scores: torch.Tensor) -> torch.Tensor:\n        scores_for_choice = scores.view(-1, self.config.n_expert) + self.e_score_correction_bias.unsqueeze(0)\n        group_scores = (\n            scores_for_choice.view(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)\n            .topk(self.config.n_topk_scores_per_group, dim=-1)[0]  # Top k scores for each group\n            .sum(dim=-1)\n        )\n\n        group_idx = torch.topk(group_scores, k=self.config.n_topk_groups, dim=-1, sorted=False)[1]\n        group_mask = torch.zeros_like(group_scores)\n        group_mask.scatter_(1, group_idx, 1)\n        score_mask = (\n            group_mask.unsqueeze(-1)\n            .expand(-1, self.config.n_expert_groups, self.config.n_expert // self.config.n_expert_groups)\n            .reshape(-1, self.config.n_expert)\n        )\n        scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)\n        topk_indices = torch.topk(scores_for_choice, k=self.config.n_expert_per_token, dim=-1, sorted=False)[1]\n        return topk_indices\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        router_logits = F.linear(x.type(torch.float32), self.weight.type(torch.float32))\n        scores = router_logits.sigmoid()\n        topk_indices = self.get_topk_indices(scores)\n        topk_weights = scores.gather(1, topk_indices)\n        if self.config.norm_topk_prob:\n            denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weights /= denominator\n        return topk_weights, topk_indices\n\n\n# ROPE: YaRN (Yet another RoPE extensioN) scaling function for extended context\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef build_rope_cache(\n    seq_len: int,\n    n_elem: int,\n    device: Optional[torch.device] = None,\n    base: int = 10000,\n    condense_ratio: int = 1,\n    extra_config: Optional[dict] = None,\n    rope_local_base_freq: Optional[float] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Enhanced Transformer with Rotary Position Embedding.\n\n    Args:\n        seq_len (int): Sequence length.\n        n_elem (int): Number of elements (head dimension).\n        device (torch.device, optional): Device for tensor allocations.\n        base (int, optional): Base for computing inverse frequencies.\n        condense_ratio (int, optional): Ratio to condense the position indices.\n        extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2)\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.\n            Shapes are `(seq_len, n_elem)`.\n    \"\"\"\n\n    # Compute the inverse frequencies theta\n    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))\n\n    # Initialize attention scaling factor (modified for YaRN)\n    attention_scaling = 1.0\n\n    if extra_config is not None:\n        factor = extra_config[\"factor\"]\n        # Check YaRN first (has beta_fast/beta_slow)\n        if \"beta_fast\" in extra_config or \"beta_slow\" in extra_config:\n            # YaRN-style RoPE scaling\n            beta_fast = extra_config[\"beta_fast\"]\n            beta_slow = extra_config[\"beta_slow\"]\n            original_max_seq_len = extra_config[\"original_max_seq_len\"]\n\n            # Calculate attention scaling factor based on mscale and mscale_all_dim\n            mscale = extra_config.get(\"mscale\")\n            mscale_all_dim = extra_config.get(\"mscale_all_dim\")\n            if mscale and mscale_all_dim:\n                attention_scaling = yarn_get_mscale(factor, mscale) / yarn_get_mscale(factor, mscale_all_dim)\n            elif mscale_all_dim:\n                attention_scaling = yarn_get_mscale(factor, mscale_all_dim)\n            elif mscale:\n                attention_scaling = yarn_get_mscale(factor, mscale)\n            # else: attention_scaling remains 1.0\n\n            # Create two frequency sets: extrapolation (unscaled) and interpolation (scaled)\n            pos_freqs = base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)\n            theta_extrapolation = 1.0 / pos_freqs\n            theta_interpolation = 1.0 / (factor * pos_freqs)\n\n            # Find correction range based on rotation counts\n            # Inverse dimension formula to find dimension based on number of rotations\n            def find_correction_dim(num_rotations, dim, base_val, max_pos):\n                return (dim * math.log(max_pos / (num_rotations * 2 * math.pi))) / (2 * math.log(base_val))\n\n            low_dim = find_correction_dim(beta_fast, n_elem, base, original_max_seq_len)\n            high_dim = find_correction_dim(beta_slow, n_elem, base, original_max_seq_len)\n\n            # Apply truncation if specified\n            if extra_config.get(\"truncate\", True):\n                low_dim = math.floor(low_dim)\n                high_dim = math.ceil(high_dim)\n\n            low_dim = max(low_dim, 0)\n            high_dim = min(high_dim, n_elem // 2 - 1)\n\n            # Create linear ramp factor for blending\n            dim_range = torch.arange(n_elem // 2, device=device, dtype=torch.float32)\n            if low_dim == high_dim:\n                high_dim += 0.001  # Prevent singularity\n\n            linear_func = (dim_range - low_dim) / (high_dim - low_dim)\n            ramp_func = torch.clamp(linear_func, 0.0, 1.0)\n\n            # Blend extrapolation and interpolation frequencies\n            # ramp_func = 0 -> use interpolation (scaled), ramp_func = 1 -> use extrapolation (unscaled)\n            theta_extrapolation_factor = ramp_func\n            theta = (\n                theta_interpolation * (1 - theta_extrapolation_factor)\n                + theta_extrapolation * theta_extrapolation_factor\n            )\n        elif \"original_max_seq_len\" in extra_config:\n            # Llama3-style RoPE scaling\n            orig_context_len = extra_config[\"original_max_seq_len\"]\n            low_freq_factor = extra_config[\"low_freq_factor\"]\n            high_freq_factor = extra_config[\"high_freq_factor\"]\n\n            wavelen = 2 * torch.pi / theta\n            ratio = orig_context_len / wavelen\n            smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)\n            smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0)\n\n            # Compute adjusted_theta without masked indexing\n            adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta\n            theta = adjusted_theta\n        else:\n            # Linear scaling fallback\n            theta = theta / factor\n\n    # Create position indices `[0, 1, ..., seq_len - 1]`\n    seq_idx = torch.arange(seq_len, device=device).float() / condense_ratio\n\n    # Calculate the product of position index and $\\theta_i$\n    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)\n    # If `n_elem` is odd, the final dimension of `idx_theta` has size\n    # `n_elem + 1`, so need to cut something off.\n    # Due to a current bug in Hugging Face, in the case `n_elem == 1`, we leave\n    # `idx_theta`, `cos`, `sin` as is. Things work out in `apply_rope` due to\n    # broadcasting. If we shorten `idx_theta`, unit tests comparing to\n    # Hugging Face fail.\n    # https://github.com/huggingface/transformers/issues/35233\n    if idx_theta.shape[-1] > n_elem > 1:\n        idx_theta = idx_theta[..., :n_elem]\n\n    # if rope_local_base_freq is given, have a separate rope value for local embedding\n    # For now, we use default RoPE for local embedding\n    if rope_local_base_freq is not None:\n        local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))\n        local_idx_theta = torch.outer(seq_idx, local_theta)\n        local_idx_theta = local_idx_theta.repeat(1, 2)\n        if local_idx_theta.shape[-1] > n_elem > 1:\n            local_idx_theta = local_idx_theta[..., :n_elem]\n\n        idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1)\n\n    cos = torch.cos(idx_theta) * attention_scaling\n    sin = torch.sin(idx_theta) * attention_scaling\n    return cos, sin\n\n\ndef batched_index_select(t, dim, idx):\n    \"\"\"index_select for batched index and unbatched t\"\"\"\n    if idx.dim() == 1:\n        return torch.index_select(t, dim, idx)\n\n    *batch_shape, idx_size = idx.shape\n    res = torch.index_select(t, dim, idx.reshape(-1))  # flat index\n    # split out single batch idx\n    res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :])\n    if dim > 0:\n        # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors\n        dims = [dim] + list(range(res.dim()))\n        del dims[dim + 1]\n        res = res.permute(dims)\n    # unflatten batch dims\n    res = res.view(*batch_shape, *res.shape[1:])\n    return res\n\n\ndef batched_index_copy_(t, dim, idx, val):\n    \"\"\"Index copy for batched t, idx, val\"\"\"\n\n    if t.device.type == \"mps\":\n        # Normalize negative dimensions\n        if dim < 0:\n            dim = t.dim() + dim\n        if idx.dim() == 1:\n            idx_shape = [1] * val.dim()\n            idx_shape[dim] = -1\n            idx_expanded = idx.view(*idx_shape)\n            idx_expanded = idx_expanded.expand_as(val)\n            t.scatter_(dim, idx_expanded, val)\n            return t\n\n        elif idx.dim() == 2:\n            assert dim != 0, \"Cannot index the batch dimension\"\n            batch_size = idx.size(0)\n            idx_size = idx.size(1)\n            assert batch_size == t.size(0) == val.size(0)\n\n            idx_shape = [batch_size] + [1] * (val.dim() - 1)\n            idx_shape[dim] = idx_size\n            idx_expanded = idx.view(*idx_shape)\n            idx_expanded = idx_expanded.expand_as(val)\n\n            t.scatter_(dim, idx_expanded, val)\n            return t\n        else:\n            raise NotImplementedError(f\"idx.dim() == {idx.dim()} not supported\")\n\n    else:\n        if idx.dim() == 1:\n            return t.index_copy_(dim, idx, val)\n\n        assert idx.dim() == 2, f\"multiple batch dims not yet {idx.shape=}\"\n        assert dim != 0, f\"cannot index batch dim {dim=}\"\n        batch_size, idx_size = idx.shape\n        assert batch_size == t.size(0)\n        assert batch_size == val.size(0)\n\n        # if we can view the batch and indexed dimensions together, we could\n        # do index trickery. This is, sadly, not the case for kvcache so we\n        # fall back to for loop\n        for i in range(batch_size):\n            unbatched_dim = dim if dim < 0 else dim - 1\n            t[i].index_copy_(unbatched_dim, idx[i], val[i])\n        return t\n\n\ndef apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Applies RoPE transform to `x`. Note that `cos`, `sin` need to have a batch\n    dimension.\n\n    Args:\n        x: Input tensor, `(B, ..., T, head_size)`\n        cos: Cached cosines, `(B, T, head_size)` or `(1, T, head_size)`\n        sin: Cached sines, `(B, T, head_size)` or `(1, T, head_size)`\n\n    Returns:\n        Encoded tensor, `(B, ..., T, head_size)`\n    \"\"\"\n    if cos.dim() != 3:\n        raise ValueError(f\"cos must be three-dimensional, but shape is {cos.shape}\")\n    if cos.shape != sin.shape:\n        raise ValueError(f\"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}\")\n    head_size_half = x.size(-1) // 2\n    x1 = x[..., :head_size_half]  # (B, ..., T, head_size/2)\n    x2 = x[..., head_size_half:]  # (B, ..., T, head_size/2)\n    rotated = torch.cat((-x2, x1), dim=-1)  # (B, ..., T, head_size)\n    dims_diff = x.dim() - cos.dim()\n    if dims_diff > 0:\n        # Ensure that shapes of `x`, `cos`, `sin` align\n        new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:]\n        cos = cos.view(*new_shape)\n        sin = sin.view(*new_shape)\n\n    roped = (x * cos) + (rotated * sin)\n    return roped.to(dtype=x.dtype)\n\n\ndef apply_rope_interleave(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n    \"\"\"Apply rotary position embeddings with interleaved tensor layout.\n\n    This version rearranges the input tensor to group even/odd indices separately\n    before applying the standard RoPE rotation, matching HuggingFace's\n    apply_rotary_pos_emb_interleave behavior.\n\n    Args:\n        x: Input tensor of shape (..., seq_len, head_dim)\n        cos: Cosine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim)\n        sin: Sine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim)\n\n    Returns:\n        Tensor with RoPE applied, same shape as input\n    \"\"\"\n    if cos.dim() != 3:\n        raise ValueError(f\"cos must be three-dimensional, but shape is {cos.shape}\")\n    if cos.shape != sin.shape:\n        raise ValueError(f\"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}\")\n\n    # Rearrange tensor to group even/odd indices: [x0,x1,x2,x3,...] -> [x0,x2,x4,...,x1,x3,x5,...]\n    *batch_dims, d = x.shape\n    x = x.view(*batch_dims, d // 2, 2).transpose(-1, -2).reshape(*batch_dims, d)\n\n    # Standard rotation logic (same as apply_rope)\n    head_size_half = x.size(-1) // 2\n    x1 = x[..., :head_size_half]\n    x2 = x[..., head_size_half:]\n    rotated = torch.cat((-x2, x1), dim=-1)\n\n    # Auto-detect dimension mismatch and reshape cos/sin\n    dims_diff = x.dim() - cos.dim()\n    if dims_diff > 0:\n        new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:]\n        cos = cos.view(*new_shape)\n        sin = sin.view(*new_shape)\n\n    roped = (x * cos) + (rotated * sin)\n    return roped.to(dtype=x.dtype)\n\n\ndef do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor:\n    return torch.tanh(x / thresh) * thresh\n\n\nclass KVCache(nn.Module):\n    \"\"\"\n    Buffers `k`, `v` have shape\n    `(batch_size, n_query_groups, max_seq_length, head_size)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        k_shape: Tuple[int, int, int, int],\n        v_shape: Tuple[int, int, int, int],\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n        is_sliding_window: bool = False,\n        sliding_window_size: Optional[int] = None,\n    ) -> None:\n        super().__init__()\n        self.register_buffer(\"k\", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)\n        self.register_buffer(\"v\", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)\n        self.is_sliding_window = is_sliding_window\n        self.sliding_window_size = sliding_window_size\n        self.max_cache_len = k_shape[2]\n\n    def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Writes new values `k` and `v` into the cache at the positions specified\n        by `input_pos` along the sequence dimension (`max_seq_length`). The batch\n        size of `k` and `v` (`bs`) must be smaller or equal to `KVCache` batch\n        size. Returns the full buffers, adjusted to the batch size `bs`.\n\n        Args:\n            input_pos: Position index, `(bs, T)` or `(T,)`\n            k: New values, `(bs, n_query_groups, T, head_size)`\n            v: New values, `(bs, n_query_groups, T, head_size)`\n\n        Returns:\n            k_full, v_full, `(bs, n_query_groups, max_seq_length, head_size)`\n\n        \"\"\"\n        # move the buffer to the activation dtype for when AMP is used\n        if self.k.dtype != k.dtype:\n            self.k = self.k.to(k.dtype)\n        if self.v.dtype != v.dtype:\n            self.v = self.v.to(v.dtype)\n        # update the cache\n        bs = k.size(0)\n        if self.is_sliding_window:\n            # Circular buffer for sliding window\n            cache_positions = input_pos % self.max_cache_len\n            k = batched_index_copy_(self.k[:bs, ...], -2, cache_positions, k)\n            v = batched_index_copy_(self.v[:bs, ...], -2, cache_positions, v)\n\n            max_pos = input_pos.max().item()\n            if max_pos < self.max_cache_len:\n                k = k[:, :, : max_pos + 1, :]\n                v = v[:, :, : max_pos + 1, :]\n        else:\n            # Standard KV cache (global attention)\n            k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k)\n            v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v)\n\n        return k, v\n\n    def reset_parameters(self) -> None:\n        torch.nn.init.zeros_(self.k)\n        torch.nn.init.zeros_(self.v)\n\n\ndef build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor:\n    ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)\n    return torch.tril(ones).unsqueeze(0).unsqueeze(0)\n\n\nclass RMSNorm(torch.nn.Module):\n    \"\"\"Root Mean Square Layer Normalization.\n\n    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:\n    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.\n    \"\"\"\n\n    def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None:\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.ones(size))\n        self.eps = eps\n        self.dim = dim\n        self.add_unit_offset = add_unit_offset\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        dtype = x.dtype\n        x = x.float()\n        # NOTE: the original RMSNorm paper implementation is not equivalent\n        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)\n        x_normed = x * torch.rsqrt(norm_x + self.eps)\n        weight = (1 + self.weight) if self.add_unit_offset else self.weight\n        return (x_normed * weight.float()).to(dtype=dtype)\n\n    def reset_parameters(self) -> None:\n        torch.nn.init.ones_(self.weight)\n"
  },
  {
    "path": "litgpt/parser_config.py",
    "content": "import sys\nfrom pathlib import Path\nfrom typing import List, Optional\n\nfrom litgpt.utils import CLI\n\n\ndef parser_commands() -> List[str]:\n    return [\n        \"download\",\n        \"chat\",\n        \"finetune\",\n        \"finetune_lora\",\n        \"finetune_full\",\n        \"finetune_adapter\",\n        \"finetune_adapter_v2\",\n        \"pretrain\",\n        \"generate\",\n        \"generate_full\",\n        \"generate_adapter\",\n        \"generate_adapter_v2\",\n        \"generate_sequentially\",\n        \"generate_speculatively\",\n        \"generate_tp\",\n        \"convert_to_litgpt\",\n        \"convert_from_litgpt\",\n        \"convert_pretrained_checkpoint\",\n        \"merge_lora\",\n        \"evaluate\",\n        \"serve\",\n    ]\n\n\ndef save_hyperparameters(\n    function: callable,\n    checkpoint_dir: Path,\n    known_commands: Optional[List[str]] = None,\n) -> None:\n    \"\"\"Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.\"\"\"\n    from jsonargparse import capture_parser\n\n    # TODO: Make this more robust\n    # This hack strips away the subcommands from the top-level CLI\n    # to parse the file as if it was called as a script\n    if known_commands is None:\n        known_commands = parser_commands()\n    known_commands = [(c,) for c in known_commands]\n    for known_command in known_commands:\n        unwanted = slice(1, 1 + len(known_command))\n        if tuple(sys.argv[unwanted]) == known_command:\n            sys.argv[unwanted] = []\n\n    parser = capture_parser(lambda: CLI(function))\n    config = parser.parse_args()\n    parser.save(config, checkpoint_dir / \"hyperparameters.yaml\", overwrite=True)\n"
  },
  {
    "path": "litgpt/pretrain.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport math\nimport pprint\nimport time\nimport warnings\nfrom dataclasses import asdict\nfrom datetime import timedelta\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Dict, Optional, Tuple, Union\n\nimport lightning as L\nimport torch\nimport torch.nn as nn\nfrom lightning.fabric.strategies import FSDPStrategy\nfrom lightning.fabric.utilities.throughput import ThroughputMonitor, measure_flops\nfrom torch.utils.data import DataLoader\nfrom torchmetrics.aggregation import RunningMean\nfrom typing_extensions import Literal\n\nfrom litgpt import Tokenizer\nfrom litgpt.args import EvalArgs, LogArgs, TrainArgs\nfrom litgpt.config import name_to_config\nfrom litgpt.constants import _TORCH_EQUAL_2_7, _TORCH_EQUAL_2_8\nfrom litgpt.data import DataModule, TinyLlama\nfrom litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.types import LoggerChoice\nfrom litgpt.utils import (\n    CycleIterator,\n    capture_hparams,\n    check_nvlink_connectivity,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    extend_checkpoint_dir,\n    find_resume_path,\n    get_default_supported_precision,\n    init_out_dir,\n    instantiate_torch_optimizer,\n    num_parameters,\n    parse_devices,\n    reset_parameters,\n    save_config,\n)\n\n\ndef setup(\n    model_name: str,\n    model_config: Optional[Config] = None,\n    out_dir: Path = Path(\"out/pretrain\"),\n    precision: Literal[\"bf16-true\", \"bf16-mixed\", \"32-true\", None] = None,\n    initial_checkpoint_dir: Optional[Path] = None,\n    resume: Union[bool, Literal[\"auto\"], Path] = False,\n    data: Optional[DataModule] = None,\n    train: TrainArgs = TrainArgs(\n        save_interval=1000,\n        log_interval=1,\n        global_batch_size=512,\n        micro_batch_size=4,\n        max_tokens=int(3e12),  # 3 trillion\n        max_norm=1.0,\n        min_lr=4e-5,\n        lr_warmup_steps=2000,\n        tie_embeddings=False,\n    ),\n    eval: EvalArgs = EvalArgs(interval=1000, max_iters=100),\n    log: LogArgs = LogArgs(),\n    optimizer: Union[str, Dict] = \"AdamW\",\n    devices: Union[int, str] = \"auto\",\n    num_nodes: int = 1,\n    tokenizer_dir: Optional[Path] = None,\n    logger_name: LoggerChoice = \"tensorboard\",\n    seed: int = 42,\n):\n    \"\"\"Pretrain a model.\n\n    Arguments:\n        model_name: The name of the model to pretrain. Choose from names in ``litgpt.config``. Use \"list\" to list the supported models.\n        model_config: A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with\n            ``model_config``. Overrides the `model_name` if specified.\n        out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in\n            /teamspace/jobs/<job-name>/share.\n        precision: The precision to use for finetuning. Determines a compatible precision setting by default.\n        initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.\n            Useful for continued pretraining. Mutually exclusive with ``resume``.\n        resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume\n            from the latest checkpoint in ``out_dir``. An error will be raised if no checkpoint is found. Passing\n            ``'auto'`` will resume from the latest checkpoint but not error if no checkpoint exists.\n        data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.\n        train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.\n        eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.\n        optimizer: An optimizer name (such as \"AdamW\") or config.\n\n        devices: How many devices/GPUs to use. Uses all GPUs by default.\n        num_nodes: How many nodes the code is being run on.\n        tokenizer_dir: Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data\n            module require this.\n        logger_name: The name of the logger to send metrics to.\n        seed: The random seed to use for reproducibility.\n    \"\"\"\n    if model_name == \"list\":\n        available_models = \"\\n\".join(sorted(name_to_config))\n        print(f\"Available values:\\n{available_models}\")\n        quit()\n\n    if initial_checkpoint_dir is not None:\n        initial_checkpoint_dir = extend_checkpoint_dir(initial_checkpoint_dir)\n\n    if tokenizer_dir is not None:\n        tokenizer_dir = extend_checkpoint_dir(tokenizer_dir)\n\n    if model_config is None:\n        # Support both model_name options: meta-llama/Meta-Llama-3-8B & Meta-Llama-3-8B\n        try:\n            model_config = Config.from_name(model_name)\n        except ValueError:\n            print(f\"Model name {model_name} is not supported.\\n\")\n            available_models = \"\\n\".join(sorted(name_to_config))\n            print(f\"Available values:\\n{available_models}\")\n            quit()\n\n    hparams = capture_hparams()\n    data = TinyLlama() if data is None else data\n\n    config = Config.from_name(model_name) if model_config is None else model_config\n    precision = precision or get_default_supported_precision(training=True)\n    devices = parse_devices(devices)\n    out_dir = init_out_dir(out_dir)\n    # in case the dataset requires the Tokenizer\n    tokenizer = Tokenizer(tokenizer_dir) if tokenizer_dir is not None else None\n\n    logger = choose_logger(\n        logger_name,\n        out_dir,\n        name=f\"pretrain-{config.name}\",\n        resume=bool(resume),\n        log_interval=train.log_interval,\n        log_args=asdict(log),\n    )\n\n    if devices * num_nodes > 1:\n        strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type=\"full\", sharding_strategy=\"HYBRID_SHARD\")\n    else:\n        strategy = \"auto\"\n\n    fabric = L.Fabric(devices=devices, num_nodes=num_nodes, strategy=strategy, precision=precision, loggers=[logger])\n\n    if torch.cuda.is_available() and devices > 1:\n        check_nvlink_connectivity(fabric)\n\n    fabric.launch()\n\n    fabric.print(pprint.pformat(hparams))\n    if logger_name in (\"tensorboard\", \"wandb\", \"mlflow\"):\n        fabric.logger.log_hyperparams(hparams)\n\n    main(\n        fabric=fabric,\n        devices=devices,\n        num_nodes=num_nodes,\n        seed=seed,\n        initial_checkpoint_dir=initial_checkpoint_dir,\n        resume=resume,\n        config=config,\n        data=data,\n        out_dir=out_dir,\n        tokenizer_dir=tokenizer_dir,\n        tokenizer=tokenizer,\n        train=train,\n        eval=eval,\n        optimizer=optimizer,\n    )\n\n\ndef main(\n    fabric: L.Fabric,\n    devices: int,\n    seed: int,\n    initial_checkpoint_dir: Optional[Path],\n    resume: Union[bool, Literal[\"auto\"], Path],\n    config: Config,\n    data: DataModule,\n    out_dir: Path,\n    tokenizer_dir: Optional[Path],\n    tokenizer: Optional[Tokenizer],\n    train: TrainArgs,\n    eval: EvalArgs,\n    optimizer: Union[str, Dict],\n    num_nodes: int = 1,\n) -> None:\n    validate_args(train, eval, initial_checkpoint_dir, resume)\n\n    if fabric.global_rank == 0:\n        out_dir.mkdir(parents=True, exist_ok=True)\n\n    fabric.seed_everything(seed)  # same seed for every process to init model (FSDP)\n\n    t0 = time.perf_counter()\n    with fabric.init_module(empty_init=True):\n        model = GPT(config)\n\n    initialize_weights(fabric, model, n_layer=config.n_layer, n_embd=config.n_embd)\n\n    if train.tie_embeddings:\n        model.transformer.wte.weight = model.lm_head.weight\n    if train.max_seq_length:\n        model.max_seq_length = train.max_seq_length\n\n    fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\")\n    fabric.print(f\"Total parameters: {num_parameters(model):,}\")\n\n    model = torch.compile(model)\n    model = fabric.setup(model)\n\n    extra_kwargs = {\"fused\": fabric.device.type == \"cuda\"}\n    optimizer = instantiate_torch_optimizer(optimizer, model.parameters(), **extra_kwargs)\n    optimizer = fabric.setup_optimizers(optimizer)\n\n    train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n\n    if initial_checkpoint_dir:\n        fabric.load_raw(initial_checkpoint_dir / \"lit_model.pth\", model)\n\n    state = {\n        \"model\": model,\n        \"optimizer\": optimizer,\n        \"train_dataloader\": train_dataloader,\n        \"iter_num\": 0,\n        \"step_count\": 0,\n    }\n\n    resume = find_resume_path(resume, out_dir)\n    if resume:\n        fabric.print(f\"Resuming training from {resume}\")\n        fabric.load(resume, state)\n\n    train_time = time.perf_counter()\n\n    # work around PyTorch issue https://github.com/pytorch/pytorch/issues/152162\n    # which does not like the lazy initialization to be called in dynamo.\n    # TODO: Happens with PyTorch 2.7+\n    if (\n        (_TORCH_EQUAL_2_7 or _TORCH_EQUAL_2_8)\n        and (model._forward_module.__class__.__name__ == \"OptimizedModule\")\n        and (model._forward_module._orig_mod.__class__.__name__ == \"FullyShardedDataParallel\")\n    ):\n        from torch.distributed.fsdp._runtime_utils import _root_pre_forward\n\n        _root_pre_forward(model._forward_module._orig_mod, model._forward_module._orig_mod, [], {})\n\n    fit(\n        fabric=fabric,\n        devices=devices,\n        num_nodes=num_nodes,\n        state=state,\n        train_dataloader=train_dataloader,\n        val_dataloader=val_dataloader,\n        out_dir=out_dir,\n        tokenizer_dir=tokenizer_dir,\n        train=train,\n        eval=eval,\n    )\n\n    # Save final checkpoint\n    save_checkpoint(fabric, state, tokenizer_dir, out_dir / \"final\" / \"lit_model.pth\")\n\n    total_tokens = state[\"iter_num\"] * train.micro_batch_size * model.max_seq_length * fabric.world_size\n\n    # Print formatted output\n    separator = \"-\" * 40\n    fabric.print(separator)\n    fabric.print(\"| Performance\")\n    fabric.print(f\"| - Total tokens  : {total_tokens:,}\")\n    fabric.print(f\"| - Training Time : {(time.perf_counter() - train_time):.2f} s\")\n    fabric.print(f\"| - Tok/sec       : {total_tokens / train_time:.2f} tok/s\")\n    fabric.print(\"| \" + \"-\" * 40)\n\n    if fabric.device.type == \"cuda\":\n        memory_used = torch.cuda.max_memory_allocated() / 1e9\n        fabric.print(\"| Memory Usage\")\n        fabric.print(f\"| - Memory Used   : {memory_used:.2f} GB\")\n    fabric.print(separator)\n\n\ndef fit(\n    fabric: L.Fabric,\n    devices: int,\n    state: dict,\n    train_dataloader: DataLoader,\n    val_dataloader: DataLoader,\n    out_dir: Path,\n    tokenizer_dir: Optional[Path],\n    train: TrainArgs,\n    eval: EvalArgs,\n    num_nodes: int = 1,\n) -> None:\n    model = state[\"model\"]\n    optimizer = state[\"optimizer\"]\n\n    if eval.initial_validation:\n        val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)\n        val_loss = f\"{val_loss:.3f}\"\n    else:\n        fabric.print(\"Verifying settings ...\")\n        validate(fabric, model, val_dataloader, max_iters=2, verbose=False)  # sanity check\n        val_loss = \"n/a\"\n\n    throughput = ThroughputMonitor(fabric, window_size=5)\n\n    with torch.device(\"meta\"):\n        meta_model = GPT(model.config)\n        x = torch.randint(0, 1, (train.micro_batch_size, meta_model.max_seq_length))\n        model_fwd = lambda: meta_model(x)  # noqa: F821\n        model_loss = lambda y: chunked_cross_entropy(y, x, chunk_size=0)  # noqa: F821\n        measured_flops = measure_flops(meta_model, model_fwd, model_loss)\n        fabric.print(f\"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}\")\n        del meta_model, x\n\n    max_tokens_per_device = train.max_tokens // fabric.world_size\n    tokens_per_iter = train.micro_batch_size * model.max_seq_length\n    max_iters = max_tokens_per_device // tokens_per_iter\n    log_iter_interval = train.log_interval * train.gradient_accumulation_iters(devices, num_nodes)\n    initial_iter = state[\"iter_num\"]\n    train_iterator = CycleIterator(train_dataloader)\n\n    running_loss = RunningMean(window=train.gradient_accumulation_iters(devices, num_nodes), sync_on_compute=False).to(\n        fabric.device\n    )\n    fabric.barrier()\n    total_t0 = time.perf_counter()\n\n    warmup_iters = train.warmup_iters(devices, num_nodes, max_iters, train_dataloader)\n\n    for train_data in train_iterator:\n        if state[\"iter_num\"] >= max_iters:\n            break\n\n        # determine and set the learning rate for this iteration\n        lr = get_lr(optimizer.defaults[\"lr\"], state[\"iter_num\"], warmup_iters, max_iters, train.min_lr)\n        for param_group in optimizer.param_groups:\n            param_group[\"lr\"] = lr\n\n        state[\"iter_num\"] += 1\n        iter_t0 = time.perf_counter()\n\n        input_ids = train_data[:, 0 : model.max_seq_length].contiguous().long()\n        targets = train_data[:, 1 : (model.max_seq_length + 1)].contiguous().long()\n\n        is_accumulating = state[\"iter_num\"] % train.gradient_accumulation_iters(devices, num_nodes) != 0\n        with fabric.no_backward_sync(model, enabled=is_accumulating):\n            logits = model(input_ids)\n            loss = chunked_cross_entropy(logits, targets)\n            fabric.backward(loss / train.gradient_accumulation_iters(devices, num_nodes))\n\n        running_loss.update(loss.detach())\n\n        if not is_accumulating:\n            fabric.clip_gradients(model, optimizer, max_norm=train.max_norm)\n            optimizer.step()\n            optimizer.zero_grad()\n            state[\"step_count\"] += 1\n\n        if state[\"iter_num\"] % log_iter_interval == 0:\n            loss = running_loss.compute().item()  # expensive device-to-host synchronization\n            t1 = time.perf_counter()\n            throughput.update(\n                time=(t1 - total_t0),\n                flops=(measured_flops * log_iter_interval),\n                batches=state[\"iter_num\"],\n                samples=(state[\"iter_num\"] * train.micro_batch_size),\n                lengths=(state[\"iter_num\"] * train.micro_batch_size * model.max_seq_length),\n            )\n            metrics = {\n                \"loss\": loss,\n                \"iter\": state[\"iter_num\"],\n                \"step\": state[\"step_count\"],\n                \"epoch\": train_iterator.epoch,\n                \"iter_time\": t1 - iter_t0,\n                \"remaining_time\": (\n                    (t1 - total_t0) / (state[\"iter_num\"] - initial_iter) * (max_iters - state[\"iter_num\"])\n                ),\n                \"tokens\": state[\"iter_num\"] * train.micro_batch_size * model.max_seq_length,\n                \"total_tokens\": (state[\"iter_num\"] * train.micro_batch_size * model.max_seq_length * fabric.world_size),\n                \"learning_rate\": lr,\n            }\n            if isinstance(val_loss, float):\n                val_loss = f\"{val_loss:.3f}\"\n            fabric.print(\n                f\"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |\"\n                f\" loss train: {metrics['loss']:.3f},\"\n                f\" val: {val_loss} |\"\n                f\" iter time: {metrics['iter_time'] * 1000:.2f} ms\"\n                f\"{' (step)' if not is_accumulating else ''}\"\n                f\" remaining time: {timedelta(seconds=int(metrics['remaining_time']))!s}\"\n            )\n\n            throughput_metrics = throughput.compute()\n            metrics.update(throughput_metrics)\n            fabric.log_dict(metrics, step=state[\"iter_num\"] - 1)\n\n        if val_dataloader is not None and not is_accumulating and state[\"step_count\"] % eval.interval == 0:\n            t0 = time.perf_counter()\n            val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)\n            val_loss = val_loss.item()\n            td = time.perf_counter() - t0\n\n            fabric.print(f\"iter {state['iter_num']}: val loss {val_loss:.4f}, val time: {td * 1000:.2f} ms\")\n            metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n            fabric.log_dict(metrics, step=state[\"iter_num\"] - 1)\n            fabric.barrier()\n\n        if train.save_interval is not None and not is_accumulating and state[\"step_count\"] % train.save_interval == 0:\n            save_checkpoint(fabric, state, tokenizer_dir, out_dir / f\"step-{state['step_count']:08d}\" / \"lit_model.pth\")\n\n    # Final validation\n    if eval.final_validation:\n        val_loss = validate(fabric, model, val_dataloader, max_iters=eval.max_iters)\n        metrics = {\"val_loss\": val_loss, \"val_ppl\": math.exp(val_loss)}\n        fabric.log_dict(metrics, step=state[\"iter_num\"])\n        fabric.print(f\"Final evaluation | val loss: {val_loss.item():.3f} | val ppl: {math.exp(val_loss):.3f}\")\n\n\n@torch.no_grad()\ndef validate(\n    fabric: L.Fabric, model: nn.Module, val_dataloader: DataLoader, max_iters: int, verbose: bool = True\n) -> torch.Tensor:\n    fabric.barrier()\n    if verbose:\n        fabric.print(\"Validating ...\")\n    model.eval()\n\n    losses = []\n    for k, batch in enumerate(val_dataloader):\n        if k >= max_iters:\n            break\n        input_ids = batch[:, 0 : model.max_seq_length].contiguous().long()\n        targets = batch[:, 1 : (model.max_seq_length + 1)].contiguous().long()\n        logits = model(input_ids)\n        loss = chunked_cross_entropy(logits, targets)\n        losses.append(loss)\n\n    val_loss = torch.stack(losses).mean()\n    model.train()\n    fabric.barrier()\n    return val_loss\n\n\ndef get_dataloaders(\n    fabric: L.Fabric, data: DataModule, tokenizer: Tokenizer, train: TrainArgs, block_size: int\n) -> Tuple[DataLoader, DataLoader]:\n    data.connect(tokenizer=tokenizer, batch_size=train.micro_batch_size, max_seq_length=block_size)\n    with fabric.rank_zero_first():\n        data.prepare_data()\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    return train_dataloader, val_dataloader\n\n\n# learning rate decay scheduler (cosine with linear warmup)\ndef get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min_lr: float) -> float:\n    # 1) linear warmup for warmup_iters steps\n    if it < warmup_iters:\n        return learning_rate * it / warmup_iters\n    # 2) if it > max_iters, return min learning rate\n    if it > max_iters:\n        return min_lr\n    # 3) in between, use cosine decay down to min learning rate\n    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)\n    assert 0 <= decay_ratio <= 1\n    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1\n    return min_lr + coeff * (learning_rate - min_lr)\n\n\ndef initialize_weights(fabric: L.Fabric, model: GPT, n_layer: int, n_embd: int) -> None:\n    \"\"\"GPT-NeoX weight initialization (https://arxiv.org/abs/2204.06745).\"\"\"\n    # Adapted from https://github.com/jzhang38/TinyLlama\n\n    def init_weights(module, std):\n        nn.init.normal_(module.weight, mean=0.0, std=std)\n        if getattr(module, \"bias\", None) is not None:\n            nn.init.zeros_(module.bias)\n\n    for mod in model.modules():\n        if isinstance(mod, (nn.Embedding, nn.Linear)):\n            mod.reset_parameters = partial(init_weights, mod, std=math.sqrt(2.0 / 5 / n_embd))\n\n    # need a separate loop because `mod.proj` below is a `nn.Linear` too\n    for mod in model.modules():\n        if isinstance(mod, (LLaMAMLP, CausalSelfAttention)):\n            mod.proj.reset_parameters = partial(init_weights, mod.proj, std=(1 / math.sqrt(n_embd) / n_layer))\n\n    if not isinstance(fabric.strategy, FSDPStrategy):\n        reset_parameters(model)\n\n\ndef save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file):\n    model = state[\"model\"]\n    checkpoint_file.parent.mkdir(parents=True, exist_ok=True)\n    fabric.print(f\"Saving checkpoint to {str(checkpoint_file)!r}\")\n    fabric.save(checkpoint_file, state)\n    if fabric.global_rank == 0:\n        save_hyperparameters(setup, checkpoint_file.parent)\n        if tokenizer_dir is not None:\n            copy_config_files(tokenizer_dir, checkpoint_file.parent)\n        save_config(model.config, checkpoint_file.parent)\n\n\ndef validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None:\n    issues = []\n    unsupported = [(train, [\"epochs\"]), (eval, [\"max_new_tokens\"])]\n    for args, names in unsupported:\n        for name in names:\n            if getattr(args, name) is not None:\n                issues.append(f\"{__file__} doesn't support the {name!r} argument. This is set in {args}\")\n    if train.max_steps is not None:\n        warnings.warn(\n            \"`train.max_steps` is intended for profiling or debug runs only. \"\n            \"For full pretraining runs, prefer `train.max_tokens` or `train.max_time`.\",\n            UserWarning,\n        )\n    required = [(train, [\"max_tokens\", \"max_norm\"])]\n    for args, names in required:\n        for name in names:\n            if getattr(args, name) is None:\n                issues.append(f\"{__file__} requires the {name!r} argument. This is set in {args}\")\n    if initial_checkpoint_dir and resume:\n        issues.append(\"Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one.\")\n    if issues:\n        raise ValueError(\"\\n\".join(issues))\n"
  },
  {
    "path": "litgpt/prompts.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport importlib\nimport re\nfrom abc import abstractmethod\nfrom json import dumps\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union\n\nimport yaml\n\nfrom litgpt.config import Config\n\nif TYPE_CHECKING:\n    from litgpt import Tokenizer\n\n\nclass PromptStyle:\n    \"\"\"Base interface for prompt styles.\"\"\"\n\n    @abstractmethod\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return prompt\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return ([tokenizer.eos_id],)\n\n    @classmethod\n    def from_name(cls, name: str) -> \"PromptStyle\":\n        return prompt_styles[name]()\n\n    @classmethod\n    def from_config(cls, config: Config) -> \"PromptStyle\":\n        return model_name_to_prompt_style(config.name)\n\n\nclass Default(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return prompt\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return ([tokenizer.eos_id],)\n\n\nclass Alpaca(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        if kwargs.get(\"input\"):\n            sys_prompt = sys_prompt or (\n                \"Below is an instruction that describes a task, paired with an input that provides further context. \"\n                \"Write a response that appropriately completes the request.\\n\\n\"\n            )\n            return f\"{sys_prompt}### Instruction:\\n{prompt}\\n\\n### Input:\\n{kwargs['input']}\\n\\n### Response:\\n\"\n\n        sys_prompt = sys_prompt or (\n            \"Below is an instruction that describes a task. \"\n            \"Write a response that appropriately completes the request.\\n\\n\"\n        )\n        return f\"{sys_prompt}### Instruction:\\n{prompt}\\n\\n### Response:\\n\"\n\n\nclass FLAN(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or (\n            \"Below is an instruction that describes a task. \"\n            \"Write a response that appropriately completes the request.\\n\\n\"\n        )\n        return f\"{sys_prompt}### Instruction:\\n{prompt}\\n\\n### Response:\\n\"\n\n\nclass Longform(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or (\n            \"Below is an instruction that describes a task, paired with an input that provides further context. \"\n            \"Write a response that appropriately completes the request.\\n\\n\"\n        )\n        return f\"{sys_prompt}### Instruction:\\n{prompt}\\n\\n### Response:\\n\"\n\n\nclass StableLMAlpha(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or (\n            \"# StableLM Tuned (Alpha version)\\n- StableLM is a helpful and harmless open-source AI language\"\n            \" model developed by StabilityAI.\\n- StableLM is excited to be able to help the user, but will refuse to do\"\n            \" anything that could be considered harmful to the user.\\n- StableLM is more than just an information\"\n            \" source, StableLM is also able to write poetry, short stories, and make jokes.\\n- StableLM will refuse to\"\n            \" participate in anything that could harm a human.\"\n        )\n        return f\"<|SYSTEM|>{sys_prompt}<|USER|>{prompt}<|ASSISTANT|>\"\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return (\n            [tokenizer.eos_id],\n            [tokenizer.token_to_id(\"<|SYSTEM|>\")],\n            [tokenizer.token_to_id(\"<|ASSISTANT|>\")],\n            [tokenizer.token_to_id(\"<|USER|>\")],\n        )\n\n\nclass StableLMZephyr(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"<|user|>\\n{prompt}<|endoftext|>\\n<|assistant|>\\n\"\n\n\nclass Falcon(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"{prompt}\\nAnswer:\"\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return (\n            [tokenizer.eos_id],\n            # the model rarely emits the eos token and instead outputs newlines, but we cannot use them\n            # to stop or else things like code generation wouldn't work\n            [tokenizer.token_to_id(\"User\"), tokenizer.token_to_id(\":\")],\n            [193, tokenizer.token_to_id(\"User\")],  # 193: '\\n'\n        )\n\n\nclass Falcon3(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"<|user|>\\n{prompt}<|endoftext|>\\n<|assistant|>\\n\"\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return (\n            [tokenizer.eos_id],\n            [tokenizer.token_to_id(\"<|endoftext|>\")],\n        )\n\n\nclass Llama2FunctionCalling(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        # Has to be before the llama config\n        b_func, e_func = \"<FUNCTIONS>\", \"</FUNCTIONS>\\n\\n\"\n        b_inst, e_inst = \"[INST]\", \"[/INST]\"\n        b_sys, e_sys = \"<<SYS>>\\n\", \"\\n<</SYS>>\\n\\n\"\n        # This is an example for how to format functions for the model\n        function_metadata = {\n            \"function\": \"search_bing\",\n            \"description\": (\n                \"Search the web for content on Bing. This allows users to search online/the internet/the web for\"\n                \" content.\"\n            ),\n            \"arguments\": [{\"name\": \"query\", \"type\": \"string\", \"description\": \"The search query string\"}],\n        }\n\n        system_prompt = sys_prompt or (\n            \"You are a helpful, respectful and honest assistant. Always answer as helpfully as\"\n            \"possible. Your only response should be JSON formatted functions\"\n        )\n        # replace the curly braces with double curly braces to escape them\n        function_list = dumps(function_metadata).replace(\"{\", \"{{\").replace(\"}\", \"}}\")\n        return (\n            f\"{b_func}{function_list.strip()}{e_func}{b_inst}{b_sys}{system_prompt.strip()}{e_sys}{prompt}{e_inst}\\n\\n\"\n        )\n\n\nclass Llama2(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        b_inst, e_inst = \"[INST]\", \"[/INST]\"\n        b_sys, e_sys = \"<<SYS>>\\n\", \"\\n<</SYS>>\\n\\n\"\n        sys_prompt = sys_prompt or (\n            \"You are a helpful, respectful and honest assistant. Always answer as helpfully as\"\n            \" possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist,\"\n            \" toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and\"\n            \" positive in nature.\\n\\nIf a question does not make any sense, or is not factually coherent, explain why\"\n            \" instead of answering something not correct. If you don't know the answer to a question, please don't\"\n            \" share false information.\"\n        )\n        return f\"{b_inst} {b_sys}{sys_prompt}{e_sys} {prompt} {e_inst} \"\n\n\nclass Llama3(PromptStyle):\n    def apply(\n        self, prompt: Union[str, List[Dict[str, str]]], *, sys_prompt: Optional[str] = None, **kwargs: str\n    ) -> str:\n        default_system_prompt = sys_prompt or \"You are a helpful assistant.\"\n\n        # https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229\n        if isinstance(prompt, str):\n            return (\n                \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\n\"\n                f\"{default_system_prompt}<|eot_id|>\"  # No newline\n                \"<|start_header_id|>user<|end_header_id|>\\n\\n\"\n                f\"{prompt}<|eot_id|>\"  # No newline\n                \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n            )\n        elif isinstance(prompt, list):\n\n            def encode_header(role: str) -> List[str]:\n                return [f\"<|start_header_id|>{role}<|end_header_id|>\\n\\n\"]\n\n            def encode_message(message: Dict[str, str]) -> List[str]:\n                tokens = encode_header(message[\"role\"])\n                # NOTE: Meta stripped this. I'm not sure I agree, but who am I to argue?\n                tokens.append(message[\"content\"].strip())\n                tokens.append(\"<|eot_id|>\")\n                return tokens\n\n            def has_system_prompt(messages: List[Dict[str, str]]) -> bool:\n                return messages[0].get(\"role\", \"\") == \"system\" if len(messages) else False\n\n            tokens = [\"<|begin_of_text|>\"]\n            if not has_system_prompt(prompt):\n                tokens.extend(encode_message({\"role\": \"system\", \"content\": default_system_prompt}))\n            for i, message in enumerate(prompt):\n                if i != 0 and message[\"role\"] == \"system\":\n                    raise ValueError(\"'system' role is only allowed at the beginning of the conversation list.\")\n                if message[\"role\"] not in [\"assistant\", \"user\", \"system\"]:\n                    raise ValueError(\n                        f\"Unknown role: '{message['role']}'. Supported roles are 'assistant', 'user', and 'system'.\"\n                    )\n                tokens.extend(encode_message(message))\n            tokens.extend(encode_header(\"assistant\"))\n            return \"\".join(tokens)\n        else:\n            raise ValueError(f\"Unsupported prompt type: {type(prompt)}\")\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return (\n            [tokenizer.eos_id],\n            [tokenizer.token_to_id(\"<|eot_id|>\")],\n        )\n\n\nclass R1Base(PromptStyle):\n    def apply(\n        self, prompt: Union[str, List[Dict[str, str]]], *, sys_prompt: Optional[str] = None, **kwargs: str\n    ) -> str:\n        default_system_prompt = sys_prompt or \"\"\n\n        bos_token = \"<｜begin▁of▁sentence｜>\"\n        eos_token = \"\"\n\n        if isinstance(prompt, str):\n            return f\"{default_system_prompt}<｜User｜>{prompt}<｜Assistant｜>\"  # Prepares for assistant response\n        elif isinstance(prompt, list):\n\n            def encode_message(message: Dict[str, str]) -> str:\n                role = message[\"role\"]\n                content = message[\"content\"].strip()\n\n                if role == \"system\":\n                    return content  # System prompt is prepended at the start\n                elif role == \"user\":\n                    return f\"<｜User｜>{content}\"\n                elif role == \"assistant\":\n                    return f\"<｜Assistant｜>{content}{eos_token}\"\n                else:\n                    raise ValueError(f\"Unknown role: '{role}'. Supported roles are 'assistant', 'user', and 'system'.\")\n\n            # Extract system prompt (if any)\n            system_prompt = \"\"\n            if prompt[0].get(\"role\") == \"system\":\n                system_prompt = prompt[0][\"content\"]\n                prompt = prompt[1:]  # Remove system message from the list\n\n            # Construct the formatted prompt\n            formatted_prompt = system_prompt\n            for message in prompt:\n                formatted_prompt += encode_message(message)\n\n            formatted_prompt += \"<｜Assistant｜>\"  # Prepares for assistant response\n            return formatted_prompt\n        else:\n            raise ValueError(f\"Unsupported prompt type: {type(prompt)}\")\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return (\n            [tokenizer.eos_id],\n            [tokenizer.token_to_id(\"<｜end▁of▁sentence｜>\")],\n        )\n\n\nclass FreeWilly2(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or \"This is a system prompt, please behave and help the user.\"\n        return f\"### System:\\n{sys_prompt}\\n\\n### User:\\n{prompt}\\n\\n### Assistant:\\n\"\n\n\nclass Platypus(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"### Instruction:\\n\\n{prompt}\\n\\n### Response:\\n\"\n\n\nclass StableCode(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"###Instruction\\n{prompt}###Response\\n\"\n\n\nclass CodeLlama(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        # for CodeLLama, we don't set a default system prompt, but it is supported:\n        # https://huggingface.co/blog/codellama#conversational-instructions\n        # Mistral does not: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format\n        b_inst, e_inst = \"[INST]\", \"[/INST]\"\n        if sys_prompt:\n            b_sys, e_sys = \"<<SYS>>\\n\", \"\\n<</SYS>>\\n\\n\"\n            return f\"{b_inst} {b_sys}{sys_prompt}{e_sys}{prompt} {e_inst}\"\n        return f\"{b_inst} {prompt} {e_inst}\"\n\n\nclass Phi1(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"{prompt}\\n\\nAnswer:\"\n\n    def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n        return (\n            [tokenizer.eos_id],\n            [tokenizer.token_to_id(\"Answer\"), tokenizer.token_to_id(\":\")],\n            [198, tokenizer.token_to_id(\"Answer\"), tokenizer.token_to_id(\":\")],\n            # the model rarely emits the eos token and instead outputs newlines, but we cannot use them\n            # to stop or else things like code generation wouldn't work\n            # [198, 198],  # '\\n', '\\n'\n        )\n\n\nclass Phi2(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"Instruct: {prompt}\\nOutput:\"\n\n\nclass Phi3(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or \"You are a helpful assistant.\"\n        return f\"<|system|>\\n{sys_prompt}<|end|>\\n<|user|>\\n{prompt}<|end|>\\n<|assistant|>\\n\"\n\n\nclass Phi4(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        res = \"\"\n        if sys_prompt:\n            res += f\"<|im_start|>system<|im_sep|>{sys_prompt}<|im_end|>\"\n        res += f\"<|im_start|>user<|im_sep|>{prompt}<|im_end|><|im_start|>assistant<|im_sep|>\"\n        return res\n\n\nclass Phi4Reasoning(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = (\n            sys_prompt\n            or \"You are Phi, a language model trained by Microsoft to help users. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> {Thought section} </think> {Solution section}. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion. Now, try to solve the following question through the above guidelines:\"\n        )\n        return f\"<|im_start>system<|im_sep|>{sys_prompt}<|im_end|><|im_start|>user<|im_sep|>{prompt}<|im_end|><|im_start|>assistant<|im_sep|>\"\n\n\nclass Phi4Mini(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        res = \"\"\n        if sys_prompt:\n            res += f\"<|system|>{sys_prompt}<|end|>\"\n        res += f\"<|user|>{prompt}<|end|><|assistant|>\"\n        return res\n\n\nclass Phi4MiniReasoning(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or \"Your name is Phi, an AI math expert developed by Microsoft.\"\n        return f\"<|system|>{sys_prompt}<|end|><|user|>{prompt}<|end|><|assistant|>\"\n\n\nclass TinyLlama(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or \"You are a friendly chatbot who always gives helpful, detailed, and polite answers.\"\n        return f\"<|system|>\\n{sys_prompt}</s>\\n<|user|>\\n{prompt}</s>\\n<|assistant|>\\n\"\n\n\nclass Gemma(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"<start_of_turn>user\\n{prompt}<end_of_turn>\\n<start_of_turn>model\\n\"\n\n\nclass OLMo(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        return f\"<|endoftext|><|user|>\\n{prompt}\\n<|assistant|>\\n\"\n\n\nclass ChatML(PromptStyle):\n    def __init__(self, system_message: Optional[str] = None):\n        self.system_message = system_message\n\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:\n        sys_prompt = sys_prompt or self.system_message\n        return (\n            f\"<|im_start|>system\\n{sys_prompt}<|im_end|>\\n<|im_start|>user\\n{prompt}<|im_end|>\\n<|im_start|>assistant\\n\"\n        )\n\n\nclass Qwen2_5(ChatML):\n    def __init__(self):\n        super().__init__(\"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\")\n\n\nclass Qwen2_5_Math(ChatML):\n    def __init__(self):\n        super().__init__(\"Please reason step by step, and put your final answer within \\\\boxed{}.\")\n\n\nclass QwQ(ChatML):\n    def __init__(self):\n        super().__init__(\n            \"You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.\"\n        )\n\n\nclass Qwen3(ChatML):\n    def __init__(self):\n        super().__init__()\n\n\nclass SmolLM2(ChatML):\n    def __init__(self):\n        super().__init__(\"You are a helpful AI assistant named SmolLM, trained by Hugging Face\")\n\n\nclass Salamandra(ChatML):\n    def __init__(self):\n        super().__init__(\n            \"I am Salamandra, an AI language model developed at the Barcelona Supercomputing Centre (BSC) by the Language Technologies Unit. My knowledge base was last updated on August 2023. Today Date: 2024-09-30\\nSoy Salamandra, un modelo lingüístico de IA desarrollado en el Barcelona Supercomputing Centre (BSC) por la Language Technologies Unit. Mi base de conocimientos se actualizó por última vez en agosto de 2023.\\nSoc Salamandra, un model de llenguatge d'IA desenvolupat al Barcelona Supercomputing Centre (BSC) per la Language Technologies Unit.\"\n        )\n\n\n# Maps prompt style names to PromptStyle classes\nprompt_styles: Dict[str, Type[PromptStyle]] = {\n    # Dataset-specific prompt styles\n    \"alpaca\": Alpaca,\n    \"flan\": FLAN,\n    \"longform\": Longform,\n    # Model-specific prompt styles\n    \"stablelm-alpha\": StableLMAlpha,\n    \"stablelm-zephyr\": StableLMZephyr,\n    \"falcon\": Falcon,\n    \"llama2-function-calling\": Llama2FunctionCalling,\n    \"llama2\": Llama2,\n    \"freewilly2\": FreeWilly2,\n    \"platypus\": Platypus,\n    \"stablecode\": StableCode,\n    \"codellama\": CodeLlama,\n    \"phi-1\": Phi1,\n    \"phi-2\": Phi2,\n    \"phi-3\": Phi3,\n    \"phi-4\": Phi4,\n    \"phi-4-reasoning\": Phi4Reasoning,\n    \"phi-4-mini\": Phi4Mini,\n    \"phi-4-mini-reasoning\": Phi4MiniReasoning,\n    \"tinyllama\": TinyLlama,\n    \"gemma\": Gemma,\n    \"llama3\": Llama3,\n    \"olmo\": OLMo,\n    \"qwen2.5\": Qwen2_5,\n    \"qwen2.5-math\": Qwen2_5_Math,\n    \"qwq\": QwQ,\n    \"qwen3\": Qwen3,\n    \"smollm2\": SmolLM2,\n    \"salamandra\": Salamandra,\n}\n\n\ndef model_name_to_prompt_style(model_name: str) -> PromptStyle:\n    if re.search(r\"stablelm-tuned-alpha\", model_name):\n        return StableLMAlpha()\n    if re.search(r\"stablelm-zephyr-3b\", model_name):\n        return StableLMZephyr()\n    if re.search(\"stablecode-instruct\", model_name):\n        return StableCode()\n    if re.search(r\"Falcon3.*-Instruct\", model_name):\n        return Falcon3()\n    if re.search(r\"falcon.*-instruct\", model_name):\n        return Falcon()\n    if re.search(\"Llama-2-7b-chat-hf-function-calling-v2\", model_name):\n        return Llama2FunctionCalling()\n    if re.search(\"Llama-2.*-chat\", model_name):\n        return Llama2()\n    if re.search(\"Llama-3.*-Instruct\", model_name):\n        return Llama3()\n    if re.search(\"Llama-3.*-Instruct-*\", model_name):\n        return Llama3()\n    if re.search(\"OLMo-2.*-(Instruct|SFT|DPO)\", model_name):\n        return Llama3()\n    if re.search(\"R1\", model_name):\n        return R1Base()\n    if re.search(\"FreeWilly2\", model_name):\n        return FreeWilly2()\n    if re.search(\"Platypus\", model_name):\n        return Platypus()\n    if re.search(\"CodeLlama|Mi[sx]tral.*Instruct\", model_name):\n        return CodeLlama()\n    if re.search(\"phi-1\", model_name):\n        return Phi1()\n    if re.search(\"phi-2\", model_name):\n        return Phi2()\n    if re.search(\"Phi-3\", model_name):\n        return Phi3()\n    if re.search(\"Phi-4-reasoning\", model_name):\n        return Phi4Reasoning()\n    if re.search(\"Phi-4-mini-reasoning\", model_name):\n        return Phi4MiniReasoning()\n    if re.search(\"Phi-4-mini\", model_name):\n        return Phi4Mini()\n    if re.search(\"phi-4\", model_name):\n        return Phi4()\n    if re.search(r\"tiny-llama.*chat\", model_name):\n        return TinyLlama()\n    if re.search(r\"(Code)?Gemma.*-it\", model_name):\n        return Gemma()\n    if re.search(r\"OLMo.*-hf\", model_name):\n        return OLMo()\n    if re.search(r\"Qwen2\\.5-Math-.*\", model_name):\n        return Qwen2_5_Math()\n    if re.search(r\"Qwen2\\.5-.*\", model_name):\n        return Qwen2_5()\n    if re.search(r\"QwQ-.*\", model_name):\n        return QwQ()\n    if re.search(r\"Qwen3-.*\", model_name):\n        return Qwen3()\n    if re.search(r\"SmolLM2.*-Instruct\", model_name):\n        return SmolLM2()\n    if re.search(r\"salamandra-.*-instruct\", model_name):\n        return Salamandra()\n    return Default()\n\n\ndef save_prompt_style(style: Union[str, PromptStyle], checkpoint_dir: Path) -> None:\n    style = PromptStyle.from_name(style) if isinstance(style, str) else style\n    cls = type(style)\n    # Allow saving the full module path for user-defined prompt classes\n    config = {\"class_path\": f\"{cls.__module__}.{cls.__name__}\"}\n    with open(checkpoint_dir / \"prompt_style.yaml\", \"w\", encoding=\"utf-8\") as file:\n        yaml.dump(config, file)\n\n\ndef load_prompt_style(checkpoint_dir: Path) -> PromptStyle:\n    with open(checkpoint_dir / \"prompt_style.yaml\", encoding=\"utf-8\") as file:\n        config = yaml.safe_load(file)\n    # Support loading the full module path for user-defined prompt classes\n    full_module_path, cls_name = config[\"class_path\"].rsplit(\".\", 1)\n    module = importlib.import_module(full_module_path)\n    cls = getattr(module, cls_name)\n    return cls()\n\n\ndef has_prompt_style(checkpoint_dir: Path) -> bool:\n    return (checkpoint_dir / \"prompt_style.yaml\").is_file()\n"
  },
  {
    "path": "litgpt/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "litgpt/scripts/convert_hf_checkpoint.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport gc\nimport json\nimport os\nimport re\nimport warnings\nfrom collections import defaultdict\nfrom functools import partial\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor\nfrom safetensors.torch import load_file as load_safetensors\nfrom tqdm import tqdm\n\nfrom litgpt.config import Config\nfrom litgpt.utils import (\n    extend_checkpoint_dir,\n    incremental_save,\n    lazy_load,\n    save_config,\n)\n\n\ndef copy_weights_gpt_neox(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    weight_map = {\n        \"gpt_neox.embed_in.weight\": \"transformer.wte.weight\",\n        \"gpt_neox.layers.{}.input_layernorm.bias\": \"transformer.h.{}.norm_1.bias\",\n        \"gpt_neox.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"gpt_neox.layers.{}.attention.query_key_value.bias\": \"transformer.h.{}.attn.qkv.bias\",\n        \"gpt_neox.layers.{}.attention.query_key_value.weight\": \"transformer.h.{}.attn.qkv.weight\",\n        \"gpt_neox.layers.{}.attention.dense.bias\": \"transformer.h.{}.attn.proj.bias\",\n        \"gpt_neox.layers.{}.attention.dense.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"gpt_neox.layers.{}.attention.rotary_emb.inv_freq\": None,\n        \"gpt_neox.layers.{}.attention.bias\": None,\n        \"gpt_neox.layers.{}.attention.masked_bias\": None,\n        \"gpt_neox.layers.{}.post_attention_layernorm.bias\": \"transformer.h.{}.norm_2.bias\",\n        \"gpt_neox.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"gpt_neox.layers.{}.mlp.dense_h_to_4h.bias\": \"transformer.h.{}.mlp.fc.bias\",\n        \"gpt_neox.layers.{}.mlp.dense_h_to_4h.weight\": \"transformer.h.{}.mlp.fc.weight\",\n        \"gpt_neox.layers.{}.mlp.dense_4h_to_h.bias\": \"transformer.h.{}.mlp.proj.bias\",\n        \"gpt_neox.layers.{}.mlp.dense_4h_to_h.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"gpt_neox.final_layer_norm.bias\": \"transformer.ln_f.bias\",\n        \"gpt_neox.final_layer_norm.weight\": \"transformer.ln_f.weight\",\n        \"embed_out.weight\": \"lm_head.weight\",\n    }\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, layer_idx = layer_template(from_name)\n        to_name = weight_map[name_template]\n        if to_name is None:\n            continue\n        to_name = to_name.format(layer_idx)\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if from_name.endswith((\".query_key_value.weight\", \".query_key_value.bias\")):\n            # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...]\n            param = qkv_reassemble(param, config)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n\ndef copy_weights_falcon(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    weight_map = {\n        \"transformer.word_embeddings.weight\": \"transformer.wte.weight\",\n        \"transformer.h.{}.self_attention.query_key_value.weight\": \"transformer.h.{}.attn.qkv.weight\",\n        \"transformer.h.{}.self_attention.dense.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"transformer.h.{}.mlp.dense_h_to_4h.weight\": \"transformer.h.{}.mlp.fc.weight\",\n        \"transformer.h.{}.mlp.dense_4h_to_h.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"transformer.ln_f.bias\": \"transformer.ln_f.bias\",\n        \"transformer.ln_f.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    # the original model definition is different for each size\n    if \"7b\" in config.name:\n        weight_map.update(\n            {\n                \"transformer.h.{}.input_layernorm.bias\": \"transformer.h.{}.norm_1.bias\",\n                \"transformer.h.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n            }\n        )\n    elif \"40b\" in config.name or \"180B\" in config.name:\n        weight_map.update(\n            {\n                \"transformer.h.{}.ln_attn.bias\": \"transformer.h.{}.norm_1.bias\",\n                \"transformer.h.{}.ln_attn.weight\": \"transformer.h.{}.norm_1.weight\",\n                \"transformer.h.{}.ln_mlp.bias\": \"transformer.h.{}.norm_2.bias\",\n                \"transformer.h.{}.ln_mlp.weight\": \"transformer.h.{}.norm_2.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, layer_idx = layer_template(from_name)\n        to_name = weight_map[name_template].format(layer_idx)\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if from_name.endswith((\".query_key_value.weight\", \".query_key_value.bias\")):\n            # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...]\n            param = qkv_reassemble(param, config)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n\ndef copy_weights_hf_llama(\n    config: Config,\n    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"model.layers.{}.input_layernorm.bias\": \"transformer.h.{}.norm_1.bias\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.self_attn.rotary_emb.inv_freq\": None,\n        \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"model.layers.{}.post_attention_layernorm.bias\": \"transformer.h.{}.norm_2.bias\",\n        \"model.norm.weight\": \"transformer.ln_f.weight\",\n        \"model.norm.bias\": \"transformer.ln_f.bias\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    if config.mlp_class_name == \"LLaMAMoE\":\n        weight_map.update(\n            {\n                \"model.layers.{}.block_sparse_moe.gate.weight\": \"transformer.h.{}.mlp.gate.weight\",\n                \"model.layers.{}.block_sparse_moe.experts.{}.w1.weight\": \"transformer.h.{}.mlp.experts.{}.fc_1.weight\",\n                \"model.layers.{}.block_sparse_moe.experts.{}.w3.weight\": \"transformer.h.{}.mlp.experts.{}.fc_2.weight\",\n                \"model.layers.{}.block_sparse_moe.experts.{}.w2.weight\": \"transformer.h.{}.mlp.experts.{}.proj.weight\",\n            }\n        )\n    elif config.mlp_class_name in (\"LLaMAMLP\", \"GemmaMLP\"):\n        weight_map.update(\n            {\n                \"model.layers.{}.mlp.gate_proj.weight\": \"transformer.h.{}.mlp.fc_1.weight\",\n                \"model.layers.{}.mlp.up_proj.weight\": \"transformer.h.{}.mlp.fc_2.weight\",\n                \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.proj.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        to_name = weight_map[name_template]\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if any(w in from_name for w in (\"q_proj\", \"k_proj\", \"v_proj\")):\n            qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))\n            weight_name, weight_type = from_name.split(\".\")[-2:]\n            qkv[weight_type][weight_name] = param\n        if to_name is None:\n            continue\n        to_name = to_name.format(*ids)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n    if \"lm_head.weight\" not in state_dict:\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.wte.weight\"]\n\n    for i in list(qkv_weights):\n        for weight_type in list(qkv_weights[i]):\n            qkv = qkv_weights[i][weight_type]\n            if len(qkv) != 3:\n                # qkv is split across different .bin files\n                continue\n            q = load_param(qkv[\"q_proj\"], f\"layer {i} q {weight_type}\", dtype, verbose=debug_mode)\n            k = load_param(qkv[\"k_proj\"], f\"layer {i} k {weight_type}\", dtype, verbose=debug_mode)\n            v = load_param(qkv[\"v_proj\"], f\"layer {i} v {weight_type}\", dtype, verbose=debug_mode)\n            qkv = torch.cat((q, k, v))\n            state_dict[f\"transformer.h.{i}.attn.qkv.{weight_type}\"] = qkv\n            del qkv_weights[i][weight_type]\n\n            if progress_per_file is not None:\n                pbar.update(progress_per_file)\n\n\ndef copy_weights_gemma_2(\n    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.mlp.gate_proj.weight\": \"transformer.h.{}.mlp.fc_1.weight\",\n        \"model.layers.{}.mlp.up_proj.weight\": \"transformer.h.{}.mlp.fc_2.weight\",\n        \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"model.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.post_attention_norm.weight\",\n        \"model.layers.{}.pre_feedforward_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"model.layers.{}.post_feedforward_layernorm.weight\": \"transformer.h.{}.post_mlp_norm.weight\",\n        \"model.norm.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        to_name = weight_map[name_template]\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if any(w in from_name for w in (\"q_proj\", \"k_proj\", \"v_proj\")):\n            qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))\n            weight_name, weight_type = from_name.split(\".\")[-2:]\n            qkv[weight_type][weight_name] = param\n        if to_name is None:\n            continue\n        to_name = to_name.format(*ids)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n    if \"lm_head.weight\" not in state_dict:\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.wte.weight\"]\n\n    for i in list(qkv_weights):\n        for weight_type in list(qkv_weights[i]):\n            qkv = qkv_weights[i][weight_type]\n            if len(qkv) != 3:\n                # qkv is split across different .bin files\n                continue\n            q = load_param(qkv[\"q_proj\"], f\"layer {i} q {weight_type}\", dtype, verbose=debug_mode)\n            k = load_param(qkv[\"k_proj\"], f\"layer {i} k {weight_type}\", dtype, verbose=debug_mode)\n            v = load_param(qkv[\"v_proj\"], f\"layer {i} v {weight_type}\", dtype, verbose=debug_mode)\n            qkv = torch.cat((q, k, v))\n            state_dict[f\"transformer.h.{i}.attn.qkv.{weight_type}\"] = qkv\n            del qkv_weights[i][weight_type]\n\n            if progress_per_file is not None:\n                pbar.update(progress_per_file)\n\n\ndef copy_weights_gemma_3(\n    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n    config: Optional[Config] = None,\n) -> None:\n    GEMMA3_LANGUAGE_MODEL_PREFIX = (\n        \"model.language_model\"\n        if any(k.startswith(\"model.language_model\") for k in hf_weights)\n        else \"language_model.model\"\n    )\n\n    GEMMA3_VISION_MODEL_PREFIX = (\n        \"model.vision_tower\" if any(k.startswith(\"model.vision_tower\") for k in hf_weights) else \"vision_tower\"\n    )\n\n    GEMMA3_MM_PROJECTOR_PREFIX = (\n        \"model.multi_modal_projector\"\n        if any(k.startswith(\"model.multi_modal_projector\") for k in hf_weights)\n        else \"multi_modal_projector\"\n    )\n\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.mlp.gate_proj.weight\": \"transformer.h.{}.mlp.fc_1.weight\",\n        \"model.layers.{}.mlp.up_proj.weight\": \"transformer.h.{}.mlp.fc_2.weight\",\n        \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"model.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.post_attention_norm.weight\",\n        \"model.layers.{}.pre_feedforward_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"model.layers.{}.post_feedforward_layernorm.weight\": \"transformer.h.{}.post_mlp_norm.weight\",\n        \"model.norm.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n        \"model.layers.{}.self_attn.q_norm.weight\": \"transformer.h.{}.attn.norm_q.weight\",\n        \"model.layers.{}.self_attn.k_norm.weight\": \"transformer.h.{}.attn.norm_k.weight\",\n    }\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))\n    # gemma3 4b+ are multimodel models, but we are only loading the text weights\n    is_multimodal = any(k.startswith(GEMMA3_LANGUAGE_MODEL_PREFIX) for k in hf_weights)\n    if is_multimodal:\n        warnings.warn(\"For Gemma3 models only the text component is supported.\")\n        new_weight_map = dict()\n        prefix = \"model\"\n        for k, v in weight_map.items():\n            if k.startswith(prefix):\n                k = GEMMA3_LANGUAGE_MODEL_PREFIX + k[len(prefix) :]\n            new_weight_map[k] = v\n        weight_map = new_weight_map\n    for from_name, param in hf_weights.items():\n        if from_name.startswith(GEMMA3_VISION_MODEL_PREFIX) or from_name.startswith(GEMMA3_MM_PROJECTOR_PREFIX):\n            continue\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        to_name = weight_map.get(name_template)\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        # in multimodal models, the text weights are the first part of the weights\n        if is_multimodal and to_name == \"transformer.wte.weight\" and config is not None:\n            param = param[: config.vocab_size]\n        if any(w in from_name for w in (\"q_proj\", \"k_proj\", \"v_proj\")):\n            qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))\n            weight_name, weight_type = from_name.split(\".\")[-2:]\n            qkv[weight_type][weight_name] = param\n\n        if to_name is None:\n            continue\n        to_name = to_name.format(*ids)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n    if \"lm_head.weight\" not in state_dict:\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.wte.weight\"]\n\n    for i in list(qkv_weights):\n        for weight_type in list(qkv_weights[i]):\n            qkv = qkv_weights[i][weight_type]\n            if len(qkv) != 3:\n                # qkv is split across different .bin files\n                continue\n            q = load_param(qkv[\"q_proj\"], f\"layer {i} q {weight_type}\", dtype, verbose=debug_mode)\n            k = load_param(qkv[\"k_proj\"], f\"layer {i} k {weight_type}\", dtype, verbose=debug_mode)\n            v = load_param(qkv[\"v_proj\"], f\"layer {i} v {weight_type}\", dtype, verbose=debug_mode)\n            qkv = torch.cat((q, k, v))\n            state_dict[f\"transformer.h.{i}.attn.qkv.{weight_type}\"] = qkv\n            del qkv_weights[i][weight_type]\n\n            if progress_per_file is not None:\n                pbar.update(progress_per_file)\n\n\ndef copy_weights_phi(\n    config: Config,\n    qkv_weights: dict,\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    if any(layer_name.startswith((\"layers.\", \"transformer.\")) for layer_name in hf_weights):\n        raise ValueError(\n            \"You are using an outdated Phi checkpoint. Please reload it as described in 'tutorials/download_phi.md'\"\n        )\n\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"model.layers.{}.input_layernorm.bias\": \"transformer.h.{}.norm_1.bias\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.q_proj.bias\": None,\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_proj.bias\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.bias\": None,\n        \"model.layers.{}.self_attn.dense.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.self_attn.dense.bias\": \"transformer.h.{}.attn.proj.bias\",\n        \"model.layers.{}.mlp.fc1.weight\": \"transformer.h.{}.mlp.fc.weight\",\n        \"model.layers.{}.mlp.fc1.bias\": \"transformer.h.{}.mlp.fc.bias\",\n        \"model.layers.{}.mlp.fc2.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"model.layers.{}.mlp.fc2.bias\": \"transformer.h.{}.mlp.proj.bias\",\n        \"model.final_layernorm.weight\": \"transformer.ln_f.weight\",\n        \"model.final_layernorm.bias\": \"transformer.ln_f.bias\",\n        \"lm_head.weight\": \"lm_head.weight\",\n        \"lm_head.bias\": \"lm_head.bias\",\n    }\n\n    if config.name.startswith((\"Phi-3\", \"phi-4\", \"Phi-4\")):\n        weight_map.update(\n            {\n                \"model.layers.{}.self_attn.qkv_proj.weight\": \"transformer.h.{}.attn.qkv.weight\",\n                \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n                \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n                \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.proj.weight\",\n                \"model.norm.weight\": \"transformer.ln_f.weight\",\n            }\n        )\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, layer_idx = layer_template(from_name)\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if any(w in from_name for w in (\"q_proj\", \"k_proj\", \"v_proj\")):\n            qkv = qkv_weights.setdefault(layer_idx, defaultdict(dict))\n            weight_name, weight_type = from_name.split(\".\")[-2:]\n            qkv[weight_type][weight_name] = param\n        elif from_name.endswith(\"gate_up_proj.weight\"):\n            weight = load_param(param, f\"layer {layer_idx} gate_up_proj\", dtype, verbose=debug_mode)\n            fc_1, fc_2 = weight.chunk(2, dim=0)\n            state_dict[f\"transformer.h.{layer_idx}.mlp.fc_1.weight\"] = fc_1\n            state_dict[f\"transformer.h.{layer_idx}.mlp.fc_2.weight\"] = fc_2\n            continue\n        to_name = weight_map[name_template]\n        if to_name is None:\n            continue\n        to_name = to_name.format(layer_idx)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n    if \"lm_head.weight\" not in state_dict and config.name.startswith(\"Phi-4\"):\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.wte.weight\"]\n\n    for i in list(qkv_weights):\n        for weight_type in list(qkv_weights[i]):\n            qkv = qkv_weights[i][weight_type]\n            if len(qkv) != 3:\n                # qkv is split across different .bin files\n                continue\n            q = load_param(qkv[\"q_proj\"], f\"layer {i} q {weight_type}\", dtype, verbose=debug_mode)\n            k = load_param(qkv[\"k_proj\"], f\"layer {i} k {weight_type}\", dtype, verbose=debug_mode)\n            v = load_param(qkv[\"v_proj\"], f\"layer {i} v {weight_type}\", dtype, verbose=debug_mode)\n            qkv = torch.cat((q, k, v))\n            state_dict[f\"transformer.h.{i}.attn.qkv.{weight_type}\"] = qkv\n            del qkv_weights[i][weight_type]\n\n            if progress_per_file is not None:\n                pbar.update(progress_per_file)\n\n\ndef copy_weights_qwen_2_5(\n    config: Config,\n    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.q_proj.bias\": None,\n        \"model.layers.{}.self_attn.k_proj.bias\": None,\n        \"model.layers.{}.self_attn.v_proj.bias\": None,\n        \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"model.layers.{}.mlp.gate_proj.weight\": \"transformer.h.{}.mlp.fc_1.weight\",\n        \"model.layers.{}.mlp.up_proj.weight\": \"transformer.h.{}.mlp.fc_2.weight\",\n        \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.proj.weight\",\n        \"model.norm.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        to_name = weight_map[name_template]\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if any(w in from_name for w in (\"q_proj\", \"k_proj\", \"v_proj\")):\n            qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))\n            weight_name, weight_type = from_name.split(\".\")[-2:]\n            qkv[weight_type][weight_name] = param\n        if to_name is None:\n            continue\n        to_name = to_name.format(*ids)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n    if \"lm_head.weight\" not in state_dict:\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.wte.weight\"]\n\n    for i in list(qkv_weights):\n        for weight_type in list(qkv_weights[i]):\n            qkv = qkv_weights[i][weight_type]\n            if len(qkv) != 3:\n                # qkv is split across different .bin files\n                continue\n            q = load_param(qkv[\"q_proj\"], f\"layer {i} q {weight_type}\", dtype, verbose=debug_mode)\n            k = load_param(qkv[\"k_proj\"], f\"layer {i} k {weight_type}\", dtype, verbose=debug_mode)\n            v = load_param(qkv[\"v_proj\"], f\"layer {i} v {weight_type}\", dtype, verbose=debug_mode)\n            qkv = torch.cat((q, k, v))\n            state_dict[f\"transformer.h.{i}.attn.qkv.{weight_type}\"] = qkv\n            del qkv_weights[i][weight_type]\n\n            if progress_per_file is not None:\n                pbar.update(progress_per_file)\n\n\ndef copy_weights_olmo2(\n    config: Config,\n    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.self_attn.q_norm.weight\": \"transformer.h.{}.attn.norm_q.weight\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_norm.weight\": \"transformer.h.{}.attn.norm_k.weight\",\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.self_attn.rotary_emb.inv_freq\": None,\n        \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.post_attention_norm.weight\",\n        \"model.layers.{}.post_attention_layernorm.bias\": \"transformer.h.{}.post_attention_norm.bias\",\n        \"model.layers.{}.post_feedforward_layernorm.weight\": \"transformer.h.{}.post_mlp_norm.weight\",\n        \"model.norm.weight\": \"transformer.ln_f.weight\",\n        \"model.norm.bias\": \"transformer.ln_f.bias\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    if config.mlp_class_name in (\"LLaMAMLP\", \"GemmaMLP\"):\n        weight_map.update(\n            {\n                \"model.layers.{}.mlp.gate_proj.weight\": \"transformer.h.{}.mlp.fc_1.weight\",\n                \"model.layers.{}.mlp.up_proj.weight\": \"transformer.h.{}.mlp.fc_2.weight\",\n                \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.proj.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        to_name = weight_map[name_template]\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if any(w in from_name for w in (\"q_proj\", \"k_proj\", \"v_proj\")):\n            qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))\n            weight_name, weight_type = from_name.split(\".\")[-2:]\n            qkv[weight_type][weight_name] = param\n        if to_name is None:\n            continue\n        to_name = to_name.format(*ids)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n    if \"lm_head.weight\" not in state_dict:\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.wte.weight\"]\n\n    for i in list(qkv_weights):\n        for weight_type in list(qkv_weights[i]):\n            qkv = qkv_weights[i][weight_type]\n            if len(qkv) != 3:\n                # qkv is split across different .bin files\n                continue\n            q = load_param(qkv[\"q_proj\"], f\"layer {i} q {weight_type}\", dtype, verbose=debug_mode)\n            k = load_param(qkv[\"k_proj\"], f\"layer {i} k {weight_type}\", dtype, verbose=debug_mode)\n            v = load_param(qkv[\"v_proj\"], f\"layer {i} v {weight_type}\", dtype, verbose=debug_mode)\n            qkv = torch.cat((q, k, v))\n            state_dict[f\"transformer.h.{i}.attn.qkv.{weight_type}\"] = qkv\n            del qkv_weights[i][weight_type]\n\n            if progress_per_file is not None:\n                pbar.update(progress_per_file)\n\n\ndef copy_weights_qwen_3(\n    config: Config,\n    qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],\n    state_dict: Dict[str, torch.Tensor],\n    hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n    dtype: Optional[torch.dtype] = None,\n    pbar: Optional[tqdm] = None,\n    progress_per_file: Optional[float] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    weight_map = {\n        \"model.embed_tokens.weight\": \"transformer.wte.weight\",\n        \"model.layers.{}.input_layernorm.weight\": \"transformer.h.{}.norm_1.weight\",\n        \"model.layers.{}.self_attn.q_proj.weight\": None,\n        \"model.layers.{}.self_attn.k_proj.weight\": None,\n        \"model.layers.{}.self_attn.v_proj.weight\": None,\n        \"model.layers.{}.self_attn.q_norm.weight\": \"transformer.h.{}.attn.norm_q.weight\",\n        \"model.layers.{}.self_attn.k_norm.weight\": \"transformer.h.{}.attn.norm_k.weight\",\n        \"model.layers.{}.self_attn.o_proj.weight\": \"transformer.h.{}.attn.proj.weight\",\n        \"model.layers.{}.post_attention_layernorm.weight\": \"transformer.h.{}.norm_2.weight\",\n        \"model.norm.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    if config.mlp_class_name == \"LLaMAMoE\":\n        weight_map.update(\n            {\n                \"model.layers.{}.mlp.experts.{}.gate_proj.weight\": \"transformer.h.{}.mlp.experts.{}.fc_1.weight\",\n                \"model.layers.{}.mlp.experts.{}.up_proj.weight\": \"transformer.h.{}.mlp.experts.{}.fc_2.weight\",\n                \"model.layers.{}.mlp.experts.{}.down_proj.weight\": \"transformer.h.{}.mlp.experts.{}.proj.weight\",\n                \"model.layers.{}.mlp.gate.weight\": \"transformer.h.{}.mlp.gate.weight\",\n            }\n        )\n    elif config.mlp_class_name == \"LLaMAMLP\":\n        weight_map.update(\n            {\n                \"model.layers.{}.mlp.gate_proj.weight\": \"transformer.h.{}.mlp.fc_1.weight\",\n                \"model.layers.{}.mlp.up_proj.weight\": \"transformer.h.{}.mlp.fc_2.weight\",\n                \"model.layers.{}.mlp.down_proj.weight\": \"transformer.h.{}.mlp.proj.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    if progress_per_file is not None:\n        progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))\n\n    for from_name, param in hf_weights.items():\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        to_name = weight_map[name_template]\n        param = load_param(param, from_name, dtype, verbose=debug_mode)\n        if any(w in from_name for w in (\"q_proj\", \"k_proj\", \"v_proj\")):\n            qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))\n            weight_name, weight_type = from_name.split(\".\")[-2:]\n            qkv[weight_type][weight_name] = param\n        if to_name is None:\n            continue\n        to_name = to_name.format(*ids)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n        if progress_per_file is not None:\n            pbar.update(progress_per_file)\n\n    if \"lm_head.weight\" not in state_dict:\n        state_dict[\"lm_head.weight\"] = state_dict[\"transformer.wte.weight\"]\n\n    for i in list(qkv_weights):\n        for weight_type in list(qkv_weights[i]):\n            qkv = qkv_weights[i][weight_type]\n            if len(qkv) != 3:\n                # qkv is split across different .bin files\n                continue\n            q = load_param(qkv[\"q_proj\"], f\"layer {i} q {weight_type}\", dtype, verbose=debug_mode)\n            k = load_param(qkv[\"k_proj\"], f\"layer {i} k {weight_type}\", dtype, verbose=debug_mode)\n            v = load_param(qkv[\"v_proj\"], f\"layer {i} v {weight_type}\", dtype, verbose=debug_mode)\n            qkv = torch.cat((q, k, v))\n            state_dict[f\"transformer.h.{i}.attn.qkv.{weight_type}\"] = qkv\n            del qkv_weights[i][weight_type]\n\n            if progress_per_file is not None:\n                pbar.update(progress_per_file)\n\n\ndef qkv_reassemble(\n    param: Union[torch.Tensor, NotYetLoadedTensor], config: Config\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"Reassemble from a normal to an interleaved placement in a QKV matrix.\n    [Q, K, V, Q, K, V, ...] --> [Q, Q, ..., K, K, ..., V, V, ...]\n    \"\"\"\n    q_per_kv = config.n_head // config.n_query_groups\n    qs = []\n    ks = []\n    vs = []\n    for chunk in torch.chunk(param, config.n_query_groups):\n        split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size])\n        qs.append(split[0])\n        ks.append(split[1])\n        vs.append(split[2])\n    q = torch.cat(qs)\n    k = torch.cat(ks)\n    v = torch.cat(vs)\n    return torch.cat((q, k, v))\n\n\ndef layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]:\n    pattern = r\"\\.(\\d+)\\.\"\n    if not (search_res := re.findall(pattern, layer_name)):\n        return layer_name, -1\n    layer_name_template = re.sub(pattern, \".{}.\", layer_name, count=num_matches)\n    return layer_name_template, *(int(x) for x in search_res[:num_matches])\n\n\ndef load_param(\n    param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool = False\n) -> torch.Tensor:\n    if hasattr(param, \"_load_tensor\"):\n        # support tensors loaded via `lazy_load()`\n        if verbose:\n            print(f\"Loading {name!r} into RAM\")\n        param = param._load_tensor()\n    if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype:\n        if verbose:\n            print(f\"Converting {name!r} from {param.dtype} to {dtype}\")\n        param = param.to(dtype)\n    return param\n\n\n@torch.inference_mode()\ndef convert_hf_checkpoint(\n    checkpoint_dir: Path,\n    *,\n    model_name: Optional[str] = None,\n    dtype: Optional[str] = None,\n    debug_mode: Optional[bool] = False,\n) -> None:\n    \"\"\"\n    Convert a Hugging Face Transformers checkpoint into a LitGPT compatible checkpoint.\n\n    Arguments:\n        checkpoint_dir: Where to save the downloaded files.\n        model_name: The existing config name to load. This is useful to download alternative weights of existing\n            architectures.\n        dtype: The data type to convert the checkpoint files to. If not specified, the weights will remain in the\n            dtype they are downloaded in.\n        debug_mode: Prints the individual layers being loaded instead of a progress bar, which can be useful when\n            developing and adding new models to LitGPT.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    if model_name is None:\n        model_name = checkpoint_dir.name\n    if dtype is not None:\n        dtype = getattr(torch, dtype)\n\n    config = Config.from_name(model_name)\n    save_config(config, checkpoint_dir)\n\n    if \"falcon\" in model_name:\n        copy_fn = partial(copy_weights_falcon, config)\n    elif model_name.lower().startswith(\"gemma-2\"):\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_gemma_2, qkv_weights)\n    elif model_name.lower().startswith(\"gemma-3\"):\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_gemma_3, qkv_weights, config=config)\n    elif model_name.lower().startswith(\"phi\"):\n        # holder to reconstitute the split q, k, v\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_phi, config, qkv_weights)\n    elif model_name.lower().startswith((\"qwen2.5\", \"qwq\")):\n        # holder to reconstitute the split q, k, v\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)\n    elif model_name.lower().startswith(\"olmo-2-\"):\n        # holder to reconstitute the split q, k, v\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_olmo2, config, qkv_weights)\n    elif model_name.lower().startswith(\"qwen3\"):\n        # holder to reconstitute the split q, k, v\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_qwen_3, config, qkv_weights)\n    elif config.mlp_class_name in (\"LLaMAMLP\", \"GemmaMLP\", \"LLaMAMoE\"):\n        # holder to reconstitute the split q, k, v\n        qkv_weights = {}\n        copy_fn = partial(copy_weights_hf_llama, config, qkv_weights)\n    else:\n        copy_fn = partial(copy_weights_gpt_neox, config)\n\n    # initialize a new empty state dict to hold our new weights\n    sd = {}\n\n    # Load the json file containing weight mapping\n    pytorch_bin_map_json_path = checkpoint_dir / \"pytorch_model.bin.index.json\"\n    model_safetensor_map_json_path = checkpoint_dir / \"model.safetensors.index.json\"\n    if pytorch_bin_map_json_path.is_file():  # not all checkpoints have this file\n        with open(pytorch_bin_map_json_path, encoding=\"utf-8\") as json_map:\n            bin_index = json.load(json_map)\n        bin_files = {checkpoint_dir / bin for bin in bin_index[\"weight_map\"].values()}\n    elif model_safetensor_map_json_path.is_file():\n        with open(model_safetensor_map_json_path, encoding=\"utf-8\") as json_map:\n            bin_index = json.load(json_map)\n        bin_files = {checkpoint_dir / bin for bin in bin_index[\"weight_map\"].values()}\n    else:\n        bin_files = set(checkpoint_dir.glob(\"*.bin\")) | set(checkpoint_dir.glob(\"*.safetensors\"))\n        # some checkpoints serialize the training arguments\n        bin_files = {f for f in bin_files if f.name != \"training_args.bin\"}\n    if not bin_files:\n        raise ValueError(f\"Expected {str(checkpoint_dir)!r} to contain .bin or .safetensors files\")\n\n    with incremental_save(checkpoint_dir / \"lit_model.pth\") as saver:\n        # for checkpoints that split the QKV across several files, we need to keep all the bin files\n        # open, so we use `ExitStack` to close them all together at the end\n\n        if not debug_mode:\n            # Using tqdm progress bar when not in debug mode\n\n            total_size = max(1, sum(os.path.getsize(bin_file) for bin_file in bin_files))\n            total_progress = 100\n\n            with tqdm(\n                total=total_progress,\n                desc=\"Initializing\",\n                bar_format=\"{desc}{percentage:3.0f}%|{bar}| {elapsed}<{remaining}, {rate_fmt}\",\n            ) as pbar:\n                for bin_file in sorted(bin_files):\n                    pbar.set_description(f\"Loading weights: {bin_file.name}\")\n                    current_file_size = os.path.getsize(bin_file)\n                    progress_per_file = (current_file_size / total_size) * total_progress\n\n                    hf_weights = (\n                        load_safetensors(bin_file) if bin_file.suffix == \".safetensors\" else lazy_load(bin_file)\n                    )\n                    copy_fn(\n                        sd,\n                        hf_weights,\n                        saver=saver,\n                        dtype=dtype,\n                        pbar=pbar,\n                        progress_per_file=progress_per_file,\n                        debug_mode=debug_mode,\n                    )\n                gc.collect()\n\n                if pbar.n < total_progress:\n                    pbar.update(total_progress - pbar.n)\n                pbar.close()\n        else:\n            # Handling files without progress bar in debug mode\n            for bin_file in sorted(bin_files):\n                hf_weights = load_safetensors(bin_file) if bin_file.suffix == \".safetensors\" else lazy_load(bin_file)\n                copy_fn(sd, hf_weights, saver=saver, dtype=dtype, debug_mode=debug_mode)\n        print(f\"Saving converted checkpoint to {checkpoint_dir}\")\n        saver.save(sd)\n"
  },
  {
    "path": "litgpt/scripts/convert_lit_checkpoint.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport gc\nfrom collections import defaultdict\nfrom functools import partial\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Dict, Optional, Union\n\nimport torch\nfrom lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor\n\nfrom litgpt import Config\nfrom litgpt.scripts.convert_hf_checkpoint import layer_template, load_param\nfrom litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load\n\n\ndef copy_weights_falcon(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"transformer.word_embeddings.weight\",\n        \"transformer.h.{}.attn.qkv.weight\": \"transformer.h.{}.self_attention.query_key_value.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"transformer.h.{}.self_attention.dense.weight\",\n        \"transformer.h.{}.mlp.fc.weight\": \"transformer.h.{}.mlp.dense_h_to_4h.weight\",\n        \"transformer.h.{}.mlp.proj.weight\": \"transformer.h.{}.mlp.dense_4h_to_h.weight\",\n        \"transformer.ln_f.bias\": \"transformer.ln_f.bias\",\n        \"transformer.ln_f.weight\": \"transformer.ln_f.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    # the original model definition is different for each size\n    if \"7b\" in config.name:\n        weight_map.update(\n            {\n                \"transformer.h.{}.norm_1.bias\": \"transformer.h.{}.input_layernorm.bias\",\n                \"transformer.h.{}.norm_1.weight\": \"transformer.h.{}.input_layernorm.weight\",\n            }\n        )\n    elif \"40b\" in config.name or \"180B\" in config.name:\n        weight_map.update(\n            {\n                \"transformer.h.{}.norm_1.bias\": \"transformer.h.{}.ln_attn.bias\",\n                \"transformer.h.{}.norm_1.weight\": \"transformer.h.{}.ln_attn.weight\",\n                \"transformer.h.{}.norm_2.bias\": \"transformer.h.{}.ln_mlp.bias\",\n                \"transformer.h.{}.norm_2.weight\": \"transformer.h.{}.ln_mlp.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    for from_name, param in lit_weights.items():\n        name_template, layer_idx = layer_template(from_name)\n        to_name = weight_map[name_template].format(layer_idx)\n        param = load_param(param, from_name, None)\n        if from_name.endswith((\".attn.qkv.weight\", \".attn.qkv.bias\")):\n            # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...]\n            param = qkv_reassemble(param, config)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n\ndef copy_weights_gpt_neox(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"gpt_neox.embed_in.weight\",\n        \"transformer.h.{}.norm_1.bias\": \"gpt_neox.layers.{}.input_layernorm.bias\",\n        \"transformer.h.{}.norm_1.weight\": \"gpt_neox.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.attn.qkv.bias\": \"gpt_neox.layers.{}.attention.query_key_value.bias\",\n        \"transformer.h.{}.attn.qkv.weight\": \"gpt_neox.layers.{}.attention.query_key_value.weight\",\n        \"transformer.h.{}.attn.proj.bias\": \"gpt_neox.layers.{}.attention.dense.bias\",\n        \"transformer.h.{}.attn.proj.weight\": \"gpt_neox.layers.{}.attention.dense.weight\",\n        \"transformer.h.{}.norm_2.bias\": \"gpt_neox.layers.{}.post_attention_layernorm.bias\",\n        \"transformer.h.{}.norm_2.weight\": \"gpt_neox.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.mlp.fc.bias\": \"gpt_neox.layers.{}.mlp.dense_h_to_4h.bias\",\n        \"transformer.h.{}.mlp.fc.weight\": \"gpt_neox.layers.{}.mlp.dense_h_to_4h.weight\",\n        \"transformer.h.{}.mlp.proj.bias\": \"gpt_neox.layers.{}.mlp.dense_4h_to_h.bias\",\n        \"transformer.h.{}.mlp.proj.weight\": \"gpt_neox.layers.{}.mlp.dense_4h_to_h.weight\",\n        \"transformer.ln_f.bias\": \"gpt_neox.final_layer_norm.bias\",\n        \"transformer.ln_f.weight\": \"gpt_neox.final_layer_norm.weight\",\n        \"lm_head.weight\": \"embed_out.weight\",\n    }\n\n    for from_name, param in lit_weights.items():\n        name_template, layer_idx = layer_template(from_name)\n        to_name = weight_map[name_template].format(layer_idx)\n        param = load_param(param, from_name, None)\n        if from_name.endswith((\".attn.qkv.weight\", \".attn.qkv.bias\")):\n            # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...]\n            param = qkv_reassemble(param, config)\n        if saver is not None:\n            param = saver.store_early(param)\n        state_dict[to_name] = param\n\n\ndef copy_weights_llama(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    untie_weights: bool = False,\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.norm_1.weight\": \"model.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.norm_1.bias\": \"model.layers.{}.input_layernorm.bias\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n        \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.norm_2.bias\": \"model.layers.{}.post_attention_layernorm.bias\",\n        \"transformer.ln_f.weight\": \"model.norm.weight\",\n        \"transformer.ln_f.bias\": \"model.norm.bias\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    if config.mlp_class_name == \"LLaMAMoE\":\n        weight_map.update(\n            {\n                \"transformer.h.{}.mlp.gate.weight\": \"model.layers.{}.block_sparse_moe.gate.weight\",\n                \"transformer.h.{}.mlp.experts.{}.fc_1.weight\": \"model.layers.{}.block_sparse_moe.experts.{}.w1.weight\",\n                \"transformer.h.{}.mlp.experts.{}.fc_2.weight\": \"model.layers.{}.block_sparse_moe.experts.{}.w3.weight\",\n                \"transformer.h.{}.mlp.experts.{}.proj.weight\": \"model.layers.{}.block_sparse_moe.experts.{}.w2.weight\",\n            }\n        )\n    elif config.mlp_class_name in (\"LLaMAMLP\", \"GemmaMLP\"):\n        weight_map.update(\n            {\n                \"transformer.h.{}.mlp.fc_1.weight\": \"model.layers.{}.mlp.gate_proj.weight\",\n                \"transformer.h.{}.mlp.fc_2.weight\": \"model.layers.{}.mlp.up_proj.weight\",\n                \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    for from_name, param in lit_weights.items():\n        if from_name == \"lm_head.weight\" and untie_weights:\n            continue\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        param = load_param(param, from_name, None)\n        if from_name.endswith(\".attn.qkv.weight\"):\n            to_names = (\n                \"model.layers.{}.self_attn.q_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.k_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.v_proj.weight\".format(*ids),\n            )\n            params = param.split(\n                (\n                    config.n_head * config.head_size,\n                    config.n_query_groups * config.head_size,\n                    config.n_query_groups * config.head_size,\n                )\n            )\n        else:\n            to_names = (weight_map[name_template].format(*ids),)\n            params = (param,)\n\n        for to_name, param in zip(to_names, params):\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n\ndef copy_weights_gemma_2(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    untie_weights: bool = True,\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n        \"transformer.h.{}.mlp.fc_1.weight\": \"model.layers.{}.mlp.gate_proj.weight\",\n        \"transformer.h.{}.mlp.fc_2.weight\": \"model.layers.{}.mlp.up_proj.weight\",\n        \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n        \"transformer.h.{}.norm_1.weight\": \"model.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.post_attention_norm.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.pre_feedforward_layernorm.weight\",\n        \"transformer.h.{}.post_mlp_norm.weight\": \"model.layers.{}.post_feedforward_layernorm.weight\",\n        \"transformer.ln_f.weight\": \"model.norm.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n\n    for from_name, param in lit_weights.items():\n        if from_name == \"lm_head.weight\" and untie_weights:\n            continue\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        param = load_param(param, from_name, None)\n        if from_name.endswith(\".attn.qkv.weight\"):\n            to_names = (\n                \"model.layers.{}.self_attn.q_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.k_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.v_proj.weight\".format(*ids),\n            )\n            params = param.split(\n                (\n                    config.n_head * config.head_size,\n                    config.n_query_groups * config.head_size,\n                    config.n_query_groups * config.head_size,\n                )\n            )\n        else:\n            to_names = (weight_map[name_template].format(*ids),)\n            params = (param,)\n\n        for to_name, param in zip(to_names, params):\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n\ndef copy_weights_gemma_3(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    untie_weights: bool = True,\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n        \"transformer.h.{}.mlp.fc_1.weight\": \"model.layers.{}.mlp.gate_proj.weight\",\n        \"transformer.h.{}.mlp.fc_2.weight\": \"model.layers.{}.mlp.up_proj.weight\",\n        \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n        \"transformer.h.{}.norm_1.weight\": \"model.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.post_attention_norm.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.pre_feedforward_layernorm.weight\",\n        \"transformer.h.{}.post_mlp_norm.weight\": \"model.layers.{}.post_feedforward_layernorm.weight\",\n        \"transformer.ln_f.weight\": \"model.norm.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n        \"transformer.h.{}.attn.norm_q.weight\": \"model.layers.{}.self_attn.q_norm.weight\",\n        \"transformer.h.{}.attn.norm_k.weight\": \"model.layers.{}.self_attn.k_norm.weight\",\n    }\n\n    for from_name, param in lit_weights.items():\n        if from_name == \"lm_head.weight\" and untie_weights:\n            continue\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        param = load_param(param, from_name, None)\n        if from_name.endswith(\".attn.qkv.weight\"):\n            to_names = (\n                \"model.layers.{}.self_attn.q_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.k_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.v_proj.weight\".format(*ids),\n            )\n            params = param.split(\n                (\n                    config.n_head * config.head_size,\n                    config.n_query_groups * config.head_size,\n                    config.n_query_groups * config.head_size,\n                )\n            )\n        else:\n            to_names = (weight_map[name_template].format(*ids),)\n            params = (param,)\n\n        for to_name, param in zip(to_names, params):\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n\ndef copy_weights_phi(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.norm_1.weight\": \"model.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.norm_1.bias\": \"model.layers.{}.input_layernorm.bias\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.dense.weight\",\n        \"transformer.h.{}.attn.proj.bias\": \"model.layers.{}.self_attn.dense.bias\",\n        \"transformer.h.{}.mlp.fc.weight\": \"model.layers.{}.mlp.fc1.weight\",\n        \"transformer.h.{}.mlp.fc.bias\": \"model.layers.{}.mlp.fc1.bias\",\n        \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.fc2.weight\",\n        \"transformer.h.{}.mlp.proj.bias\": \"model.layers.{}.mlp.fc2.bias\",\n        \"transformer.ln_f.weight\": \"model.final_layernorm.weight\",\n        \"transformer.ln_f.bias\": \"model.final_layernorm.bias\",\n        \"lm_head.weight\": \"lm_head.weight\",\n        \"lm_head.bias\": \"lm_head.bias\",\n    }\n    if config.name.lower().startswith((\"phi-3\", \"phi-4\")):\n        weight_map.update(\n            {\n                \"transformer.h.{}.attn.qkv.weight\": \"model.layers.{}.self_attn.qkv_proj.weight\",\n                \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n                \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n                \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n                \"transformer.ln_f.weight\": \"model.norm.weight\",\n            }\n        )\n        gate_up_proj_weights = defaultdict(dict)\n\n    for from_name, param in lit_weights.items():\n        if from_name == \"lm_head.weight\" and config.name.startswith(\"Phi-4\"):\n            continue\n        name_template, layer_idx = layer_template(from_name)\n        param = load_param(param, from_name, None)\n        if from_name.endswith((\".attn.qkv.weight\", \".attn.qkv.bias\")):\n            if config.name.lower().startswith((\"phi-3\", \"phi-4\")):\n                to_names = (weight_map[name_template].format(layer_idx),)\n                params = (param,)\n            else:\n                weight_type = from_name.split(\".\")[-1]  # weight or bias\n                to_names = (\n                    f\"model.layers.{{}}.self_attn.q_proj.{weight_type}\".format(layer_idx),\n                    f\"model.layers.{{}}.self_attn.k_proj.{weight_type}\".format(layer_idx),\n                    f\"model.layers.{{}}.self_attn.v_proj.{weight_type}\".format(layer_idx),\n                )\n                params = param.split(\n                    (\n                        config.n_head * config.head_size,\n                        config.n_query_groups * config.head_size,\n                        config.n_query_groups * config.head_size,\n                    )\n                )\n        elif from_name.endswith((\".fc_1.weight\", \".fc_2.weight\")):\n            weight = load_param(param, from_name, None)\n            weight_name = from_name.split(\".\")[-2]\n            gate_up_proj_weights[layer_idx][weight_name] = weight\n        else:\n            to_names = (weight_map[name_template].format(layer_idx),)\n            params = (param,)\n\n        for to_name, param in zip(to_names, params):\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n    if config.name.lower().startswith((\"phi-3\", \"phi-4\")):\n        for layer_idx in list(gate_up_proj_weights):\n            fc_1_weight = gate_up_proj_weights[layer_idx][\"fc_1\"]\n            fc_2_weight = gate_up_proj_weights[layer_idx][\"fc_2\"]\n            weight = torch.concat([fc_1_weight, fc_2_weight], dim=0)\n            layer_name = f\"model.layers.{layer_idx}.mlp.gate_up_proj.weight\"\n            state_dict[layer_name] = weight\n            del gate_up_proj_weights[layer_idx]\n\n\ndef copy_weights_qwen_2_5(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    untie_weights: bool = False,\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.norm_1.weight\": \"model.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n        \"transformer.h.{}.mlp.fc_1.weight\": \"model.layers.{}.mlp.gate_proj.weight\",\n        \"transformer.h.{}.mlp.fc_2.weight\": \"model.layers.{}.mlp.up_proj.weight\",\n        \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n        \"transformer.ln_f.weight\": \"model.norm.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n\n    for from_name, param in lit_weights.items():\n        if from_name == \"lm_head.weight\" and untie_weights:\n            continue\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        param = load_param(param, from_name, None)\n        if from_name.endswith((\".attn.qkv.weight\", \".attn.qkv.bias\")):\n            weight_type = from_name.split(\".\")[-1]  # weight or bias\n            to_names = (\n                \"model.layers.{}.self_attn.q_proj.{}\".format(*ids, weight_type),\n                \"model.layers.{}.self_attn.k_proj.{}\".format(*ids, weight_type),\n                \"model.layers.{}.self_attn.v_proj.{}\".format(*ids, weight_type),\n            )\n            params = param.split(\n                (\n                    config.n_head * config.head_size,\n                    config.n_query_groups * config.head_size,\n                    config.n_query_groups * config.head_size,\n                )\n            )\n        else:\n            to_names = (weight_map[name_template].format(*ids),)\n            params = (param,)\n\n        for to_name, param in zip(to_names, params):\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n\ndef copy_weights_olmo2(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    untie_weights: bool = False,\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n        \"transformer.h.{}.attn.norm_q.weight\": \"model.layers.{}.self_attn.q_norm.weight\",\n        \"transformer.h.{}.attn.norm_k.weight\": \"model.layers.{}.self_attn.k_norm.weight\",\n        \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.norm_2.bias\": \"model.layers.{}.post_attention_layernorm.bias\",\n        \"transformer.h.{}.post_mlp_norm.weight\": \"model.layers.{}.post_feedforward_layernorm.weight\",\n        \"transformer.ln_f.weight\": \"model.norm.weight\",\n        \"transformer.ln_f.bias\": \"model.norm.bias\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    if config.mlp_class_name in (\"LLaMAMLP\", \"GemmaMLP\"):\n        weight_map.update(\n            {\n                \"transformer.h.{}.mlp.fc_1.weight\": \"model.layers.{}.mlp.gate_proj.weight\",\n                \"transformer.h.{}.mlp.fc_2.weight\": \"model.layers.{}.mlp.up_proj.weight\",\n                \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    for from_name, param in lit_weights.items():\n        if from_name == \"lm_head.weight\" and untie_weights:\n            continue\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        param = load_param(param, from_name, None)\n        if from_name.endswith(\".attn.qkv.weight\"):\n            to_names = (\n                \"model.layers.{}.self_attn.q_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.k_proj.weight\".format(*ids),\n                \"model.layers.{}.self_attn.v_proj.weight\".format(*ids),\n            )\n            params = param.split(\n                (\n                    config.n_head * config.head_size,\n                    config.n_query_groups * config.head_size,\n                    config.n_query_groups * config.head_size,\n                )\n            )\n        else:\n            to_names = (weight_map[name_template].format(*ids),)\n            params = (param,)\n\n        for to_name, param in zip(to_names, params):\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n\ndef copy_weights_qwen_3(\n    config: Config,\n    state_dict: Dict[str, torch.Tensor],\n    lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],\n    untie_weights: bool = False,\n    saver: Optional[incremental_save] = None,\n) -> None:\n    weight_map = {\n        \"transformer.wte.weight\": \"model.embed_tokens.weight\",\n        \"transformer.h.{}.norm_1.weight\": \"model.layers.{}.input_layernorm.weight\",\n        \"transformer.h.{}.norm_2.weight\": \"model.layers.{}.post_attention_layernorm.weight\",\n        \"transformer.h.{}.attn.proj.weight\": \"model.layers.{}.self_attn.o_proj.weight\",\n        \"transformer.h.{}.attn.norm_q.weight\": \"model.layers.{}.self_attn.q_norm.weight\",\n        \"transformer.h.{}.attn.norm_k.weight\": \"model.layers.{}.self_attn.k_norm.weight\",\n        \"transformer.ln_f.weight\": \"model.norm.weight\",\n        \"lm_head.weight\": \"lm_head.weight\",\n    }\n    if config.mlp_class_name == \"LLaMAMoE\":\n        weight_map.update(\n            {\n                \"transformer.h.{}.mlp.gate.weight\": \"model.layers.{}.mlp.gate.weight\",\n                \"transformer.h.{}.mlp.experts.{}.fc_1.weight\": \"model.layers.{}.mlp.experts.{}.gate_proj.weight\",\n                \"transformer.h.{}.mlp.experts.{}.fc_2.weight\": \"model.layers.{}.mlp.experts.{}.up_proj.weight\",\n                \"transformer.h.{}.mlp.experts.{}.proj.weight\": \"model.layers.{}.mlp.experts.{}.down_proj.weight\",\n            }\n        )\n    elif config.mlp_class_name == \"LLaMAMLP\":\n        weight_map.update(\n            {\n                \"transformer.h.{}.mlp.fc_1.weight\": \"model.layers.{}.mlp.gate_proj.weight\",\n                \"transformer.h.{}.mlp.fc_2.weight\": \"model.layers.{}.mlp.up_proj.weight\",\n                \"transformer.h.{}.mlp.proj.weight\": \"model.layers.{}.mlp.down_proj.weight\",\n            }\n        )\n    else:\n        raise NotImplementedError\n\n    for from_name, param in lit_weights.items():\n        if from_name == \"lm_head.weight\" and untie_weights:\n            continue\n        name_template, *ids = layer_template(from_name, num_matches=2)\n        param = load_param(param, from_name, None)\n        if from_name.endswith(\".attn.qkv.weight\"):\n            weight_type = from_name.split(\".\")[-1]  # weight or bias\n            to_names = (\n                \"model.layers.{}.self_attn.q_proj.{}\".format(*ids, weight_type),\n                \"model.layers.{}.self_attn.k_proj.{}\".format(*ids, weight_type),\n                \"model.layers.{}.self_attn.v_proj.{}\".format(*ids, weight_type),\n            )\n            params = param.split(\n                (\n                    config.n_head * config.head_size,\n                    config.n_query_groups * config.head_size,\n                    config.n_query_groups * config.head_size,\n                )\n            )\n        else:\n            to_names = (weight_map[name_template].format(*ids),)\n            params = (param,)\n\n        for to_name, param in zip(to_names, params):\n            if saver is not None:\n                param = saver.store_early(param)\n            state_dict[to_name] = param\n\n\ndef qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:\n    \"\"\"Reassemble from a normal to an interleaved placement in a QKV matrix.\n    [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]\n    \"\"\"\n    q, k, v = param.split(\n        (\n            config.n_head * config.head_size,\n            config.n_query_groups * config.head_size,\n            config.n_query_groups * config.head_size,\n        )\n    )\n    qs = q.split(config.n_head // config.n_query_groups * config.head_size)\n    ks = k.split(config.head_size)\n    vs = v.split(config.head_size)\n    interleaved = [t for group in zip(qs, ks, vs) for t in group]\n    return torch.cat(interleaved)\n\n\ndef check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None:\n    if any(\"lora\" in wn for wn in lit_weights):\n        raise ValueError(\"Checkpoints with LoRA weights cannot be converted. Call `scripts/merge_lora.py` first.\")\n    if any(\"adapter\" in wn or \"gating_factor\" in wn for wn in lit_weights):\n        raise NotImplementedError(\"Converting adapter models is not supported.\")\n\n\n@torch.inference_mode()\ndef convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:\n    \"\"\"Convert a LitGPT trained checkpoint into a Hugging Face Transformers checkpoint.\"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n\n    output_dir.mkdir(parents=True, exist_ok=True)\n    output_path = output_dir / \"model.pth\"\n\n    if \"falcon\" in config.name:\n        copy_fn = partial(copy_weights_falcon, config)\n    elif config.name.startswith(\"Gemma-2\"):\n        copy_fn = partial(copy_weights_gemma_2, config)\n    elif config.name.startswith(\"Gemma-3\"):\n        copy_fn = partial(copy_weights_gemma_3, config)\n    elif config.name.lower().startswith(\"phi\"):\n        copy_fn = partial(copy_weights_phi, config)\n    elif config.name.lower().startswith((\"qwen2.5\", \"qwq\")):\n        copy_fn = partial(copy_weights_qwen_2_5, config)\n    elif config.name.lower().startswith(\"olmo-2-\"):\n        copy_fn = partial(copy_weights_olmo2, config)\n    elif config.name.lower().startswith(\"qwen3\"):\n        copy_fn = partial(copy_weights_qwen_3, config)\n    elif config.mlp_class_name in (\"LLaMAMLP\", \"GemmaMLP\", \"LLaMAMoE\"):\n        untie_weights = \"Gemma\" in config.name\n        copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)\n    else:\n        copy_fn = partial(copy_weights_gpt_neox, config)\n\n    # initialize a new empty state dict to hold our new weights\n    sd = {}\n    with incremental_save(output_path) as saver:\n        lit_weights = lazy_load(checkpoint_dir / \"lit_model.pth\")\n        lit_weights = lit_weights.get(\"model\", lit_weights)\n        check_conversion_supported(lit_weights)\n        copy_fn(sd, lit_weights, saver=saver)\n        gc.collect()\n        saver.save(sd)\n"
  },
  {
    "path": "litgpt/scripts/convert_pretrained_checkpoint.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom pathlib import Path\nfrom pprint import pprint\n\nimport torch\n\nfrom litgpt.utils import copy_config_files, extend_checkpoint_dir, incremental_save\n\n\n@torch.inference_mode()\ndef convert_pretrained_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:\n    \"\"\"Convert a checkpoint after pretraining.\n\n    The pretrained checkpoint contains optimizer states and several other metadata that are not needed after training\n    is finished. This script will export the state-dict of the model and place it in the chosen output folder,\n    which then can be loaded by other scripts for inference, evaluation, etc.\n\n    Args:\n        checkpoint_dir: Path to a checkpoint directory produced by ``litgpt.pretrain``.\n        output_dir: The output folder where the converted state-dict file and config files will be saved to.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    pprint(locals())\n\n    if output_dir.is_dir() and output_dir.glob(\"*\"):\n        raise FileExistsError(\n            f\"The output folder exists and is not empty: {str(output_dir)}.\"\n            \" Please delete it first or choose a different name.\"\n        )\n\n    output_dir.mkdir(parents=True)\n    checkpoint_file = checkpoint_dir / \"lit_model.pth\"\n    output_checkpoint_file = output_dir / \"lit_model.pth\"\n\n    # TODO: Consolidate sharded checkpoint if applicable\n    # Extract the model state dict and save to output folder\n    with incremental_save(output_checkpoint_file) as saver:\n        print(\"Processing\", checkpoint_file)\n        full_checkpoint = torch.load(str(checkpoint_file), mmap=True)\n        loaded_state_dict = full_checkpoint[\"model\"]\n        converted_state_dict = {}\n        for param_name, param in loaded_state_dict.items():\n            saver.store_early(param)\n            # remove prefix for compiled model (if any)\n            param_name = param_name.replace(\"_orig_mod.\", \"\")\n            converted_state_dict[param_name] = param\n        print(f\"Saving converted checkpoint to {str(output_checkpoint_file)}.\")\n        saver.save(converted_state_dict)\n\n    copy_config_files(checkpoint_dir, output_dir)\n"
  },
  {
    "path": "litgpt/scripts/download.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport importlib.util\nimport os\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple\n\nfrom litgpt.config import configs\nfrom litgpt.constants import _HF_TRANSFER_AVAILABLE, _SAFETENSORS_AVAILABLE\nfrom litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint\n\n\ndef download_from_hub(\n    repo_id: str,\n    access_token: Optional[str] = os.getenv(\"HF_TOKEN\"),\n    tokenizer_only: bool = False,\n    convert_checkpoint: bool = True,\n    dtype: Optional[str] = None,\n    checkpoint_dir: Path = Path(\"checkpoints\"),\n    model_name: Optional[str] = None,\n) -> None:\n    \"\"\"Download weights or tokenizer data from the Hugging Face Hub.\n\n    Arguments:\n        repo_id: The repository ID in the format ``org/name`` or ``user/name`` as shown in Hugging Face.\n            If \"list\" is provided as input, a list of the currently supported models in LitGPT and quits.\n        access_token: Optional API token to access models with restrictions.\n        tokenizer_only: Whether to download only the tokenizer files.\n        convert_checkpoint: Whether to convert the checkpoint files to the LitGPT format after downloading.\n        dtype: The data type to convert the checkpoint files to. If not specified, the weights will remain in the\n            dtype they are downloaded in.\n        checkpoint_dir: Where to save the downloaded files.\n        model_name: The existing config name to use for this repo_id. This is useful to download alternative weights of\n            existing architectures.\n    \"\"\"\n    options = [f\"{config['hf_config']['org']}/{config['hf_config']['name']}\" for config in configs]\n\n    if repo_id == \"list\":\n        print(\"Please specify --repo_id <repo_id>. Available values:\")\n        print(\"\\n\".join(sorted(options, key=lambda x: x.lower())))\n        return\n\n    if model_name is None and repo_id not in options:\n        print(\n            f\"Unsupported `repo_id`: {repo_id}.\"\n            \"\\nIf you are trying to download alternative \"\n            \"weights for a supported model, please specify the corresponding model via the `--model_name` option, \"\n            \"for example, `litgpt download NousResearch/Hermes-2-Pro-Llama-3-8B --model_name Llama-3-8B`.\"\n            \"\\nAlternatively, please choose a valid `repo_id` from the list of supported models, which can be obtained via \"\n            \"`litgpt download list`.\"\n        )\n        return\n\n    from huggingface_hub import snapshot_download\n\n    if importlib.util.find_spec(\"hf_transfer\") is None:\n        print(\n            \"It is recommended to install hf_transfer for faster checkpoint download speeds: `pip install hf_transfer`\"\n        )\n\n    download_files = [\"tokenizer*\", \"generation_config.json\", \"config.json\"]\n    if not tokenizer_only:\n        bins, safetensors = find_weight_files(repo_id, access_token)\n        if bins:\n            # covers `.bin` files and `.bin.index.json`\n            download_files.append(\"*.bin*\")\n        elif safetensors:\n            if not _SAFETENSORS_AVAILABLE:\n                raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE))\n            download_files.append(\"*.safetensors*\")\n        else:\n            raise ValueError(f\"Couldn't find weight files for {repo_id}\")\n\n    import huggingface_hub._snapshot_download as download\n    import huggingface_hub.constants as constants\n\n    previous = constants.HF_HUB_ENABLE_HF_TRANSFER\n    if _HF_TRANSFER_AVAILABLE and not previous:\n        print(\"Setting HF_HUB_ENABLE_HF_TRANSFER=1\")\n        constants.HF_HUB_ENABLE_HF_TRANSFER = True\n        download.HF_HUB_ENABLE_HF_TRANSFER = True\n\n    directory = checkpoint_dir / repo_id\n    with gated_repo_catcher(repo_id, access_token):\n        snapshot_download(\n            repo_id,\n            local_dir=directory,\n            allow_patterns=download_files,\n            token=access_token,\n        )\n\n    constants.HF_HUB_ENABLE_HF_TRANSFER = previous\n    download.HF_HUB_ENABLE_HF_TRANSFER = previous\n\n    if convert_checkpoint and not tokenizer_only:\n        print(\"Converting checkpoint files to LitGPT format.\")\n        convert_hf_checkpoint(checkpoint_dir=directory, dtype=dtype, model_name=model_name)\n\n\ndef find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[str], List[str]]:\n    from huggingface_hub import repo_info\n    from huggingface_hub.utils import filter_repo_objects\n\n    with gated_repo_catcher(repo_id, access_token):\n        info = repo_info(repo_id, token=access_token)\n    filenames = [f.rfilename for f in info.siblings]\n    bins = list(filter_repo_objects(items=filenames, allow_patterns=[\"*model*.bin*\"]))\n    safetensors = list(filter_repo_objects(items=filenames, allow_patterns=[\"*.safetensors*\"]))\n    return bins, safetensors\n\n\n@contextmanager\ndef gated_repo_catcher(repo_id: str, access_token: Optional[str]):\n    try:\n        yield\n    except OSError as e:\n        err_msg = str(e)\n        if \"Repository Not Found\" in err_msg:\n            raise ValueError(\n                f\"Repository at https://huggingface.co/api/models/{repo_id} not found.\"\n                \" Please make sure you specified the correct `repo_id`.\"\n            ) from None\n        elif \"gated repo\" in err_msg:\n            if not access_token:\n                raise ValueError(\n                    f\"https://huggingface.co/{repo_id} requires authentication, please set the `HF_TOKEN=your_token`\"\n                    \" environment variable or pass `--access_token=your_token`. You can find your token by visiting\"\n                    \" https://huggingface.co/settings/tokens.\"\n                ) from None\n            else:\n                raise ValueError(\n                    f\"https://huggingface.co/{repo_id} requires authentication. The access token provided by `HF_TOKEN=your_token`\"\n                    \" environment variable or `--access_token=your_token` may not have sufficient access rights. Please\"\n                    f\" visit https://huggingface.co/{repo_id} for more information.\"\n                ) from None\n        raise e from None\n"
  },
  {
    "path": "litgpt/scripts/merge_lora.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\"\"\"This script merges the LoRA weights with the base model\"\"\"\n\nfrom pathlib import Path\nfrom pprint import pprint\nfrom typing import Any, Dict, Optional, Tuple\n\nimport lightning as L\nimport torch\nimport yaml\n\nfrom litgpt.lora import GPT, Config, lora_filter, merge_lora_weights\nfrom litgpt.utils import check_valid_checkpoint_dir, extend_checkpoint_dir\n\n\ndef merge_lora(\n    checkpoint_dir: Path, pretrained_checkpoint_dir: Optional[Path] = None, precision: Optional[str] = None\n) -> None:\n    \"\"\"Merges the LoRA weights with the base model.\n\n    See ``litgpt finetune lora``.\n\n    Creates a new ``lit_model.pth`` file by merging the LoRA weights (``lit_model.pth.lora``)\n    with the original checkpoint weights.\n\n    Arguments:\n        checkpoint_dir: Path to the checkpoint directory with trained LoRA weights, which is the output of\n            ``litgpt finetune lora``.\n        pretrained_checkpoint_dir: Optional path to the checkpoint directory with the weights of the base model\n            corresponding to the LoRA checkpoint. By default, this will automatically be inferred from the metadata\n            in the given `checkpoint_dir` directory. Only set this if the base model's checkpoint directory\n            has moved or was renamed.\n        precision: Optional precision setting to instantiate the model weights in. By default, this will\n            automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.\n    \"\"\"\n    checkpoint_dir = extend_checkpoint_dir(checkpoint_dir)\n    if pretrained_checkpoint_dir is not None:\n        pretrained_checkpoint_dir = extend_checkpoint_dir(pretrained_checkpoint_dir)\n    pprint(locals())\n\n    check_valid_checkpoint_dir(checkpoint_dir, model_filename=\"lit_model.pth.lora\")\n    if pretrained_checkpoint_dir is not None:\n        check_valid_checkpoint_dir(pretrained_checkpoint_dir)\n    if (checkpoint_dir / \"lit_model.pth\").is_file():\n        print(\"LoRA weights have already been merged in this checkpoint.\")\n        return\n\n    lora_params, meta_pretrained_checkpoint_dir, lora_precision = load_lora_metadata(checkpoint_dir)\n    precision = precision if precision is not None else lora_precision\n\n    if pretrained_checkpoint_dir is None:\n        pretrained_checkpoint_dir = meta_pretrained_checkpoint_dir\n        pretrained_checkpoint_dir = extend_checkpoint_dir(pretrained_checkpoint_dir)\n\n    fabric = L.Fabric(devices=1, precision=precision, accelerator=\"cpu\")\n    config = Config.from_file(checkpoint_dir / \"model_config.yaml\", **lora_params)\n\n    with fabric.init_module(), torch.device(\"meta\"):\n        model = GPT(config)\n        # we don't care about these to perform merging\n        model.cos = None\n        model.sin = None\n\n    lora_path = checkpoint_dir / \"lit_model.pth.lora\"\n    pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / \"lit_model.pth\"), mmap=True)\n    lora_checkpoint = torch.load(str(lora_path), mmap=True)\n    lora_checkpoint = lora_checkpoint.get(\"model\", lora_checkpoint)\n\n    # Merge LoRA weights into the base model\n    pretrained_checkpoint.update(lora_checkpoint)\n    model.load_state_dict(pretrained_checkpoint, assign=True)\n    # since LoRA finetuning only saves the LoRA weights, we treat the lora weights dtype as the expected dtype\n    lora_dtype = next(iter(lora_checkpoint.values())).dtype\n    model.to(dtype=lora_dtype, device=\"cpu\")\n    merge_lora_weights(model)\n\n    # Remove LoRA parameters and the LoRA linear substring\n    state_dict = {k.replace(\"linear.\", \"\"): v for k, v in model.state_dict().items() if not lora_filter(k, v)}\n    save_path = checkpoint_dir / \"lit_model.pth\"\n    torch.save(state_dict, save_path)\n\n    fabric.print(f\"Saved merged weights to {str(checkpoint_dir / 'lit_model.pth')!r}\")\n\n\ndef load_lora_metadata(checkpoint_dir: Path) -> Tuple[Dict[str, Any], Path, Optional[str]]:\n    hparams_file = checkpoint_dir / \"hyperparameters.yaml\"\n    if not hparams_file.is_file():\n        raise FileNotFoundError(\n            f\"The path {str(hparams_file)!r} is not a valid checkpoint directory. It is missing a\"\n            f\" `hyperparameters.yaml` file. Please point to the checkpoint directory that was produced by\"\n            f\" the `litgpt/finetune/lora.py` script.\"\n        )\n\n    with open(hparams_file, encoding=\"utf-8\") as file:\n        hparams = yaml.safe_load(file)\n\n    lora_params = {k: v for k, v in hparams.items() if k.startswith(\"lora_\")}\n    pretrained_checkpoint_dir = Path(hparams[\"checkpoint_dir\"])\n    precision = hparams.get(\"precision\")\n    return lora_params, pretrained_checkpoint_dir, precision\n"
  },
  {
    "path": "litgpt/tokenizer.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport json\nfrom pathlib import Path\nfrom typing import Iterable, Iterator, Optional, Union\n\nimport torch\n\nfrom litgpt.utils import fix_and_load_json\n\n\nclass Tokenizer:\n    def __init__(self, checkpoint_dir: Union[Path, str]) -> None:\n        checkpoint_dir = Path(checkpoint_dir)\n        if not checkpoint_dir.exists():\n            raise NotADirectoryError(f\"The checkpoint directory does not exist: {str(checkpoint_dir)}\")\n\n        self.model_name = checkpoint_dir.stem\n        self.use_bos = self.check_if_bos_token_used(checkpoint_dir)\n        self.bos_id = None\n        self.eos_id = None\n\n        # some checkpoints have both files, `.json` takes precedence\n        if (vocabulary_path := checkpoint_dir / \"tokenizer.json\").is_file():\n            from tokenizers import Tokenizer as HFTokenizer\n\n            self.processor = HFTokenizer.from_file(str(vocabulary_path))\n            self.backend = \"huggingface\"\n\n            if (special_tokens_path := checkpoint_dir / \"tokenizer_config.json\").is_file():\n                with open(special_tokens_path, encoding=\"utf-8\") as fp:\n                    config = json.load(fp)\n                bos_token = config.get(\"bos_token\")\n                eos_token = config.get(\"eos_token\")\n                if bos_token is not None and isinstance(bos_token, dict):\n                    bos_token = bos_token.get(\"content\")\n                if eos_token is not None and isinstance(eos_token, dict):\n                    eos_token = eos_token.get(\"content\")\n                self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None\n                self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None\n            if (special_tokens_path := checkpoint_dir / \"generation_config.json\").is_file():\n                try:\n                    with open(special_tokens_path, encoding=\"utf-8\") as fp:\n                        config = json.load(fp)\n                except json.JSONDecodeError:  # Some files like the Llama 3.2 one have bugs\n                    with open(special_tokens_path, encoding=\"utf-8\") as fp:\n                        json_string = fp.read()\n                        config = fix_and_load_json(json_string)\n                if self.bos_id is None:\n                    self.bos_id = config.get(\"bos_token_id\")\n                if self.eos_id is None:\n                    self.eos_id = config.get(\"eos_token_id\")\n\n        elif (vocabulary_path := checkpoint_dir / \"tokenizer.model\").is_file():\n            from sentencepiece import SentencePieceProcessor\n\n            self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))\n            self.backend = \"sentencepiece\"\n            self.bos_id = self.processor.bos_id()\n            self.eos_id = self.processor.eos_id()\n        else:\n            raise NotImplementedError\n\n        # NOTE: A temporary fix until it's resolved on Tokenizers side.\n        # LlaMA tokenizer strips leading spaces if to decode a single token at a time.\n        # https://github.com/huggingface/transformers/issues/31643\n        self.apply_decoding_fix = None\n        if (config_path := checkpoint_dir / \"tokenizer_config.json\").is_file():\n            with open(config_path, encoding=\"utf-8\") as fp:\n                self.apply_decoding_fix = \"LlamaTokenizer\" in json.load(fp)[\"tokenizer_class\"]\n\n    @property\n    def vocab_size(self) -> int:\n        if self.backend == \"huggingface\":\n            return self.processor.get_vocab_size(with_added_tokens=False)\n        if self.backend == \"sentencepiece\":\n            return self.processor.vocab_size()\n        raise RuntimeError\n\n    def token_to_id(self, token: str) -> int:\n        if self.backend == \"huggingface\":\n            id_ = self.processor.token_to_id(token)\n        elif self.backend == \"sentencepiece\":\n            id_ = self.processor.piece_to_id(token)\n        else:\n            raise RuntimeError\n        if id_ is None:\n            raise ValueError(f\"token {token!r} not found in the collection.\")\n        return id_\n\n    def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:\n        if not (tokenizer_config_path := checkpoint_dir / \"tokenizer_config.json\").is_file():\n            return False\n        with open(tokenizer_config_path, encoding=\"utf-8\") as fp:\n            config = json.load(fp)\n        # for LlaMA-3 tokenizer there is no `add_bos_token` at all and `tokenizer_class` is only\n        # `PreTrainedTokenizerFast`\n        if checkpoint_dir.stem.startswith((\"Meta-Llama-3\", \"Llama-3\")):\n            return True\n        if checkpoint_dir.stem.startswith(\"SmolLM2\") and checkpoint_dir.name.endswith(\"Instruct\"):\n            return True\n        if \"add_bos_token\" in config:\n            return config[\"add_bos_token\"]\n        # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.\n        # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2\n        return config.get(\"tokenizer_class\") == \"LlamaTokenizer\"\n\n    def encode(\n        self,\n        string: str,\n        device: Optional[torch.device] = None,\n        bos: Optional[bool] = None,\n        eos: bool = False,\n        max_length: int = -1,\n    ) -> torch.Tensor:\n        if self.backend == \"huggingface\":\n            tokens = self.processor.encode(string).ids\n        elif self.backend == \"sentencepiece\":\n            tokens = self.processor.encode(string)\n        else:\n            raise RuntimeError(f\"`{self.backend}` is not supported.\")\n        if tokens is None:\n            raise ValueError(\"`self.processor` returned tokens of None value.\")\n\n        if bos or (bos is None and self.use_bos):\n            if self.bos_id is None:\n                raise NotImplementedError(\"This tokenizer does not have a defined bos token.\")\n            if not tokens or tokens[0] != self.bos_id:\n                tokens = [self.bos_id] + tokens\n        # if the processor misbehaves and adds `bos` token no matter what\n        elif tokens and tokens[0] == self.bos_id:\n            tokens = tokens[1:]\n\n        if eos and (not tokens or tokens[-1] != self.eos_id):\n            tokens = tokens + [self.eos_id]\n        # if the processor misbehaves and adds `eos` token no matter what\n        elif tokens and tokens[-1] == self.eos_id:\n            tokens = tokens[:-1]\n\n        if max_length > 0:\n            tokens = tokens[:max_length]\n        return torch.tensor(tokens, dtype=torch.int, device=device)\n\n    def decode(self, tensor: torch.Tensor) -> str:\n        tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()\n        if len(tokens) == 1 and self.apply_decoding_fix:\n            dummy_token_id = 33  # \\x1e\n            dummy_token = self.processor.decode([dummy_token_id])\n            if dummy_token != \"\\x1e\":\n                dummy_token_id = 165  # \\x1e is different in salamandra tokenizers\n                dummy_token = self.processor.decode([dummy_token_id])\n            return self.processor.decode([dummy_token_id] + tokens)[len(dummy_token) :]\n        return self.processor.decode(tokens)\n\n    def decode_stream(\n        self, token_stream: Iterable[torch.Tensor], device: Optional[torch.device] = None\n    ) -> Iterator[str]:\n        if self.backend == \"huggingface\":\n            try:\n                for token in token_stream:\n                    yield self.decode(token)\n            except KeyboardInterrupt:\n                return\n        elif self.backend == \"sentencepiece\":\n            # TODO: Is there a way to not have to do this?\n            # This may actually affect our tokens per second.\n\n            # sentencepiece does not support decoding token-by-token because it adds spaces based on the surrounding tokens\n            # meaning that we need to decode everything each time\n            so_far = torch.tensor([], dtype=torch.long, device=device)\n            decoded_so_far = \"\"\n            try:\n                for token in token_stream:\n                    so_far = so_far.to(device=token.device)\n                    so_far = torch.cat((so_far, token.view(-1)))\n                    decoded_new = self.decode(so_far)\n                    yield decoded_new[len(decoded_so_far) :]\n                    decoded_so_far = decoded_new\n            except KeyboardInterrupt:\n                return\n        else:\n            raise NotImplementedError(self.backend)\n"
  },
  {
    "path": "litgpt/types.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\"\"\"Type aliases used across LitGPT modules.\"\"\"\n\nfrom typing import Literal\n\n# Logger-related types\nLoggerChoice = Literal[\"csv\", \"tensorboard\", \"wandb\", \"mlflow\", \"litlogger\"]\n\"\"\"Valid logger choices for experiment tracking.\n\nAvailable options:\n- \"csv\": Local CSV file logging (default for most scripts)\n- \"tensorboard\": TensorBoard visualization (default for pretrain)\n- \"wandb\": Weights & Biases cloud tracking\n- \"mlflow\": MLflow experiment tracking\n- \"litlogger\": Lightning.ai native tracking\n\"\"\"\n"
  },
  {
    "path": "litgpt/utils.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\n\"\"\"Utility functions for training and inference.\"\"\"\n\nimport inspect\nimport json\nimport math\nimport os\nimport pickle\nimport random\nimport re\nimport shutil\nimport subprocess\nimport sys\nimport warnings\nfrom dataclasses import asdict, is_dataclass\nfrom io import BytesIO\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union\n\nimport lightning as L\nimport psutil\nimport torch\nimport torch.nn as nn\nimport torch.utils._device\nimport yaml\nfrom lightning.fabric.loggers import CSVLogger, TensorBoardLogger\nfrom lightning.fabric.strategies import FSDPStrategy, ModelParallelStrategy\nfrom lightning.fabric.utilities.load import _lazy_load as lazy_load\nfrom lightning.pytorch.cli import instantiate_class\nfrom lightning.pytorch.loggers import MLFlowLogger, WandbLogger\nfrom packaging import version\nfrom torch.serialization import normalize_storage_type\nfrom typing_extensions import Self\n\nfrom litgpt.constants import (\n    _LITLOGGER_AVAILABLE,\n    _SUPPORTED_LOGGERS,\n    _THUNDER_AVAILABLE,\n)\nfrom litgpt.types import LoggerChoice\n\nif TYPE_CHECKING:\n    from litgpt import GPT, Config\n\n\ndef init_out_dir(out_dir: Path) -> Path:\n    if not isinstance(out_dir, Path):\n        out_dir = Path(out_dir)\n    if not out_dir.is_absolute() and \"LIGHTNING_ARTIFACTS_DIR\" in os.environ:\n        return Path(os.getenv(\"LIGHTNING_ARTIFACTS_DIR\")) / out_dir\n    return out_dir\n\n\ndef find_resume_path(resume: Union[bool, Literal[\"auto\"], Path], out_dir: Path) -> Optional[Path]:\n    if not resume or isinstance(resume, Path):\n        return resume\n\n    resume_path = max(out_dir.rglob(\"step-*/*.pth\"), key=(lambda p: int(p.parent.name.split(\"-\")[1])), default=None)\n    if resume == \"auto\":\n        return resume_path\n    if resume is True and resume_path is None:\n        raise FileNotFoundError(\n            f\"You passed `--resume=True`, but no checkpoint file was found in `--out_dir={out_dir}`.\"\n        )\n    return resume_path\n\n\ndef num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:\n    total = 0\n    for p in module.parameters():\n        if requires_grad is None or p.requires_grad == requires_grad:\n            if hasattr(p, \"quant_state\"):\n                # bitsandbytes 4bit layer support\n                total += math.prod(p.quant_state.shape)\n            else:\n                total += p.numel()\n    return total\n\n\ndef reset_parameters(module: nn.Module) -> None:\n    \"\"\"Calls `reset_parameters` on the module and all its submodules.\"\"\"\n    for mod in module.modules():\n        if callable(getattr(mod, \"reset_parameters\", None)):\n            mod.reset_parameters()\n\n\ndef check_valid_checkpoint_dir(\n    checkpoint_dir: Path,\n    model_filename: str = \"lit_model.pth\",\n    verbose: bool = True,\n    raise_error: bool = False,\n    ignore_tokenizer_files: bool = False,\n) -> None:\n    files = {\n        model_filename: (checkpoint_dir / model_filename).is_file(),\n        \"model_config.yaml\": (checkpoint_dir / \"model_config.yaml\").is_file(),\n    }\n    if not ignore_tokenizer_files:\n        files.update(\n            {\n                \"tokenizer.json OR tokenizer.model\": (checkpoint_dir / \"tokenizer.json\").is_file()\n                or (checkpoint_dir / \"tokenizer.model\").is_file(),\n                \"tokenizer_config.json\": (checkpoint_dir / \"tokenizer_config.json\").is_file(),\n            }\n        )\n\n    if checkpoint_dir.is_dir():\n        if all(files.values()):\n            # we're good\n            return\n        problem = f\" is missing the files: {[f for f, exists in files.items() if not exists]!r}\"\n    else:\n        problem = \" is not a checkpoint directory\"\n\n    # list locally available checkpoints\n    available = list(Path(\"checkpoints\").glob(\"*/*\"))\n    if available:\n        options = \"\\n\".join([\"\"] + [repr(str(p.resolve())) for p in available])\n        extra = f\"\\nYou have downloaded locally:{options}\\n\"\n    else:\n        extra = \"\"\n\n    if verbose:\n        error_message = (\n            f\"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}.\"\n            \"\\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\\n\"\n            f\"{extra}\\nSee all download options by running:\\n litgpt download\"\n        )\n        print(error_message, file=sys.stderr)\n\n    if raise_error:\n        raise FileNotFoundError(f\"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}.\")\n    else:\n        raise SystemExit(1)\n\n\nclass SavingProxyForStorage:\n    def __init__(self, obj, saver, protocol_version=5):\n        self.protocol_version = protocol_version\n        self.saver = saver\n        if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):\n            raise TypeError(f\"expected storage, not {type(obj)}\")\n\n        # this logic is taken from PyTorch 2.0+ torch/serialization.py\n        if isinstance(obj, torch.storage.TypedStorage):\n            # PT upstream wants to deprecate this eventually...\n            storage = obj._untyped_storage\n            storage_type_str = obj._pickle_storage_type()\n            storage_type = getattr(torch, storage_type_str)\n            storage_numel = obj._size()\n        else:\n            storage = obj\n            storage_type = normalize_storage_type(type(obj))\n            storage_numel = storage.nbytes()\n\n        storage_key = saver._write_storage_and_return_key(storage)\n        location = torch.serialization.location_tag(storage)\n\n        self.storage_info = (\"storage\", storage_type, storage_key, location, storage_numel)\n\n    def __reduce_ex__(self, protocol_version):\n        assert False, \"this should be handled with out of band\"\n\n\nclass SavingProxyForTensor:\n    def __init__(self, tensor, saver, protocol_version=5):\n        self.protocol_version = protocol_version\n        self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)\n        if reduce_args[0] == torch._utils._rebuild_tensor_v2:\n            # for Tensors with Python attributes\n            (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args\n            assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), (\n                \"Please check for updates\"\n            )\n            storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)\n            self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)\n        else:\n            (storage, *other_reduce_args) = reduce_args\n            assert isinstance(storage, (torch.storage.TypedStorage, torch.storage.UntypedStorage)), (\n                \"Please check for updates\"\n            )\n            storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version)\n            self.reduce_args = (storage_proxy, *other_reduce_args)\n\n    def __reduce_ex__(self, protocol_version):\n        if protocol_version != self.protocol_version:\n            raise RuntimeError(f\"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}\")\n        return self.reduce_ret_fn, self.reduce_args\n\n\nclass IncrementalPyTorchPickler(pickle.Pickler):\n    def __init__(self, saver, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.storage_dtypes = {}\n        self.saver = saver\n        self.id_map = {}\n\n    # this logic is taken from PyTorch 2.0+ torch/serialization.py\n    def persistent_id(self, obj):\n        # FIXME: the docs say that persistent_id should only return a string\n        # but torch store returns tuples. This works only in the binary protocol\n        # see\n        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects\n        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537\n        if isinstance(obj, SavingProxyForStorage):\n            return obj.storage_info\n\n        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):\n            if isinstance(obj, torch.storage.TypedStorage):\n                # TODO: Once we decide to break serialization FC, this case\n                # can be deleted\n                storage = obj._untyped_storage\n                storage_dtype = obj.dtype\n                storage_type_str = obj._pickle_storage_type()\n                storage_type = getattr(torch, storage_type_str)\n                storage_numel = obj._size()\n\n            else:\n                storage = obj\n                storage_dtype = torch.uint8\n                storage_type = normalize_storage_type(type(obj))\n                storage_numel = storage.nbytes()\n\n            # If storage is allocated, ensure that any other saved storages\n            # pointing to the same data all have the same dtype. If storage is\n            # not allocated, don't perform this check\n            if storage.data_ptr() != 0:\n                if storage.data_ptr() in self.storage_dtypes:\n                    if storage_dtype != self.storage_dtypes[storage.data_ptr()]:\n                        raise RuntimeError(\n                            \"Cannot save multiple tensors or storages that view the same data as different types\"\n                        )\n                else:\n                    self.storage_dtypes[storage.data_ptr()] = storage_dtype\n\n            storage_key = self.id_map.get(storage._cdata)\n            if storage_key is None:\n                storage_key = self.saver._write_storage_and_return_key(storage)\n                self.id_map[storage._cdata] = storage_key\n            location = torch.serialization.location_tag(storage)\n\n            return (\"storage\", storage_type, storage_key, location, storage_numel)\n\n        return None\n\n\nclass incremental_save:\n    def __init__(self, name):\n        self.name = name\n        self.zipfile = torch._C.PyTorchFileWriter(str(name))\n        self.has_saved = False\n        self.next_key = 0\n        self.protocol_version = 2\n\n    def __enter__(self):\n        return self\n\n    def store_early(self, tensor):\n        if isinstance(tensor, torch.Tensor):\n            return SavingProxyForTensor(tensor, self, protocol_version=self.protocol_version)\n        raise TypeError(f\"can only store tensors early, not {type(tensor)}\")\n\n    def save(self, obj):\n        if self.has_saved:\n            raise RuntimeError(\"have already saved\")\n        # Write the pickle data for `obj`\n        data_buf = BytesIO()\n        pickler = IncrementalPyTorchPickler(self, data_buf, protocol=self.protocol_version)\n        pickler.dump(obj)\n        data_value = data_buf.getvalue()\n        self.zipfile.write_record(\"data.pkl\", data_value, len(data_value))\n        self.has_saved = True\n\n    def _write_storage_and_return_key(self, storage):\n        if self.has_saved:\n            raise RuntimeError(\"have already saved\")\n        key = self.next_key\n        self.next_key += 1\n        name = f\"data/{key}\"\n        if storage.device.type != \"cpu\":\n            storage = storage.cpu()\n        num_bytes = storage.nbytes()\n\n        current_version = version.parse(torch.__version__)\n        threshold_version = version.parse(\"2.2.2\")\n        if current_version <= threshold_version:\n            self.zipfile.write_record(name, storage.data_ptr(), num_bytes)\n        else:\n            self.zipfile.write_record(name, storage, num_bytes)\n\n        return key\n\n    def __exit__(self, type, value, traceback):\n        self.zipfile.write_end_of_file()\n\n\nT = TypeVar(\"T\")\n\n\ndef chunked_cross_entropy(\n    logits: Union[torch.Tensor, List[torch.Tensor]],\n    targets: torch.Tensor,\n    chunk_size: int = 128,\n    ignore_index: int = -100,\n) -> torch.Tensor:\n    # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate\n    # the memory usage in fine-tuning settings with low number of parameters.\n    # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing\n    # the memory spike's magnitude\n\n    # lm_head was chunked (we are fine-tuning)\n    if isinstance(logits, list):\n        # don't want to chunk cross entropy\n        if chunk_size == 0:\n            logits = torch.cat(logits, dim=1)\n            logits = logits.reshape(-1, logits.size(-1))\n            targets = targets.reshape(-1)\n            return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index)\n\n        # chunk cross entropy\n        logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]\n        target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)]\n        loss_chunks = [\n            torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction=\"none\")\n            for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)\n        ]\n        non_masked_elems = (targets != ignore_index).sum()\n        # See [non_masked_elems div note]\n        return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))\n\n    # no chunking at all\n    logits = logits.reshape(-1, logits.size(-1))\n    targets = targets.reshape(-1)\n    if chunk_size == 0:\n        return torch.nn.functional.cross_entropy(logits, targets, ignore_index=ignore_index)\n\n    # lm_head wasn't chunked, chunk cross entropy\n    logit_chunks = logits.split(chunk_size)\n    target_chunks = targets.split(chunk_size)\n    loss_chunks = [\n        torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=ignore_index, reduction=\"none\")\n        for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)\n    ]\n    non_masked_elems = (targets != ignore_index).sum()\n    # [non_masked_elems div note]:\n    #   max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that\n    #   results in a python int which is then passed back to torch division. By using the\n    #   `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.\n    return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(torch.ones_like(non_masked_elems))\n\n\ndef map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:\n    for checkpoint_name, attribute_name in mapping.items():\n        full_checkpoint_name = prefix + checkpoint_name\n        if full_checkpoint_name in state_dict:\n            full_attribute_name = prefix + attribute_name\n            state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)\n    return state_dict\n\n\ndef get_default_supported_precision(training: bool) -> str:\n    \"\"\"\n    Return the default precision that is supported by the hardware: either `bf16` or `16`.\n\n    Args:\n        training: If True, returns '-mixed' version of the precision; if False, returns '-true' version.\n\n    Returns:\n        The default precision that is suitable for the task and is supported by the hardware.\n    \"\"\"\n    import torch\n\n    if torch.cuda.is_available():\n        if torch.cuda.is_bf16_supported():\n            return \"bf16-mixed\" if training else \"bf16-true\"\n        else:\n            return \"16-mixed\" if training else \"16-true\"\n    return \"bf16-mixed\" if training else \"bf16-true\"\n\n\ndef load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:\n    if isinstance(fabric.strategy, FSDPStrategy):\n        fabric.load_raw(checkpoint_path, model, strict=strict)\n    elif isinstance(fabric.strategy, ModelParallelStrategy):\n        state_dict = torch.load(checkpoint_path, mmap=True)\n        load_from_full_model_state_dict(\n            model=model,\n            full_sd=state_dict,\n            device=fabric.device,\n            strict=strict,\n            cpu_offload=True,\n        )\n    else:\n        state_dict = lazy_load(checkpoint_path)\n        state_dict = state_dict.get(\"model\", state_dict)\n        model.load_state_dict(state_dict, strict=strict)\n\n\ndef load_checkpoint_update(\n    fabric: L.Fabric, adapter_path: Path, model: nn.Module, checkpoint_path: Path, strict: bool = True\n) -> None:\n    if isinstance(fabric.strategy, FSDPStrategy):\n        fabric.load_raw(checkpoint_path, model, strict=strict)\n    else:\n        state_dict = lazy_load(checkpoint_path)\n        state_dict = state_dict.get(\"model\", state_dict)\n        adapter_cp = lazy_load(adapter_path)\n        state_dict.update(adapter_cp)\n        model.load_state_dict(state_dict, strict=strict)\n\n\ndef load_from_full_model_state_dict(\n    model: torch.nn.Module,\n    full_sd: Dict[str, Any],\n    device: torch.device,\n    strict: bool = False,\n    cpu_offload: bool = False,\n):\n    from torch.distributed._tensor import distribute_tensor\n\n    meta_sharded_sd = model.state_dict()\n    sharded_sd = {}\n    print(meta_sharded_sd.keys())\n    for param_name, full_tensor in full_sd.items():\n        if \"norm\" not in param_name and \"wte\" not in param_name and \"ln_f\" not in param_name:\n            param_name = param_name.replace(\".weight\", \".linear.weight\")\n            param_name = param_name.replace(\".bias\", \".linear.bias\")\n        else:\n            param_name = param_name\n\n        print(param_name)\n\n        sharded_meta_param = meta_sharded_sd.get(param_name)\n        full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device)\n        sharded_tensor = distribute_tensor(\n            full_tensor,\n            sharded_meta_param.device_mesh,\n            sharded_meta_param.placements,\n        )\n        if cpu_offload:\n            sharded_tensor = sharded_tensor.cpu()\n        sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor)\n    # choose `assign=True` since we cannot call `copy_` on meta tensor\n    return model.load_state_dict(sharded_sd, strict=strict, assign=True)\n\n\ndef flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:\n    flops_per_token = 2 * n_params  # each parameter is used for a MAC (2 FLOPS) per network operation\n    # this assumes that all samples have a fixed length equal to the block size\n    # which is most likely false during finetuning\n    flops_per_seq = flops_per_token * max_seq_length\n    attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))\n    return flops_per_seq + attn_flops_per_seq\n\n\ndef estimate_flops(model: \"GPT\", training: bool) -> int:\n    \"\"\"Measures estimated FLOPs for MFU.\n\n    Refs:\n        * https://ar5iv.labs.arxiv.org/html/2205.05198#A1\n        * https://ar5iv.labs.arxiv.org/html/2204.02311#A2\n    \"\"\"\n    # using all parameters for this is a naive over estimation because not all model parameters actually contribute to\n    # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage\n    # (~10%) compared to the measured FLOPs, making those lower but more realistic.\n    # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.\n    n_trainable_params = num_parameters(model, requires_grad=True)\n    trainable_flops = flops_per_param(\n        model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params\n    )\n    # forward + backward + gradients (assumes no gradient accumulation)\n    ops_per_step = 3 if training else 1\n    n_frozen_params = num_parameters(model, requires_grad=False)\n    frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params)\n    # forward + backward\n    frozen_ops_per_step = 2 if training else 1\n    return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops\n\n\nclass CycleIterator:\n    \"\"\"An iterator that cycles through an iterable indefinitely.\n\n    Example:\n        >>> iterator = CycleIterator([1, 2, 3])\n        >>> [next(iterator) for _ in range(5)]\n        [1, 2, 3, 1, 2]\n\n    Note:\n        Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.\n    \"\"\"\n\n    def __init__(self, iterable: Iterable) -> None:\n        self.iterable = iterable\n        self.epoch = 0\n        self._iterator = None\n\n    def __next__(self) -> Any:\n        if self._iterator is None:\n            self._iterator = iter(self.iterable)\n        try:\n            return next(self._iterator)\n        except StopIteration:\n            self._iterator = iter(self.iterable)\n            self.epoch += 1\n            return next(self._iterator)\n\n    def __iter__(self) -> Self:\n        return self\n\n\ndef copy_config_files(source_dir: Path, out_dir: Path) -> None:\n    \"\"\"Copies the specified configuration and tokenizer files into the output directory.\"\"\"\n\n    config_files = [\"config.json\", \"generation_config.json\", \"model_config.yaml\"]\n    tokenizer_files = [\"tokenizer.json\", \"tokenizer.model\", \"tokenizer_config.json\"]\n\n    for file_name in config_files + tokenizer_files:\n        src_path = source_dir / file_name\n        if src_path.exists():\n            shutil.copy(src_path, out_dir)\n\n\ndef CLI(*args: Any, **kwargs: Any) -> Any:\n    from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options\n\n    set_docstring_parse_options(attribute_docstrings=True)\n    set_config_read_mode(urls_enabled=True)\n\n    return CLI(*args, **kwargs)\n\n\ndef capture_hparams() -> Dict[str, Any]:\n    \"\"\"Captures the local variables ('hyperparameters') from where this function gets called.\"\"\"\n    caller_frame = inspect.currentframe().f_back\n    locals_of_caller = caller_frame.f_locals\n    hparams = {}\n    for name, value in locals_of_caller.items():\n        if value is None or isinstance(value, (int, float, str, bool, Path)):\n            hparams[name] = value\n        elif is_dataclass(value):\n            hparams[name] = asdict(value)\n        else:\n            hparams[name] = str(value)\n    return hparams\n\n\ndef save_config(config: \"Config\", checkpoint_dir: Path) -> None:\n    config_dict = asdict(config)\n    with open(checkpoint_dir / \"model_config.yaml\", \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(config_dict, fp)\n\n\ndef parse_devices(devices: Union[str, int]) -> int:\n    if devices in (-1, \"auto\"):\n        return torch.cuda.device_count() or 1\n    if isinstance(devices, int) and devices > 0:\n        return devices\n    raise ValueError(f\"Devices must be 'auto' or a positive integer, got: {devices!r}\")\n\n\ndef choose_logger(\n    logger_name: LoggerChoice,\n    out_dir: Path,\n    name: str,\n    log_interval: int = 1,\n    log_args: Optional[Dict] = None,\n    resume: Optional[bool] = None,\n    **kwargs: Any,\n):\n    if logger_name == \"csv\":\n        return CSVLogger(root_dir=(out_dir / \"logs\"), name=\"csv\", flush_logs_every_n_steps=log_interval, **kwargs)\n    if logger_name == \"tensorboard\":\n        return TensorBoardLogger(root_dir=(out_dir / \"logs\"), name=\"tensorboard\", **kwargs)\n    if logger_name == \"wandb\":\n        project = log_args.pop(\"project\", name)\n        run = log_args.pop(\"run\", os.environ.get(\"WANDB_RUN_NAME\"))\n        group = log_args.pop(\"group\", os.environ.get(\"WANDB_RUN_GROUP\"))\n        return WandbLogger(project=project, name=run, group=group, resume=resume, **kwargs)\n    if logger_name == \"mlflow\":\n        return MLFlowLogger(experiment_name=name, **kwargs)\n    if logger_name == \"litlogger\":\n        if not _LITLOGGER_AVAILABLE:\n            raise ModuleNotFoundError(_LITLOGGER_AVAILABLE)\n        from lightning.pytorch.loggers import LitLogger\n\n        # Extract litlogger-specific args\n        teamspace = log_args.pop(\"teamspace\", None) if log_args else None\n        metadata = log_args.pop(\"metadata\", None) if log_args else None\n        log_model = log_args.pop(\"log_model\", False) if log_args else False\n        save_logs = log_args.pop(\"save_logs\", True) if log_args else True\n        checkpoint_name = log_args.pop(\"checkpoint_name\", None) if log_args else None\n\n        return LitLogger(\n            root_dir=(out_dir / \"logs\"),\n            name=name,\n            teamspace=teamspace,\n            metadata=metadata,\n            log_model=log_model,\n            save_logs=save_logs,\n            checkpoint_name=checkpoint_name,\n            **kwargs,\n        )\n    raise ValueError(\n        f\"`--logger_name={logger_name}` is not a valid option. Choose from {', '.join(_SUPPORTED_LOGGERS)}.\"\n    )\n\n\ndef get_argument_names(cls):\n    sig = inspect.signature(cls.__init__)\n    return {\n        name\n        for name, param in sig.parameters.items()\n        if param.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]\n    }\n\n\ndef instantiate_bnb_optimizer(optimizer, model_parameters):\n    if (isinstance(optimizer, str) and \"AdamW\" not in optimizer) or (\n        isinstance(optimizer, dict) and \"AdamW\" not in optimizer.get(\"class_path\", \"\")\n    ):\n        raise ValueError(\"The chosen quantization format only supports the AdamW optimizer.\")\n\n    import bitsandbytes as bnb\n\n    if isinstance(optimizer, str):\n        optimizer = bnb.optim.PagedAdamW(model_parameters)\n    else:\n        optim_args = get_argument_names(bnb.optim.PagedAdamW)\n        allowed_kwargs = {key: optimizer[\"init_args\"][key] for key in optim_args & optimizer[\"init_args\"].keys()}\n        optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)\n    return optimizer\n\n\ndef instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):\n    # Special care taken where some optimizers do not have some parameters referenced in some of the code, for example \"fused\" in the pretrain.py script:\n    #   bnb.optim.AdamW8bit\n    #   grokadamw.GrokAdamW\n    #   torch.optim.RMSprop\n\n    if isinstance(optimizer, str):\n        if \".\" in optimizer:\n            class_module, class_name = optimizer.rsplit(\".\", 1)\n        else:\n            class_module, class_name = \"torch.optim\", optimizer\n\n        module = __import__(class_module, fromlist=[class_name])\n        optimizer_cls = getattr(module, class_name)\n\n        valid_params = set(inspect.signature(optimizer_cls).parameters)\n        kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}\n        optimizer = optimizer_cls(model_parameters, **kwargs)\n    elif isinstance(optimizer, dict):\n        optimizer = dict(optimizer)\n        class_module, class_name = optimizer[\"class_path\"].rsplit(\".\", 1)\n        module = __import__(class_module, fromlist=[class_name])\n        optimizer_cls = getattr(module, class_name)\n\n        valid_params = set(inspect.signature(optimizer_cls).parameters)\n        kwargs = {key: value for key, value in dict(kwargs).items() if key in valid_params}\n\n        optimizer[\"init_args\"].update(kwargs)\n        optimizer = instantiate_class(model_parameters, optimizer)\n    else:\n        raise ValueError(f'Unrecognized \"optimizer\" value: {optimizer}')\n\n    return optimizer\n\n\ndef extend_checkpoint_dir(checkpoint_dir: Path) -> Path:\n    new_checkpoint_dir = \"checkpoints\" / checkpoint_dir\n    should_return_new_dir = (\n        not checkpoint_dir.is_dir()\n        and checkpoint_dir.parts[0] != \"checkpoints\"\n        and not checkpoint_dir.is_absolute()\n        and new_checkpoint_dir.exists()\n    )\n    return new_checkpoint_dir if should_return_new_dir else checkpoint_dir\n\n\ndef check_file_size_on_cpu_and_warn(checkpoint_path, device, size_limit=4_509_715_660):\n    \"\"\"\n    Checks the file size and raises a warning if it exceeds the size_limit.\n    The default size limit is 4.2 GB, the size of TinyLlama 1.1B: 4.2 * 1024 * 1024 * 1024 = 4_509_715_660\n    \"\"\"\n    size = 0.0\n    if os.path.exists(checkpoint_path):\n        size = os.path.getsize(checkpoint_path)\n        if size > size_limit and str(device) == \"cpu\":\n            warnings.warn(\n                f\"The file size of {checkpoint_path} is over {size_limit / 1024 / 1024 / 1024:.1f} GB. Using a model \"\n                \"with more than 1B parameters on a CPU can be slow, it is recommended to switch to a GPU.\"\n            )\n    return size\n\n\ndef auto_download_checkpoint(model_name, access_token=None, ignore_tokenizer_files=False):\n    from litgpt.scripts.download import download_from_hub  # moved here due to circular import issue\n\n    checkpoint_dir = extend_checkpoint_dir(Path(model_name))\n    try:\n        check_valid_checkpoint_dir(\n            checkpoint_dir, verbose=False, raise_error=True, ignore_tokenizer_files=ignore_tokenizer_files\n        )\n    except FileNotFoundError as e:\n        if access_token is None:\n            access_token = os.getenv(\"HF_TOKEN\")\n\n        if checkpoint_dir.parts[0] != \"checkpoints\" and not checkpoint_dir.is_absolute():\n            download_from_hub(repo_id=str(model_name), access_token=access_token)\n            checkpoint_dir = Path(\"checkpoints\") / checkpoint_dir\n        else:\n            raise e\n\n    return checkpoint_dir\n\n\ndef check_nvlink_connectivity(fabric=None):\n    \"\"\"Checks GPU connectivity for both NVIDIA and AMD GPUs.\n\n    This function delegates to vendor-specific implementations based on\n    the detected GPU vendor.\n    \"\"\"\n    if fabric is not None:\n        custom_print = fabric.print\n    else:\n        custom_print = print\n\n    if os.getenv(\"RANK\", \"0\") == \"0\":\n        try:\n            if torch.cuda.is_available():\n                device_properties = torch.cuda.get_device_properties(0)\n                gpu_name = device_properties.name.lower()\n                if \"nvidia\" in gpu_name:\n                    _check_nvidia_connectivity(custom_print)\n                elif \"advanced micro devices\" in gpu_name or \"amd\" in gpu_name:\n                    _check_amd_connectivity(custom_print)\n                else:\n                    custom_print(f\"Unrecognized GPU vendor: {device_properties.name}\")\n            else:\n                custom_print(\"No GPUs available\")\n        except Exception as e:\n            custom_print(f\"An error occurred while checking GPU connectivity: {e}\")\n\n\ndef _check_nvidia_connectivity(custom_print):\n    \"\"\"Checks NVLink connectivity on NVIDIA GPUs.\"\"\"\n    result = subprocess.run([\"nvidia-smi\", \"topo\", \"-m\"], stdout=subprocess.PIPE, text=True)\n    if result.returncode != 0:\n        custom_print(\"Failed to run nvidia-smi\")\n        return\n\n    lines = result.stdout.strip().split(\"\\n\")\n    start_index = next((i for i, line in enumerate(lines) if \"GPU0\" in line), None)\n    if start_index is None:\n        custom_print(\"Failed to parse nvidia-smi output\")\n        return\n\n    headers_line = lines[start_index]\n    headers = headers_line.split()\n    gpu_regex = re.compile(r\"^GPU\\d+$\")\n    gpu_count = len([header for header in headers if gpu_regex.match(header)])\n\n    all_nvlink = True\n    for line in lines[start_index + 1 : start_index + 1 + gpu_count]:\n        columns = line.split()\n        connections = columns[1 : 1 + gpu_count]\n        if not all(\"NV\" in conn for conn in connections if conn != \"X\"):\n            all_nvlink = False\n            break\n\n    if all_nvlink:\n        custom_print(\"All GPUs are fully connected via NVLink.\")\n    else:\n        custom_print(\n            \"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. \"\n            \"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance.\"\n        )\n\n\ndef _check_amd_connectivity(custom_print):\n    \"\"\"Checks XGMI connectivity on AMD GPUs.\"\"\"\n    result = subprocess.run([\"rocm-smi\", \"--showtopotype\"], stdout=subprocess.PIPE, text=True)\n    if result.returncode != 0:\n        custom_print(\"Failed to run rocm-smi\")\n        return\n\n    lines = result.stdout.strip().split(\"\\n\")\n    gpu_header_index = next((i for i, line in enumerate(lines) if re.match(r\"^\\s*GPU0\", line)), None)\n    if gpu_header_index is None or gpu_header_index == 0:\n        custom_print(\"Failed to parse rocm-smi output (no GPU headers found)\")\n        return\n\n    header_line = lines[gpu_header_index - 1]\n    headers = header_line.strip().split()\n    gpu_regex = re.compile(r\"^GPU\\d+$\")\n    gpu_count = len([header for header in headers if gpu_regex.match(header)])\n\n    gpu_lines = []\n    for line in lines[gpu_header_index : gpu_header_index + gpu_count]:\n        if re.match(r\"^\\s*GPU\\d+\", line):\n            gpu_lines.append(line.strip())\n    if len(gpu_lines) != gpu_count:\n        custom_print(\"Mismatch in GPU count when parsing rocm-smi output\")\n        return\n\n    all_xgmi = True\n    for line in gpu_lines:\n        columns = line.split()\n        connections = columns[1 : 1 + gpu_count]\n        for conn in connections:\n            if conn not in (\"XGMI\", \"0\"):\n                all_xgmi = False\n                break\n        if not all_xgmi:\n            break\n\n    if all_xgmi:\n        custom_print(\"All GPUs are fully connected via XGMI.\")\n    else:\n        custom_print(\n            \"Warning: Not all GPUs are fully connected via XGMI. Some GPUs are connected via slower interfaces. \"\n            \"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance.\"\n        )\n\n\ndef fix_and_load_json(s):\n    # Remove trailing commas before } or ]\n    s = re.sub(r\",(\\s*[}\\]])\", r\"\\1\", s)\n\n    # Insert missing commas between properties\n    # Match positions where a value is followed by a newline and then a quote without a comma\n    pattern = r'(?<=[}\\]0-9truefalsenull\"])\\s*(\\n\\s*)\"'\n    replacement = r',\\1\"'\n    s = re.sub(pattern, replacement, s)\n\n    # Now try to parse the JSON\n    try:\n        return json.loads(s)\n    except json.JSONDecodeError as e:\n        raise ValueError(f\"Failed to parse JSON after fixing: {e}\")\n\n\ndef create_finetuning_performance_report(training_time, token_counts, device_type):\n    tok_sec = token_counts[\"raw_tokens_plus_prompt_template_and_padding\"] / training_time\n    output = f\"\"\"\n| ------------------------------------------------------\n| Token Counts\n| - Input Tokens              :  {token_counts[\"raw_tokens\"]:>5}\n| - Tokens w/ Prompt          :  {token_counts[\"raw_tokens_plus_prompt_template\"]:>5}\n| - Total Tokens (w/ Padding) :  {token_counts[\"raw_tokens_plus_prompt_template_and_padding\"]:>5}\n| -----------------------------------------------------\n| Performance\n| - Training Time             :  {training_time:.2f} s\n| - Tok/sec                   :  {tok_sec:.2f} tok/s\n| -----------------------------------------------------\n\"\"\"\n\n    if device_type == \"cuda\":\n        memory_used = torch.cuda.max_memory_allocated() / 1e9\n        output += \"| Memory Usage                                                                 \\n\"\n        output += f\"| - Memory Used               :  {memory_used:.02f} GB                                        \\n\"\n    output += \"-------------------------------------------------------\\n\"\n\n    return output\n\n\ndef select_sft_generate_example(eval, data):\n    if eval.evaluate_example == \"first\":\n        if len(data.test_dataset.data):\n            instruction = data.test_dataset.data[0][\"instruction\"]\n        else:\n            instruction = data.train_dataset.data[0][\"instruction\"]\n\n    elif eval.evaluate_example == \"random\":\n        if len(data.test_dataset.data):\n            random_idx = random.randint(0, len(data.test_dataset.data) - 1)\n            instruction = data.test_dataset.data[random_idx][\"instruction\"]\n        else:\n            random_idx = random.randint(0, len(data.train_dataset.data) - 1)\n            instruction = data.train_dataset.data[random_idx][\"instruction\"]\n\n    elif isinstance(eval.evaluate_example, int):\n        index = eval.evaluate_example\n        if len(data.test_dataset.data) > index:\n            instruction = data.test_dataset.data[index][\"instruction\"]\n        elif len(data.train_dataset.data) > index:\n            instruction = data.train_dataset.data[index][\"instruction\"]\n        else:\n            raise IndexError(f\"Index {index} is out of range for both test and training datasets.\")\n\n    else:\n        raise ValueError(f\"Unknown evaluation example type: {eval.evaluate_example}\")\n    return instruction\n\n\ndef _RunIf(thunder: bool = False, **kwargs):\n    import pytest\n    from lightning.fabric.utilities.testing import _runif_reasons\n\n    reasons, marker_kwargs = _runif_reasons(**kwargs)\n\n    if thunder and not _THUNDER_AVAILABLE:\n        # if we require Thunder, but it's not available, we should skip\n        reasons.append(\"Thunder\")\n\n    return pytest.mark.skipif(condition=len(reasons) > 0, reason=f\"Requires: [{' + '.join(reasons)}]\", **marker_kwargs)\n\n\ndef kill_process_tree(pid: int):\n    \"\"\"\n    Kill a process and all its child processes given the parent PID.\n    \"\"\"\n    try:\n        parent = psutil.Process(pid)\n        children = parent.children(recursive=True)\n        for child in children:\n            child.kill()\n        parent.kill()\n    except psutil.NoSuchProcess:\n        pass  # Process already exited\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nbuild-backend = \"setuptools.build_meta\"\n\nrequires = [\n  \"setuptools>=68.2.2\",\n  \"wheel>=0.41.2\",\n]\n\n[project]\nname = \"litgpt\"\nversion = \"0.5.12\"\ndescription = \"Hackable implementation of state-of-the-art open-source LLMs\"\nreadme = \"README.md\"\nlicense = { file = \"LICENSE\" }\n\nauthors = [\n  { name = \"Lightning AI\", email = \"contact@lightning.ai\" },\n]\nrequires-python = \">=3.10\"\nclassifiers = [\n  \"Programming Language :: Python :: 3 :: Only\",\n  \"Programming Language :: Python :: 3.10\",\n  \"Programming Language :: Python :: 3.11\",\n  \"Programming Language :: Python :: 3.12\",\n  \"Programming Language :: Python :: 3.13\",\n  \"Programming Language :: Python :: 3.14\",\n]\ndependencies = [\n  # download models:\n  \"huggingface-hub>=0.30,<1.4\",\n  \"jsonargparse[signatures]>=4.37,<=4.41; python_version>='3.10'\", # required to work with Python >=3.10\n  \"lightning>=2.6.1\",\n  \"psutil==7.1.3\",\n  \"safetensors>=0.4.3\",\n  # tokenization in most models:\n  \"tokenizers>=0.21\",\n  \"torch>=2.7\",\n  # convert_hf_checkpoint\n  \"tqdm>4.66\",\n]\n\noptional-dependencies.compiler = [\n  # compilaton:\n  \"lightning-thunder>=0.2.dev20250119; python_version>='3.10' and sys_platform=='linux'\",\n]\noptional-dependencies.extra = [\n  \"bitsandbytes>=0.42,<0.43; sys_platform=='darwin'\",\n  # quantization:\n  \"bitsandbytes>=0.45.2,<0.50; sys_platform=='linux' or sys_platform=='win32'\",\n  # litgpt.evaluate:\n  \"datasets>=2.18,<4\",\n  # download:\n  \"huggingface-hub[hf-transfer]>=0.21\",\n  \"litdata==0.2.59\",\n  # litgpt logging:\n  \"litlogger>=0.1.7\",\n  # litgpt.deploy:\n  \"litserve>0.2\",\n  # lm-eval: pinned <0.4.9.1 due to trust_remote_code issues with datasets like logiqa.\n  # See: https://github.com/EleutherAI/lm-evaluation-harness/issues/3171\n  \"lm-eval>=0.4.2,<0.4.9.1\",\n  # litgpt.data.prepare_starcoder.py:\n  \"pandas>=1.9\",\n  \"pyarrow>=15.0.2\",\n  # litgpt.data:\n  \"requests>=2.31\",\n  # llama-based models:\n  \"sentencepiece>=0.2\",\n  # litgpt.pretrain:\n  \"tensorboard>=2.14\",\n  \"torchmetrics>=1.3.1\",\n  \"transformers>=4.51.3,<4.57\",\n  # litdata, only on non-Windows:\n  \"uvloop>=0.2; sys_platform!='win32'\",\n  # litgpt.data.prepare_slimpajama.py:\n  \"zstandard>=0.22\",\n]\noptional-dependencies.test = [\n  \"einops>=0.7\",\n  \"protobuf>=4.23.4\",\n  \"pytest>=8.1.1\",\n  \"pytest-benchmark>=5.1\",\n  \"pytest-dependency>=0.6\",\n  \"pytest-rerunfailures>=14\",\n  \"pytest-timeout>=2.3.1\",\n]\nurls.documentation = \"https://github.com/lightning-AI/litgpt/tutorials\"\nurls.homepage = \"https://github.com/lightning-AI/litgpt\"\nscripts.litgpt = \"litgpt.__main__:main\"\n\n[tool.setuptools.packages.find]\ninclude = [\n  \"litgpt\",\n  \"litgpt.*\",\n]\nexclude = [  ]\n\n[tool.setuptools.package-data]\nlitgpt = [\n  \"LICENSE\",\n  \"README.md\",\n]\n\n[tool.ruff]\ntarget-version = \"py38\"\nline-length = 120\nexclude = [\n  \"build\",\n  \"dist\",\n  \"docs\",\n]\n\nlint.select = [\n  \"E\",\n  \"F\",  # see: https://pypi.org/project/pyflakes\n  \"I\",  # implementation for isort\n  \"UP\", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up\n  \"W\",  # see: https://pypi.org/project/pycodestyle\n]\n#extend-select = [\n#    \"C4\",  # see: https://pypi.org/project/flake8-comprehensions\n#    \"PT\",  # see: https://pypi.org/project/flake8-pytest-style\n#    \"RET\",  # see: https://pypi.org/project/flake8-return\n#    \"SIM\",  # see: https://pypi.org/project/flake8-simplify\n#]\nlint.ignore = [\n  \"E501\", # Line too long\n  \"E731\", # Do not assign a lambda expression, use a def\n  \"E741\", # todo: Ambiguous variable name\n  \"F841\", # todo: Local variable is assigned to but never used\n]\n# Use Google-style docstrings.\nlint.pydocstyle.convention = \"google\"\n\n[tool.codespell]\n#skip = '*.py'\nquiet-level = 3\nignore-words-list = \"\"\"\n  tral, \\\n  Rockerfeller\n\"\"\"\n\n[tool.pytest.ini_options]\naddopts = [\n  \"--strict-markers\",\n  #\"--doctest-modules\",\n  \"--color=yes\",\n  \"--disable-pytest-warnings\",\n]\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport shutil\nimport sys\nfrom pathlib import Path\nfrom typing import List, Optional\n\nimport pytest\nimport torch\n\n# support running without installing as a package, adding extensions to the Python path\nwd = Path(__file__).parent.parent.resolve()\nif wd.is_dir():\n    sys.path.append(str(wd))\nelse:\n    import warnings\n\n    warnings.warn(f\"Could not find extensions directory at {wd}\")\n\n\n@pytest.fixture()\ndef fake_checkpoint_dir(tmp_path):\n    os.chdir(tmp_path)\n    checkpoint_dir = tmp_path / \"checkpoints\" / \"tmp\"\n    checkpoint_dir.mkdir(parents=True)\n    (checkpoint_dir / \"lit_model.pth\").touch()\n    (checkpoint_dir / \"model_config.yaml\").touch()\n    (checkpoint_dir / \"tokenizer.json\").touch()\n    (checkpoint_dir / \"tokenizer_config.json\").touch()\n    return checkpoint_dir\n\n\nclass TensorLike:\n    def __eq__(self, other):\n        return isinstance(other, torch.Tensor)\n\n\n@pytest.fixture()\ndef tensor_like():\n    return TensorLike()\n\n\nclass FloatLike:\n    def __eq__(self, other):\n        return not isinstance(other, int) and isinstance(other, float)\n\n\n@pytest.fixture()\ndef float_like():\n    return FloatLike()\n\n\n@pytest.fixture(autouse=True)\ndef restore_default_dtype():\n    # just in case\n    torch.set_default_dtype(torch.float32)\n\n\n@pytest.fixture(autouse=True)\ndef destroy_process_group():\n    yield\n\n    import torch.distributed\n\n    if torch.distributed.is_available() and torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\n@pytest.fixture\ndef turn_off_tf32_and_set_seed(monkeypatch):\n    monkeypatch.setenv(\"NVIDIA_TF32_OVERRIDE\", \"0\")\n    torch.manual_seed(42)\n    yield\n    torch.seed()\n\n\nclass MockTokenizer:\n    \"\"\"A dummy tokenizer that encodes each character as its ASCII code.\"\"\"\n\n    bos_id = 0\n    eos_id = 1\n\n    def encode(self, text: str, bos: Optional[bool] = None, eos: bool = False, max_length: int = -1) -> torch.Tensor:\n        output = []\n        if bos:\n            output.append(self.bos_id)\n        output.extend([ord(c) for c in text])\n        if eos:\n            output.append(self.eos_id)\n        output = output[:max_length] if max_length > 0 else output\n        return torch.tensor(output)\n\n    def decode(self, tokens: torch.Tensor) -> str:\n        return \"\".join(chr(int(t)) for t in tokens.tolist())\n\n\n@pytest.fixture()\ndef mock_tokenizer():\n    return MockTokenizer()\n\n\n@pytest.fixture()\ndef alpaca_path(tmp_path):\n    file = Path(__file__).parent / \"data\" / \"_fixtures\" / \"alpaca.json\"\n    shutil.copyfile(file, tmp_path / \"alpaca.json\")\n    return tmp_path / \"alpaca.json\"\n\n\n@pytest.fixture()\ndef dolly_path(tmp_path):\n    file = Path(__file__).parent / \"data\" / \"_fixtures\" / \"dolly.json\"\n    shutil.copyfile(file, tmp_path / \"dolly.json\")\n    return tmp_path / \"dolly.json\"\n\n\n@pytest.fixture()\ndef longform_path(tmp_path):\n    path = tmp_path / \"longform\"\n    path.mkdir()\n    for split in (\"train\", \"val\"):\n        file = Path(__file__).parent / \"data\" / \"_fixtures\" / f\"longform_{split}.json\"\n        shutil.copyfile(file, path / f\"{split}.json\")\n    return path\n\n\n# https://github.com/Lightning-AI/lightning/blob/6e517bd55b50166138ce6ab915abd4547702994b/tests/tests_fabric/conftest.py#L140\ndef pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:\n    initial_size = len(items)\n    conditions = []\n    filtered, skipped = 0, 0\n\n    options = {\"standalone\": \"PL_RUN_STANDALONE_TESTS\", \"min_cuda_gpus\": \"RUN_ONLY_CUDA_TESTS\"}\n    if os.getenv(options[\"standalone\"], \"0\") == \"1\" and os.getenv(options[\"min_cuda_gpus\"], \"0\") == \"1\":\n        # special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests.\n        # by deleting the key, we avoid filtering out the CPU tests\n        del options[\"min_cuda_gpus\"]\n\n    for kwarg, env_var in options.items():\n        # this will compute the intersection of all tests selected per environment variable\n        if os.getenv(env_var, \"0\") == \"1\":\n            conditions.append(env_var)\n            for i, test in reversed(list(enumerate(items))):  # loop in reverse, since we are going to pop items\n                already_skipped = any(marker.name == \"skip\" for marker in test.own_markers)\n                if already_skipped:\n                    # the test was going to be skipped anyway, filter it out\n                    items.pop(i)\n                    skipped += 1\n                    continue\n                has_runif_with_kwarg = any(\n                    marker.name == \"skipif\" and marker.kwargs.get(kwarg) for marker in test.own_markers\n                )\n                if not has_runif_with_kwarg:\n                    # the test has `@_RunIf(kwarg=True)`, filter it out\n                    items.pop(i)\n                    filtered += 1\n\n    if config.option.verbose >= 0 and (filtered or skipped):\n        writer = config.get_terminal_writer()\n        writer.write(\n            f\"\\nThe number of tests has been filtered from {initial_size} to {initial_size - filtered} after the\"\n            f\" filters {conditions}.\\n{skipped} tests are marked as unconditional skips.\\nIn total,\"\n            f\" {len(items)} tests will run.\\n\",\n            flush=True,\n            bold=True,\n            purple=True,  # oh yeah, branded pytest messages\n        )\n\n    for test in items:\n        if \"test_hf_for_nemo\" in test.nodeid and \"Qwen/Qwen2.5-7B-Instruct\" in test.nodeid:\n            test.add_marker(\n                # Don't use `raises=TypeError` because the actual exception is\n                # wrapped inside `torch._dynamo.exc.BackendCompilerFailed`,\n                # which prevents pytest from recognizing it as a TypeError.\n                pytest.mark.xfail(\n                    reason=\"currently not working, see https://github.com/Lightning-AI/lightning-thunder/issues/2085\",\n                )\n            )\n"
  },
  {
    "path": "tests/convert/__init__.py",
    "content": ""
  },
  {
    "path": "tests/convert/test_hf_checkpoint.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom unittest import mock\n\nimport pytest\nimport torch\n\nfrom litgpt import Config\nfrom litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint, copy_weights_hf_llama, qkv_reassemble\n\n\ndef test_llama2_70b_conversion():\n    shapes = {\n        \"model.embed_tokens.weight\": (32000, 8192),\n        \"model.layers.0.input_layernorm.weight\": (8192,),\n        \"model.layers.0.mlp.down_proj.weight\": (8192, 28672),\n        \"model.layers.0.mlp.gate_proj.weight\": (28672, 8192),\n        \"model.layers.0.mlp.up_proj.weight\": (28672, 8192),\n        \"model.layers.0.post_attention_layernorm.weight\": (8192,),\n        \"model.layers.0.self_attn.q_proj.weight\": (8192, 8192),\n        \"model.layers.0.self_attn.k_proj.weight\": (1024, 8192),\n        \"model.layers.0.self_attn.v_proj.weight\": (1024, 8192),\n        \"model.layers.0.self_attn.o_proj.weight\": (8192, 8192),\n        \"model.layers.1.input_layernorm.weight\": (8192,),\n        \"model.layers.1.mlp.down_proj.weight\": (8192, 28672),\n        \"model.layers.1.mlp.gate_proj.weight\": (28672, 8192),\n        \"model.layers.1.mlp.up_proj.weight\": (28672, 8192),\n        \"model.layers.1.post_attention_layernorm.weight\": (8192,),\n        \"model.layers.1.self_attn.o_proj.weight\": (8192, 8192),\n        \"model.layers.2.input_layernorm.weight\": (8192,),\n        \"model.layers.2.mlp.down_proj.weight\": (8192, 28672),\n        \"model.layers.2.mlp.gate_proj.weight\": (28672, 8192),\n        \"model.layers.2.mlp.up_proj.weight\": (28672, 8192),\n        \"model.layers.2.post_attention_layernorm.weight\": (8192,),\n        \"model.layers.2.self_attn.o_proj.weight\": (8192, 8192),\n        \"model.layers.3.input_layernorm.weight\": (8192,),\n        \"model.layers.3.mlp.down_proj.weight\": (8192, 28672),\n        \"model.layers.3.mlp.gate_proj.weight\": (28672, 8192),\n        \"model.layers.3.mlp.up_proj.weight\": (28672, 8192),\n        \"model.layers.3.post_attention_layernorm.weight\": (8192,),\n        \"model.layers.3.self_attn.o_proj.weight\": (8192, 8192),\n        \"model.layers.4.input_layernorm.weight\": (8192,),\n        \"model.layers.4.mlp.down_proj.weight\": (8192, 28672),\n        \"model.layers.4.mlp.gate_proj.weight\": (28672, 8192),\n        \"model.layers.4.mlp.up_proj.weight\": (28672, 8192),\n        \"model.layers.4.post_attention_layernorm.weight\": (8192,),\n        \"model.layers.4.self_attn.o_proj.weight\": (8192, 8192),\n        \"model.layers.5.mlp.gate_proj.weight\": (28672, 8192),\n        \"model.layers.5.self_attn.o_proj.weight\": (8192, 8192),\n    }\n\n    config = Config.from_name(\"Llama-2-70b-hf\")\n    holder = {}\n    qkv_weights = {}\n    with torch.device(\"meta\"):\n        weight_map = {k: torch.empty(s) for k, s in shapes.items()}\n    copy_weights_hf_llama(config, qkv_weights, holder, weight_map)\n\n    # NOTE: there are 5 layers, but only in the first layer we have `q`, `k` and `v`\n    assert len(qkv_weights) == 1\n    # there are no loaded qkv weights\n    assert all(v is None for qkv in qkv_weights.values() for v in qkv)\n    # the shapes are correct\n    holder = {k: tuple(t.shape) for k, t in holder.items()}\n    assert holder == {\n        \"transformer.h.0.attn.qkv.weight\": (10240, 8192),\n        \"transformer.h.0.attn.proj.weight\": (8192, 8192),\n        \"transformer.h.0.mlp.fc_1.weight\": (28672, 8192),\n        \"transformer.h.0.mlp.fc_2.weight\": (28672, 8192),\n        \"transformer.h.0.mlp.proj.weight\": (8192, 28672),\n        \"transformer.h.0.norm_1.weight\": (8192,),\n        \"transformer.h.0.norm_2.weight\": (8192,),\n        \"transformer.h.1.attn.proj.weight\": (8192, 8192),\n        \"transformer.h.1.mlp.fc_1.weight\": (28672, 8192),\n        \"transformer.h.1.mlp.fc_2.weight\": (28672, 8192),\n        \"transformer.h.1.mlp.proj.weight\": (8192, 28672),\n        \"transformer.h.1.norm_1.weight\": (8192,),\n        \"transformer.h.1.norm_2.weight\": (8192,),\n        \"transformer.h.2.attn.proj.weight\": (8192, 8192),\n        \"transformer.h.2.mlp.fc_1.weight\": (28672, 8192),\n        \"transformer.h.2.mlp.fc_2.weight\": (28672, 8192),\n        \"transformer.h.2.mlp.proj.weight\": (8192, 28672),\n        \"transformer.h.2.norm_1.weight\": (8192,),\n        \"transformer.h.2.norm_2.weight\": (8192,),\n        \"transformer.h.3.attn.proj.weight\": (8192, 8192),\n        \"transformer.h.3.mlp.fc_1.weight\": (28672, 8192),\n        \"transformer.h.3.mlp.fc_2.weight\": (28672, 8192),\n        \"transformer.h.3.mlp.proj.weight\": (8192, 28672),\n        \"transformer.h.3.norm_1.weight\": (8192,),\n        \"transformer.h.3.norm_2.weight\": (8192,),\n        \"transformer.h.4.attn.proj.weight\": (8192, 8192),\n        \"transformer.h.4.mlp.fc_1.weight\": (28672, 8192),\n        \"transformer.h.4.mlp.fc_2.weight\": (28672, 8192),\n        \"transformer.h.4.mlp.proj.weight\": (8192, 28672),\n        \"transformer.h.4.norm_1.weight\": (8192,),\n        \"transformer.h.4.norm_2.weight\": (8192,),\n        \"transformer.h.5.attn.proj.weight\": (8192, 8192),\n        \"transformer.h.5.mlp.fc_1.weight\": (28672, 8192),\n        \"transformer.wte.weight\": (32000, 8192),\n        \"lm_head.weight\": (32000, 8192),  # due to weight tying lm_head is in the converted weights\n    }\n\n\n@pytest.mark.parametrize(\"model_name\", (\"pythia-14m\", \"falcon-7b\", \"Llama-2-7b-hf\", \"phi-2\"))\ndef test_convert_hf_checkpoint(tmp_path, model_name):\n    with pytest.raises(ValueError, match=\"to contain .bin\"):\n        convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name)\n\n    bin_file = tmp_path / \"foo.bin\"\n    bin_file.touch()\n    with mock.patch(\"litgpt.scripts.convert_hf_checkpoint.lazy_load\") as load:\n        # bypass if-statement for weight tying\n        if model_name == \"Llama-2-7b-hf\":\n            load.return_value = {\"model.embed_tokens.weight\": torch.rand((10, 10))}\n        convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name)\n    load.assert_called_with(bin_file)\n\n    assert {p.name for p in tmp_path.glob(\"*\")} == {\"foo.bin\", \"model_config.yaml\", \"lit_model.pth\"}\n\n    # ensure that the config dict can be loaded\n    config = Config.from_file(tmp_path / \"model_config.yaml\")\n    assert isinstance(config, Config)\n\n\ndef test_qkv_reassemble():\n    # MHA\n    config = Config(n_embd=4, n_head=4)\n    qkv_interleaved = torch.tensor(\n        [\n            [0, 1, 2, 3],  # query\n            [16, 17, 18, 19],  # key\n            [32, 33, 34, 35],  # value\n            [4, 5, 6, 7],  # query\n            [20, 21, 22, 23],  # key\n            [36, 37, 38, 39],  # value\n            [8, 9, 10, 11],  # query\n            [24, 25, 26, 27],  # key\n            [40, 41, 42, 43],  # value\n            [12, 13, 14, 15],  # query\n            [28, 29, 30, 31],  # key\n            [44, 45, 46, 47],  # value\n        ]\n    )\n    qkv = qkv_reassemble(qkv_interleaved, config)\n    torch.testing.assert_close(\n        qkv,\n        torch.tensor(\n            [\n                [0, 1, 2, 3],  # query\n                [4, 5, 6, 7],  # query\n                [8, 9, 10, 11],  # query\n                [12, 13, 14, 15],  # query\n                [16, 17, 18, 19],  # key\n                [20, 21, 22, 23],  # key\n                [24, 25, 26, 27],  # key\n                [28, 29, 30, 31],  # key\n                [32, 33, 34, 35],  # value\n                [36, 37, 38, 39],  # value\n                [40, 41, 42, 43],  # value\n                [44, 45, 46, 47],  # value\n            ]\n        ),\n    )\n\n    # GQA\n    config = Config(n_embd=4, n_head=4, n_query_groups=2)\n    qkv_interleaved = torch.tensor(\n        [\n            [0, 1, 2, 3],  # query\n            [4, 5, 6, 7],  # query\n            [16, 17, 18, 19],  # key\n            [24, 25, 26, 27],  # value\n            [8, 9, 10, 11],  # query\n            [12, 13, 14, 15],  # query\n            [20, 21, 22, 23],  # key\n            [28, 29, 30, 31],  # value\n        ]\n    )\n    qkv = qkv_reassemble(qkv_interleaved, config)\n    torch.testing.assert_close(\n        qkv,\n        torch.tensor(\n            [\n                [0, 1, 2, 3],  # query\n                [4, 5, 6, 7],  # query\n                [8, 9, 10, 11],  # query\n                [12, 13, 14, 15],  # query\n                [16, 17, 18, 19],  # key\n                [20, 21, 22, 23],  # key\n                [24, 25, 26, 27],  # value\n                [28, 29, 30, 31],  # value\n            ]\n        ),\n    )\n\n    # MQA\n    config = Config(n_embd=4, n_head=4, n_query_groups=1)\n    qkv_interleaved = torch.tensor(\n        [\n            [0, 1, 2, 3],  # query\n            [4, 5, 6, 7],  # query\n            [8, 9, 10, 11],  # query\n            [12, 13, 14, 15],  # query\n            [16, 17, 18, 19],  # key\n            [20, 21, 22, 23],  # value\n        ]\n    )\n    qkv = qkv_reassemble(qkv_interleaved, config)\n    torch.testing.assert_close(\n        qkv,\n        torch.tensor(\n            [\n                [0, 1, 2, 3],  # query\n                [4, 5, 6, 7],  # query\n                [8, 9, 10, 11],  # query\n                [12, 13, 14, 15],  # query\n                [16, 17, 18, 19],  # key\n                [20, 21, 22, 23],  # value\n            ]\n        ),\n    )\n"
  },
  {
    "path": "tests/convert/test_lit_checkpoint.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nfrom dataclasses import asdict\nfrom unittest.mock import ANY\n\nimport pytest\nimport torch\nimport yaml\nfrom transformers import AutoConfig, AutoModelForCausalLM\nfrom transformers.models.falcon import FalconConfig, FalconForCausalLM\nfrom transformers.models.gemma import GemmaConfig, GemmaForCausalLM\nfrom transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM\nfrom transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig\nfrom transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM\nfrom transformers.models.llama import LlamaConfig, LlamaForCausalLM\nfrom transformers.models.mixtral import MixtralConfig, MixtralForCausalLM\nfrom transformers.models.olmo import OlmoConfig, OlmoForCausalLM\nfrom transformers.models.phi.configuration_phi import PhiConfig\nfrom transformers.models.phi.modeling_phi import PhiForCausalLM\nfrom transformers.models.phi3.configuration_phi3 import Phi3Config\nfrom transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM\nfrom transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM\n\nfrom litgpt import GPT, Config\nfrom litgpt.scripts.convert_lit_checkpoint import (\n    check_conversion_supported,\n    convert_lit_checkpoint,\n    copy_weights_falcon,\n    copy_weights_gemma_2,\n    copy_weights_gemma_3,\n    copy_weights_gpt_neox,\n    copy_weights_llama,\n    copy_weights_phi,\n    copy_weights_qwen_2_5,\n    qkv_reassemble,\n)\nfrom litgpt.utils import _RunIf\n\n\n@pytest.mark.parametrize(\"model_name\", (\"pythia-14m\", \"falcon-7b\", \"Llama-2-7b-hf\", \"phi-2\"))\ndef test_convert_lit_checkpoint(tmp_path, model_name):\n    ours_config = Config.from_name(model_name, block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128)\n    ours_model = GPT(ours_config)\n    checkpoint_path = tmp_path / \"lit_model.pth\"\n    config_path = tmp_path / \"model_config.yaml\"\n    torch.save(ours_model.state_dict(), checkpoint_path)\n    with open(config_path, \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n    output_dir = tmp_path / \"out_dir\"\n\n    convert_lit_checkpoint(checkpoint_path.parent, output_dir)\n    assert set(os.listdir(tmp_path)) == {\"lit_model.pth\", \"model_config.yaml\", \"out_dir\"}\n    assert os.path.isfile(output_dir / \"model.pth\")\n\n    # check checkpoint is unwrapped\n    torch.save({\"model\": ours_model.state_dict()}, checkpoint_path)\n    convert_lit_checkpoint(checkpoint_path.parent, output_dir)\n    converted_sd = torch.load(output_dir / \"model.pth\")\n    assert \"model\" not in converted_sd\n\n\n@torch.inference_mode()\ndef test_against_falcon_40b():\n    ours_config = Config.from_name(\"falcon-40b\", n_layer=2, n_head=8, n_query_groups=4, n_embd=32)\n    theirs_config = FalconConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_hidden_layers=ours_config.n_layer,\n        num_attention_heads=ours_config.n_head,\n        num_kv_heads=ours_config.n_query_groups,\n        new_decoder_architecture=True,\n        parallel_attn=ours_config.parallel_residual,\n        bias=ours_config.bias,\n    )\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_falcon(ours_config, theirs_state_dict, ours_state_dict)\n\n    theirs_model = FalconForCausalLM(theirs_config)\n    # assign must be set to True for torch.testing.assert_close to pass\n    theirs_model.load_state_dict(theirs_state_dict, assign=True)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\ndef test_against_original_gpt_neox():\n    ours_config = Config(block_size=64, vocab_size=100, n_layer=4, n_head=8, n_embd=16)\n    assert ours_config.padded_vocab_size == 512\n    theirs_config = GPTNeoXConfig(\n        hidden_act=\"gelu\",\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        initializer_range=0.02,\n        intermediate_size=ours_config.intermediate_size,\n        layer_norm_eps=1e-05,\n        max_position_embeddings=ours_config.block_size,\n        rotary_emb_base=10000,\n        rotary_pct=ours_config.rotary_percentage,\n        vocab_size=ours_config.padded_vocab_size,\n        use_parallel_residual=ours_config.parallel_residual,\n    )\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_gpt_neox(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = GPTNeoXForCausalLM(theirs_config)\n    # strict=False because we don't save the rotary embeddings inv frequency\n    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)\n    assert not keys.unexpected_keys\n    assert all(\"inv_freq\" in k for k in keys.missing_keys)\n\n    # test end to end\n    x = torch.randint(0, ours_config.padded_vocab_size, size=(2, ours_config.block_size), dtype=torch.int64)\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"ours_kwargs\", [{\"name\": \"Llama-2-7b-hf\"}, {\"name\": \"CodeLlama-7b-hf\"}, {\"name\": \"Llama-2-70b-chat-hf\"}]\n)\ndef test_against_hf_llama2(ours_kwargs):\n    ours_config = Config.from_name(\n        padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs\n    )\n    T = 5\n    theirs_config = LlamaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_query_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = LlamaForCausalLM(theirs_config)\n    theirs_model.load_state_dict(theirs_state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"Mixtral-8x7B-Instruct-v0.1\", \"Mixtral-8x22B-Instruct-v0.1\"))\ndef test_against_mixtral(model_name):\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_embd=32,\n        n_head=8,\n        n_query_groups=2,\n        intermediate_size=86,\n        n_expert=4,\n    )\n    T = 5\n    theirs_config = MixtralConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        num_local_experts=ours_config.n_expert,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = MixtralForCausalLM(theirs_config)\n    theirs_model.load_state_dict(theirs_state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32)\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"OLMo-1B-hf\", \"OLMo-7B-hf\"))\ndef test_against_olmo(model_name):\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_head=8,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    T = 5\n    theirs_config = OlmoConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        num_hidden_layers=ours_config.n_layer,\n        num_attention_heads=ours_config.n_head,\n        num_key_value_heads=ours_config.n_query_groups,\n        max_positional_embeddings=T,\n        attention_bias=ours_config.bias,\n        rope_theta=ours_config.rope_base,\n        tie_word_embeddings=(model_name == \"OLMo-1B-hf\"),\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config)\n    # tie weights\n    ours_model.lm_head.weight = ours_model.transformer.wte.weight\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=(model_name == \"OLMo-1B-hf\"))\n    theirs_model = OlmoForCausalLM(theirs_config)\n    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)\n    assert not keys.unexpected_keys\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\ndef test_against_original_open_llama_3b():\n    ours_config = Config.from_name(\"open_llama_3b\", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)\n    T = 5\n    theirs_config = LlamaConfig(\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = LlamaForCausalLM(theirs_config)\n    theirs_model.load_state_dict(theirs_state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"phi-1_5\", \"phi-2\"))\ndef test_against_hf_phi(model_name):\n    ours_config = Config.from_name(\n        model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5\n    )\n    T = 5\n    theirs_config = PhiConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        max_position_embeddings=ours_config.block_size,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        partial_rotary_factor=ours_config.rotary_percentage,\n    )\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = PhiForCausalLM(theirs_config)\n    # strict=False because we don't save the rotary embeddings inv frequency\n    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)\n    assert not keys.unexpected_keys\n    assert all(\"inv_freq\" in k for k in keys.missing_keys)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"Phi-3-mini-4k-instruct\",))\ndef test_against_hf_phi_3(model_name):\n    ours_config = Config.from_name(model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256)\n    T = 5\n    theirs_config = Phi3Config(\n        attention_bias=ours_config.bias,\n        head_dim=ours_config.head_size,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        num_key_value_heads=ours_config.n_query_groups,\n        pad_token_id=ours_config.padded_vocab_size - 1,\n        partial_rotary_factor=ours_config.rotary_percentage,\n        rms_norm_eps=ours_config.norm_eps,\n        rope_theta=ours_config.rope_base,\n        vocab_size=ours_config.padded_vocab_size,\n    )\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_phi(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = Phi3ForCausalLM(theirs_config)\n    # strict=False because we don't save the rotary embeddings inv frequency\n    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)\n    assert not keys.unexpected_keys\n    assert all(\"inv_freq\" in k for k in keys.missing_keys)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\ndef test_against_original_stablelm_zephyr_3b():\n    T = 5\n    ours_config = Config.from_name(\"stablelm-zephyr-3b\", n_layer=2, n_head=16, n_embd=32, intermediate_size=86)\n    theirs_config = AutoConfig.from_pretrained(\n        \"stabilityai/stablelm-zephyr-3b\",\n        trust_remote_code=True,\n        num_hidden_layers=ours_config.n_layer,\n        num_attention_heads=ours_config.n_head,\n        num_key_value_heads=ours_config.n_head,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config)\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = AutoModelForCausalLM.from_config(theirs_config, trust_remote_code=True, torch_dtype=torch.float32)\n    theirs_model.load_state_dict(theirs_state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", [\"gemma-2b\", \"gemma-7b\"])\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 5\n    ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)\n    theirs_config = GemmaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config).to(device)\n    # tie weights\n    ours_model.lm_head.weight = ours_model.transformer.wte.weight\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)\n    theirs_model = GemmaForCausalLM(theirs_config).to(device)\n    theirs_model.load_state_dict(\n        theirs_state_dict,\n        strict=False,\n    )\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-2-2b\", \"gemma-2-9b\", \"gemma-2-27b\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma_2(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config).to(device)\n    # tie weights\n    ours_model.lm_head.weight = ours_model.transformer.wte.weight\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_gemma_2(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = Gemma2ForCausalLM(theirs_config).to(device)\n    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)\n    assert not keys.unexpected_keys\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-3-1b-it\", \"gemma-3-4b-it\", \"gemma-3-12b-it\", \"gemma-3-27b-it\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        pytest.param(torch.device(\"cpu\"), torch.float32, marks=[pytest.mark.flaky(reruns=3)]),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # todo: the reference does softmax upscaled to fp32 during attention\n                # additionally, the final layernorm input is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma_3(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma3TextConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config).to(device)\n    # tie weights\n    ours_model.lm_head.weight = ours_model.transformer.wte.weight\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_gemma_3(ours_config, theirs_state_dict, ours_state_dict)\n    theirs_model = Gemma3ForCausalLM(theirs_config).to(device)\n    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)\n    assert not keys.unexpected_keys\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)\n\n\ndef test_check_conversion_supported_adapter():\n    lit_weights = {\"some.key.name\": ANY, \"error.key.gating_factor\": ANY}\n    with pytest.raises(NotImplementedError, match=\"Converting adapter\"):\n        check_conversion_supported(lit_weights=lit_weights)\n\n    lit_weights = {\"some.key.name\": ANY, \"error.key.adapter_bias\": ANY}\n    with pytest.raises(NotImplementedError, match=\"Converting adapter\"):\n        check_conversion_supported(lit_weights=lit_weights)\n\n\ndef test_check_conversion_supported_lora():\n    lit_weights = {\"some.key.name\": ANY, \"error.key.lora\": ANY}\n    with pytest.raises(ValueError, match=r\"LoRA.*cannot be converted\"):\n        check_conversion_supported(lit_weights=lit_weights)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"model_name\",\n    (\n        \"Qwen2.5-1.5B\",\n        \"Qwen2.5-Coder-1.5B\",\n        \"Qwen2.5-Math-1.5B\",\n        \"QwQ-32B-Preview\",\n        \"QwQ-32B\",\n        \"Qwen2.5-7B-Instruct-1M\",\n    ),\n)\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_qwen_2_5(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Qwen2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.attn_bias,\n        tie_word_embeddings=True,\n    )\n\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    ours_model = GPT(ours_config).to(device)\n    # tie weights\n    ours_model.lm_head.weight = ours_model.transformer.wte.weight\n    ours_state_dict = ours_model.state_dict()\n    theirs_state_dict = {}\n    copy_weights_qwen_2_5(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True)\n    theirs_model = Qwen2ForCausalLM(theirs_config).to(device)\n    keys = theirs_model.load_state_dict(theirs_state_dict, strict=False)\n    assert not keys.unexpected_keys\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\ndef test_qkv_reassemble():\n    # MHA\n    config = Config(n_embd=4, n_head=4)\n    qkv = torch.tensor(\n        [\n            [0, 1, 2, 3],  # query\n            [4, 5, 6, 7],  # query\n            [8, 9, 10, 11],  # query\n            [12, 13, 14, 15],  # query\n            [16, 17, 18, 19],  # key\n            [20, 21, 22, 23],  # key\n            [24, 25, 26, 27],  # key\n            [28, 29, 30, 31],  # key\n            [32, 33, 34, 35],  # value\n            [36, 37, 38, 39],  # value\n            [40, 41, 42, 43],  # value\n            [44, 45, 46, 47],  # value\n        ]\n    )\n    qkv_interleaved = qkv_reassemble(qkv, config)\n    torch.testing.assert_close(\n        qkv_interleaved,\n        torch.tensor(\n            [\n                [0, 1, 2, 3],  # query\n                [16, 17, 18, 19],  # key\n                [32, 33, 34, 35],  # value\n                [4, 5, 6, 7],  # query\n                [20, 21, 22, 23],  # key\n                [36, 37, 38, 39],  # value\n                [8, 9, 10, 11],  # query\n                [24, 25, 26, 27],  # key\n                [40, 41, 42, 43],  # value\n                [12, 13, 14, 15],  # query\n                [28, 29, 30, 31],  # key\n                [44, 45, 46, 47],  # value\n            ]\n        ),\n    )\n\n    # GQA\n    config = Config(n_embd=4, n_head=4, n_query_groups=2)\n    qkv = torch.tensor(\n        [\n            [0, 1, 2, 3],  # query\n            [4, 5, 6, 7],  # query\n            [8, 9, 10, 11],  # query\n            [12, 13, 14, 15],  # query\n            [16, 17, 18, 19],  # key\n            [20, 21, 22, 23],  # key\n            [24, 25, 26, 27],  # value\n            [28, 29, 30, 31],  # value\n        ]\n    )\n    qkv_interleaved = qkv_reassemble(qkv, config)\n    torch.testing.assert_close(\n        qkv_interleaved,\n        torch.tensor(\n            [\n                [0, 1, 2, 3],  # query\n                [4, 5, 6, 7],  # query\n                [16, 17, 18, 19],  # key\n                [24, 25, 26, 27],  # value\n                [8, 9, 10, 11],  # query\n                [12, 13, 14, 15],  # query\n                [20, 21, 22, 23],  # key\n                [28, 29, 30, 31],  # value\n            ]\n        ),\n    )\n\n    # MQA\n    config = Config(n_embd=4, n_head=4, n_query_groups=1)\n    qkv = torch.tensor(\n        [\n            [0, 1, 2, 3],  # query\n            [4, 5, 6, 7],  # query\n            [8, 9, 10, 11],  # query\n            [12, 13, 14, 15],  # query\n            [16, 17, 18, 19],  # key\n            [20, 21, 22, 23],  # value\n        ]\n    )\n    qkv_interleaved = qkv_reassemble(qkv, config)\n    torch.testing.assert_close(\n        qkv_interleaved,\n        torch.tensor(\n            [\n                [0, 1, 2, 3],  # query\n                [4, 5, 6, 7],  # query\n                [8, 9, 10, 11],  # query\n                [12, 13, 14, 15],  # query\n                [16, 17, 18, 19],  # key\n                [20, 21, 22, 23],  # value\n            ]\n        ),\n    )\n"
  },
  {
    "path": "tests/convert/test_pretrained_checkpoint.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\n\nimport torch\n\nfrom litgpt.scripts.convert_pretrained_checkpoint import convert_pretrained_checkpoint\n\n\ndef test_convert_pretrained_checkpoint(tmp_path, fake_checkpoint_dir):\n    # Pretend we made a checkpoint from pretraining\n    pretrained_checkpoint = {\n        \"model\": {\"some.module.weight\": torch.rand(2, 2), \"_orig_mod.some.other.module.weight\": torch.rand(2, 2)},\n        \"the_optimizer\": \"optimizer_state\",\n        \"other\": 1,\n    }\n    torch.save(pretrained_checkpoint, fake_checkpoint_dir / \"lit_model.pth\")\n\n    convert_pretrained_checkpoint(checkpoint_dir=fake_checkpoint_dir, output_dir=(tmp_path / \"converted\"))\n\n    assert set(os.listdir(tmp_path / \"converted\")) == {\n        \"lit_model.pth\",\n        \"model_config.yaml\",\n        \"tokenizer_config.json\",\n        \"tokenizer.json\",\n    }\n    converted_checkpoint = torch.load(tmp_path / \"converted\" / \"lit_model.pth\")\n    assert list(converted_checkpoint.keys()) == [\"some.module.weight\", \"some.other.module.weight\"]\n"
  },
  {
    "path": "tests/data/__init__.py",
    "content": ""
  },
  {
    "path": "tests/data/_fixtures/alpaca.json",
    "content": "[\n  {\n    \"instruction\": \"Give three tips for staying healthy.\",\n    \"input\": \"\",\n    \"output\": \"1. Eat a balanced diet and make sure to include plenty of fruits and vegetables. \\n2. Exercise regularly to keep your body active and strong. \\n3. Get enough sleep and maintain a consistent sleep schedule.\"\n  },\n  {\n    \"instruction\": \"What are the three primary colors?\",\n    \"input\": \"\",\n    \"output\": \"The three primary colors are red, blue, and yellow.\"\n  },\n  {\n    \"instruction\": \"Describe the structure of an atom.\",\n    \"input\": \"\",\n    \"output\": \"An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.\"\n  },\n  {\n    \"instruction\": \"How can we reduce air pollution?\",\n    \"input\": \"\",\n    \"output\": \"There are a number of ways to reduce air pollution, such as shifting to renewable energy sources, encouraging the use of public transportation, prohibiting the burning of fossil fuels, implementing policies to reduce emissions from industrial sources, and implementing vehicle emissions standards. Additionally, individuals can do their part to reduce air pollution by reducing car use, avoiding burning materials such as wood, and changing to energy efficient appliances.\"\n  },\n  {\n    \"instruction\": \"Pretend you are a project manager of a construction company. Describe a time when you had to make a difficult decision.\",\n    \"input\": \"\",\n    \"output\": \"I had to make a difficult decision when I was working as a project manager at a construction company. I was in charge of a project that needed to be completed by a certain date in order to meet the client\\u2019s expectations. However, due to unexpected delays, we were not able to meet the deadline and so I had to make a difficult decision. I decided to extend the deadline, but I had to stretch the team\\u2019s resources even further and increase the budget. Although it was a risky decision, I ultimately decided to go ahead with it to ensure that the project was completed on time and that the client\\u2019s expectations were met. The project was eventually successfully completed and this was seen as a testament to my leadership and decision-making abilities.\"\n  },\n  {\n    \"instruction\": \"Identify the odd one out.\",\n    \"input\": \"Twitter, Instagram, Telegram\",\n    \"output\": \"Telegram\"\n  },\n  {\n    \"instruction\": \"Explain why the following fraction is equivalent to 1/4\",\n    \"input\": \"4/16\",\n    \"output\": \"The fraction 4/16 is equivalent to 1/4 because both numerators and denominators are divisible by 4. Dividing both the top and bottom numbers by 4 yields the fraction 1/4.\"\n  },\n  {\n    \"instruction\": \"Write a short story in third person narration about a protagonist who has to make an important career decision.\",\n    \"input\": \"\",\n    \"output\": \"John was at a crossroads in his life. He had just graduated college and was now facing the big decision of what career to pursue. After much deliberation, he decided that he wanted to be an accountant and help the financially disadvantaged. He had always been good with numbers and enjoyed seeing the tangible results of his work. \\n\\nJohn enrolled in accounting courses and initially found it quite challenging. He had to learn multiple systems and regulations quickly, but he worked hard and eventually excelled in his studies. After a few years, John started working at an accounting firm in his city. He was eager to put his knowledge of taxes and accounting to use in a real-world setting.\\n\\nJohn loved his job, as it let him express his creativity in finding strategies to save his clients money. After a few years at the firm, he became a senior accountant and was asked to manage bigger and more challenging cases. He was now a respected figure in the financial industry, but he still remembers when he was just a recent college graduate, unsure of the direction in which his life would take him.\"\n  },\n  {\n    \"instruction\": \"Evaluate this sentence for spelling and grammar mistakes\",\n    \"input\": \"He finished his meal and left the restaurant\",\n    \"output\": \"He finished his meal and left the restaurant.\"\n  },\n  {\n    \"instruction\": \"How did Julius Caesar die?\",\n    \"input\": \"\",\n    \"output\": \"Julius Caesar was assassinated by a group of up to 60 conspirators, led by Gaius Cassius Longinus and Marcus Junius Brutus, in the Senate House on the Ides of March (15 March) of 44 BC.\"\n  },\n  {\n    \"instruction\": \"What is the capital of France?\",\n    \"input\": \"\",\n    \"output\": \"The capital of France is Paris.\"\n  },\n  {\n    \"instruction\": \"Generate a list of ten items a person might need for a camping trip\",\n    \"input\": \"\",\n    \"output\": \"1. Tent\\n2. Sleeping bags\\n3. Flashlight\\n4. Matches/lighter\\n5. Insect repellent\\n6. Extra clothes\\n7. First aid kit\\n8. Camp stove\\n9. Cooking utensils\\n10. Biodegradable soaps/shampoos\"\n  },\n  {\n    \"instruction\": \"Discuss the causes of the Great Depression\",\n    \"input\": \"\",\n    \"output\": \"The Great Depression was a severe global economic downturn that lasted from 1929 to 1939. It was caused by several factors, including the stock market crash of 1929, structural weaknesses in the economy, low consumer spending, high unemployment, and a lack of confidence in the banking system. Other contributing factors included overproduction of commodities, reliance on borrowing to finance investments, agricultural overproduction, and a decrease in international trade.\"\n  },\n  {\n    \"instruction\": \"Classify the following into animals, plants, and minerals\",\n    \"input\": \"Oak tree, copper or, elephant\",\n    \"output\": \"Oak tree: Plant\\n Copper or: Mineral\\n Elephant: Animal\"\n  },\n  {\n    \"instruction\": \"Explain the use of word embeddings in Natural Language Processing\",\n    \"input\": \"\",\n    \"output\": \"Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies. They allow for words to be processed as numerical values, giving machines an easier way to perform NLP tasks.\"\n  },\n  {\n    \"instruction\": \"Describe the function of a computer motherboard\",\n    \"input\": \"\",\n    \"output\": \"A computer motherboard is the main circuit board in a computer. It is the backbone of the system, connecting all other components such as the processor, RAM, storage, graphics card and peripheral ports. It provides communication pathways between all of these components, allowing them to interact and work together. The motherboard is responsible for regulating power, and also houses slots for expansion cards, as well as ports for USB, audio, Ethernet, and other connection types.\"\n  },\n  {\n    \"instruction\": \"Reverse engineer this code to create a new version\",\n    \"input\": \"def factorialize(num):\\n  factorial = 1\\n  for i in range(1, num):\\n    factorial *= i\\n  \\n  return factorial\",\n    \"output\": \"def factorialize(num):\\n  factorial = 1\\n  for i in range(num, 0, -1):\\n    factorial *= i\\n  \\n  return factorial\"\n  },\n  {\n    \"instruction\": \"Propose an ethical solution to the problem of data privacy\",\n    \"input\": \"\",\n    \"output\": \"One ethical solution to the problem of data privacy is to implement an opt-in data collection policy. This means that companies cannot collect and use personal data without the explicit permission of the data owners. This approach gives users the power to determine which data companies can use and for what purpose. Companies should also be transparent about how they use the collected data, so that users can make informed decisions and have control over their own data.\"\n  },\n  {\n    \"instruction\": \"Generate three verbs that mean the same as \\\"to apologize\\\"\",\n    \"input\": \"\",\n    \"output\": \"Acknowledge, express regret, make amends.\"\n  },\n  {\n    \"instruction\": \"What does DNA stand for?\",\n    \"input\": \"\",\n    \"output\": \"DNA stands for deoxyribonucleic acid.\"\n  },\n  {\n    \"instruction\": \"Compare and contrast the Cuban Missile Crisis and the Vietnam War.\",\n    \"input\": \"\",\n    \"output\": \"The Cuban Missile Crisis and the Vietnam War were two of the largest and deadliest conflicts of the 20th century. The Cuban Missile Crisis involved a standoff between the United States and the Soviet Union over Soviet nuclear missile bases in Cuba, while the Vietnam War was a prolonged military conflict in Indochina between the North, which was backed by the Soviet Union and the South, which was backed by the United States. The Cuban Missile Crisis unfolded over a period of 13 days in October 1962, with the resolution coming in the form of a diplomatic agreement. The Vietnam War lasted for 19 years, beginning in 1954 and ending in 1973, with the signing of the Paris Peace Accords. While both of these conflicts were Cold War related, they were also very different in terms of causes, objectives, strategic objectives and outcomes. The Cuban Missile Crisis did not result in any direct military conflict and was resolved primarily through diplomacy and the decision to withdraw nuclear weapons from Cuba. On the other hand, the Vietnam War was a prolonged series of battles resulting in significant losses for both sides. The Vietnam War also had a significant impact on the United States, as it highlighted the realities of a prolonged conflict and the limits of U.S. military intervention. Additionally, the Vietnam War also drastically changed the way public opinion viewed the military and the use of force.\"\n  },\n  {\n    \"instruction\": \"Generate a list of random words.\",\n    \"input\": \"\",\n    \"output\": \"sneeze, conflict, ancestor, thunder, companion, amulet, forge, serpent, magnify, jester, cascade, universe, surge, nausea, oasis, anomaly, pursue, rapture, pagoda, mosaic.\"\n  },\n  {\n    \"instruction\": \"Transcribe the recording into text.\",\n    \"input\": \"Recording about the history of Waterloo\",\n    \"output\": \"The Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years.\"\n  }\n]\n"
  },
  {
    "path": "tests/data/_fixtures/dolly.json",
    "content": "[\n  {\n    \"instruction\": \"When did Virgin Australia start operating?\",\n    \"context\": \"Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.\",\n    \"response\": \"Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.\",\n    \"category\": \"closed_qa\"\n  },\n  { \"instruction\": \"Which is a species of fish? Tope or Rope\", \"context\": \"\", \"response\": \"Tope\", \"category\": \"classification\" },\n  {\n    \"instruction\": \"Why can camels survive for long without water?\",\n    \"context\": \"\",\n    \"response\": \"Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.\",\n    \"category\": \"open_qa\"\n  },\n  {\n    \"instruction\": \"Alice's parents have three daughters: Amy, Jessy, and what\\u2019s the name of the third daughter?\",\n    \"context\": \"\",\n    \"response\": \"The name of the third daughter is Alice\",\n    \"category\": \"open_qa\"\n  },\n  {\n    \"instruction\": \"When was Tomoaki Komorida born?\",\n    \"context\": \"Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In September 2005, he moved to the J2 club Montedio Yamagata. In 2006, he moved to the J2 club Vissel Kobe. Although he became a regular player as a defensive midfielder, his gradually was played less during the summer. In 2007, he moved to the Japan Football League club Rosso Kumamoto (later Roasso Kumamoto) based in his local region. He played as a regular player and the club was promoted to J2 in 2008. Although he did not play as much, he still played in many matches. In 2010, he moved to Indonesia and joined Persela Lamongan. In July 2010, he returned to Japan and joined the J2 club Giravanz Kitakyushu. He played often as a defensive midfielder and center back until 2012 when he retired.\",\n    \"response\": \"Tomoaki Komorida was born on July 10,1981.\",\n    \"category\": \"closed_qa\"\n  },\n  {\n    \"instruction\": \"If I have more pieces at the time of stalemate, have I won?\",\n    \"context\": \"Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and other chess problems.\\n\\nThe outcome of a stalemate was standardized as a draw in the 19th century. Before this standardization, its treatment varied widely, including being deemed a win for the stalemating player, a half-win for that player, or a loss for that player; not being permitted; and resulting in the stalemated player missing a turn. Stalemate rules vary in other games of the chess family.\",\n    \"response\": \"No. \\nStalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position\",\n    \"category\": \"information_extraction\"\n  },\n  {\n    \"instruction\": \"Given a reference text about Lollapalooza, where does it take place, who started it and what is it?\",\n    \"context\": \"Lollapalooza /\\u02ccl\\u0252l\\u0259p\\u0259\\u02c8lu\\u02d0z\\u0259/ (Lolla) is an annual American four-day music festival held in Grant Park in Chicago. It originally started as a touring event in 1991, but several years later, Chicago became its permanent location. Music genres include but are not limited to alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. Lollapalooza has also featured visual arts, nonprofit organizations, and political organizations. The festival, held in Grant Park, hosts an estimated 400,000 people each July and sells out annually. Lollapalooza is one of the largest and most iconic music festivals in the world and one of the longest-running in the United States.\\n\\nLollapalooza was conceived and created in 1991 as a farewell tour by Perry Farrell, singer of the group Jane's Addiction.\",\n    \"response\": \"Lollapalooze is an annual musical festival held in Grant Park in Chicago, Illinois. It was started in 1991 as a farewell tour by Perry Farrell, single of the group Jane's Addiction. The festival includes an array of musical genres including alternative rock, heavy metal, punk rock, hip hop, and electronic dance music. The festivals welcomes an estimated 400,000 people each year and sells out annually. Some notable headliners include: the Red Hot Chili Peppers, Chance the Rapper, Metallica, and Lady Gauge. Lollapalooza is one of the largest and most iconic festivals in the world and a staple of Chicago.\",\n    \"category\": \"closed_qa\"\n  },\n  {\n    \"instruction\": \"Who gave the UN the land in NY to build their HQ\",\n    \"context\": \"\",\n    \"response\": \"John D Rockerfeller\",\n    \"category\": \"open_qa\"\n  },\n  {\n    \"instruction\": \"Why mobile is bad for human\",\n    \"context\": \"\",\n    \"response\": \"We are always engaged one phone which is not good.\",\n    \"category\": \"brainstorming\"\n  },\n  {\n    \"instruction\": \"Who was John Moses Browning?\",\n    \"context\": \"John Moses Browning (January 23, 1855 \\u2013 November 26, 1926) was an American firearm designer who developed many varieties of military and civilian firearms, cartridges, and gun mechanisms \\u2013 many of which are still in use around the world. He made his first firearm at age 13 in his father's gun shop and was awarded the first of his 128 firearm patents on October 7, 1879, at the age of 24. He is regarded as one of the most successful firearms designers of the 19th and 20th centuries and pioneered the development of modern repeating, semi-automatic, and automatic firearms.\\n\\nBrowning influenced nearly all categories of firearms design, especially the autoloading of ammunition. He invented, or made significant improvements to, single-shot, lever-action, and pump-action rifles and shotguns. He developed the first reliable and compact autoloading pistols by inventing the telescoping bolt, then integrating the bolt and barrel shroud into what is known as the pistol slide. Browning's telescoping bolt design is now found on nearly every modern semi-automatic pistol, as well as several modern fully automatic weapons. He also developed the first gas-operated firearm, the Colt\\u2013Browning Model 1895 machine gun \\u2013 a system that surpassed mechanical recoil operation to become the standard for most high-power self-loading firearm designs worldwide. He also made significant contributions to automatic cannon development.\\n\\nBrowning's most successful designs include the M1911 pistol, the water-cooled M1917, the air-cooled M1919, and heavy M2 machine guns, the M1918 Browning Automatic Rifle, and the Browning Auto-5 \\u2013 the first semi-automatic shotgun. Some of these arms are still manufactured, often with only minor changes in detail and cosmetics to those assembled by Browning or his licensees. The Browning-designed M1911 and Hi-Power are some of the most copied firearms in the world.\",\n    \"response\": \"John Moses Browning is one of the most well-known designer of modern firearms.  He started building firearms in his father's shop at the age of 13, and was awarded his first patent when he was 24.\\n\\nHe  designed the first reliable automatic pistol, and the first gas-operated firearm, as well inventing or improving single-shot, lever-action, and pump-action rifles and shotguns.\\n\\nToday, he is most well-known for the M1911 pistol, the Browning Automatic Rifle, and the Auto-5 shotgun, all of which are in still in current production in either their original design, or with minor changes.  His M1911 and Hi-Power pistols designs are some of the most reproduced firearms in the world today.\",\n    \"category\": \"information_extraction\"\n  },\n  {\n    \"instruction\": \"Who is Thomas Jefferson?\",\n    \"context\": \"Thomas Jefferson (April 13, 1743 \\u2013 July 4, 1826) was an American statesman, diplomat, lawyer, architect, philosopher, and Founding Father who served as the third president of the United States from 1801 to 1809. Among the Committee of Five charged by the Second Continental Congress with authoring the Declaration of Independence, Jefferson was the Declaration's primary author. Following the American Revolutionary War and prior to becoming the nation's third president in 1801, Jefferson was the first United States secretary of state under George Washington and then the nation's second vice president under John Adams.\",\n    \"response\": \"Thomas Jefferson (April 13, 1743 \\u2013 July 4, 1826) was an American statesman, diplomat, lawyer, architect, philosopher, and Founding Father who served as the third president of the United States from 1801 to 1809. Among the Committee of Five charged by the Second Continental Congress with authoring the Declaration of Independence, Jefferson was the Declaration's primary author. Following the American Revolutionary War and prior to becoming the nation's third president in 1801, Jefferson was the first United States secretary of state under George Washington and then the nation's second vice president under John Adams.\\n\\nAmong the nation's Founding Fathers, Jefferson is considered unmatched in his intellectual depth and breadth. His passionate writings and advocacy for human rights, including freedom of thought, speech, and religion, were a leading inspiration behind the American Revolution, which ultimately gave rise to the American Revolutionary War, American independence, and the United States Constitution. Jefferson's ideas were globally influential in shaping and inspiring the Age of Enlightenment, which proved transformational in the late 17th and 18th centuries. He was a leading proponent of democracy, republicanism, and individual rights, and produced formative documents and decisions at the state, national, and international levels.\\n\\nDuring the American Revolution, Jefferson represented Virginia in the Second Continental Congress in Philadelphia, which adopted the Declaration of Independence on July 4, 1776. As a Virginia legislator, he drafted a state law for religious freedom. He served as the second Governor of Virginia from 1779 to 1781, during the Revolutionary War. In 1785, Jefferson was appointed the United States Minister to France, and subsequently, the nation's first secretary of state under President George Washington from 1790 to 1793. Jefferson and James Madison organized the Democratic-Republican Party to oppose the Federalist Party during the formation of the First Party System. With Madison, he anonymously wrote the Kentucky and Virginia Resolutions in 1798 and 1799, which sought to strengthen states' rights by nullifying the federal Alien and Sedition Acts.\\n\\nJefferson and Federalist John Adams became friends as well as political rivals, serving in the Continental Congress and drafting the Declaration of Independence together. In the 1796 presidential election between the two, Jefferson came in second, which according to electoral procedure at the time, made him vice president to Adams. Jefferson challenged Adams again in 1800 and won the presidency. After his term in office, Jefferson eventually reconciled with Adams and they shared a correspondence that lasted 14 years. He and Adams both died on the same day, July 4, 1826, which was also the 50th anniversary of Declaration of Independence.\\n\\nAs president, Jefferson pursued the nation's shipping and trade interests against Barbary pirates and aggressive British trade policies. Starting in 1803, he promoted a western expansionist policy with the Louisiana Purchase, which doubled the nation's claimed land area. To make room for settlement, Jefferson began the process of Indian tribal removal from the newly acquired territory. As a result of peace negotiations with France, his administration reduced military forces. He was re-elected in 1804, but his second term was beset with difficulties at home, including the trial of former vice president Aaron Burr. In 1807, American foreign trade was diminished when Jefferson implemented the Embargo Act in response to British threats to U.S. shipping. The same year, Jefferson signed the Act Prohibiting Importation of Slaves.\\n\\nJefferson was a plantation owner, lawyer, and politician, and mastered many disciplines including surveying, mathematics, horticulture, and mechanics. He was also an architect in the Palladian tradition. Jefferson's keen interest in religion and philosophy led to his appointment as president of the American Philosophical Society. He largely shunned organized religion but was influenced by Christianity, Epicureanism, and deism. Jefferson rejected fundamental Christianity, denying Christ's divinity. A philologist, Jefferson knew several languages. He was a prolific letter writer and corresponded with many prominent people, including Edward Carrington, John Taylor of Caroline, and James Madison. In 1785, Jefferson authored Notes on the State of Virginia, considered perhaps the most important American book published before 1800. Jefferson championed the ideals, values, and teachings of the Enlightenment.\\n\\nSince the 1790s, Jefferson was rumored to have had children by his sister-in-law and slave Sally Hemings, leading to what is known as the Jefferson-Hemings controversy. A 1998 DNA test concluded that one of Sally Hemings's children, Eston Hemings, was of the Jefferson male line. According to scholarly consensus, based on documentary and statistical evaluation, as well as oral history, Jefferson probably fathered at least six children with Hemings, including four that survived to adulthood.\\n\\nAfter retiring from public office, Jefferson founded the University of Virginia. Presidential scholars and historians generally praise Jefferson's public achievements, including his advocacy of religious freedom and tolerance in Virginia, his peaceful acquisition of the Louisiana Territory from France without war or controversy, and his ambitious and successful Lewis and Clark Expedition. Some modern historians are critical of Jefferson's personal involvement with slavery. Jefferson is consistently ranked among the top ten presidents of American history.\",\n    \"category\": \"information_extraction\"\n  },\n  {\n    \"instruction\": \"Who was Kyle Van Zyl playing against when he scored 36 of hisa teams 61 points?\",\n    \"context\": \"Van Zyl joined the Eastern Province Kings Academy, where he played for the Eastern Province U19 side in the 2010 Under-19 Provincial Championship. He was a key player for the Eastern Province U21 side in the 2012 Under-21 Provincial Championship, scoring 71 points in eight appearances. Van Zyl was under the Top SARU Performers, scoring the most tries at 6 in the 2012 Provincial Under 21 in the Rugby Junior Provincials.\\n\\nThis included a record and a remarkable personal haul in their opening match, when he scored 36 of his team's points in a 61\\u20133 victory over Boland U21, consisting of four tries and eight conversions and was awarded Man of the Match.\",\n    \"response\": \"Kyle Van Zyl was playing against Boland U21 when he scored 36 points, leading his team to victory in a 61-3 win.\",\n    \"category\": \"closed_qa\"\n  }\n]\n"
  },
  {
    "path": "tests/data/_fixtures/longform_train.json",
    "content": "[\n  {\n    \"input\": \"What are the positions held by Beto O'Rourke, Lupe Valdez, and Veronica Escobar on decriminalizing unauthorized border crossings in the US?\",\n    \"output\": \"The liberal mainstream media squealed with delight last week when they announced Ted Cruz was only 4 points ahead of Democrat Beto O\\u2019Rourke.\\nThe the bottom dropped out on Beto\\u2019s campaign.\\nFirst O\\u2019Rourke told a liberal Texas crowd that there was nothing more American than kneeling for the National Anthem.\\nBeto O\\u2019Rourke is now pushing to legalize unauthorized border crossers into America.\\nA trio of Democrats running for top offices in Texas want to decriminalize unauthorized border crossings.\\nU.S. Rep. Beto O\\u2019Rourke, who\\u2019s challenging Republican Sen. Ted Cruz for his Senate seat, gubernatorial candidate Lupe Valdez, and congressional candidate Veronica Escobar told HuffPost they are in favor of such a move, which the report says would be simpler to implement than abolishing Immigration and Customs Enforcement, as some Democrats have demanded.\\nDecriminalizing unauthorized border crossings would go against the Trump administration\\u2019s \\u201czero tolerance\\u201d policy that led to family separations at the border. The Department of Homeland Security would still be able to charge unauthorized immigrants with civil violations, the report notes, but it would save the federal government billions in incarceration costs.\",\n    \"source\": \"C4\",\n    \"subset\": \"search-engine\"\n  },\n  {\n    \"input\": \"For example, for the original quote we shall use, 'One swallow does not a summer make'. If someone were to say 'One good quote does not a philosopher make' (No dig at Aristotle intended...), or 'A fancy suit does not a businessman make'. Essentially, a clear allusion to a quote, but not the quote itself.\",\n    \"output\": \"The word I've heard used for this on the linguistics blog Language Log is snowclone (it's derived from phrases of the format 'If Eskimos have N words for snow, X surely have Y words for Z.') Here's a Language Log post that describes the moment the word was coined and that links to some examples: Snowclones: Lexicographical Dating to the Second And here's how it's defined by Collins Dictionaries online: a verbal formula that is adapted for reuse by changing only a few words so that the allusion to the original phrase remains clear You might also be interested in looking at the answers to the following questions: This book will change your wife trope? Is there a general name for this: 'The more you X the more you'll Y' A term describing the replacement of a specific word in a saying with one that ryhmes\",\n    \"source\": \"Natural Instructions\",\n    \"subset\": \"task225 english language answer generation\"\n  },\n  {\n    \"input\": \"Describe the U.S.-Mexico border. Respond in detail.\",\n    \"output\": \"President Donald Trump\\u2019s border wall with Mexico will have an initial price tag of between $12 billion and $15 billion, according to Congressional Republican leaders, but the method of paying for Trump\\u2019s promised project is still not settled.\\nAt the Congressional Republican Retreat in Philadelphia on Thursday, House Speaker Paul Ryan said the wall would cost up to $15 billion, more than the $8 billion Trump has estimated the wall will cost.\\nA U.S. Border Patrol agent removes a ladder used by undocumented immigrants to climb a border fence on Oct. 18, 2016 near McAllen, Texas.\\nU.S. Customs and Border Protection agents fly near the U.S.-Mexico border while on helicopter patrol on Oct. 18, 2016 near McAllen, Texas.\\nA U.S. Border Patrol agent detains a group of undocumented immigrants on Oct. 18, 2016 near McAllen, Texas.\\nPeople stand in line to cross legally into the United States from Mexico on Sept. 24, 2016 in Tijuana, Mexico.\\nThe Rio Grande forms a stretch of the U.S.-Mexico border on Oct. 15, 2016 in the Big Bend region of West Texas near Lajitas, Texas. Big Bend is a rugged, vast and remote region along the U.S.-Mexico border and includes the Big Bend National Park.\\nU.S. Border Patrol agents with a K-9 unit detain undocumented immigrants after they illegally crossed the U.S.-Mexico border on Oct. 18, 2016, in McAllen, Texas.\\nThis photo made with a smart phone through night vision goggles shows the Rio Grande flowing along the U.S.-Mexico border, as seen from a U.S. Customs and Border Protection helicopter during a patrol over the U.S.-Mexico border on Oct. 18, 2016 in McAllen, Texas.\\nThe moon rises over the swirling current of the Rio Grande on Oct. 15, 2016 in the Big Bend region of West Texas near Lajitas, Texas.\\nA bullet-proof shield stands to aid U.S. Border Patrol agents on the U.S.-Mexico border on Oct. 3, 2016 in El Paso, Texas.\\nA child plays in the Pacific surf near the U.S.-Mexico border fence on Sept. 25, 2016 in Tijuana, Mexico. The nearby Friendship Park is one of the few places on the 2,000-mile border where separated families are allowed to meet.\\nDunes stretch into the distance near the U.S.-Mexico border on Sept. 27, 2016 in the Imperial Sand Dunes recreation area, California.\\nMexican farm workers hoe a cabbage field on Sept. 27, 2016, in Holtville, Calif. Thousands of Mexican seasonal workers legally cross over daily from Mexicali, Mexico to work the fields of Imperial Valley, Calif., which is some of the most productive farmland in the United States.\\nA man looks through the U.S.-Mexico border fence into the United States on Septt. 25, 2016 in Tijuana, Mexico.\\nA Studebaker sits outside a cotton field near the U.S.-Mexico border on Oct. 14, 2016 near Fort Hancock, Texas.\\nBodies of suspected undocumented immigrants lie in the morgue at the Office of the Pima County Medical Examiner on Sept. 29, 2016 in Tucson, Ariz. Hundreds of migrants die every year in the desert while crossing illegally from Mexico into the United States.\\nPersonal effects found on the body of a suspected undocumented immigrant are stored at the Office of the Pima County Medical Examiner on Sept. 29, 2016 in Tucson, Ariz.\\nDinner tables await immigrants, many of them deportees from the United States, at the Casa del Migrante Sept. 23, 2016 in Tijuana, Mexico. The shelter, run by Catholic priests, is part of the Coalicion Pro Defensa, which aids immigrants and asylum applicants seeking entry into the United States. The coalition of immigrant advocates has helped thousands of migrants on their journey north and many more after being being deported from the United States.\\nMexicans enjoy a late afternoon near the U.S.-Mexico border fence, which ends in the Pacific Ocean, on Sept. 25, 2016 in Tijuana, Mexico. Friendship Park, located on the border between the two countries is one of the few places on the 2,000-mile border where separated families are allowed to meet.\\nTrump signed two executive orders on Wednesday directing the construction of the border wall and increasing the number of immigration enforcement officers carrying out deportation, as well as border patrol forces.\",\n    \"source\": \"C4\",\n    \"subset\": \"instruction\"\n  },\n  {\n    \"input\": \"Must a Boss use their Level Up ability?\\n\\nMost Level Up abilities clearly benefit the Boss, but Robobo's is:\\n\\nEach opponent mush choose and destroy one Room in his or her Dungeon.\\n\\nThe opponents get to destroy a room (often quite beneficial), and Robobo doesn't!\\nIs there any way - short of only building a four room dungeon - that Robobo can avoid using this ability? What play strategies can Robobo use to maximize his Level Up ability?\",\n    \"output\": \"According to the rules, your boss's \\\"Level Up\\\" ability is triggered when your dungeon is 5 rooms in size. \\\"Triggering\\\" implies necessity. So yes, the only way to avoid levelling up is by not building a fifth room.\\nSecondly, and more to your concern... according to the Advanced Rules (see the entry for Robobo in the glossary using the link above), using Robobo's level up does not trigger the \\\"destroy this Room\\\" ability for the destroyed room.\",\n    \"source\": \"StackExchange\",\n    \"subset\": \"boardgames\"\n  },\n  {\n    \"input\": \"What is Podd\\u0119bice?\",\n    \"output\": \"Podd\\u0119bice () is a village in the administrative district of Gmina W\\u0142oc\\u0142awek, within W\\u0142oc\\u0142awek County, Kuyavian-Pomeranian Voivodeship, in north-central Poland. It lies approximately south of W\\u0142oc\\u0142awek and south-east of Toru\\u0144.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"Describe the history of MV \\\"Alexia\\\" Respond in 4 sentences.\",\n    \"output\": \"MV \\\"Alexia\\\" was one of nine Royal Dutch/Shell oil tankers converted to become a Merchant Aircraft Carrier (MAC ship). The group is collectively known as the \\\"Rapana\\\" class.\\nMV \\\"Alexia\\\" was built at Bremer Vulkan and completed in April, 1935 as an oil tanker for the Anglo Saxon Royal Dutch/Shell line. She sustained severe damage in two separate U-boat attacks in 1940 and 1942.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"instruction\"\n  },\n  {\n    \"input\": \"What is the definition of absurdity?\",\n    \"output\": \"1. the state or quality of being absurd.\\nAttending the Gerald R. Ford Symposium on Humor and the Presidency is sort of like attending the Ayatollah Khomeini Symposium on the sexual revolution \\u2014Pat Paulsen, at September 19, 1986 symposium in Grand Rapids, Michigan.\\n2. absurdity - a ludicrous folly; \\\"the crowd laughed at the absurdity of the clown's behavior\\\"\\nnoun ridiculousness, nonsense, folly, stupidity, foolishness, silliness, idiocy, irrationality, incongruity, meaninglessness, daftness (informal), senselessness, illogicality, ludicrousness, unreasonableness, preposterousness, farcicality, craziness (informal), b\\u00eatise (rare), farcicalness, illogicalness I get angry at the absurdity of a situation.\\nfolly, foolery, foolishness, idiocy, imbecility, insanity, lunacy, madness, nonsense, preposterousness, senselessness, silliness, tomfoolery, zaniness.\\nBut in the Epic poem the absurdity passes unnoticed.\\nFor if absurdity be the subject of laughter, doubt you not but great boldness is seldom without some absurdity.\\nI am temperate to the verge of absurdity,\\\" replied the Tramp.\\nPoets, of course, may be satisfactorily read in volumes of, selections; but to me, at least, a book of brief extracts from twenty or a hundred prose authors is an absurdity.\\nThen,\\\" suggested the idea, with a blush for its own absurdity, \\\"why not go on pilgrimage and seek her?\\nJudges and starters have been conveniently blind to this absurdity, but the public demonstration off St.\\nAnd now that the providential occurrence was apparently close at hand, it would have been sheer absurdity to think that the supply would be short of the need: as absurd as a faith that believed in half a miracle for want of strength to believe in a whole one.\\nutterances, the absurdity being attested by his motley costume.\\nOr what greater absurdity can there be than putting before us an old man as a swashbuckler, a young man as a poltroon, a lackey using fine language, a page giving sage advice, a king plying as a porter, a princess who is a kitchen-maid?\\nNow to say that the honour I here mean, and which was, I thought, all the honour I could be supposed to mean, will uphold, much less dictate an untruth, is to assert an absurdity too shocking to be conceived.\\nHence the absurdity of the interview; the gulf between them was economic as well as spiritual.\",\n    \"source\": \"C4\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"Can felons run for federal office in Minnesota?\",\n    \"output\": \"Minnesota law doesn't block felons from running for federal office.\\nEven if enough voters choose Leonard J. Richards as the DFL candidate for U.S. Senate, there\\u2019s no way he will ever get to take the oath of office and begin a six-year term.\\nRichards is already serving a lifetime term. In Stillwater prison. For murder. Make that two murders.\\nYes, it is legal in Minnesota for felons to run for office, so long as it is a federal seat. Nobody knows that better than Richards, who is trying to wrest the party nod away from incumbent Amy Klobuchar.\\nNow 75 years old and sporting Department of Corrections ID No. 149837, Richards has run for federal office several times \\u2014 without a victory \\u2014 since his imprisonment, most recently when he sought the seat that U.S. Rep. Tom Emmer now holds.\\nIn 1992, Richards ran in the DFL primary for the Eighth Congressional District seat and received more than 14,500 votes. He ran for the U.S. Senate in the DFL primary in 1994, winning more than 4,000 votes.\\nMinnesota law does not permit inmates to run for a state-level office.\\nThe official ballot for Minnesota\\u2019s primary next month lists double murderer Leonard Richards among those seeking the DFL nomination for U.S. Senate.\\nRichards was convicted of murder twice in Hennepin County for the 1982 slaying of his half-sister, May Wilson, and the 1987 shooting death of his attorney, Robert Stratton. His life sentence imposed nearly 30 years ago offers no parole.\\nRichards did not respond to a message Thursday seeking an interview.\",\n    \"source\": \"C4\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"What is the purpose of consciousness?\",\n    \"output\": \"Cardiff University and University College London provide funding as founding partners of The Conversation UK.\\nMost experts think that consciousness can be divided into two parts: the experience of consciousness (or personal awareness), and the contents of consciousness, which include things such as thoughts, beliefs, sensations, perceptions, intentions, memories and emotions.\\nIt\\u2019s easy to assume that these contents of consciousness are somehow chosen, caused or controlled by our personal awareness \\u2013 after all, thoughts don\\u2019t exist until until we think them. But in a new research paper in Frontiers of Psychology, we argue that this is a mistake.\\nWe suggest that our personal awareness does not create, cause or choose our beliefs, feelings or perceptions. Instead, the contents of consciousness are generated \\u201cbehind the scenes\\u201d by fast, efficient, non-conscious systems in our brains. All this happens without any interference from our personal awareness, which sits passively in the passenger seat while these processes occur.\\nPut simply, we don\\u2019t consciously choose our thoughts or our feelings \\u2013 we become aware of them.\\nIf this sounds strange, consider how effortlessly we regain consciousness each morning after losing it the night before; how thoughts and emotions \\u2013 welcome or otherwise \\u2013 arrive already formed in our minds; how the colours and shapes we see are constructed into meaningful objects or memorable faces without any effort or input from our conscious mind.\\nConsider that all the neuropsychological processes responsible for moving your body or using words to form sentences take place without involving your personal awareness. We believe that the processes responsible for generating the contents of consciousness do the same.\\nOur thinking has been influenced by research into neuropsychological and neuropsychiatric disorders, as well as more recent cognitive neuroscience studies using hypnosis. The studies using hypnosis show that a person\\u2019s mood, thoughts and perceptions can be profoundly altered by suggestion.\\nIn such studies, participants go through a hypnosis induction procedure, to help them to enter a mentally focused and absorbed state. Then, suggestions are made to change their perceptions and experiences.\\nFor example, in one study, researchers recorded the brain activity of participants when they raised their arm intentionally, when it was lifted by a pulley, and when it moved in response to a hypnotic suggestion that it was being lifted by a pulley.\\nSimilar areas of the brain were active during the involuntary and the suggested \\u201calien\\u201d movement, while brain activity for the intentional action was different. So, hypnotic suggestion can be seen as a means of communicating an idea or belief that, when accepted, has the power to alter a person\\u2019s perceptions or behaviour.\\nAll this may leave one wondering where our thoughts, emotions and perceptions actually come from. We argue that the contents of consciousness are a subset of the experiences, emotions, thoughts and beliefs that are generated by non-conscious processes within our brains.\\nThis subset takes the form of a personal narrative, which is constantly being updated. The personal narrative exists in parallel with our personal awareness, but the latter has no influence over the former.\\nThe personal narrative is important because it provides information to be stored in your autobiographical memory (the story you tell yourself, about yourself), and gives human beings a way of communicating the things we have perceived and experienced to others.\\nThis, in turn, allows us to generate survival strategies; for example, by learning to predict other people\\u2019s behaviour. Interpersonal skills like this underpin the development of social and cultural structures, which have promoted the survival of human kind for millennia.\\nSo, we argue that it is the ability to communicate the contents of one\\u2019s personal narrative \\u2013\\u2013 and not personal awareness \\u2013 that gives humans their unique evolutionary advantage.\\nIf the experience of consciousness does not confer any particular advantage, it\\u2019s not clear what its purpose is. But as a passive accompaniment to non-conscious processes, we don\\u2019t think that the phenomenon of personal awareness has a purpose, in much the same way that rainbows do not. Rainbows simply result from the reflection, refraction and dispersion of sunlight through water droplets \\u2013 none of which serves any particular purpose.\\nOur conclusions also raise questions about the notions of free will and personal responsibility. If our personal awareness does not control the contents of the personal narrative which reflects our thoughts, feelings, emotions, actions and decisions, then perhaps we should not be held responsible for them.\\nIn response to this, we argue that free will and personal responsibility are notions that have been constructed by society. As such, they are built into the way we see and understand ourselves as individuals, and as a species. Because of this, they are represented within the non-conscious processes that create our personal narratives, and in the way we communicate those narratives to others.\\nJust because consciousness has been placed in the passenger seat, does not mean we need to dispense with important everyday notions such as free will and personal responsibility. In fact, they are embedded in the workings of our non-conscious brain systems. They have a powerful purpose in society and have a deep impact on the way we understand ourselves.\",\n    \"source\": \"C4\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"What is Nemaha, IA? Respond briefly.\",\n    \"output\": \"Nemaha is a city in Sac County, Iowa, United States. The population was 85 at the 2010 census.\\nAccording to the United States Census Bureau, the city has a total area of , all of it land.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"In the book I am reading (Patrick Fairbairn, The Interpretation of Prophecy, 2nd ed., 1865), the abbreviation 'sq.' is used sometimes in referencing both Scripture and other books (as 'Typology of Scripture, vol. i. p. 100 sq.' or 'Ezek xxvi. 7, sq.'). What does 'sq.' mean in this context? (Instances of this abbreviation's usage may be found at Google Books here.).\",\n    \"output\": \"Sq. is short for the Latin word sequiturque, which means the following. So p. 100 sq. means page 100 and those following pages. Here's the source.\",\n    \"source\": \"Natural Instructions\",\n    \"subset\": \"task225 english language answer generation\"\n  },\n  {\n    \"input\": \"\\\"Describe Brugnens.\\\" Respond in 11 words.\",\n    \"output\": \" Brugnens is a commune in the Gers department in southwestern France.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"instruction\"\n  },\n  {\n    \"input\": \"What is the DeLand Municipal Airport? Respond in 1 sentences.\",\n    \"output\": \"DeLand Municipal Airport , also known as Sidney H. Taylor Field, is a city-owned public-use airport located three nautical miles (5\\u00a0km) northeast of the central business district of DeLand, a city in Volusia County, Florida, United States.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"Respond in 203 words. Describe the plot of the 1976 Filipino drama film \\\"Minsa'y Isang Gamu-gamo\\\".\",\n    \"output\": \"Minsa'y Isang Gamu-gamo (\\\"\\\") is a 1976 Filipino drama film directed by Lupita Aquino-Kashiwahara and written by Marina Feleo-Gonzales. It concerns a Filipino nurse, Cora de la Cruz, who dreams of moving to America. When her brother is killed, her ideas change. The film criticises American military presence in the Philippines. It was an entry to the 1976 Metro Manila Film Festival and won five awards at the 25th FAMAS Awards, including those for best picture, director and screenplay. The film was restored by the ABS-CBN Film Restoration through Kantana Post Production in Thailand and the restored version was released on October 12, 2018 as part of Cinema One Originals film festival.\\nThe de la Cruises and the Santos are two lower-middle-class families who live in Pampanga. Cora de la Cruz is a nurse who dreams of living in the United States. Her papers are ready and she organises employment at an American hospital. She hopes to get a green card allowing her to stay, achieve immigrant status and then bring her family to America for a better life. Cora's mother and younger brother encourage her but her father, Ingkong, disagrees. He believes that moving to America is a betrayal of their country.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"instruction\"\n  },\n  {\n    \"input\": \"Hey, when does Marvel's Runaways Season 2 start?\",\n    \"output\": \"The second season picks up where the first left off.\\nHulu has announced that the second season of Marvel's Runaways will premiere December 21st, Deadline reports. And there won't be any weekly installments this time around as there were with the first season because all 13 episodes are getting dropped at once. Hulu renewed the series for a second season just as the first one wrapped up.\\nSeason 2 will kick off right where the first ended -- with the kids having escaped their evil parents and setting out their own. Now they'll have to figure out how to live with each other, how to take care of themselves and how to take down their parents' nefarious Pride squad. But they might also have to sniff out a mole since someone's sending the Pride leader mysterious messages. Meanwhile, the Pride is bound and determined to get them back.\",\n    \"source\": \"C4\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"I need to serve at a dinner party, but I'm not sure where to start. Can you give me 10 actionable steps?\",\n    \"output\": \"Which way should you pass the plates? How should you clear the table? Serving at a dinner party isn't exactly as easy as pie. Here are some basic guidelines to help you out in your next dinner party.\\n1. Steamed or roasted vegetables, rice, salad, potatoes, and condiments are some of the things you can usually leave in the middle of the table for guests to help themselves to. This isn't strictly necessary, though. If you prefer to keep things simple, you don't have to leave anything more than salt and pepper on the table.\\n2. This is known as \\\"plating-up.\\\" Don't expect guests to help themselves to food that requires assembly or is difficult to manage. The only way they'll get to see your creativity is if you put it out on the plate for them before you even get to the dinner table.\\nKeep the rim of the plate clear of sauces, spills, anything; it frames the food.\\n3. You can do this if you want to keep with tradition and the occasion is quite formal. Otherwise, choose one end of the table and move around it in a clockwise fashion, regardless of the genders of your guests.\\n\\nPass all dishes from the left. Guests and servers should pass dishes from left. The logic behind this is that most people are right-handed and this allows them them to serve themselves from the dish while it is being held by the passer. Nowadays it is less likely that the passer will continue to hold the dish, but will expect you to take it, so if you are left-handed, it won't be a problem. Place the dish down on your side plate to serve from it.\\n\\nAs the cook, or host, always serve yourself last. This is polite and also sensible, since you'll probably be busy anyway with host's duties.\\n4. They'll get fidgety, anxious and gossipy about what you're doing.\\n5. On the other hand, do not ever go into details about how the flesh portion of the meal was hunted/killed. This is bad taste and makes some guests very queasy. Leave it for discussion around the fireplace with a like-minded friend after dinner.\\n6. The host or hired help should clear no more than two plates at a time to avoid bumping guests and interfering with their eating. There is nothing more annoying than the server's elbows in your face when you're just about to take the next bite.\\n7. Preferably the noises should not reach the guests but this is unrealistic for most homes. Just do it as quietly as possible and try not to clank, crack, break or drop the dishes. The last thing you need on top of anything else is a dropped plate to clean up.\\n8. This means all the dishes on the table, the condiments and the side plates. If you haven't already set out the dessert spoons, this is the time to do so.\\n9. Chocolates will have their own method of getting around the table; it has been suspected they have legs...\\n10. Don't take advantage of the situation because that guest is there to enjoy himself too but don't hesitate to ask for a quick hand with a simple task that won't risk spills on their clothes.\\n\",\n    \"source\": \"WikiHow\",\n    \"subset\": \"main\"\n  },\n  {\n    \"input\": \"What can you tell me about Brushcreek Township? Respond in 2 sentences.\",\n    \"output\": \"Brushcreek Township is one of the seventeen townships of Highland County, Ohio, United States. As of the 2010 census the population was 1,381, of whom 1,248 lived in the unincorporated portion of the township.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"chatbot\"\n  }\n]\n"
  },
  {
    "path": "tests/data/_fixtures/longform_val.json",
    "content": "[\n  {\n    \"input\": \"The Big Mistake\\n\\nThis day was full of joy and happiness, but something went wrong after when she told him that she is pregnant. \\nShe was so happy about this news and she thought that he will be happy as well, but... At the moment when she told him that she is pregnant, he just frozen. It was too much for him.\\nHe thought it's easier to walk away and avoid this conversation with her. He tried to understand what's going on with him. It was weird how many questions and doubts came up on his head. One part of him was really happy, because to be a father is his one of the biggest dream, but on the other hand he couldn't handle this news. \\nThis woman, who stands next to him, is his girlfriend, his future wife. He should be happy and hug her and celebrate this good news. But somehow he didn't feel anything. All that he felt was anxiety. He couldn't handle the feeling, when she tried to touch him, to kiss and hug him. He knew that she loves him as much as he loves her, but everything was just too wrong on this moment.  Her voice, her smile and even her body was to wrong on this moment.\\nShe couldn't understand what\\u2019s going on, because it was just three weeks ago when they had a conversation about kids and family and he said that he would like to have kids with her. And now, he just avoids her and he is not happy about the news. \\nIt was a big mess. She was more than happy and he couldn't handle this news. They both knew that they love each other and they care about each other, but it was too much for him.  He stood up and said to her girlfriend that he needs a break and he have to be alone.  Michael closed the door and knew at that moment he had made a mistake.\\n\\nFix all grammar errors and style issues and rewrite this document\",\n    \"output\": \"The Big Mistake\\n\\nThis day was full of joy and happiness, but something went wrong  when she told him that she was pregnant. \\nShe was so happy about this news and she thought that he would be happy as well, but... At the moment when she told him that she was pregnant, he just frozen. It was too much for him.\\nHe thought it was easier to walk away and avoid this conversation with her. He tried to understand what was going on with him. It was weird how many questions and doubts came into his head. One part of him was really happy, because to be a father was one of his  biggest dreams, but on the other hand, he couldn't handle this news. \\nThis woman, who was standing next to him, was his girlfriend, his future wife. He should be happy and hug her and celebrate this good news. But somehow he didn't feel anything. All that he felt was anxiety. He couldn't handle the feeling, when she tried to touch him, to kiss and hug him. He knew that she loved him as much as he loved her, but everything was just all wrong at this moment.  Her voice, her smile and even her body was  wrong on this moment.\\nShe couldn't understand what was going on, because it was just three weeks ago that they had a conversation about kids and family and he said that he would like to have kids with her. And now, he was just avoiding her and he was not happy about the news. \\nIt was a big mess. She was more than happy and he couldn't handle this news. They both knew that they loved each other and they cared about each other, but it was too much for him.  He stood up and said to his girlfriend that he needed a break and he had to be alone.  Michael closed the door and knew at that moment he had made a mistake.\\n\",\n    \"source\": \"BEA-GEC\",\n    \"subset\": \"main\"\n  },\n  {\n    \"input\": \"Although the fact that personal vehicle like car now is popular, in the future, the stance of public transport is still very crucial. \\n\\nNowaday, people use fossil fuel as the most important source of energy for travelling, manufacturing, this source of energy doesn't always maintain at the constant level, but someday it will be run out. People must find ways to use other energy sources like nuclear electricity, wind power, public transportation such as electric metro use this kind of energy. As a result, governments will invest more in researching the usage of these new types of energy, travelling using public transport will become much cheaper than personal car. The range of public services will be expanded to remote areas, it become much more convenient. Commuting, travelling using public services will be the first prior choice.\\n\\nIn addition, the population is increasing rapidly, in the near future, almost cities in the world will become overcrowded, It will lead to many other problems to be solved, for example, traffic jams, air pollution. In some perespectives, these problems are led by the over-using of personal vehicle. People will realise that using their own car is just to waste their time for traffic jams or to gradually killing themselves for breathing polluted air. So the best way for them to travel is public transpotations. \\nEach person should practice saving when using any source of eneny to protect his own live. \\nIn conclusion, investments on developing public transport will be increased considerably, public services of travelling have a bright potential and it's stance in the future can't be replaceble\\n\\nTweak this document to make it read smoother and fix any mistakes you see.\",\n    \"output\": \"Despite the fact that personal vehicles like the car are now popular, in the future, the existence of public transport will still absolutelyery crucial. \\n\\nNowadays, people use fossil fuels as the most important source of energy for travelling and manufacturing. This source of energy doesn't always remain at a constant level, and someday it will  run out. People must find ways to use other energy sources like nuclear power and wind power. Public transportation such as the electric metro use this kind of energy. As a result, governments will invest more in researching the use of these new types of energy, and travelling using public transport will become much cheaper than privatepersonal car. The range of public services will be expanded to remote areas, and it will become much more convenient. Commuting and travelling using public services will be the first priority.\\n\\nIn addition, the population is increasing rapidly. In the near future, most cities in the world will become overcrowded. This will lead to many other problems that will need to be solved. For example, traffic jams and air pollution. In some ways, these problems are caused by the overuse of personal vehicles. People will realise that using their own car is just  wasting their time in traffic jams or  gradually killing themselves by breathing polluted air. So the best way for them to travel is public transport. \\nEach person should practice saving energy when using any source of energy to protect his own life. \\nIn conclusion, investments in developing public transport will be increased considerably. Public transport services have a bright future and their existence in the future can't be replaced.\\n\",\n    \"source\": \"BEA-GEC\",\n    \"subset\": \"main\"\n  },\n  {\n    \"input\": \"What can you tell me about Nang Ka Teih? Respond in 58 words.\",\n    \"output\": \"Founded by the people of Chin (Zo), Nang Ka Teih is a town ward within Sagain Division or next to Khampat town in the Myanmar state of Chin and Sagain Division. It has a very high concentration of the religion. Nang Ka Teih is one of the best location in Chin State and listed among the peace area.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"I'm having a serious argument with a friend on the status of the word 'suck' when I used it about him by saying 'You suck!' because he missed a train. We are both non-native English speakers. He claimed that I used it as a foul or slang term. I vehemently disagreed with him about it. I told him that the word 'suck' can be used to describe something inefficient or not good enough as well like it is used in the example 'Samsung mobile sucks'. Am I right to describe the use of the word 'suck' in such context?.\",\n    \"output\": \"Yes, you can use the word in the way that you have described, but it's considered more harsh than polite, and it has somewhat vulgar overtones. How it's regarded or received might be generational. I typed is suck vulgar? on Google, and found mixed responses. Feel free to do the same if you want diverse opinions on the matter. I thought this excerpt from a blog post, though, was worth pasting into an answer here: Some may not believe this, but suck as in 'Man, this class sucks' was also in the raw obscenity category when I was a teenager. It was used plenty in the school hallways but not in front of your teacher and never in front of your mother. I remember some agitation by certain culturally-advanced youngsters who tried to railroad their elders into accepting sucks as a safe and harmless substitute for stinks. The elders weren't having any of it, last I checked, but the liberalizing linguists seem to have carried the day. I have always assumedrightly or wrongly, I do not know that the word was originally intended to carry sexual overtones, which was the reason for its suppression. Today, the sexual overtones are either forgotten or are now acceptable in mixed company. I'm not sure which explanation disturbs me more. I think you and your friend are unlikely to come up with an agreed-upon viewpoint, because you're both right in a way. Feel free to use it on message boards and the like when you want to express a negative opinion, but realize you'll risk sounding a bit uncouth to some when you do. Then again, maybe I'm just showing my age here. As a footnote, you might want to check out our sister site, English Language Learners.\",\n    \"source\": \"Natural Instructions\",\n    \"subset\": \"task225 english language answer generation\"\n  },\n  {\n    \"input\": \"How is BeeHighve Inc. in Corner Brook, NL infusing cannabis with honey products and bringing them to markets, both local and global?\",\n    \"output\": \"BeeHighve CEO Rita Hall intends to bring Newfoundland honey and honey-based products to the market, some infused with cannabis, while others will be \\\"buzz free.\\\"\\nA selection of some of the products available from Corner Brook, N.L.-based BeeHighve Inc.\\nThere's a lot of buzz around a joint venture between two Newfoundland companies that want to bring cannabis-infused honey and honey products to local and global markets.\\nBeeHighve Inc., based in Corner Brook, will be cultivating the cannabis crop, while G and M Family Farms, near Placentia, will supply the honey.\\n\\\"It ranges from pure honey to sauces as well as chocolate and health bars, and everything is organic,\\\" said BeeHighve CEO Rita Hall. \\\"And everything is very healthy for you. We don't use sugar in our products.\\\"\\nThe partnership is the brainchild of Hall, a trailblazer who is on track to become the first Indigenous woman to gain licensed producer [of marijuana] status in Canada, as well as one of the country's few Indigenous female CEOs.\\nHall intends for Newfoundland honey to be the backbone of her operation, and utilizes it in all the products, including the flagship Nuts About Honey bars.\\nAlthough no official date has been set for the legalization of cannabis-infused edibles \\u2014 with recreational marijuana use legal as of Oct. 17, 2018 \\u2014 BeeHighve plans to produce the same line of products, without the cannabis elements before and after the legalization of edibles.\\n\\\"I don't think it's going to have a negative impact on the business at all. The honey is really generating a lot of interest,\\\" Hall said.\\nPart of the interest is because of Newfoundland's uniquely thriving honey-bee population, who enjoy the benefits of a closed ecosystem comparatively free from mites and diseases associated with colony collapse.\\n\\\"They love the idea of mite-free, antibiotic-free honey. So I have no doubt that the production and sale of honey and the consumables without cannabis infusions will go very well.\\\"\\nEventually BeeHighve intends to get into the beekeeping business as well, allowing them to produce larger volumes of product in less time.\\nAside from the plans to export the cannabis-infused products \\u2014 where legal \\u2014 as well as the \\\"buzz-free\\\" ones globally, BeeHighve is looking to expand its production to another province.\\nPartnering with the Madawaska Maliseet First Nation reserve in New Brunswick, crops will be cultivated on the reserve to later be infused with Newfoundland honey. Hall believes the partnership will be a fruitful one, and has great respect for the Madawaska Maliseet, whose senior leadership is entirely made of women.\\n\\\"It shows the strength of women in any marketplace. Women are really underrepresented in the cannabis industry right now, so it's, I'll say, a feather in our caps, no pun intended, to be a part of this industry as well.\\\"\\nHall has just as much faith in this venture as the one in Newfoundland.\\n\\\"We're very strong women and I think very successful, and we'll succeed at this as well.\\\"\",\n    \"source\": \"C4\",\n    \"subset\": \"search-engine\"\n  },\n  {\n    \"input\": \"Respond briefly. What is the history of the Eagle Ranger Station?\",\n    \"output\": \"The Eagle Ranger Station, also known as the Eagle Guard Station and presently known as the Sol Duc Ranger Station, is a complex of three buildings built in the 1930s in what would become Olympic National Park. The primary structures were built by the U.S. Forest Service in what was at the time the Olympic National Forest., While the main residence was built by the USFS, the generating plant and landscaping were built by the National Park Service using labor provided by the Civilian Conservation Corps.\",\n    \"source\": \"Wikipedia\",\n    \"subset\": \"chatbot\"\n  },\n  {\n    \"input\": \"How to deal with non negotiable change\\n\\nOne of the issues my organisation has is where change requests are requested during a sprint cycle in a particular high volume worksteam. The development team are often forced to accept these change requests, where any resistance results in an escalation, which results with somebody high up in the business just telling the dev team to do it. Thus, team is constantly over-committing every sprint cycle.\\nI have highlighted this as an issue to Senior management, showing the impact the additional work is having on morale and quality, but nothing seems to be changing at org level, and they seem to think the problem is at team level. Hence, keep trying to address the issue by adding and removing Scrum Masters; where they all are having the same problem.\\nWhat is the best way to address this?\\n\\nA:\\n\\nTL;DR\\nYou have both a political problem and a process problem. There is no silver bullet, and you can't fix these problems solely within the team. Instead, you must rely on the agile concepts of transparency and visibility to ensure that the costs associated with bypassing the process are charged to the project budget.\\nMake Costs Visible\\nIn agile frameworks, change is never non-negotiable. If you are faced with \\\"non-negotiable\\\" requirements that make the goal or plan for the current iteration obsolete, then the team must scrap the plan.\\nThis is true regardless of the agile framework you're using. In Scrum, the Product Owner cancels the Sprint and the Scrum Team returns to Sprint Planning. In Kanban, work is pulled off the board to meet work-in-progress (WIP) limits, and queued in the appropriate bucket until capacity is once again available. Your specific framework may vary in the way it implements this freeing up of capacity, but it's an essential component of any viable agile (and even non-agile) process.\\nThe cost of replanning and re-queuing is made visible through metrics such as cycle time, cumulative flow, velocity, burn-down, or other framework- and implementation-specific metrics. Communicating about these costs is also essential. By presenting the data to stakeholders and company officers, you place the responsibility for those costs squarely where it belongs.\\nSchedule risk (like any risk) can be accepted, transferred, or mitigated. Which of these three options is best is a business decision, not a team decision. Regardless of the choice, senior management owns both the choice and the outcome.\\nSee Also\\nhttps://pm.stackexchange.com/a/10829/4271\",\n    \"output\": \"I agree with Todd about it being a political and process problem.\\nI want to add that apart from showing them the impact on morale and quality -which I believe that they are quite important for your team- I would recommend having a discussion with them to understand why they believe it is useful to have this way of working and also prepare for them a \\\"business case\\\" that explores what is they key impact of this way of working for the senior management (trying to talk their own language basically). \\nMost of the times, this comes down to how much money they are losing by having this way of working. Examples of the impact to senior management are:\\n\\nDelayed features due to the introduction of these changes \\nBugs introduced due to quality issues that required to introduce more tasks in future releases.\",\n    \"source\": \"StackExchange\",\n    \"subset\": \"pm\"\n  },\n  {\n    \"input\": \"Use of LaTeX commands\\n\\nIt appears that LaTeX commands cannot be used on SE Linguistics, but they can on other SE sites. Personally, I wanted to use this feature a number of times to insert a formula and here's another question with the same problem. Expressing formulae can be useful when discussing things as simple as the number of combinations that arise from the different conditions in an experiment, or when discussing statistics questions relevant to linguistics. \\nCould we please allow the use of LaTeX commands? I don't think having this feature will be a problem for anyone not wanting to use it and it will help those who do.\\n\\nA:\\n\\nAbsolutely subscribing here!\\nI am surprised that this is not a feature already, I am really missing TeX support - for a wide range of uses that I consider essential in the field of linguistics:  \\n\\nMathematical formulas.\\nAs someone who frequently answers questions on formal semantics, like here, here and here, I feel that answering (and asking) these questions in the current state is a mess, to an extent where it has sometimes kept me from writing answers to those questions altogether, simply because the process of setting them up is so annoying.\\nWhile I am aware that there are tools which ease the inclusion of Unicode symbols as HTML, all of these solutions (at least the ones I am aware of) still make typing more than one line of math quite cumbersome given that every symbol has to be clicked on individually, but worse than that, many characters and necessary formatting options, like pretty much half of the inventory I needed in this linguistically originated question, aren't even available as predefined Unicode symbols at all. Not to mention the fact that the output is visually not particularly pleasing and often makes the question hard to read, as can well be seen in this post.\\nBesides, I doubt that the average user is aware of the existence of such tools; it took me more than a year on SE to find out about these, while new users have repeatedly been asking for LaTeX support, because this what users are used to.\\nApart from formulas for questions on logic, MathJax would prove useful in questions involving statistics or combinatorics, which time and time again turn up this site.  \\nIPA, as already discussed here and here, to frequently be used in posts like here, here, here and many posts by other users.\\nAt this point I am not sure whether enabling MathJax will also provide the opportunity to support other LaTeX packages, but if so, being able to use tipa would be of great help in typing IPA symbols.\\nI am aware of IPA typeit which I have used so far, and some other external tools, but again, clicking is cumbersome, and users will be more comfortable writing questions and answers with the tools thy are used to (and the standard way of typesetting IPA symbols is LaTeX tipa).     \\nOther, less urgent cases where I could imagine TeX support to be useful - again under the premise that support for special packages is in principal possible - include\\n\\ntree structures, most importantly syntax trees, like here, here and here; a similar request was already made on meta here.\\nlinguistic example formatting, which I could have used here, and interlinear glosses for posts like here, which are rather inconvenient to do appropriately without automatic alignment.\\nThe current solution to these use cases is to typeset them externally and then include them in the post via a screenshot - I could live with this solution for these particular cases, but if MathJax comes, then this would certainly be a nice for these less frequent use cases as well.  \\n\\nA:\\n\\nIt won't be a problem, but we need to show that the site actually needs it. \\nOn Physics SE, the same question has been asked and you can see it got at least 32 votes. If there is not enough support from the community, which means enough people that want it, then it's probably not gonna happen. It's the same as for Migration paths in other sites, you need to prove there is a case for it to be implemented.\\nIf it gets enough support, it'll be added. \",\n    \"output\": \"A clear yes. And do not forget to include tipa support for IPA in LaTeX, I find tipa input from the keyboard very convenient (compared to searching, cutting, and pasting IPA characters from a Unicode character table).\",\n    \"source\": \"StackExchange\",\n    \"subset\": \"linguistics\"\n  },\n  {\n    \"input\": \"In this task, you are given text for US Congressional and California state bills, your task is to generate a summary for this bill.\\n\\\"SECTION 1. SHORT TITLE.\\n\\n    This Act may be cited as the ``Sequestration Relief Act of 2013''.\\n\\nSEC. 2. FINDINGS AND PURPOSE.\\n\\n    (a) Findings.--Congress finds the following:\\n            (1) Congress must enact a comprehensive, deficit reduction \\n        plan to solve the country's fiscal challenges and to promote \\n        national security, economic stability, and the continued growth \\n        and prosperity of the United States.\\n            (2) The keys to a comprehensive, deficit reduction solution \\n        are increased revenues and changes in mandatory spending.\\n            (3) The Budget Control Act of 2011 was enacted to avert a \\n        default on Federal debt obligations, and it reduced \\n        discretionary spending by approximately $1 trillion through \\n        fiscal year 2021.\\n            (4) Because the Joint Select Committee on Deficit Reduction \\n        failed to recommend legislation providing an additional $1.2 \\n        trillion in deficit reduction, Federal law mandates that the \\n        additional savings be sequestered.\\n            (5) Sequestration was designed as a forcing mechanism for \\n        an agreement on a comprehensive, deficit reduction plan. It has \\n        failed to produce the intended results.\\n            (6) It no longer makes sense to rely on sequestration as a \\n        forcing mechanism for a balanced solution. The costs to our \\n        government and to the economy are too great.\\n            (7) Under sequestration, automatic, indiscriminate cuts \\n        would be applied, through fiscal year 2021, to a wide variety \\n        of discretionary spending programs to achieve $1.2 trillion in \\n        savings, forestalling the sound planning needed for prudent and \\n        meaningful investments in national security, the workforce, \\n        transportation infrastructure, education, health care, public \\n        safety, housing, innovation, small business development, and \\n        many other facets of enduring national strength.\\n            (8) Even the prospect of sequestration is disruptive to \\n        regular order and to the congressional appropriations process, \\n        and it fosters damaging economic uncertainty, while short-term \\n        solutions only suspend the prospect and continue to undermine \\n        the certainty needed for economic recovery.\\n            (9) Therefore, Congress must eliminate the threat of \\n        sequestration.\\n            (10) Given the magnitude of the Federal deficit, it is \\n        likely that additional cuts to discretionary spending will be \\n        necessary for a comprehensive deficit reduction solution.\\n            (11) Congress must establish a manageable, long-term \\n        discretionary spending plan. An additional $320 billion in \\n        targetable cuts to discretionary appropriations from fiscal \\n        year 2014 through fiscal year 2021 represents one-third of the \\n        net amount that would have been indiscriminately cut by \\n        sequestration over fiscal years 2013 through 2021.\\n            (12) It is recognized that a reduction of $167 billion to \\n        discretionary appropriations within budget function 050 from \\n        fiscal year 2014 through fiscal year 2021 will affect the \\n        National Military Strategy. The Department of Defense is highly \\n        encouraged to revisit its current strategic guidance and to \\n        work closely with Congress in building a new National Military \\n        Strategy that accounts for available resource levels.\\n    (b) Purposes.--The purposes of this Act are to--\\n            (1) eliminate the threat of sequestration to the American \\n        economy;\\n            (2) offer the Federal Government, industry, and the \\n        American people the predictability that economic recovery \\n        demands;\\n            (3) enable the Congress to pass appropriations legislation \\n        in regular order with a clear discretionary spending budget and \\n        grant the legislative and executive branches of government the \\n        flexibility needed to identify and implement specific \\n        discretionary spending reductions in a responsible and \\n        deliberate manner; and\\n            (4) provide a practicable, long-term discretionary spending \\n        plan that will contribute to a comprehensive, balanced, long-\\n        term, deficit reduction solution that includes affordable \\n        revisions to mandatory spending and new revenues.\\n\\nSEC. 3. REPEAL OF SECTION 251A SEQUESTRATIONS.\\n\\n    Section 251A of the Balanced Budget and Emergency Deficit Control \\nAct of 1985 is repealed.\\n\\nSEC. 4. $320 BILLION REDUCTION IN DISCRETIONARY SPENDING LIMITS.\\n\\n    The discretionary spending limits set forth in paragraphs (3) \\nthrough (10) of section 251(c) of the Balanced Budget and Emergency \\nDeficit Control Act of 1985 are amended to read as follows:\\n            ``(3) for fiscal year 2014--\\n                    ``(A) for the security category, $546,000,000,000 \\n                in budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $501,000,000,000 in budget authority;\\n            ``(4) with respect to fiscal year 2015--\\n                    ``(A) for the security category, $550,000,000,000 \\n                in new budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $505,000,000,000 in new budget authority;\\n            ``(5) with respect to fiscal year 2016--\\n                    ``(A) for the security category, $559,000,000,000 \\n                in new budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $513,000,000,000 in new budget authority;\\n            ``(6) with respect to fiscal year 2017--\\n                    ``(A) for the security category, $569,000,000,000 \\n                in new budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $522,000,000,000 in new budget authority;\\n            ``(7) with respect to fiscal year 2018--\\n                    ``(A) for the security category, $579,000,000,000 \\n                in new budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $531,000,000,000 in new budget authority;\\n            ``(8) with respect to fiscal year 2019--\\n                    ``(A) for the security category, $589,500,000,000 \\n                in new budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $541,000,000,000 in new budget authority;\\n            ``(9) with respect to fiscal year 2020--\\n                    ``(A) for the security category, $602,500,000,000 \\n                in new budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $553,000,000,000 in new budget authority;\\n            ``(10) with respect to fiscal year 2021--\\n                    ``(A) for the security category, $616,000,000,000 \\n                in new budget authority; and\\n                    ``(B) for the nonsecurity category, \\n                $565,000,000,000 in new budget authority;''.\\n\\nSEC. 5. DEFINITION OF SECURITY CATEGORY.\\n\\n    Section 250(c)(4)(B) of the Balanced Budget and Emergency Deficit \\nControl Act of 1985 is amended to read as follows:\\n            ``(B)(i) For fiscal years 2012 and 2013, the term `security \\n        category' means discretionary appropriations associated with \\n        agency budgets for the Department of Defense, the Department of \\n        Homeland Security, the Department of Veterans Affairs, the \\n        National Nuclear Security Administration, the intelligence \\n        community management account (95-0401-0-1-054), and all budget \\n        accounts in budget function 150 (international affairs).\\n            ``(ii) For fiscal years 2014 through 2021, the term \\n        `security category' means discretionary appropriations in \\n        budget function 050 (national defense).''.\\n\\nSEC. 6. SUSPENSION OF STATUTORY LIMIT ON THE PUBLIC DEBT UNTIL FEBRUARY \\n              1, 2017.\\n\\n    Section 2 of the No Budget, No Pay Act of 2013 is amended--\\n            (1) in subsection (a), by striking ``May 18, 2013'' and \\n        inserting ``January 31, 2017''; and\\n            (2) in subsection (b), by striking ``May 19, 2013'' each \\n        place it appears and inserting ``February 1, 2017''.\\\"\",\n    \"output\": \"Sequestration Relief Act of 2013 - Amends the Balanced Budget and Emergency Deficit Controlled Act of 1985 (Gramm-Rudman-Hollings Act), as amended by the Budget Control Act of 2011, to repeal automatic sequestration to enforce a specified budget goal. Reduces the discretionary spending limits for security and nonsecurity categories by $320 billion for FY2014-FY2021. Confines the current meaning of &quot;security category&quot; under the Gramm-Rudman-Hollings Act to FY2012-FY2013. Revises the definition of &quot;security category&quot; for FY2014-FY2021 to mean discretionary appropriations in all of budget function 050 (national defense). Amends the No Budget, No Pay Act of 2013 to suspend through January 31, 2017, the current $16.394 trillion public debt limit. Postpones until February 1, 2017, an automatic increase in the public debt limit to the extent that: (1) the face amount of obligations issued and the face amount of obligations whose principal and interest are guaranteed by the federal government (except guaranteed obligations held by the Secretary of the Treasury) outstanding on February 1, 2017, exceeds (2) the face amount of such obligations outstanding on the date of enactment of the No Budget, No Pay Act of 2013 (February 24, 2013). Prohibits an obligation from being taken into account unless its issuance was necessary to fund a commitment incurred by the federal government that required payment before February 1, 2017.\",\n    \"source\": \"Natural Instructions\",\n    \"subset\": \"task1658 billsum summarization\"\n  }\n]\n"
  },
  {
    "path": "tests/data/test_alpaca.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom litgpt.data import Alpaca\nfrom litgpt.prompts import Alpaca as AlpacaPromptStyle\n\n\ndef test_alpaca(mock_tokenizer, alpaca_path):\n    alpaca = Alpaca(val_split_fraction=0.5, download_dir=alpaca_path.parent, file_name=alpaca_path.name, num_workers=0)\n    assert isinstance(alpaca.prompt_style, AlpacaPromptStyle)\n    alpaca.connect(mock_tokenizer, batch_size=2, max_seq_length=10)\n    alpaca.prepare_data()\n    alpaca.setup()\n\n    train_dataloader = alpaca.train_dataloader()\n    val_dataloader = alpaca.val_dataloader()\n\n    assert len(train_dataloader) == 6\n    assert len(val_dataloader) == 6\n\n    train_batch = next(iter(train_dataloader))\n    val_batch = next(iter(val_dataloader))\n\n    assert train_batch.keys() == val_batch.keys() == {\"input_ids\", \"labels\", \"token_counts\"}\n    for key in [\"input_ids\", \"labels\"]:\n        assert train_batch[key].shape == (2, 10), f\"Unexpected shape for train_batch[{key}]\"\n        assert val_batch[key].shape == (2, 10), f\"Unexpected shape for val_batch[{key}]\"\n\n    assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle)\n    assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle)\n\n    # has attributes from super class `LightningDataModule`\n    assert alpaca.prepare_data_per_node\n"
  },
  {
    "path": "tests/data/test_base.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom typing import Optional\n\nimport pytest\nimport torch\n\nfrom litgpt.data.base import SFTDataset, get_sft_collate_fn\nfrom litgpt.prompts import PromptStyle\n\n\n@pytest.mark.parametrize(\"mask_prompt\", [True, False])\n@pytest.mark.parametrize(\"ignore_index\", [-1, -100])\n@pytest.mark.parametrize(\"max_seq_length\", [1000, 5, -1])\ndef test_sft_dataset(max_seq_length, ignore_index, mask_prompt, mock_tokenizer):\n    class Style(PromptStyle):\n        def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str:\n            return f\"In: {prompt} Out:\"\n\n    i = ignore_index\n    data = [{\"instruction\": \"Foo\", \"output\": \"Bar\"}, {\"instruction\": \"Boo\", \"output\": \"Ahh\"}]\n\n    dataset = SFTDataset(\n        data=data,\n        tokenizer=mock_tokenizer,\n        prompt_style=Style(),\n        mask_prompt=mask_prompt,\n        ignore_index=ignore_index,\n        max_seq_length=max_seq_length,\n    )\n    assert len(dataset) == len(data)\n\n    expected_input_ids = torch.tensor([73, 110, 58, 32, 70, 111, 111, 32, 79, 117, 116, 58, 66, 97, 114, 1])\n    # If prompt is not masked, labels == input_ids\n    expected_labels = (\n        torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 1]) if mask_prompt else expected_input_ids\n    )\n\n    if max_seq_length == -1:\n        assert torch.equal(dataset[0][\"input_ids\"], expected_input_ids)\n        assert torch.equal(dataset[0][\"labels\"], expected_labels)\n    else:\n        assert torch.equal(dataset[0][\"input_ids\"], expected_input_ids[:max_seq_length])\n        assert torch.equal(dataset[0][\"labels\"], expected_labels[:max_seq_length])\n\n\n@pytest.mark.parametrize(\"ignore_index\", [-1, -100])\n@pytest.mark.parametrize(\"pad_id\", [0, 100])\ndef test_sft_collate_fn_padding(pad_id, ignore_index):\n    collate = get_sft_collate_fn(pad_id=pad_id, ignore_index=ignore_index)\n    samples = [\n        {\n            \"input_ids\": torch.tensor([1, 2, 3]),\n            \"labels\": torch.tensor([10, 20, 30]),\n            \"token_counts\": {\"raw\": 3, \"raw_plus_prompt_template\": 25},\n        },\n        {\n            \"input_ids\": torch.tensor([4, 5, 6, 7, 8]),\n            \"labels\": torch.tensor([40, 50, 60, 70, 80]),\n            \"token_counts\": {\"raw\": 5, \"raw_plus_prompt_template\": 27},\n        },\n    ]\n    expected = {\n        \"input_ids\": torch.tensor([[1, 2, 3, pad_id, pad_id], [4, 5, 6, 7, 8]]),\n        \"labels\": torch.tensor([[10, 20, 30, ignore_index, ignore_index], [40, 50, 60, 70, 80]]),\n        \"token_counts\": {\"raw\": torch.tensor([[3], [5]]), \"raw_plus_prompt_template\": torch.tensor([[25], [27]])},\n    }\n    batch = collate(samples)\n    assert all(torch.equal(batch[k], expected[k]) for k in (\"input_ids\", \"labels\"))\n    for key in (\"raw\", \"raw_plus_prompt_template\"):\n        assert torch.equal(batch[\"token_counts\"][key], expected[\"token_counts\"][key]), f\"Token count mismatch for {key}\"\n\n\ndef test_sft_collate_fn_truncation():\n    collate = get_sft_collate_fn(max_seq_length=2)\n    samples = [\n        {\n            \"input_ids\": torch.tensor([1, 2, 3]),\n            \"labels\": torch.tensor([10, 20, 30]),\n            \"token_counts\": {\"raw\": 3, \"raw_plus_prompt_template\": 25},\n        },\n        {\n            \"input_ids\": torch.tensor([4, 5, 6, 7, 8]),\n            \"labels\": torch.tensor([40, 50, 60, 70, 80]),\n            \"token_counts\": {\"raw\": 5, \"raw_plus_prompt_template\": 27},\n        },\n    ]\n    expected = {\n        \"input_ids\": torch.tensor([[1, 2], [4, 5]]),\n        \"labels\": torch.tensor([[10, 20], [40, 50]]),\n        \"token_counts\": {\"raw\": torch.tensor([[3], [5]]), \"raw_plus_prompt_template\": torch.tensor([[25], [27]])},\n    }\n    batch = collate(samples)\n    assert all(torch.equal(batch[k], expected[k]) for k in (\"input_ids\", \"labels\"))\n    for key in (\"raw\", \"raw_plus_prompt_template\"):\n        assert torch.equal(batch[\"token_counts\"][key], expected[\"token_counts\"][key]), f\"Token count mismatch for {key}\"\n"
  },
  {
    "path": "tests/data/test_deita.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom unittest import mock\n\nfrom litgpt.data import Deita, SFTDataset\nfrom litgpt.data.deita import format_dataset\nfrom litgpt.prompts import Alpaca as AlpacaPromptStyle\n\n\ndef test_format_dataset():\n    data = [\n        {\n            \"prompt\": \"prompt1\",\n            \"prompt_id\": \"1\",\n            \"messages\": [\n                {\"content\": \"question1\", \"role\": \"user\"},\n                {\"content\": \"response1\", \"role\": \"assistant\"},\n                {\"content\": \"question2\", \"role\": \"user\"},\n                {\"content\": \"response2\", \"role\": \"assistant\"},\n            ],\n        },\n        {\n            \"prompt\": \"prompt2\",\n            \"prompt_id\": \"2\",\n            \"messages\": [\n                {\"content\": \"question3\", \"role\": \"user\"},\n                {\"content\": \"response3\", \"role\": \"assistant\"},\n                {\"content\": \"question4\", \"role\": \"user\"},\n                {\"content\": \"response4\", \"role\": \"assistant\"},\n            ],\n        },\n    ]\n\n    assert format_dataset(data, include_multi_turn_conversations=False) == [\n        {\"instruction\": \"question1\", \"output\": \"response1\", \"input\": \"\"},\n        {\"instruction\": \"question3\", \"output\": \"response3\", \"input\": \"\"},\n    ]\n    assert format_dataset(data, include_multi_turn_conversations=True) == [\n        {\"instruction\": \"question1\", \"output\": \"response1\", \"input\": \"\"},\n        {\"instruction\": \"question2\", \"output\": \"response2\", \"input\": \"\"},\n        {\"instruction\": \"question3\", \"output\": \"response3\", \"input\": \"\"},\n        {\"instruction\": \"question4\", \"output\": \"response4\", \"input\": \"\"},\n    ]\n\n\n@mock.patch(\"litgpt.data.deita.format_dataset\")\n@mock.patch(\"datasets.load_dataset\")\ndef test_deita(_, format_dataset_mock, mock_tokenizer, tmp_path):\n    format_dataset_mock.return_value = [\n        {\"instruction\": \"inst1\", \"output\": \"out1\"},\n        {\"instruction\": \"inst2\", \"output\": \"out2\"},\n        {\"instruction\": \"inst3\", \"output\": \"out3\"},\n    ]\n\n    deita = Deita(num_workers=0, download_dir=tmp_path)\n    assert isinstance(deita.prompt_style, AlpacaPromptStyle)\n    deita.connect(mock_tokenizer, batch_size=2, max_seq_length=10)\n    deita.prepare_data()\n    deita.setup()\n\n    train_dataloader = deita.train_dataloader()\n    assert isinstance(train_dataloader.dataset, SFTDataset)\n    assert len(train_dataloader) == 2\n\n    val_dataloader = deita.val_dataloader()\n    assert isinstance(val_dataloader.dataset, SFTDataset)\n    assert len(val_dataloader) == 2\n\n    assert isinstance(train_dataloader.dataset.prompt_style, AlpacaPromptStyle)\n    assert isinstance(val_dataloader.dataset.prompt_style, AlpacaPromptStyle)\n\n    # has attributes from super class `LightningDataModule`\n    assert deita.prepare_data_per_node\n"
  },
  {
    "path": "tests/data/test_json.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport json\nfrom typing import Optional\n\nimport pytest\n\nfrom litgpt.data import JSON\nfrom litgpt.prompts import PromptStyle\n\n\n@pytest.mark.parametrize(\"as_jsonl\", [False, True])\ndef test_json(as_jsonl, tmp_path, mock_tokenizer):\n    class Style(PromptStyle):\n        def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str:\n            return f\"X: {prompt} {kwargs['input']} Y:\"\n\n    json_path = tmp_path / (\"data.jsonl\" if as_jsonl else \"data.json\")\n    mock_data = [\n        {\"instruction\": \"Add\", \"input\": \"2+2\", \"output\": \"4\"},\n        {\"instruction\": \"Subtract\", \"input\": \"5-3\", \"output\": \"2\"},\n        {\"instruction\": \"Multiply\", \"input\": \"6*4\", \"output\": \"24\"},\n        {\"instruction\": \"Divide\", \"input\": \"10/2\", \"output\": \"5\"},\n        {\"instruction\": \"Exponentiate\", \"input\": \"2^3\", \"output\": \"8\"},\n        {\"instruction\": \"Square root\", \"input\": \"√9\", \"output\": \"3\"},\n    ]\n\n    with open(json_path, \"w\", encoding=\"utf-8\") as fp:\n        if as_jsonl:\n            for line in mock_data:\n                json.dump(line, fp)\n                fp.write(\"\\n\")\n        else:\n            json.dump(mock_data, fp)\n\n    data = JSON(json_path, val_split_fraction=0.5, prompt_style=Style(), num_workers=0)\n    data.connect(tokenizer=mock_tokenizer, batch_size=2)\n    data.prepare_data()  # does nothing\n    data.setup()\n\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n\n    assert len(train_dataloader) == 2\n    assert len(val_dataloader) == 2\n\n    train_data = list(train_dataloader)\n    val_data = list(val_dataloader)\n\n    assert train_data[0][\"input_ids\"].size(0) == 2\n    assert train_data[1][\"input_ids\"].size(0) == 1\n    assert val_data[0][\"input_ids\"].size(0) == 2\n    assert val_data[1][\"input_ids\"].size(0) == 1\n\n    assert mock_tokenizer.decode(train_data[0][\"input_ids\"][0]).startswith(\"X: Divide 10/2 Y:5\")\n    assert mock_tokenizer.decode(train_data[0][\"input_ids\"][1]).startswith(\"X: Add 2+2 Y:4\")\n    assert mock_tokenizer.decode(train_data[1][\"input_ids\"][0]).startswith(\"X: Multiply 6*4 Y:24\")\n\n    assert mock_tokenizer.decode(val_data[0][\"input_ids\"][0]).startswith(\"X: Exponentiate 2^3 Y:8\")\n    assert mock_tokenizer.decode(val_data[0][\"input_ids\"][1]).startswith(\"X: Subtract 5-3 Y:2\")\n    assert mock_tokenizer.decode(val_data[1][\"input_ids\"][0]).startswith(\"X: Square root √9 Y:3\")\n\n    assert isinstance(train_dataloader.dataset.prompt_style, Style)\n    assert isinstance(val_dataloader.dataset.prompt_style, Style)\n\n    # has attributes from super class `LightningDataModule`\n    assert data.prepare_data_per_node\n\n\ndef test_json_input_validation(tmp_path):\n    with pytest.raises(FileNotFoundError, match=\"The `json_path` must be a file or a directory\"):\n        JSON(tmp_path / \"not exist\")\n\n    with pytest.raises(ValueError, match=\"`val_split_fraction` should not be set\"):\n        JSON(tmp_path, val_split_fraction=0.5)\n\n    data = JSON(tmp_path)\n    data.prepare_data()  # does nothing\n\n    # Empty directory\n    with pytest.raises(FileNotFoundError, match=\"must be a file or a directory containing\"):\n        data.setup()\n\n    # Only train.json exists\n    (tmp_path / \"train.json\").touch()\n    with pytest.raises(FileNotFoundError, match=\"must be a file or a directory containing\"):\n        data.setup()\n\n    # When a single file is passed without val_split_fraction, it defaults to 0.05 and warns.\n    with pytest.warns(UserWarning, match=\"Defaulting to `val_split_fraction=0.05`\"):\n        data = JSON(tmp_path / \"train.json\", val_split_fraction=None)\n    assert data.val_split_fraction == 0.05\n\n\n@pytest.mark.parametrize(\"as_jsonl\", [False, True])\ndef test_json_with_splits(as_jsonl, tmp_path, mock_tokenizer):\n    mock_train_data = [\n        {\"instruction\": \"Add\", \"input\": \"2+2\", \"output\": \"4\"},\n        {\"instruction\": \"Subtract\", \"input\": \"5-3\", \"output\": \"2\"},\n        {\"instruction\": \"Exponentiate\", \"input\": \"2^3\", \"output\": \"8\"},\n    ]\n    mock_test_data = [\n        {\"instruction\": \"Multiply\", \"input\": \"6*4\", \"output\": \"24\"},\n        {\"instruction\": \"Divide\", \"input\": \"10/2\", \"output\": \"5\"},\n    ]\n\n    train_file = tmp_path / (\"train.jsonl\" if as_jsonl else \"train.json\")\n    val_file = tmp_path / (\"val.jsonl\" if as_jsonl else \"val.json\")\n\n    with open(train_file, \"w\", encoding=\"utf-8\") as fp:\n        if as_jsonl:\n            for line in mock_train_data:\n                json.dump(line, fp)\n                fp.write(\"\\n\")\n        else:\n            json.dump(mock_train_data, fp)\n    with open(val_file, \"w\", encoding=\"utf-8\") as fp:\n        if as_jsonl:\n            for line in mock_test_data:\n                json.dump(line, fp)\n                fp.write(\"\\n\")\n        else:\n            json.dump(mock_test_data, fp)\n\n    data = JSON(tmp_path, num_workers=0)\n    data.connect(tokenizer=mock_tokenizer, batch_size=2)\n    data.prepare_data()  # does nothing\n    data.setup()\n\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n\n    assert len(train_dataloader) == 2\n    assert len(val_dataloader) == 1\n"
  },
  {
    "path": "tests/data/test_lit_data.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport sys\nfrom unittest import mock\nfrom unittest.mock import ANY\n\nimport pytest\n\nfrom litgpt.data import LitData\n\n\n@pytest.mark.skipif(sys.platform == \"win32\", reason=\"Needs to implement platform agnostic path/url joining\")\n@mock.patch(\"litgpt.data.lit_data.LitData._dataloader\")\ndef test_input_dir_and_splits(dl_mock, tmp_path):\n    with pytest.raises(ValueError, match=\"If provided `split_names` must be a tuple of two strings\"):\n        LitData(data_path=tmp_path, split_names=(\"train\",))\n\n    # local dir, no splits\n    data = LitData(data_path=tmp_path)\n    data.train_dataloader()\n    dl_mock.assert_called_with(input_dir=str(tmp_path), train=True)\n    data.val_dataloader()\n    dl_mock.assert_called_with(input_dir=str(tmp_path), train=False)\n\n    # local dir, splits\n    data = LitData(data_path=tmp_path, split_names=(\"train\", \"val\"))\n    data.train_dataloader()\n    dl_mock.assert_called_with(input_dir=str(tmp_path / \"train\"), train=True)\n    data.val_dataloader()\n    dl_mock.assert_called_with(input_dir=str(tmp_path / \"val\"), train=False)\n\n    # remote dir, splits\n    data = LitData(data_path=\"s3://mydataset/data\", split_names=(\"train\", \"val\"))\n    data.train_dataloader()\n    dl_mock.assert_called_with(input_dir=\"s3://mydataset/data/train\", train=True)\n    data.val_dataloader()\n    dl_mock.assert_called_with(input_dir=\"s3://mydataset/data/val\", train=False)\n\n\n@pytest.mark.skipif(sys.platform == \"win32\", reason=\"Needs to implement platform agnostic path/url joining\")\n@mock.patch(\"litdata.streaming.StreamingDataset\")\n@mock.patch(\"litdata.streaming.StreamingDataLoader\")\ndef test_dataset_args(streaming_dataloader_mock, streaming_dataset_mock, tmp_path):\n    data = LitData(data_path=tmp_path, seed=1000)\n    data.train_dataloader()\n    streaming_dataset_mock.assert_called_with(\n        input_dir=str(tmp_path),\n        item_loader=ANY,\n        shuffle=True,\n        seed=1000,\n    )\n    streaming_dataloader_mock.assert_called_with(\n        streaming_dataset_mock(),\n        batch_size=1,\n        pin_memory=True,\n        num_workers=8,\n        drop_last=True,\n    )\n"
  },
  {
    "path": "tests/data/test_longform.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom litgpt.data import LongForm\nfrom litgpt.prompts import Longform as LongFormPromptStyle\n\n\ndef test_longform(mock_tokenizer, longform_path):\n    longform = LongForm(download_dir=longform_path, num_workers=0)\n    assert isinstance(longform.prompt_style, LongFormPromptStyle)\n    longform.connect(mock_tokenizer, batch_size=2, max_seq_length=10)\n    longform.prepare_data()\n    longform.setup()\n\n    train_dataloader = longform.train_dataloader()\n    val_dataloader = longform.val_dataloader()\n\n    assert len(train_dataloader) == 9\n    assert len(val_dataloader) == 5\n\n    train_batch = next(iter(train_dataloader))\n    val_batch = next(iter(val_dataloader))\n\n    assert train_batch.keys() == val_batch.keys() == {\"input_ids\", \"labels\", \"token_counts\"}\n    for key in [\"input_ids\", \"labels\"]:\n        assert train_batch[key].shape == (2, 10), f\"Unexpected shape for train_batch[{key}]\"\n        assert val_batch[key].shape == (2, 10), f\"Unexpected shape for val_batch[{key}]\"\n\n    assert isinstance(train_dataloader.dataset.prompt_style, LongFormPromptStyle)\n    assert isinstance(val_dataloader.dataset.prompt_style, LongFormPromptStyle)\n\n    # has attributes from super class `LightningDataModule`\n    assert longform.prepare_data_per_node\n"
  },
  {
    "path": "tests/data/test_openwebtext.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport sys\nfrom unittest import mock\nfrom unittest.mock import ANY, call\n\nimport pytest\nfrom litdata.streaming import StreamingDataLoader, StreamingDataset\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import OpenWebText\n\n\n@pytest.mark.skipif(sys.platform == \"win32\", reason=\"Not in the mood to add Windows support right now.\")\n@mock.patch(\"litdata.optimize\")\n@mock.patch(\"litdata.streaming.dataset.subsample_streaming_dataset\", return_value=([], []))\n@mock.patch(\"datasets.load_dataset\")\ndef test_openwebtext(_, __, optimize_mock, tmp_path, mock_tokenizer):\n    data = OpenWebText(data_path=(tmp_path / \"openwebtext\"))\n    assert data.seq_length == 2048\n    assert data.batch_size == 1\n\n    data.connect(tokenizer=mock_tokenizer, batch_size=2, max_seq_length=1024)\n    assert data.seq_length == 1025\n    assert data.batch_size == 2\n\n    # Data does not exist, preprocess it\n    data.prepare_data()\n    optimize_mock.assert_has_calls(\n        [\n            call(\n                fn=ANY,\n                num_workers=ANY,\n                inputs=[],\n                output_dir=str(tmp_path / \"openwebtext\" / \"train\"),\n                chunk_bytes=\"200MB\",\n            ),\n            call(\n                fn=ANY,\n                num_workers=ANY,\n                inputs=[],\n                output_dir=str(tmp_path / \"openwebtext\" / \"val\"),\n                chunk_bytes=\"200MB\",\n            ),\n        ]\n    )\n    optimize_mock.reset_mock()\n\n    # Data exists, already preprocessed\n    (tmp_path / \"openwebtext\" / \"train\").mkdir(parents=True)\n    (tmp_path / \"openwebtext\" / \"val\").mkdir(parents=True)\n    data.prepare_data()\n    optimize_mock.assert_not_called()\n\n    data.setup()\n\n    train_dataloader = data.train_dataloader()\n    assert isinstance(train_dataloader, StreamingDataLoader)\n    assert isinstance(train_dataloader.dataset, StreamingDataset)\n\n    val_dataloader = data.val_dataloader()\n    assert isinstance(val_dataloader, DataLoader)\n    assert isinstance(val_dataloader.dataset, StreamingDataset)\n\n    # has attributes from super class `LightningDataModule`\n    assert data.prepare_data_per_node\n"
  },
  {
    "path": "tests/data/test_textfiles.py",
    "content": "import json\n\nimport torch\nfrom litdata import TokensLoader, optimize\nfrom torch.utils._pytree import tree_map\n\nfrom litgpt.data.text_files import TextFiles\n\n\nclass Tokenizer:\n    bos_id = 0\n\n    def encode(self, text, bos, eos):\n        assert bos\n        assert not eos\n        return [self.bos_id] + [ord(c) for c in text]\n\n\ndef tokenize(data):\n    for story in data:\n        yield torch.tensor(story)\n\n\ndef fake_chunk(path, data):\n    optimize(\n        fn=tokenize,\n        inputs=[data] * len(data),\n        output_dir=str(path),\n        num_workers=1,\n        chunk_bytes=\"200MB\",\n        item_loader=TokensLoader(),\n    )\n\n\ndef test_textfiles_datamodule(tmp_path):\n    from litgpt.data.text_files import TextFiles\n\n    data_dir = tmp_path / \"textfiles\"\n    datamodule = TextFiles(train_data_path=data_dir, num_workers=1)\n    datamodule.connect(max_seq_length=2, tokenizer=Tokenizer())\n\n    # simulate `datamodule.prepare_data`\n    train_data_dir = data_dir / \"train\"\n    train_data_dir.mkdir(parents=True)\n    fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]])\n    datamodule.setup()\n\n    tr_dataloader = datamodule.train_dataloader()\n    tr_dataloader.shuffle = False\n\n    actual = tree_map(torch.Tensor.tolist, list(tr_dataloader))\n\n    # there is 1 sample per index in the data (13)\n    assert actual == [\n        [[73, 5, 0]],\n        [[12, 0, 23]],\n        [[5, 0, 1]],\n        [[0, 73, 5]],\n        [[1999, 0, 13]],\n        [[0, 1, 1999]],\n        [[1, 1999, 0]],\n        [[0, 23, 15]],\n        [[13, 12, 0]],\n        [[63, 0, 73]],\n        [[23, 15, 63]],\n        [[15, 63, 0]],\n        [[0, 13, 12]],\n    ]\n\n\nclass MockTokenizer:\n    bos_id = 0\n    eos_id = 1\n    use_bos = True\n\n    def encode(self, text, bos=True, eos=False, device=None, max_length=-1):\n        # Simple: map each character to its ordinal + 2\n        tokens = [ord(c) + 2 for c in text]\n        if bos:\n            tokens = [self.bos_id] + tokens\n        if eos:\n            tokens.append(self.eos_id)\n        if max_length > 0:\n            tokens = tokens[:max_length]\n        return torch.tensor(tokens, dtype=torch.long, device=device)\n\n    def decode(self, tensor):\n        ids = tensor.tolist() if tensor.ndim > 0 else [tensor.item()]\n        chars = []\n        for tid in ids:\n            if tid == self.bos_id:\n                chars.append(\"<BOS>\")\n            elif tid == self.eos_id:\n                chars.append(\"<EOS>\")\n            else:\n                chars.append(chr(tid - 2))\n        return \"\".join(chars)\n\n    def decode_stream(self, token_stream, device=None):\n        for token in token_stream:\n            yield self.decode(token)\n\n    @property\n    def vocab_size(self):\n        return 130\n\n\ndef test_textfiles_token_loader(tmp_path):\n    # Create the directory for text files\n    data_dir = tmp_path / \"textfiles\"\n    data_dir.mkdir(parents=True, exist_ok=True)\n\n    # Write sample training data to the directory\n    sample_texts = [\"hello world\", \"foo bar\", \"lorem ipsum\"]\n    for i, text in enumerate(sample_texts):\n        (data_dir / f\"{i}.txt\").write_text(text)\n\n    datamodule = TextFiles(train_data_path=data_dir, num_workers=1)\n    datamodule.connect(max_seq_length=2, tokenizer=MockTokenizer())\n    datamodule.prepare_data()\n\n    # ensure training set uses tokens loader\n    index_json = data_dir / \"train\" / \"index.json\"\n    assert index_json.exists()\n    meta = json.loads(index_json.read_text())\n    assert meta[\"config\"][\"item_loader\"] == \"TokensLoader\"\n\n    # ensure validation set uses tokens loader\n    index_json = data_dir / \"val\" / \"index.json\"\n    assert index_json.exists()\n    meta = json.loads(index_json.read_text())\n    assert meta[\"config\"][\"item_loader\"] == \"TokensLoader\"\n"
  },
  {
    "path": "tests/data/test_tinyllama.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom unittest import mock\n\nimport pytest\nfrom litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset\nfrom torch.utils.data import DataLoader\n\nfrom litgpt.data import TinyLlama\n\n\n@mock.patch(\"litdata.streaming.dataset.subsample_streaming_dataset\", return_value=([], []))\ndef test_tinyllama(_, tmp_path):\n    data = TinyLlama(data_path=(tmp_path / \"data\"))\n    assert data.seq_length == 2048\n    assert data.batch_size == 1\n\n    data.connect(batch_size=2, max_seq_length=1024)\n    assert data.seq_length == 1025\n    assert data.batch_size == 2\n\n    with pytest.raises(FileNotFoundError, match=\"The directory .*data/slimpajama/train does not exist\"):\n        data.prepare_data()\n\n    (tmp_path / \"data\" / \"slimpajama\" / \"train\").mkdir(parents=True)\n    (tmp_path / \"data\" / \"slimpajama\" / \"val\").mkdir(parents=True)\n    (tmp_path / \"data\" / \"starcoder\").mkdir(parents=True)\n\n    data.prepare_data()\n    data.setup()\n\n    train_dataloader = data.train_dataloader()\n    assert isinstance(train_dataloader, StreamingDataLoader)\n    assert isinstance(train_dataloader.dataset, CombinedStreamingDataset)\n\n    val_dataloader = data.val_dataloader()\n    assert isinstance(val_dataloader, DataLoader)\n    assert isinstance(val_dataloader.dataset, StreamingDataset)\n\n    # has attributes from super class `LightningDataModule`\n    assert data.prepare_data_per_node\n"
  },
  {
    "path": "tests/data/test_tinystories.py",
    "content": "import json\n\nimport pytest\nimport torch\nfrom litdata import optimize\nfrom litdata.streaming import StreamingDataset, TokensLoader\nfrom torch.utils._pytree import tree_map\n\n\ndef tokenize(data):\n    for story in data:\n        yield torch.tensor(story)\n\n\ndef fake_chunk(path, data):\n    optimize(\n        fn=tokenize,\n        inputs=[data] * len(data),\n        output_dir=str(path),\n        num_workers=1,\n        chunk_bytes=\"200MB\",\n        item_loader=TokensLoader(),\n    )\n\n\n@pytest.mark.parametrize(\n    (\"max_seq_len\", \"expected\"),\n    [\n        (2, [[0, 23, 15], [63, 0, 73], [5, 0, 1], [1999, 0, 13]]),\n        (5, [[0, 23, 15, 63, 0, 73], [5, 0, 1, 1999, 0, 13]]),\n        (6, [[0, 23, 15, 63, 0, 73, 5]]),\n        (7, [[0, 23, 15, 63, 0, 73, 5, 0]]),\n    ],\n)\ndef test_pretok_dataset(tmp_path, max_seq_len, expected):\n    fake_data = [0, 23, 15, 63, 0, 73, 5, 0, 1, 1999, 0, 13]\n    assert len(fake_data) == 12\n    fake_chunk(tmp_path, [fake_data])\n\n    dataset = StreamingDataset(\n        input_dir=str(tmp_path), item_loader=TokensLoader(block_size=max_seq_len + 1), shuffle=False, drop_last=False\n    )\n    actual = tree_map(torch.Tensor.tolist, list(dataset))\n    assert actual == expected\n\n\ndef test_tokenize(tmp_path, monkeypatch):\n    from litgpt.data.tinystories import tokenize\n\n    story1, story2 = \"foo bar\", \"    fun    \"\n    data = [{\"story\": story1}, {\"story\": story2}]\n    shard_path = tmp_path / \"data.json\"\n    with open(shard_path, \"w\", encoding=\"utf-8\") as f:\n        json.dump(data, f)\n\n    class Tokenizer:\n        bos_id = 0\n\n        def encode(self, text, bos, eos):\n            assert bos\n            assert not eos\n            return [self.bos_id] + [ord(c) for c in text]\n\n    monkeypatch.setenv(\"DATA_OPTIMIZER_GLOBAL_RANK\", \"0\")\n    monkeypatch.setenv(\"DATA_OPTIMIZER_NUM_WORKERS\", \"1\")\n    data = tokenize(str(shard_path), Tokenizer())\n    assert list(data) == [[0, 102, 111, 111, 32, 98, 97, 114], [0, 102, 117, 110]]\n\n\ndef test_tinystories_datamodule(tmp_path):\n    from litgpt.data.tinystories import TinyStories\n\n    data_dir = tmp_path / \"tinystories\"\n\n    datamodule = TinyStories(data_dir, seed=42, num_workers=1)\n    datamodule.connect(max_seq_length=2)\n\n    # simulate `datamodule.prepare_data`\n    train_data_dir = data_dir / \"train\"\n    train_data_dir.mkdir(parents=True)\n    fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]])\n\n    datamodule.setup()\n\n    tr_dataloader = datamodule.train_dataloader()\n    tr_dataloader.shuffle = False\n\n    actual = tree_map(torch.Tensor.tolist, list(tr_dataloader))\n\n    # there is 1 sample per index in the data (13)\n    assert actual == [\n        [[73, 5, 0]],\n        [[12, 0, 23]],\n        [[5, 0, 1]],\n        [[0, 73, 5]],\n        [[1999, 0, 13]],\n        [[0, 1, 1999]],\n        [[1, 1999, 0]],\n        [[0, 23, 15]],\n        [[13, 12, 0]],\n        [[63, 0, 73]],\n        [[23, 15, 63]],\n        [[15, 63, 0]],\n        [[0, 13, 12]],\n    ]\n"
  },
  {
    "path": "tests/ext_thunder/__init__.py",
    "content": "import sys\nfrom pathlib import Path\n\n# support running without installing as a package, adding extensions to the Python path\nwd = Path(__file__).parent.parent.parent.resolve()\nif wd.is_dir():\n    sys.path.append(str(wd))\nelse:\n    import warnings\n\n    warnings.warn(f\"Could not find extensions directory at {wd}\")\n"
  },
  {
    "path": "tests/ext_thunder/test_thunder_distributed.py",
    "content": "import os\nimport sys\nfrom pathlib import Path\nfrom typing import Optional, Tuple, Union\n\nimport pytest\nimport torch\nfrom lightning.fabric import Fabric\nfrom lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3\n\nfrom litgpt.constants import _THUNDER_AVAILABLE\nfrom litgpt.utils import _RunIf\n\n# support running without installing as a package\nwd = Path(__file__).parent.parent.resolve()\nsys.path.append(str(wd))\n\nif _THUNDER_AVAILABLE:\n    from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy\n    from extensions.thunder.strategies.thunder_fsdp import ThunderFSDPStrategy\n\n\n@_RunIf(thunder=True)\ndef test_thunder_strategy_ddp_input_parsing():\n    with pytest.raises(ValueError, match=\"doesn't have an effect with `jit=False\"):\n        ThunderDDPStrategy(jit=False, executors=(\"python\",))\n\n\n@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)\n@pytest.mark.parametrize(\"choice\", [\"ddp\", \"fsdp\"])\n@pytest.mark.xfail(TypeError, reason=\"temporally disabled until resolved with Thunder\")\ndef test_no_backward_sync_thunder(choice):\n    if choice == \"ddp\":\n        strategy = ThunderDDPStrategy()\n    elif choice == \"fsdp\":\n        strategy = ThunderFSDPStrategy()\n    else:\n        raise ValueError(f\"Invalid choice: {choice}\")\n\n    fabric = Fabric(devices=2, accelerator=\"cuda\", strategy=strategy)\n    fabric.launch()\n\n    # account for sharding in the case of FSDP\n    out_features = 1 if \"ddp\" in choice else fabric.world_size\n\n    model = torch.nn.Linear(1, out_features, bias=False, device=fabric.device)\n    x = torch.randn(1, 1, device=fabric.device)\n    model = fabric.setup(model)\n\n    # 6 iters, 3 grad accumulation iters\n    for i, enabled in enumerate((True, True, False, True, True, False), 1):\n        x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32)\n\n        with fabric.no_backward_sync(model, enabled):\n            y = model(x)\n            fabric.backward(y.sum())\n        if not enabled:\n            # Math for the first 3 iters\n            #\n            # DistributedDataParallel\n            # (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2       = 9\n            #  ^^^^^^^^^^^   ^^^^^^^^^^^  ^^^\n            #  rank0         rank1        allreduce\n            #\n            # thunder.distributed.ddp\n            # ((1*1+2*1) + (1*2+2*2)) / 2        + (3*1 + 3*2)  / 2        = 9\n            #   ^^^^^^^     ^^^^^^^   ^^^           ^^^   ^^^   ^^^\n            #   rank0       rank1     allreduce1    rank0 rank1 allreduce2\n            assert model.weight.grad.shape.numel() == 1, model.weight.grad.shape\n            assert model.weight.grad.item() == (9.0 if i == 3 else 22.5)\n            assert not hasattr(model.weight, \"_thunder_fsdp_unsharded_grad\")\n            model.weight.grad = None\n        elif choice == \"fsdp\":\n            assert model.weight._thunder_fsdp_unsharded_grad.shape == (2, 1)\n            assert model.weight.grad is None\n\n\n@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)\n@pytest.mark.parametrize(\"jit\", (False, True))\n@pytest.mark.xfail(TypeError, reason=\"temporally disabled until resolved with Thunder\")\ndef test_jit_ddp_before_setup(jit):\n    import thunder\n\n    fabric = Fabric(devices=2, accelerator=\"cuda\", strategy=ThunderDDPStrategy(jit=jit))\n    fabric.launch()\n\n    x = torch.randn(1, 1, device=fabric.device)\n    model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)\n\n    tmodel = thunder.jit(model)\n    fmodel = fabric.setup(tmodel)\n    fmodel(x)\n\n    assert \"all_reduce\" in thunder.last_backward_traces(tmodel)[-1].python()\n\n\n@_RunIf(min_cuda_gpus=1, thunder=True)\ndef test_strategy_ddp_setup_already_traced():\n    import thunder\n\n    device = torch.device(\"cuda\")\n    x = torch.randn(1, 1, device=device)\n    model = torch.nn.Linear(1, 2, bias=False, device=device)\n\n    strategy = ThunderDDPStrategy()\n\n    tmodel = thunder.jit(model)\n    tmodel(x)\n    with pytest.raises(RuntimeError, match=\"already called\"):\n        strategy.setup_module(tmodel)\n\n\n@_RunIf(thunder=True)\ndef test_thunder_strategy_fsdp_input_parsing():\n    from thunder.distributed import FSDPBucketingStrategy, FSDPType\n\n    strategy = ThunderFSDPStrategy(bucketing_strategy=\"BlOcK\", executors=(\"python\",), sharding_strategy=\"zero3\")\n\n    assert strategy.bucketing_strategy is FSDPBucketingStrategy.BLOCK\n    assert strategy.sharding_strategy is FSDPType.ZERO3\n\n    with pytest.raises(ValueError, match=\"doesn't have an effect with `jit=False\"):\n        ThunderFSDPStrategy(jit=False, executors=(\"python\",))\n\n\n@_RunIf(thunder=True)\ndef test_save_checkpoint_invalid_settings_raise(tmp_path):\n    strategy = ThunderFSDPStrategy(state_dict_type=\"full\")\n    with pytest.raises(TypeError, match=\"not supported\"):\n        strategy.save_checkpoint(tmp_path, {}, storage_options=object())\n\n    with pytest.raises(IsADirectoryError, match=\"path exists\"):\n        strategy.save_checkpoint(tmp_path, {})\n\n    model = torch.nn.Linear(1, 1)\n    with pytest.raises(ValueError, match=\"Could not find\"):\n        strategy.save_checkpoint(tmp_path / \"foo\", {})\n\n    model.use_fsdp = True\n    with pytest.raises(ValueError, match=\"Found multiple\"):\n        strategy.save_checkpoint(tmp_path / \"foo\", {\"model1\": model, \"model2\": model})\n\n    with pytest.raises(ValueError, match=\"at least a model\"):\n        strategy.load_checkpoint(tmp_path / \"foo\", {})\n\n    with pytest.raises(ValueError, match=\"must be a single file\"):\n        strategy.load_checkpoint(tmp_path, model)\n\n    optimizer = torch.optim.Adam(model.parameters())\n    with pytest.raises(NotImplementedError, match=\"not supported\"):\n        strategy.load_checkpoint(tmp_path, optimizer)\n\n    with pytest.raises(ValueError, match=\"Found multiple\"):\n        strategy.load_checkpoint(tmp_path / \"foo\", {\"model1\": model, \"model2\": model})\n\n    with pytest.raises(ValueError, match=\"Could not find\"):\n        strategy.load_checkpoint(tmp_path / \"foo\", {\"foo\": 1})\n\n\nclass Submodule(torch.nn.Module):\n    def __init__(self, h: int):\n        super().__init__()\n        self.l = torch.nn.Linear(4, h * 2, bias=False)\n\n    def forward(self, x):\n        # defined just because preprocessing fails otherwise\n        ...\n\n\nclass MyModel(torch.nn.Module):\n    def __init__(self, h: int):\n        super().__init__()\n        self.register_buffer(\"buf\", torch.tensor(0))\n        self.l = torch.nn.Linear(2, h)\n        self.inner = Submodule(h)\n\n    def forward(self):\n        # defined just because preprocessing fails otherwise\n        ...\n\n    def reset_parameters(self):\n        self.buf = torch.empty_like(self.buf)\n\n\n@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)\n@pytest.mark.xfail(TypeError, reason=\"temporally disabled until resolved with Thunder\")\ndef test_materialize_meta_tensors():\n    strategy = ThunderFSDPStrategy()\n    fabric = Fabric(accelerator=\"cuda\", devices=2, strategy=strategy)\n    fabric.launch()\n\n    with fabric.init_module(empty_init=True):\n        model = MyModel(2)\n\n    model = fabric.setup(model)\n    # all parameters were moved\n    assert len(list(model.parameters())) == 3\n    assert all(p.device.type == \"cuda\" for p in model.parameters())\n    # buffers were moved too\n    assert model.buf.device.type == \"cuda\"\n\n\nclass StatefulThing:\n    def state_dict(self):\n        return {\"thing\": 1}\n\n    def load_state_dict(self, state_dict):\n        assert state_dict == self.state_dict()\n\n\nclass TensorLike:\n    def __init__(self, device: Optional[Union[str, torch.device]] = None, shape: Optional[Tuple[int, ...]] = None):\n        self.device = torch.device(device) if device is not None else None\n        self.shape = torch.Size(shape) if shape is not None else None\n\n    def __eq__(self, other):\n        return (\n            isinstance(other, torch.Tensor)\n            and (self.device is None or other.device == self.device)\n            and (self.shape is None or other.shape == self.shape)\n        )\n\n\n@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)\n@pytest.mark.xfail(TypeError, reason=\"temporally disabled until resolved with Thunder\")\ndef test_save_load_full_checkpoint(tmp_path):\n    strategy = ThunderFSDPStrategy(state_dict_type=\"full\", broadcast_from=0)\n    fabric = Fabric(accelerator=\"cuda\", devices=2, strategy=strategy)\n    fabric.launch()\n\n    model = MyModel(4)\n    expected = model.state_dict()\n\n    # save a sharded model\n    model = fabric.setup(model)\n    state = {\"model\": model, \"stateful\": StatefulThing(), \"primitive\": 123}\n    checkpoint_path = tmp_path / \"foo\"\n    fabric.save(checkpoint_path, state)\n\n    # assert the file contents\n    if fabric.global_rank == 0:\n        checkpoint = torch.load(checkpoint_path)\n        # cpu_offload is enabled by default\n        assert checkpoint == {\n            \"model\": {\n                \"buf\": TensorLike(\"cpu\", tuple()),\n                \"inner.l.weight\": TensorLike(\"cpu\", (8, 4)),\n                \"l.bias\": TensorLike(\"cpu\", (4,)),\n                \"l.weight\": TensorLike(\"cpu\", (4, 2)),\n            },\n            \"stateful\": {\"thing\": 1},\n            \"primitive\": 123,\n        }\n        torch.testing.assert_close(checkpoint[\"model\"], expected)\n\n    # load its weights into a different sharded model\n    model = MyModel(4)\n    model = fabric.setup(model)\n    state = {\"model\": model, \"stateful\": StatefulThing(), \"primitive\": 321}\n    fabric.load(checkpoint_path, state)\n\n    from thunder.distributed import _unshard_params\n\n    # unshard this model's parameters to compare with the original state dict before sharding\n    _unshard_params(model, model.process_group_for_ddp, True)\n    # we loaded rank 0's weights, so this would fail in the other ranks\n    if fabric.global_rank == 0:\n        actual = model.state_dict()\n        # `_unshard_params` doesn't offload buffers at the moment\n        assert actual[\"buf\"].device.type == \"cuda\"\n        actual[\"buf\"] = actual[\"buf\"].to(device=\"cpu\")\n        torch.testing.assert_close(actual, expected)\n    assert state[\"primitive\"] == 123\n\n\n@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)\n@pytest.mark.xfail(TypeError, reason=\"temporally disabled until resolved with Thunder\")\ndef test_load_full_checkpoint_only_model(tmp_path):\n    strategy = ThunderFSDPStrategy()\n    fabric = Fabric(accelerator=\"cuda\", devices=2, strategy=strategy)\n    fabric.launch()\n\n    checkpoint_path = tmp_path / \"foo\"\n    checkpoint_path = fabric.broadcast(checkpoint_path)\n    if fabric.global_rank == 0:\n        model = MyModel(4)\n        expected = model.state_dict()\n        torch.save(expected, checkpoint_path)\n    fabric.barrier()\n    expected = torch.load(checkpoint_path)\n\n    # before sharding\n    model = MyModel(4)\n    fabric.load_raw(checkpoint_path, model)\n    torch.testing.assert_close(model.state_dict(), expected)\n\n    # after sharding\n    model = MyModel(4)\n    model = fabric.setup(model)\n    fabric.load_raw(checkpoint_path, model)\n    from thunder.distributed import _unshard_params\n\n    # unshard this model's parameters to compare with the original state dict before sharding\n    _unshard_params(model, model.process_group_for_ddp, True)\n    actual = model.state_dict()\n    # `_unshard_params` doesn't offload buffers at the moment\n    assert actual[\"buf\"].device.type == \"cuda\"\n    actual[\"buf\"] = actual[\"buf\"].to(device=\"cpu\")\n    torch.testing.assert_close(actual, expected)\n\n\ndef distributed_ckpt_to_regular(path):\n    \"\"\"From ``torch.distributed.checkpoint.format_utils.dcp_to_torch_save``.\"\"\"\n    from torch.distributed.checkpoint import FileSystemReader\n    from torch.distributed.checkpoint.state_dict_loader import _load_state_dict\n\n    if _TORCH_GREATER_EQUAL_2_3:\n        from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner\n    else:\n        from torch.distributed.checkpoint._traverse import set_element\n        from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner\n        from torch.distributed.checkpoint.metadata import TensorStorageMetadata\n\n        class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, **kwargs)\n\n            def set_up_planner(self, state_dict, metadata, is_coordinator):\n                assert not state_dict\n                # rebuild the state dict from the metadata\n                for k, v in metadata.state_dict_metadata.items():\n                    if isinstance(v, TensorStorageMetadata):\n                        v = torch.empty(v.size, dtype=v.properties.dtype)\n                    if k in metadata.planner_data:\n                        set_element(state_dict, metadata.planner_data[k], v)\n                    else:\n                        state_dict[k] = v\n                super().set_up_planner(state_dict, metadata, is_coordinator)\n\n    state_dict = {}\n    storage_reader = FileSystemReader(path)\n    _load_state_dict(state_dict, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True)\n    return state_dict\n\n\n@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)\n@pytest.mark.xfail(TypeError, reason=\"temporally disabled until resolved with Thunder\")\ndef test_save_load_sharded_checkpoint(tmp_path):\n    strategy = ThunderFSDPStrategy(state_dict_type=\"sharded\", broadcast_from=0)\n    fabric = Fabric(accelerator=\"cuda\", devices=2, strategy=strategy)\n    fabric.launch()\n\n    model = MyModel(4)\n    expected = model.state_dict()\n\n    # save a sharded model\n    model = fabric.setup(model)\n    state = {\"model\": model, \"stateful\": StatefulThing(), \"primitive\": 123}\n    fabric.save(tmp_path, state)\n\n    # assert the file contents\n    if fabric.global_rank == 0:\n        assert set(os.listdir(tmp_path)) == {\"meta.pt\", \"__1_0.distcp\", \"__0_0.distcp\", \".metadata\"}\n\n        metadata = torch.load(tmp_path / \"meta.pt\")\n        assert metadata == {\"stateful\": {\"thing\": 1}, \"primitive\": 123}\n\n        checkpoint = distributed_ckpt_to_regular(tmp_path)\n        # cpu_offload is enabled by default\n        assert checkpoint == {\n            \"model\": {\n                \"buf\": TensorLike(\"cpu\", tuple()),\n                \"inner.l.weight\": TensorLike(\"cpu\", (8, 4)),\n                \"l.bias\": TensorLike(\"cpu\", (4,)),\n                \"l.weight\": TensorLike(\"cpu\", (4, 2)),\n            }\n        }\n        torch.testing.assert_close(checkpoint[\"model\"], expected)\n\n    # load its weights into a different sharded model\n    model = MyModel(4)\n    model = fabric.setup(model)\n    state = {\"model\": model, \"stateful\": StatefulThing(), \"primitive\": 321}\n    fabric.load(tmp_path, state)\n\n    from thunder.distributed import _unshard_params\n\n    # unshard this model's parameters to compare with the original state dict before sharding\n    _unshard_params(model, model.process_group_for_ddp, True)\n    # we loaded rank 0's weights, so this would fail in the other ranks\n    if fabric.global_rank == 0:\n        actual = model.state_dict()\n        # `_unshard_params` doesn't offload buffers at the moment\n        assert actual[\"buf\"].device.type == \"cuda\"\n        actual[\"buf\"] = actual[\"buf\"].to(device=\"cpu\")\n        torch.testing.assert_close(actual, expected)\n    assert state[\"primitive\"] == 123\n\n\n@_RunIf(min_cuda_gpus=2, thunder=True, standalone=True)\n@pytest.mark.parametrize(\"jit\", (False, True))\n@pytest.mark.xfail(TypeError, reason=\"temporally disabled until resolved with Thunder\")\ndef test_jit_fsdp_before_setup(jit):\n    import thunder\n\n    fabric = Fabric(devices=2, accelerator=\"cuda\", strategy=ThunderFSDPStrategy(jit=jit))\n    fabric.launch()\n\n    x = torch.randn(1, 1, device=fabric.device)\n    model = torch.nn.Linear(1, 2, bias=False, device=fabric.device)\n\n    tmodel = thunder.jit(model)\n    fmodel = fabric.setup(tmodel)\n    fmodel(x)\n\n    assert \"all_gather\" in thunder.last_traces(tmodel)[-1].python()\n\n\n@_RunIf(min_cuda_gpus=1, thunder=True)\ndef test_strategy_fsdp_setup_already_traced():\n    import thunder\n\n    device = torch.device(\"cuda\")\n    x = torch.randn(1, 1, device=device)\n    model = torch.nn.Linear(1, 2, bias=False, device=device)\n\n    strategy = ThunderFSDPStrategy()\n\n    tmodel = thunder.jit(model)\n    tmodel(x)\n    with pytest.raises(RuntimeError, match=\"already called\"):\n        strategy.setup_module(tmodel)\n"
  },
  {
    "path": "tests/ext_thunder/test_thunder_networks.py",
    "content": "\"\"\"Run thunder tests as part of LitGPT CI\"\"\"\n\nfrom litgpt.constants import _THUNDER_AVAILABLE\n\nif _THUNDER_AVAILABLE:\n    from thunder.tests.test_networks import *  # noqa: F403\nelse:\n    print(\"Skipping test_thunder_networks.py (thunder not available)\")\n"
  },
  {
    "path": "tests/ext_thunder/test_thunder_pretrain.py",
    "content": "import os\nfrom contextlib import redirect_stdout\nfrom io import StringIO\nfrom unittest.mock import Mock\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom litgpt import Config\nfrom litgpt.args import EvalArgs, TrainArgs\nfrom litgpt.constants import _THUNDER_AVAILABLE\nfrom litgpt.utils import _RunIf\n\nif _THUNDER_AVAILABLE:\n    import extensions.thunder.pretrain as thunder_pretrain\n\n\n@_RunIf(min_cuda_gpus=1, thunder=True)\ndef test_pretrain_thunder(tmp_path, monkeypatch):\n    model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)\n\n    dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])\n    dataloader = DataLoader(dataset)\n    monkeypatch.setattr(thunder_pretrain, \"get_dataloaders\", Mock(return_value=(dataloader, dataloader)))\n    monkeypatch.setattr(thunder_pretrain, \"save_hyperparameters\", Mock())\n\n    out_dir = tmp_path / \"out\"\n    stdout = StringIO()\n    with redirect_stdout(stdout):\n        thunder_pretrain.setup(\n            devices=1,\n            model_config=model_config,\n            out_dir=out_dir,\n            train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),\n            eval=EvalArgs(interval=1, max_iters=1),\n            optimizer=\"AdamW\",\n        )\n\n    out_dir_contents = set(os.listdir(out_dir))\n    checkpoint_dirs = {\"step-00000001\", \"step-00000002\", \"step-00000003\", \"step-00000004\"}\n    assert checkpoint_dirs.issubset(out_dir_contents)\n    assert all((out_dir / p).is_dir() for p in checkpoint_dirs)\n    for checkpoint_dir in checkpoint_dirs:\n        # the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here\n        assert set(os.listdir(out_dir / checkpoint_dir)) == {\"lit_model.pth\", \"model_config.yaml\"}\n\n    assert (out_dir / \"logs\" / \"tensorboard\" / \"version_0\").is_dir()\n\n    logs = stdout.getvalue()\n    assert logs.count(\"(step)\") == 4\n    assert logs.count(\"val loss\") == 4\n    assert \"Total parameters: 1,888\" in logs\n"
  },
  {
    "path": "tests/ext_thunder/test_unsloth_executor.py",
    "content": "import pytest\nimport torch\n\nfrom litgpt import GPT, Config\nfrom litgpt.model import apply_rope, build_rope_cache\nfrom litgpt.utils import _RunIf, chunked_cross_entropy\n\n\n@_RunIf(min_cuda_gpus=1, thunder=True)\n@pytest.mark.parametrize(\"reduction\", [\"none\", \"mean\"])\ndef test_unsloth_cross_entropy(reduction):\n    import thunder\n\n    from extensions.thunder.unsloth.executor import unsloth_ex\n\n    logits = torch.randn(64, 128, device=\"cuda\", requires_grad=True)\n    labels = torch.randint(128, (64,), device=\"cuda\")\n\n    def foo(logits, labels):\n        # this is the variant supported by unsloth.\n        # if different arguments are used, the implementation would no be lowered to unsloth and instead would get\n        # decomposed\n        return torch.nn.functional.cross_entropy(logits, labels, reduction=reduction, ignore_index=-100)\n\n    cfoo = thunder.jit(foo, executors=[unsloth_ex])\n    actual = cfoo(logits, labels)\n    trace_str = str(thunder.last_traces(cfoo)[-1])\n    assert \"unsloth_cross_entropy\" in trace_str and \"backward\" not in trace_str\n    trace_str = str(thunder.last_backward_traces(cfoo)[-1])\n    assert \"unsloth_cross_entropy_backward\" in trace_str\n\n    expected = foo(logits, labels)\n    torch.testing.assert_close(actual, expected)\n\n    (actual_grad,) = torch.autograd.grad(actual.sum(), logits)\n    trace_str = str(thunder.last_backward_traces(cfoo)[-1])\n    assert \"unsloth_cross_entropy_backward\" in trace_str\n    out = foo(logits, labels)\n    assert logits.grad is None\n    (expected_grad,) = torch.autograd.grad(out.sum(), logits)\n    torch.testing.assert_close(actual_grad, expected_grad)\n\n\n@pytest.mark.skip(reason=\"out of date\")\n@_RunIf(min_cuda_gpus=1, thunder=True)\ndef test_unsloth_rope():\n    import thunder\n\n    from extensions.thunder.unsloth.executor import unsloth_ex\n\n    B, nh, T, hs = 2, 32, 64, 16\n    cos, sin = build_rope_cache(T, hs, device=\"cuda\")\n    cos = cos.unsqueeze(0)\n    sin = sin.unsqueeze(0)\n    q = torch.rand((B, nh, T, hs), device=\"cuda\", requires_grad=True)\n\n    def foo(x, cos, sin):\n        return apply_rope(x, cos, sin)\n\n    cfoo = thunder.jit(foo, executors=[unsloth_ex])\n    actual = cfoo(q, cos, sin)\n    trace_str = str(thunder.last_traces(cfoo)[-1])\n    assert \"unsloth_apply_rope\" in trace_str and \"backward\" not in trace_str\n    trace_str = str(thunder.last_backward_traces(cfoo)[-1])\n    assert \"unsloth_apply_rope_backward\" in trace_str\n\n    expected = foo(q, cos, sin)\n    torch.testing.assert_close(actual, expected)\n\n    (actual_grad,) = torch.autograd.grad(actual.sum(), q)\n    (expected_grad,) = torch.autograd.grad(expected.sum(), q)\n    torch.testing.assert_close(actual_grad, expected_grad)\n\n\n@_RunIf(min_cuda_gpus=1, thunder=True)\ndef test_unsloth_swiglu():\n    import thunder\n\n    from extensions.thunder.unsloth.executor import ThunderLLaMAMLP, unsloth_ex\n    from litgpt import Config\n    from litgpt.model import LLaMAMLP\n\n    config = Config.from_name(\"Llama-2-7b-hf\")\n    with torch.device(\"cuda\"):\n        x = torch.randn(2, 16, config.n_embd, requires_grad=True)\n        mlp = LLaMAMLP(config)\n    # monkeypatching was successful\n    assert isinstance(mlp, ThunderLLaMAMLP)\n\n    cmlp = thunder.jit(mlp, executors=[unsloth_ex])\n    actual = cmlp(x)\n    trace_str = str(thunder.last_traces(cmlp)[-1])\n    assert \"unsloth_swiglu\" in trace_str and \"backward\" not in trace_str\n    trace_str = str(thunder.last_backward_traces(cmlp)[-1])\n    assert \"unsloth_swiglu_backward\" in trace_str\n\n    expected = mlp(x)\n    torch.testing.assert_close(actual, expected)\n\n    (actual_grad,) = torch.autograd.grad(actual.sum(), x)\n    (expected_grad,) = torch.autograd.grad(expected.sum(), x)\n    torch.testing.assert_close(actual_grad, expected_grad)\n\n\n@_RunIf(min_cuda_gpus=1, thunder=True)\ndef test_unsloth_gpt():\n    import thunder\n\n    from extensions.thunder.unsloth.executor import unsloth_ex\n\n    def forward_and_loss(model, input_ids, targets):\n        logits = model(input_ids)\n        return chunked_cross_entropy(logits, targets, chunk_size=0)\n\n    cfn = thunder.jit(forward_and_loss, executors=[unsloth_ex])\n\n    device = torch.device(\"cuda\")\n    config = Config(\n        vocab_size=320,\n        padding_multiple=64,\n        n_layer=2,\n        n_head=4,\n        n_embd=64,\n        rotary_percentage=1.0,\n        parallel_residual=False,\n        bias=False,\n        norm_class_name=\"RMSNorm\",\n        mlp_class_name=\"LLaMAMLP\",\n        intermediate_size=1376,\n    )\n    with device:\n        model = GPT(config)\n        input_ids = torch.randint(1, 10, (2, 3))\n        targets = torch.randint(0, 10, (2, 3))\n\n    loss = cfn(model, input_ids, targets)\n    assert isinstance(loss, torch.Tensor)\n\n    fwd = thunder.last_traces(cfn)\n    bwd = thunder.last_backward_traces(cfn)\n    fwd_str, bwd_str = fwd[-1].python(), bwd[-1].python()\n\n    assert \"unsloth_cross_entropy\" in fwd_str\n    assert \"unsloth_cross_entropy_backward\" in bwd_str\n    assert \"unsloth_apply_rope\" in fwd_str\n    assert \"unsloth_apply_rope_backward\" in bwd_str\n    assert \"unsloth_swiglu\" in fwd_str\n    assert \"unsloth_swiglu_backward\" in bwd_str\n"
  },
  {
    "path": "tests/generate/__init__.py",
    "content": ""
  },
  {
    "path": "tests/generate/test_adapter.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport re\nimport subprocess\nimport sys\nfrom contextlib import redirect_stderr, redirect_stdout\nfrom io import StringIO\nfrom unittest.mock import ANY, Mock, call\n\nimport pytest\nimport torch\nimport yaml\n\nskip_in_ci_on_macos = pytest.mark.skipif(\n    sys.platform == \"darwin\" and os.getenv(\"GITHUB_ACTIONS\") == \"true\",\n    reason=\"Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.\",\n)\n\n\n@skip_in_ci_on_macos\n@pytest.mark.parametrize(\"version\", (\"v1\", \"v2\"))\ndef test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like):\n    if version == \"v1\":\n        import litgpt.generate.adapter as generate\n    else:\n        import litgpt.generate.adapter_v2 as generate\n\n    config_path = fake_checkpoint_dir / \"model_config.yaml\"\n    config = {\"block_size\": 128, \"vocab_size\": 50, \"n_layer\": 2, \"n_head\": 4, \"n_embd\": 8, \"rotary_percentage\": 1}\n    config_path.write_text(yaml.dump(config))\n\n    monkeypatch.setattr(generate, \"lazy_load\", Mock())\n    monkeypatch.setattr(generate.GPT, \"load_state_dict\", Mock())\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])\n    tokenizer_mock.return_value.decode.return_value = \"### Response:foo bar baz\"\n    monkeypatch.setattr(generate, \"Tokenizer\", tokenizer_mock)\n    generate_mock = Mock()\n    generate_mock.return_value = torch.tensor([[3, 2, 1]])\n    monkeypatch.setattr(generate, \"generate\", generate_mock)\n\n    num_samples = 1\n    out, err = StringIO(), StringIO()\n    with redirect_stdout(out), redirect_stderr(err):\n        generate.main(temperature=2.0, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir)\n\n    assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples\n    assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)\n    assert (\n        generate_mock.mock_calls\n        == [call(ANY, tensor_like, 101, temperature=2.0, top_k=2, top_p=0.9, eos_id=ANY)] * num_samples\n    )\n\n    expected_output = \"foo bar baz\\n\" * num_samples\n    # Allow for the config to be printed before the expected repeated strings.\n    pattern = rf\".*^{re.escape(expected_output.strip())}$.*\"\n    assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)\n\n    err_value = err.getvalue()\n    expected_parts = [\n        \"'padded_vocab_size': 512\",\n        \"'n_layer': 2\",\n        \"'n_head': 4\",\n        \"'head_size': 2\",\n        \"'n_embd': 8\",\n    ]\n    assert all(part in err_value for part in expected_parts)\n\n\n@pytest.mark.parametrize(\"version\", (\"\", \"_v2\"))\ndef test_cli(version):\n    args = [\"litgpt\", f\"generate_adapter{version}\", \"-h\"]\n    output = subprocess.check_output(args)\n    output = str(output.decode())\n    assert \"For models finetuned with\" in output\n"
  },
  {
    "path": "tests/generate/test_main.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport re\nimport subprocess\nimport sys\nfrom contextlib import redirect_stderr, redirect_stdout\nfrom io import StringIO\nfrom unittest import mock\nfrom unittest.mock import ANY, Mock, call\n\nimport pytest\nimport torch\nimport yaml\n\nimport litgpt.generate.base as generate\nfrom litgpt import GPT, Config\nfrom litgpt.generate.base import sample\n\nskip_in_ci_on_macos = pytest.mark.skipif(\n    sys.platform == \"darwin\" and os.getenv(\"GITHUB_ACTIONS\") == \"true\",\n    reason=\"Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.\",\n)\n\n\n@pytest.mark.parametrize(\n    \"max_seq_length\", (pytest.param(10, marks=pytest.mark.xfail(raises=NotImplementedError, strict=True)), 20 + 5)\n)\ndef test_generate(max_seq_length):\n    import lightning as L\n\n    L.seed_everything(1234)\n\n    T = 5\n    input_idx = torch.arange(0, T)\n\n    config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)\n    model = GPT(config)\n    model.max_seq_length = max_seq_length\n    model.set_kv_cache(batch_size=1)\n    max_new_tokens = 20\n\n    multinomial_results = []\n\n    def multinomial(*args, **kwargs):\n        out = torch.multinomial(*args, **kwargs, num_samples=1)\n        multinomial_results.append(out)\n        return out\n\n    with mock.patch(\"litgpt.generate.base.multinomial_num_samples_1\", multinomial):\n        out = generate.generate(model, input_idx, T + max_new_tokens, top_k=1)\n\n    assert out.size(0) == T + max_new_tokens, (out.size(0), T + max_new_tokens)\n    multinomial_results = torch.hstack(multinomial_results)\n    expected = torch.cat((input_idx, multinomial_results))\n    assert out.shape == expected.shape, (out.shape, expected.shape)\n    torch.testing.assert_close(out, expected)\n\n\n@skip_in_ci_on_macos\ndef test_main(fake_checkpoint_dir, monkeypatch, tensor_like):\n    config_path = fake_checkpoint_dir / \"model_config.yaml\"\n    config = {\"block_size\": 128, \"vocab_size\": 50, \"n_layer\": 2, \"n_head\": 4, \"n_embd\": 8, \"rotary_percentage\": 1}\n    config_path.write_text(yaml.dump(config))\n\n    module_mock = Mock()\n    module_mock.config.block_size = 128\n    load_mock = Mock()\n    load_mock.return_value = load_mock\n    monkeypatch.setattr(generate, \"load_checkpoint\", load_mock)\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])\n    tokenizer_mock.return_value.decode.return_value = \"foo bar baz\"\n    monkeypatch.setattr(generate, \"Tokenizer\", tokenizer_mock)\n    generate_mock = Mock()\n    generate_mock.return_value = torch.tensor([3, 2, 1])\n    monkeypatch.setattr(generate, \"generate\", generate_mock)\n\n    num_samples = 2\n    out, err = StringIO(), StringIO()\n    with redirect_stdout(out), redirect_stderr(err):\n        generate.main(temperature=2.0, top_k=2, top_p=0.9, num_samples=num_samples, checkpoint_dir=fake_checkpoint_dir)\n\n    assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples\n    assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)\n    assert (\n        generate_mock.mock_calls\n        == [call(ANY, tensor_like, 53, temperature=2.0, top_k=2, top_p=0.9, eos_id=tokenizer_mock.return_value.eos_id)]\n        * num_samples\n    )\n    expected_output = \"foo bar baz\\n\" * num_samples\n    # Allow for the config to be printed before the expected repeated strings.\n    pattern = rf\".*^{re.escape(expected_output.strip())}$.*\"\n    assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)\n\n    err_value = err.getvalue()\n    expected_parts = [\n        \"'padded_vocab_size': 512\",\n        \"'n_layer': 2\",\n        \"'n_head': 4\",\n    ]\n    assert all(part in err_value for part in expected_parts)\n\n\ndef test_cli():\n    args = [\"litgpt\", \"generate\", \"-h\"]\n    output = subprocess.check_output(args)\n    output = str(output.decode())\n    assert \"Default generation option\" in output\n\n\n@pytest.mark.parametrize(\"temperature\", (0.0, 1.0, 0.5))\ndef test_sample(temperature):\n    # shape: 2x3x5\n    logits = torch.tensor(\n        [\n            [[24, 4, 98, 77, 47], [65, 70, 32, 67, 24], [92, 32, 88, 36, 62]],\n            [[85, 79, 57, 68, 50], [89, 46, 72, 45, 32], [68, 96, 68, 24, 36]],\n        ],\n        dtype=torch.float32,\n    )\n    token = sample(logits, temperature=temperature, top_p=0.8)\n\n    assert token.shape == (1,)\n    # sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported\n    assert token.tolist() == [0]\n\n\ndef test_generate_different_results_with_different_top_p():\n    config = Config(block_size=128, vocab_size=16, n_layer=1, n_head=4, n_embd=8)\n    model = GPT(config)\n    model.max_seq_length = 50\n    model.set_kv_cache(batch_size=1)\n\n    torch.manual_seed(123)\n    input_idx = torch.randint(10, size=(1,))\n\n    torch.manual_seed(123)\n    output1 = generate.generate(model, input_idx, 20, top_p=1.0)\n    torch.manual_seed(123)\n    output2 = generate.generate(model, input_idx, 20, top_p=0.1)\n\n    assert not torch.equal(output1, output2)\n"
  },
  {
    "path": "tests/generate/test_sequentially.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport itertools\nimport subprocess\nimport sys\nfrom dataclasses import asdict\nfrom pathlib import Path\nfrom re import escape\n\nimport pytest\nimport torch\nimport yaml\nfrom lightning import Fabric\n\nfrom litgpt import Config\nfrom litgpt.generate.sequentially import (\n    chunk_sizes,\n    layer_to_device,\n    replace_device,\n    sequential,\n)\nfrom litgpt.model import GPT, Block\nfrom litgpt.scripts.download import download_from_hub\nfrom litgpt.utils import _RunIf\n\nfrom .utils import find_forward_hooks\n\n\n@pytest.mark.parametrize(\n    (\"n_layer\", \"devices\", \"expected\"),\n    [\n        (6, 1, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0}),\n        (6, 2, {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1}),\n        (6, 3, {0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2}),\n        (6, 4, {0: 0, 1: 1, 2: 2, 3: 2, 4: 3, 5: 3}),\n        (6, 5, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4}),\n        (6, 6, {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5}),\n    ],\n)\ndef test_layer_to_device(n_layer, devices, expected):\n    with torch.device(\"meta\"):\n        model = GPT.from_name(\"pythia-14m\", n_layer=n_layer)\n\n    c_sizes = chunk_sizes(n_layer, devices)\n    actual = layer_to_device(model, Block, chunk_sizes=c_sizes)\n    expected = {f\"transformer.h.{i}\": v for i, v in expected.items()}\n    assert actual == expected\n\n\ndef path_to_device(model):\n    return {k: str(v.device) for k, v in itertools.chain(model.named_parameters(), model.named_buffers())}\n\n\ndef test_replace_device():\n    class Submodule(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.register_buffer(\"foo\", torch.tensor(1, device=\"cpu\"))\n            self.register_buffer(\"bar\", torch.tensor(1, device=\"cpu\"))\n\n    class MyModel(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.modules = torch.nn.ModuleDict(\n                {\n                    \"module1\": torch.nn.Linear(1, 1, bias=True, device=\"meta\"),\n                    \"module2\": torch.nn.Linear(1, 1, bias=False, device=\"cpu\"),\n                }\n            )\n            self.submodule = Submodule()\n\n    model = MyModel()\n    assert path_to_device(model) == {\n        \"modules.module1.bias\": \"meta\",\n        \"modules.module1.weight\": \"meta\",\n        \"modules.module2.weight\": \"cpu\",\n        \"submodule.bar\": \"cpu\",\n        \"submodule.foo\": \"cpu\",\n    }\n    model = replace_device(model, torch.device(\"cpu\"), torch.device(\"meta\"))\n    assert path_to_device(model) == {\n        \"modules.module1.bias\": \"meta\",\n        \"modules.module1.weight\": \"meta\",\n        \"modules.module2.weight\": \"meta\",\n        \"submodule.bar\": \"meta\",\n        \"submodule.foo\": \"meta\",\n    }\n\n    model = MyModel()\n    model.submodule.bar = model.submodule.bar.to(\"meta\")\n    with pytest.raises(\n        ValueError,\n        match=escape(\"multiple devices: {'submodule.foo': device(type='cpu'), 'submodule.bar': device(type='meta')}\"),\n    ):\n        replace_device(model, torch.device(\"cpu\"), torch.device(\"meta\"))\n\n\ndef _test_model_1device(accelerator):\n    fabric = Fabric(accelerator=accelerator, devices=1)\n    with torch.device(\"meta\"):\n        model = GPT.from_name(\"pythia-14m\", n_layer=2)\n    model = sequential(model, fabric.device, 15, 1)\n\n    device_str = str(fabric.device)\n    assert path_to_device(model) == {\n        \"cos\": device_str,\n        \"sin\": device_str,\n        \"lm_head.weight\": device_str,\n        \"transformer.h.0.attn.qkv.bias\": device_str,\n        \"transformer.h.0.attn.qkv.weight\": device_str,\n        \"transformer.h.0.attn.proj.bias\": device_str,\n        \"transformer.h.0.attn.proj.weight\": device_str,\n        \"transformer.h.0.mlp.fc.bias\": device_str,\n        \"transformer.h.0.mlp.fc.weight\": device_str,\n        \"transformer.h.0.mlp.proj.bias\": device_str,\n        \"transformer.h.0.mlp.proj.weight\": device_str,\n        \"transformer.h.0.norm_1.bias\": device_str,\n        \"transformer.h.0.norm_1.weight\": device_str,\n        \"transformer.h.0.norm_2.bias\": device_str,\n        \"transformer.h.0.norm_2.weight\": device_str,\n        \"transformer.h.0.attn.kv_cache.k\": device_str,\n        \"transformer.h.0.attn.kv_cache.v\": device_str,\n        \"transformer.h.1.attn.qkv.bias\": device_str,\n        \"transformer.h.1.attn.qkv.weight\": device_str,\n        \"transformer.h.1.attn.proj.bias\": device_str,\n        \"transformer.h.1.attn.proj.weight\": device_str,\n        \"transformer.h.1.mlp.fc.bias\": device_str,\n        \"transformer.h.1.mlp.fc.weight\": device_str,\n        \"transformer.h.1.mlp.proj.bias\": device_str,\n        \"transformer.h.1.mlp.proj.weight\": device_str,\n        \"transformer.h.1.norm_1.bias\": device_str,\n        \"transformer.h.1.norm_1.weight\": device_str,\n        \"transformer.h.1.norm_2.bias\": device_str,\n        \"transformer.h.1.norm_2.weight\": device_str,\n        \"transformer.h.1.attn.kv_cache.k\": device_str,\n        \"transformer.h.1.attn.kv_cache.v\": device_str,\n        \"transformer.ln_f.bias\": device_str,\n        \"transformer.ln_f.weight\": device_str,\n        \"transformer.wte.weight\": device_str,\n    }\n    assert model.max_seq_length == 15\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_model_1device_cuda():\n    _test_model_1device(\"cuda\")\n\n\ndef test_model_1device_cpu():\n    _test_model_1device(\"cpu\")\n\n\n@_RunIf(min_cuda_gpus=2)\ndef test_model_forward_hooks():\n    fabric = Fabric(accelerator=\"cuda\", devices=1)\n    with torch.device(\"meta\"):\n        model = GPT.from_name(\"pythia-14m\")  # 6 layers\n    model = sequential(model, fabric.device, max_seq_length=15, devices=2)\n\n    hooks = find_forward_hooks(model)\n    actual = path_to_device(model)\n    assert actual == {\n        \"lm_head.weight\": \"cuda:0\",\n        \"transformer.wte.weight\": \"cuda:0\",\n        \"transformer.h.0.norm_1.weight\": \"cuda:0\",\n        \"transformer.h.0.norm_1.bias\": \"cuda:0\",\n        \"transformer.h.0.attn.qkv.weight\": \"cuda:0\",\n        \"transformer.h.0.attn.qkv.bias\": \"cuda:0\",\n        \"transformer.h.0.attn.proj.weight\": \"cuda:0\",\n        \"transformer.h.0.attn.proj.bias\": \"cuda:0\",\n        \"transformer.h.0.norm_2.weight\": \"cuda:0\",\n        \"transformer.h.0.norm_2.bias\": \"cuda:0\",\n        \"transformer.h.0.mlp.fc.weight\": \"cuda:0\",\n        \"transformer.h.0.mlp.fc.bias\": \"cuda:0\",\n        \"transformer.h.0.mlp.proj.weight\": \"cuda:0\",\n        \"transformer.h.0.mlp.proj.bias\": \"cuda:0\",\n        \"transformer.h.1.norm_1.weight\": \"cuda:0\",\n        \"transformer.h.1.norm_1.bias\": \"cuda:0\",\n        \"transformer.h.1.attn.qkv.weight\": \"cuda:0\",\n        \"transformer.h.1.attn.qkv.bias\": \"cuda:0\",\n        \"transformer.h.1.attn.proj.weight\": \"cuda:0\",\n        \"transformer.h.1.attn.proj.bias\": \"cuda:0\",\n        \"transformer.h.1.norm_2.weight\": \"cuda:0\",\n        \"transformer.h.1.norm_2.bias\": \"cuda:0\",\n        \"transformer.h.1.mlp.fc.weight\": \"cuda:0\",\n        \"transformer.h.1.mlp.fc.bias\": \"cuda:0\",\n        \"transformer.h.1.mlp.proj.weight\": \"cuda:0\",\n        \"transformer.h.1.mlp.proj.bias\": \"cuda:0\",\n        \"transformer.h.2.norm_1.weight\": \"cuda:0\",\n        \"transformer.h.2.norm_1.bias\": \"cuda:0\",\n        \"transformer.h.2.attn.qkv.weight\": \"cuda:0\",\n        \"transformer.h.2.attn.qkv.bias\": \"cuda:0\",\n        \"transformer.h.2.attn.proj.weight\": \"cuda:0\",\n        \"transformer.h.2.attn.proj.bias\": \"cuda:0\",\n        \"transformer.h.2.norm_2.weight\": \"cuda:0\",\n        \"transformer.h.2.norm_2.bias\": \"cuda:0\",\n        \"transformer.h.2.mlp.fc.weight\": \"cuda:0\",\n        \"transformer.h.2.mlp.fc.bias\": \"cuda:0\",\n        \"transformer.h.2.mlp.proj.weight\": \"cuda:0\",\n        \"transformer.h.2.mlp.proj.bias\": \"cuda:0\",\n        \"transformer.h.3.norm_1.weight\": \"cuda:1\",\n        \"transformer.h.3.norm_1.bias\": \"cuda:1\",\n        \"transformer.h.3.attn.qkv.weight\": \"cuda:1\",\n        \"transformer.h.3.attn.qkv.bias\": \"cuda:1\",\n        \"transformer.h.3.attn.proj.weight\": \"cuda:1\",\n        \"transformer.h.3.attn.proj.bias\": \"cuda:1\",\n        \"transformer.h.3.norm_2.weight\": \"cuda:1\",\n        \"transformer.h.3.norm_2.bias\": \"cuda:1\",\n        \"transformer.h.3.mlp.fc.weight\": \"cuda:1\",\n        \"transformer.h.3.mlp.fc.bias\": \"cuda:1\",\n        \"transformer.h.3.mlp.proj.weight\": \"cuda:1\",\n        \"transformer.h.3.mlp.proj.bias\": \"cuda:1\",\n        \"transformer.h.4.norm_1.weight\": \"cuda:1\",\n        \"transformer.h.4.norm_1.bias\": \"cuda:1\",\n        \"transformer.h.4.attn.qkv.weight\": \"cuda:1\",\n        \"transformer.h.4.attn.qkv.bias\": \"cuda:1\",\n        \"transformer.h.4.attn.proj.weight\": \"cuda:1\",\n        \"transformer.h.4.attn.proj.bias\": \"cuda:1\",\n        \"transformer.h.4.norm_2.weight\": \"cuda:1\",\n        \"transformer.h.4.norm_2.bias\": \"cuda:1\",\n        \"transformer.h.4.mlp.fc.weight\": \"cuda:1\",\n        \"transformer.h.4.mlp.fc.bias\": \"cuda:1\",\n        \"transformer.h.4.mlp.proj.weight\": \"cuda:1\",\n        \"transformer.h.4.mlp.proj.bias\": \"cuda:1\",\n        \"transformer.h.5.norm_1.weight\": \"cuda:1\",\n        \"transformer.h.5.norm_1.bias\": \"cuda:1\",\n        \"transformer.h.5.attn.qkv.weight\": \"cuda:1\",\n        \"transformer.h.5.attn.qkv.bias\": \"cuda:1\",\n        \"transformer.h.5.attn.proj.weight\": \"cuda:1\",\n        \"transformer.h.5.attn.proj.bias\": \"cuda:1\",\n        \"transformer.h.5.norm_2.weight\": \"cuda:1\",\n        \"transformer.h.5.norm_2.bias\": \"cuda:1\",\n        \"transformer.h.5.mlp.fc.weight\": \"cuda:1\",\n        \"transformer.h.5.mlp.fc.bias\": \"cuda:1\",\n        \"transformer.h.5.mlp.proj.weight\": \"cuda:1\",\n        \"transformer.h.5.mlp.proj.bias\": \"cuda:1\",\n        \"transformer.ln_f.weight\": \"cuda:0\",\n        \"transformer.ln_f.bias\": \"cuda:0\",\n        \"cos\": \"cuda:0\",\n        \"sin\": \"cuda:0\",\n        \"transformer.h.0.attn.kv_cache.k\": \"cuda:0\",\n        \"transformer.h.0.attn.kv_cache.v\": \"cuda:0\",\n        \"transformer.h.1.attn.kv_cache.k\": \"cuda:0\",\n        \"transformer.h.1.attn.kv_cache.v\": \"cuda:0\",\n        \"transformer.h.2.attn.kv_cache.k\": \"cuda:0\",\n        \"transformer.h.2.attn.kv_cache.v\": \"cuda:0\",\n        \"transformer.h.3.attn.kv_cache.k\": \"cuda:1\",\n        \"transformer.h.3.attn.kv_cache.v\": \"cuda:1\",\n        \"transformer.h.4.attn.kv_cache.k\": \"cuda:1\",\n        \"transformer.h.4.attn.kv_cache.v\": \"cuda:1\",\n        \"transformer.h.5.attn.kv_cache.k\": \"cuda:1\",\n        \"transformer.h.5.attn.kv_cache.v\": \"cuda:1\",\n    }\n    assert hooks == {\n        \"transformer.h.3\": [(\"forward_pre_hook\", \"move_block_input\", (torch.device(type=\"cuda\", index=1),), {})],\n        \"transformer.h.4\": [(\"forward_pre_hook\", \"move_block_input\", (torch.device(type=\"cuda\", index=1),), {})],\n        \"transformer.h.5\": [\n            (\"forward_pre_hook\", \"move_block_input\", (torch.device(type=\"cuda\", index=1),), {}),\n            (\"forward_hook\", \"move_block_output\", (torch.device(type=\"cuda\", index=0),), {}),\n        ],\n    }\n\n\nroot = Path(__file__).parent.parent.resolve()\n\n\n@_RunIf(min_cuda_gpus=2)\n@pytest.mark.flaky(reruns=5, reruns_delay=2)\ndef test_base_with_sequentially(tmp_path):\n    # download the tokenizer\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    checkpoint_dir = tmp_path / \"EleutherAI/pythia-14m\"\n    # save the config\n    config = Config.from_name(\"pythia-14m\")\n    (checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(asdict(config)))\n    # create a state dict to load from\n    torch.save(GPT(config).state_dict(), checkpoint_dir / \"lit_model.pth\")\n\n    args = [\n        str(checkpoint_dir),\n        \"--num_samples=1\",\n        \"--max_new_tokens=10\",\n        \"--precision=16-true\",\n        \"--temperature=0.0\",\n    ]\n    env = {\"CUDA_VISIBLE_DEVICES\": \"0,1\"}\n    sequential_stdout = subprocess.check_output(\n        [sys.executable, \"-m\", \"litgpt\", \"generate_sequentially\", *args],\n        env=env,\n        cwd=root,\n    ).decode()\n\n    assert \"What food do llamas eat?\" in sequential_stdout\n\n\ndef test_cli():\n    args = [\"litgpt\", \"generate_sequentially\", \"-h\"]\n    output = subprocess.check_output(args)\n    output = str(output.decode())\n    assert \"Generation script that partitions layers across\" in output\n"
  },
  {
    "path": "tests/generate/test_tp.py",
    "content": "import subprocess\nimport sys\nfrom dataclasses import asdict, replace\nfrom pathlib import Path\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\nimport yaml\n\nfrom litgpt import GPT, Config\nfrom litgpt.generate.tp import tensor_parallel, tensor_parallel_linear\nfrom litgpt.scripts.download import download_from_hub\nfrom litgpt.utils import _RunIf\n\nfrom .utils import find_forward_hooks\n\n\ndef test_tensor_parallel_linear():\n    fabric = Mock()\n    fabric.world_size = 4\n    fabric.global_rank = 2\n\n    def get_linear(bias=True):\n        linear = torch.nn.Linear(8, 8, bias=bias)\n        linear.weight.data = torch.arange(64, dtype=torch.float32).reshape(8, 8)\n        if bias:\n            linear.bias.data = torch.arange(8, dtype=torch.float32)\n        return linear\n\n    linear = get_linear()\n    tensor_parallel_linear(fabric, linear, \"colwise\")\n    expected = torch.arange(32, 48, dtype=torch.float32).reshape(2, 8)\n    torch.testing.assert_close(linear.weight, expected)\n    expected = torch.arange(4, 6, dtype=torch.float32)\n    torch.testing.assert_close(linear.bias, expected)\n\n    linear = get_linear(bias=False)\n    tensor_parallel_linear(fabric, linear, \"rowwise\")\n    expected = torch.arange(4, 62, 8, dtype=torch.float32).reshape(8, 1)\n    expected = torch.cat([expected, expected + 1], dim=1)\n    torch.testing.assert_close(linear.weight, expected)\n    assert linear.bias is None\n\n\n@pytest.mark.parametrize(\n    (\"name\", \"expected\"),\n    [\n        (\n            \"Llama-2-70b-hf\",\n            {\n                \"transformer.h.0.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.0.mlp\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.1.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.1.mlp\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.2.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.2.mlp\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n            },\n        ),\n        (\n            \"falcon-180B\",\n            {\n                \"transformer.h.0.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.0.mlp\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.1.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.1.mlp\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.2.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.2.mlp\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n            },\n        ),\n        (\n            \"Mixtral-8x7B-v0.1\",\n            {\n                \"transformer.h.0.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.0.mlp.experts.0\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.0.mlp.experts.1\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.1.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.1.mlp.experts.0\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.1.mlp.experts.1\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.2.attn\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.2.mlp.experts.0\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n                \"transformer.h.2.mlp.experts.1\": [(\"forward_hook\", \"all_reduce_output\", (8,), {})],\n            },\n        ),\n    ],\n)\ndef test_tensor_parallel_llama(name, expected):\n    fabric = Mock()\n    fabric.world_size = 8\n    fabric.global_rank = 1\n\n    with torch.device(\"meta\"):\n        model = GPT.from_name(name, n_layer=3, n_expert=2)\n    config = replace(model.config)  # make a copy\n\n    model = tensor_parallel(fabric, model)\n\n    hooks = find_forward_hooks(model)\n    assert hooks == expected\n\n    assert model.config.n_embd * 8 == config.n_embd\n    assert model.config.n_head * 8 == config.n_head\n    assert model.config.n_query_groups * 8 == config.n_query_groups\n\n\nroot = Path(__file__).parent.parent.resolve()\n\n\n@_RunIf(min_cuda_gpus=2)\ndef test_tp(tmp_path):\n    # download the tokenizer\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    checkpoint_dir = tmp_path / \"EleutherAI/pythia-14m\"\n    # save the config\n    config = Config.from_name(\"pythia-14m\")\n    (checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(asdict(config)))\n    # create a state dict to load from\n    torch.save(GPT(config).state_dict(), checkpoint_dir / \"lit_model.pth\")\n\n    args = [\n        str(checkpoint_dir),\n        \"--num_samples=1\",\n        \"--max_new_tokens=10\",\n        \"--precision=16-true\",\n        \"--temperature=0.0\",\n    ]\n    env = {\"CUDA_VISIBLE_DEVICES\": \"0,1\"}\n    tp_stdout = subprocess.check_output(\n        [sys.executable, \"-m\", \"litgpt\", \"generate_tp\", *args], env=env, cwd=root\n    ).decode()\n\n    # there is some unaccounted randomness so cannot compare the output with that of `generate/base.py`\n    assert \"What food do llamas eat?\" in tp_stdout\n\n\ndef test_cli():\n    args = [\"litgpt\", \"generate_tp\", \"-h\"]\n    output = subprocess.check_output(args)\n    output = str(output.decode())\n    assert \"Generation script that uses tensor parallelism\" in output\n"
  },
  {
    "path": "tests/generate/utils.py",
    "content": "from collections import defaultdict\n\n\ndef find_forward_hooks(module):\n    mapping = defaultdict(list)\n    for name, submodule in module.named_modules():\n        for hook in submodule._forward_pre_hooks.values():\n            hook_data = (\"forward_pre_hook\", hook.func.__name__, hook.args, hook.keywords)\n            mapping[name].append(hook_data)\n        for hook in submodule._forward_hooks.values():\n            hook_data = (\"forward_hook\", hook.func.__name__, hook.args, hook.keywords)\n            mapping[name].append(hook_data)\n    return dict(mapping)\n"
  },
  {
    "path": "tests/test_adapter.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nfrom contextlib import redirect_stdout\nfrom copy import deepcopy\nfrom dataclasses import asdict\nfrom io import StringIO\nfrom unittest import mock\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\nimport yaml\nfrom lightning import Fabric\nfrom lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision\nfrom lightning.fabric.wrappers import _FabricOptimizer\nfrom torch._dynamo.backends import debugging\nfrom transformers.models.gemma import GemmaConfig, GemmaForCausalLM\nfrom transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM\nfrom transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig\n\nimport litgpt.adapter as gpt_adapter\nimport litgpt.finetune.adapter as module\nimport litgpt.model as gpt\nfrom litgpt.adapter import GPT, CausalSelfAttention, Config, adapter_filter\nfrom litgpt.args import EvalArgs, TrainArgs\nfrom litgpt.data import Alpaca\nfrom litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama\nfrom litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved\nfrom litgpt.utils import _RunIf\n\n\ndef test_config_identical():\n    name = \"pythia-14m\"\n    base_config = asdict(gpt.Config.from_name(name))\n    adapter_config = asdict(gpt_adapter.Config.from_name(name))\n    del adapter_config[\"adapter_prompt_length\"]\n    del adapter_config[\"adapter_start_layer\"]\n    assert adapter_config == base_config\n\n    with Fabric(accelerator=\"cpu\").init_module(empty_init=True):\n        base_model = gpt.GPT.from_name(name)\n        adapter_model = gpt_adapter.GPT.from_name(name)\n    assert adapter_model.lm_head.weight.shape == base_model.lm_head.weight.shape\n\n\ndef test_adapter_filter(tmp_path):\n    fabric = Fabric(devices=1)\n    model = GPT.from_name(\"pythia-14m\", n_layer=4)\n    save_path = tmp_path / \"model.pth\"\n    fabric.save(save_path, {\"model\": model}, filter={\"model\": adapter_filter})\n    saved = torch.load(save_path)[\"model\"]\n\n    expected = {\n        \"transformer.h.2.attn.adapter_wte.weight\",\n        \"transformer.h.2.attn.gating_factor\",\n        \"transformer.h.3.attn.adapter_wte.weight\",\n        \"transformer.h.3.attn.gating_factor\",\n    }\n    assert set(saved) == expected\n\n\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\ndef test_adapter_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):\n    model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)\n    (fake_checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(model_config))\n\n    monkeypatch.setattr(module, \"load_checkpoint\", Mock())\n\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value = tokenizer_mock\n    tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])\n    monkeypatch.setattr(module, \"Tokenizer\", tokenizer_mock)\n\n    out_dir = tmp_path / \"out\"\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"adapter.py\", str(fake_checkpoint_dir)]):\n        module.setup(\n            fake_checkpoint_dir,\n            data=Alpaca(\n                download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0\n            ),\n            out_dir=out_dir,\n            precision=\"32-true\",\n            train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),\n            eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),\n        )\n\n    out_dir_contents = set(os.listdir(out_dir))\n    checkpoint_dirs = {\"step-000002\", \"step-000004\", \"step-000006\", \"final\"}\n    assert checkpoint_dirs.issubset(out_dir_contents)\n    assert all((out_dir / p).is_dir() for p in checkpoint_dirs)\n    for checkpoint_dir in checkpoint_dirs:\n        assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {\n            \"lit_model.pth.adapter\",\n            \"model_config.yaml\",\n            \"tokenizer_config.json\",\n            \"tokenizer.json\",\n            \"hyperparameters.yaml\",\n            \"prompt_style.yaml\",\n        }\n    assert (out_dir / \"logs\" / \"csv\" / \"version_0\" / \"metrics.csv\").is_file()\n\n    logs = stdout.getvalue()\n    assert logs.count(\"(step)\") == 6\n    assert logs.count(\"val loss\") == 4  # 3 validations + 1 final validation\n    assert logs.count(\"Final evaluation\") == 1\n    assert \"of trainable parameters: 168\" in logs\n\n\ndef test_adapter_gpt_init_weights():\n    config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0)\n    model = GPT(config)\n    param = model.transformer.h[0].attn.gating_factor\n\n    assert (param == 0).all()\n    torch.nn.init.constant_(param, 1.23)\n    assert (param != 0).any()\n    model.apply(model._init_weights)\n    assert (param == 0).all()\n\n\n@_RunIf(dynamo=True)\n@torch.inference_mode()\ndef test_adapter_compile():\n    model = GPT.from_name(\"pythia-14m\", n_layer=3)\n    x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)\n\n    explanation = torch._dynamo.explain(model)(x)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n    model = GPT(model.config)\n    model.set_kv_cache(2)\n    input_pos = torch.arange(model.config.block_size)\n    explanation = torch._dynamo.explain(model)(x, input_pos)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):\n    if not _BITSANDBYTES_AVAILABLE:\n        pytest.skip(\"BNB not available\")\n\n    from bitsandbytes.optim import PagedAdamW\n\n    model_config = dict(\n        block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True\n    )\n    (fake_checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(model_config))\n\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value = tokenizer_mock\n    tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])\n    monkeypatch.setattr(module, \"Tokenizer\", tokenizer_mock)\n\n    monkeypatch.setattr(module, \"load_checkpoint\", Mock())\n    train_mock = Mock()\n    train_mock.return_value = {\n        \"raw_tokens\": 1000,\n        \"raw_tokens_plus_prompt_template\": 1100,\n        \"raw_tokens_plus_prompt_template_and_padding\": 1200,\n    }\n    monkeypatch.setattr(module, \"fit\", train_mock)\n\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"adapter.py\", str(fake_checkpoint_dir)]):\n        module.setup(\n            fake_checkpoint_dir,\n            data=Alpaca(\n                download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0\n            ),\n            precision=\"16-true\",\n            quantize=\"bnb.nf4-dq\",\n            out_dir=tmp_path,\n        )\n\n    _, kwargs = train_mock.call_args\n    fabric = kwargs[\"fabric\"]\n    model = kwargs[\"model\"]\n    optimizer = kwargs[\"optimizer\"]\n    assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)\n    assert isinstance(optimizer, _FabricOptimizer)\n    assert isinstance(optimizer._optimizer, PagedAdamW)\n\n    dtype_to_name = {\"torch.uint8\": set(), \"torch.float16\": set()}\n    for name, layer in model.named_parameters():\n        name = name[len(\"_forward_module.\") :]\n        dtype_to_name[str(layer.dtype)].add(name)\n    assert dtype_to_name == {\n        \"torch.float16\": {\n            \"transformer.wte.weight\",\n            \"transformer.wte.norm.weight\",\n            \"transformer.wte.norm.bias\",\n            \"transformer.h.0.norm_1.weight\",\n            \"transformer.h.0.norm_1.bias\",\n            \"transformer.h.0.attn.gating_factor\",\n            \"transformer.h.0.attn.qkv.bias\",\n            \"transformer.h.0.attn.proj.bias\",\n            \"transformer.h.0.attn.adapter_wte.weight\",\n            \"transformer.h.0.norm_2.weight\",\n            \"transformer.h.0.norm_2.bias\",\n            \"transformer.h.0.mlp.fc.bias\",\n            \"transformer.h.0.mlp.proj.bias\",\n            \"transformer.h.1.norm_1.weight\",\n            \"transformer.h.1.norm_1.bias\",\n            \"transformer.h.1.attn.gating_factor\",\n            \"transformer.h.1.attn.qkv.bias\",\n            \"transformer.h.1.attn.proj.bias\",\n            \"transformer.h.1.attn.adapter_wte.weight\",\n            \"transformer.h.1.norm_2.weight\",\n            \"transformer.h.1.norm_2.bias\",\n            \"transformer.h.1.mlp.fc.bias\",\n            \"transformer.h.1.mlp.proj.bias\",\n            \"transformer.ln_f.weight\",\n            \"transformer.ln_f.bias\",\n        },\n        \"torch.uint8\": {\n            \"lm_head.weight\",\n            \"transformer.h.0.attn.qkv.weight\",\n            \"transformer.h.0.attn.proj.weight\",\n            \"transformer.h.0.mlp.fc.weight\",\n            \"transformer.h.0.mlp.proj.weight\",\n            \"transformer.h.1.attn.qkv.weight\",\n            \"transformer.h.1.attn.proj.weight\",\n            \"transformer.h.1.mlp.fc.weight\",\n            \"transformer.h.1.mlp.proj.weight\",\n        },\n    }\n\n    assert {p.name for p in tmp_path.rglob(\"*.pth.adapter\")} == {\"lit_model.pth.adapter\"}\n    state_dict = torch.load(tmp_path / \"final\" / \"lit_model.pth.adapter\")\n    assert len(state_dict) == 1\n    dtype_to_name = {\"torch.float16\": set()}\n    for name, layer in state_dict[\"model\"].items():\n        dtype_to_name[str(layer.dtype)].add(name)\n    assert dtype_to_name == {\n        \"torch.float16\": {\n            \"transformer.h.0.attn.adapter_wte.weight\",\n            \"transformer.h.0.attn.gating_factor\",\n            \"transformer.h.1.attn.adapter_wte.weight\",\n            \"transformer.h.1.attn.gating_factor\",\n        }\n    }\n\n    logs = stdout.getvalue()\n    assert \"of trainable parameters: 168\" in logs\n    assert \"of non-trainable parameters: 1,888\" in logs\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", [\"gemma-2b\", \"gemma-7b\"])\ndef test_against_hf_gemma(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    T = 5\n    ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)\n    theirs_config = GemmaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = GemmaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-2-9b\", \"gemma-2-27b\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma_2(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = Gemma2ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_gemma_2({}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-3-1b-it\", \"gemma-3-4b-it\", \"gemma-3-12b-it\", \"gemma-3-27b-it\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma_3(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma3TextConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = Gemma3ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_gemma_3({}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5)\n\n\ndef test_load_legacy_state_dict():\n    \"\"\"Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.\"\"\"\n    config = Config(\n        n_embd=32,\n        n_head=4,\n        head_size=8,\n        n_query_groups=4,\n        bias=True,\n    )\n\n    attention_1 = CausalSelfAttention(config=config, block_idx=0)\n\n    # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`\n    # and make them interleaved\n    state_dict = deepcopy(attention_1.state_dict())\n    state_dict[\"attn.weight\"] = make_qkv_interleaved(state_dict.pop(\"qkv.weight\"), config)\n    state_dict[\"attn.bias\"] = make_qkv_interleaved(state_dict.pop(\"qkv.bias\"), config)\n\n    attention_2 = CausalSelfAttention(config=config, block_idx=0)\n    attention_2.load_state_dict(state_dict)\n"
  },
  {
    "path": "tests/test_adapter_v2.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nfrom contextlib import redirect_stdout\nfrom copy import deepcopy\nfrom io import StringIO\nfrom unittest import mock\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\nimport yaml\nfrom lightning import Fabric\nfrom lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision\nfrom lightning.fabric.wrappers import _FabricOptimizer\nfrom torch._dynamo.backends import debugging\nfrom transformers.models.gemma import GemmaConfig, GemmaForCausalLM\nfrom transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM\nfrom transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig\nfrom transformers.models.mixtral import MixtralConfig, MixtralForCausalLM\n\nimport litgpt.config as config_module\nimport litgpt.finetune.adapter_v2 as module\nfrom litgpt.adapter_v2 import GPT as AdapterV2GPT\nfrom litgpt.adapter_v2 import CausalSelfAttention, Config, adapter_filter\nfrom litgpt.args import EvalArgs, TrainArgs\nfrom litgpt.data import Alpaca\nfrom litgpt.model import GPT as BaseGPT\nfrom litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama\nfrom litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved\nfrom litgpt.utils import _RunIf\n\n\ndef test_config_identical():\n    name = \"pythia-14m\"\n    with Fabric(accelerator=\"cpu\").init_module(empty_init=True):\n        base_model = BaseGPT.from_name(name)\n        adapter_model = AdapterV2GPT.from_name(name)\n\n    assert not hasattr(base_model.transformer.h[2].attn.qkv, \"adapter_bias\")\n    assert not hasattr(base_model.transformer.h[2].attn.qkv, \"adapter_scale\")\n    assert hasattr(adapter_model.transformer.h[2].attn.qkv, \"adapter_bias\")\n    assert hasattr(adapter_model.transformer.h[2].attn.qkv, \"adapter_scale\")\n\n\ndef test_adapter_v2_filter(tmp_path):\n    fabric = Fabric(devices=1)\n    model = AdapterV2GPT.from_name(\"pythia-14m\", n_layer=3)\n    save_path = tmp_path / \"model.pth\"\n    fabric.save(save_path, {\"model\": model}, filter={\"model\": adapter_filter})\n    saved = torch.load(save_path)[\"model\"]\n\n    expected = {\n        \"lm_head.adapter_bias\",\n        \"lm_head.adapter_scale\",\n        \"transformer.ln_f.bias\",\n        \"transformer.ln_f.weight\",\n        \"transformer.h.2.attn.adapter_wte.weight\",\n        \"transformer.h.2.attn.gating_factor\",\n    }\n    for layer in range(3):\n        for param in (\n            \"attn.qkv.adapter_bias\",\n            \"attn.qkv.adapter_scale\",\n            \"attn.proj.adapter_bias\",\n            \"attn.proj.adapter_scale\",\n            \"mlp.fc.adapter_bias\",\n            \"mlp.fc.adapter_scale\",\n            \"mlp.proj.adapter_bias\",\n            \"mlp.proj.adapter_scale\",\n            \"norm_1.bias\",\n            \"norm_1.weight\",\n            \"norm_2.bias\",\n            \"norm_2.weight\",\n        ):\n            expected.add(f\"transformer.h.{layer}.{param}\")\n    assert set(saved) == expected\n\n\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\ndef test_adapter_v2_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):\n    model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0)\n    (fake_checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(model_config))\n\n    monkeypatch.setattr(module, \"load_checkpoint\", Mock())\n\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value = tokenizer_mock\n    tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])\n    monkeypatch.setattr(module, \"Tokenizer\", tokenizer_mock)\n\n    out_dir = tmp_path / \"out\"\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"adapter_v2.py\", str(fake_checkpoint_dir)]):\n        module.setup(\n            fake_checkpoint_dir,\n            data=Alpaca(\n                download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0\n            ),\n            out_dir=out_dir,\n            precision=\"32-true\",\n            train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),\n            eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),\n        )\n\n    out_dir_contents = set(os.listdir(out_dir))\n    checkpoint_dirs = {\"step-000002\", \"step-000004\", \"step-000006\", \"final\"}\n    assert checkpoint_dirs.issubset(out_dir_contents)\n    assert all((out_dir / p).is_dir() for p in checkpoint_dirs)\n    for checkpoint_dir in checkpoint_dirs:\n        assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {\n            \"lit_model.pth.adapter_v2\",\n            \"model_config.yaml\",\n            \"tokenizer_config.json\",\n            \"tokenizer.json\",\n            \"hyperparameters.yaml\",\n            \"prompt_style.yaml\",\n        }\n    assert (out_dir / \"logs\" / \"csv\" / \"version_0\" / \"metrics.csv\").is_file()\n\n    logs = stdout.getvalue()\n    assert logs.count(\"(step)\") == 6\n    assert logs.count(\"val loss\") == 4  # 3 validations + 1 final validation\n    assert logs.count(\"Final evaluation\") == 1\n    assert \"of trainable parameters: 552\" in logs\n\n\ndef test_adapter_v2_gpt_init_weights():\n    config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, adapter_start_layer=0)\n    model = AdapterV2GPT(config)\n\n    for param in (model.transformer.h[0].attn.gating_factor, model.lm_head.adapter_bias):\n        assert (param == 0).all()\n        torch.nn.init.constant_(param, 1.23)\n        assert (param != 0).any()\n        model.apply(model._init_weights)\n        assert (param == 0).all()\n\n\n@pytest.mark.parametrize(\"name\", [c[\"name\"] for c in config_module.configs])\ndef test_base_model_can_be_adapter_v2_loaded(name):\n    kwargs = {\"n_layer\": 2, \"n_head\": 8, \"n_query_groups\": 4, \"n_embd\": 16, \"padded_vocab_size\": 32}\n    base_model = BaseGPT.from_name(name, **kwargs)\n    base_model_state_dict = base_model.state_dict()\n    lora_model = AdapterV2GPT.from_name(name, **kwargs, adapter_start_layer=0)\n    keys = lora_model.load_state_dict(base_model_state_dict, strict=False)\n    assert not keys.unexpected_keys\n    for k in keys.missing_keys:\n        assert adapter_filter(k, None)\n\n\n@_RunIf(dynamo=True)\n@torch.inference_mode()\ndef test_adapter_v2_compile():\n    model = AdapterV2GPT.from_name(\"pythia-14m\", n_layer=3)\n    x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)\n\n    explanation = torch._dynamo.explain(model)(x)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n    model = AdapterV2GPT(model.config)\n    model.set_kv_cache(2)\n    input_pos = torch.arange(model.config.block_size)\n    explanation = torch._dynamo.explain(model)(x, input_pos)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n\n@torch.inference_mode()\ndef test_against_hf_mixtral():\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    ours_config = Config.from_name(\n        \"Mixtral-8x7B-Instruct-v0.1\",\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_embd=32,\n        n_head=8,\n        n_query_groups=2,\n        intermediate_size=86,\n        n_expert=4,\n    )\n    T = 5\n    theirs_config = MixtralConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        num_local_experts=ours_config.n_expert,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = MixtralForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = AdapterV2GPT(ours_config).to(device)\n    # strict=False because missing keys due to adapter weights not contained in state dict\n    ours_model.load_state_dict(state_dict, strict=False)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", [\"gemma-2b\", \"gemma-7b\"])\ndef test_against_hf_gemma(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    T = 5\n    ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)\n    theirs_config = GemmaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = GemmaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = AdapterV2GPT(ours_config).to(device)\n    keys = ours_model.load_state_dict(state_dict, strict=False)\n    assert not keys.unexpected_keys\n    for k in keys.missing_keys:\n        assert adapter_filter(k, None)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-2-9b\", \"gemma-2-27b\"))\ndef test_against_original_gemma_2(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = Gemma2ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_gemma_2({}, state_dict, theirs_state_dict)\n    ours_model = AdapterV2GPT(ours_config).to(device)\n    keys = ours_model.load_state_dict(state_dict, strict=False)\n    assert not keys.unexpected_keys\n    for k in keys.missing_keys:\n        assert adapter_filter(k, None)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(\n        # some macOS devices have numerical differences, hence the tol bump\n        ours_y,\n        theirs_y,\n        atol=1e-4,\n        rtol=1e-5,\n    )\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-3-1b-it\", \"gemma-3-4b-it\", \"gemma-3-12b-it\", \"gemma-3-27b-it\"))\ndef test_against_original_gemma_3(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n\n    theirs_config = Gemma3TextConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n        rope_scaling={\"factor\": 8.0, \"rope_type\": \"linear\"},\n        rope_local_base_freq=ours_config.rope_local_base_freq,\n    )\n\n    theirs_model = Gemma3ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n\n    copy_weights_gemma_3({}, state_dict, theirs_state_dict)\n    ours_model = AdapterV2GPT(ours_config).to(device)\n    keys = ours_model.load_state_dict(state_dict, strict=False)\n    assert not keys.unexpected_keys\n    for k in keys.missing_keys:\n        assert adapter_filter(k, None)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(\n        ours_y, theirs_y, rtol=3e-5, atol=3e-5\n    )  # some macOS devices have numerical differences, hence the tol bump\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):\n    if not _BITSANDBYTES_AVAILABLE:\n        pytest.skip(\"BNB not available\")\n\n    from bitsandbytes.optim import PagedAdamW\n\n    model_config = dict(\n        block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8, adapter_start_layer=0, bias=True\n    )\n    (fake_checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(model_config))\n\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value = tokenizer_mock\n    tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])\n    monkeypatch.setattr(module, \"Tokenizer\", tokenizer_mock)\n\n    monkeypatch.setattr(module, \"load_checkpoint\", Mock())\n    train_mock = Mock()\n    train_mock.return_value = {\n        \"raw_tokens\": 1000,\n        \"raw_tokens_plus_prompt_template\": 1100,\n        \"raw_tokens_plus_prompt_template_and_padding\": 1200,\n    }\n    monkeypatch.setattr(module, \"fit\", train_mock)\n\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"adapter_v2.py\", str(fake_checkpoint_dir)]):\n        module.setup(\n            fake_checkpoint_dir,\n            data=Alpaca(\n                download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0\n            ),\n            precision=\"16-true\",\n            quantize=\"bnb.nf4-dq\",\n            out_dir=tmp_path,\n        )\n\n    _, kwargs = train_mock.call_args\n    fabric = kwargs[\"fabric\"]\n    model = kwargs[\"model\"]\n    optimizer = kwargs[\"optimizer\"]\n    assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)\n    assert isinstance(optimizer, _FabricOptimizer)\n    assert isinstance(optimizer._optimizer, PagedAdamW)\n\n    dtype_to_name = {\"torch.uint8\": set(), \"torch.float16\": set()}\n    for name, layer in model.named_parameters():\n        name = name[len(\"_forward_module.\") :]\n        dtype_to_name[str(layer.dtype)].add(name)\n    assert dtype_to_name == {\n        \"torch.uint8\": {\n            \"transformer.h.0.mlp.fc.linear.weight\",\n            \"transformer.h.1.mlp.proj.linear.weight\",\n            \"transformer.h.1.attn.qkv.linear.weight\",\n            \"transformer.h.0.attn.proj.linear.weight\",\n            \"lm_head.linear.weight\",\n            \"transformer.h.1.attn.proj.linear.weight\",\n            \"transformer.h.0.mlp.proj.linear.weight\",\n            \"transformer.h.0.attn.qkv.linear.weight\",\n            \"transformer.h.1.mlp.fc.linear.weight\",\n        },\n        \"torch.float16\": {\n            \"transformer.h.1.attn.qkv.adapter_bias\",\n            \"transformer.h.1.mlp.proj.adapter_bias\",\n            \"transformer.h.0.attn.qkv.adapter_bias\",\n            \"transformer.h.0.norm_1.bias\",\n            \"transformer.h.0.attn.qkv.linear.bias\",\n            \"transformer.h.1.attn.adapter_wte.weight\",\n            \"transformer.ln_f.weight\",\n            \"transformer.h.0.mlp.fc.linear.bias\",\n            \"transformer.h.0.mlp.proj.linear.bias\",\n            \"transformer.h.1.mlp.fc.linear.bias\",\n            \"transformer.h.0.attn.proj.adapter_scale\",\n            \"transformer.h.0.attn.qkv.adapter_scale\",\n            \"transformer.h.1.norm_2.bias\",\n            \"transformer.h.1.attn.proj.adapter_scale\",\n            \"transformer.h.0.norm_2.bias\",\n            \"transformer.h.0.mlp.fc.adapter_scale\",\n            \"transformer.h.0.attn.proj.linear.bias\",\n            \"transformer.h.1.attn.proj.linear.bias\",\n            \"transformer.h.1.norm_1.bias\",\n            \"transformer.h.0.norm_1.weight\",\n            \"transformer.h.1.attn.proj.adapter_bias\",\n            \"transformer.h.0.mlp.proj.adapter_scale\",\n            \"transformer.h.0.mlp.proj.adapter_bias\",\n            \"transformer.h.1.mlp.fc.adapter_bias\",\n            \"transformer.h.1.mlp.proj.adapter_scale\",\n            \"transformer.h.1.attn.gating_factor\",\n            \"transformer.h.1.norm_1.weight\",\n            \"transformer.ln_f.bias\",\n            \"transformer.h.0.mlp.fc.adapter_bias\",\n            \"lm_head.adapter_scale\",\n            \"lm_head.adapter_bias\",\n            \"transformer.h.1.norm_2.weight\",\n            \"transformer.h.0.attn.adapter_wte.weight\",\n            \"transformer.h.1.attn.qkv.adapter_scale\",\n            \"transformer.h.1.mlp.fc.adapter_scale\",\n            \"transformer.h.1.attn.qkv.linear.bias\",\n            \"transformer.wte.weight\",\n            \"transformer.wte.norm.weight\",\n            \"transformer.wte.norm.bias\",\n            \"transformer.h.0.norm_2.weight\",\n            \"transformer.h.1.mlp.proj.linear.bias\",\n            \"transformer.h.0.attn.gating_factor\",\n            \"transformer.h.0.attn.proj.adapter_bias\",\n        },\n    }\n\n    assert {p.name for p in tmp_path.rglob(\"*.pth.adapter_v2\")} == {\"lit_model.pth.adapter_v2\"}\n    state_dict = torch.load(tmp_path / \"final\" / \"lit_model.pth.adapter_v2\")\n    assert len(state_dict) == 1\n    dtype_to_name = {\"torch.float16\": set()}\n    for name, layer in state_dict[\"model\"].items():\n        dtype_to_name[str(layer.dtype)].add(name)\n    assert dtype_to_name == {\n        \"torch.float16\": {\n            \"transformer.h.1.attn.adapter_wte.weight\",\n            \"transformer.h.1.attn.proj.adapter_bias\",\n            \"transformer.h.1.mlp.fc.adapter_scale\",\n            \"lm_head.adapter_bias\",\n            \"transformer.h.0.mlp.proj.adapter_scale\",\n            \"transformer.ln_f.bias\",\n            \"lm_head.adapter_scale\",\n            \"transformer.h.1.norm_2.weight\",\n            \"transformer.h.0.attn.qkv.adapter_scale\",\n            \"transformer.h.0.mlp.proj.adapter_bias\",\n            \"transformer.h.0.attn.gating_factor\",\n            \"transformer.h.1.norm_1.bias\",\n            \"transformer.h.1.mlp.fc.adapter_bias\",\n            \"transformer.h.1.mlp.proj.adapter_scale\",\n            \"transformer.h.0.mlp.fc.adapter_scale\",\n            \"transformer.h.1.attn.qkv.adapter_bias\",\n            \"transformer.h.0.norm_2.weight\",\n            \"transformer.h.1.norm_2.bias\",\n            \"transformer.h.0.norm_1.weight\",\n            \"transformer.h.0.attn.proj.adapter_scale\",\n            \"transformer.h.1.mlp.proj.adapter_bias\",\n            \"transformer.h.0.attn.qkv.adapter_bias\",\n            \"transformer.h.0.attn.adapter_wte.weight\",\n            \"transformer.ln_f.weight\",\n            \"transformer.h.1.attn.gating_factor\",\n            \"transformer.h.0.mlp.fc.adapter_bias\",\n            \"transformer.h.1.attn.proj.adapter_scale\",\n            \"transformer.h.0.attn.proj.adapter_bias\",\n            \"transformer.h.0.norm_1.bias\",\n            \"transformer.h.0.norm_2.bias\",\n            \"transformer.h.1.norm_1.weight\",\n            \"transformer.h.1.attn.qkv.adapter_scale\",\n        }\n    }\n\n    logs = stdout.getvalue()\n    assert \"of trainable parameters: 552\" in logs\n    assert \"of non-trainable parameters: 1,808\" in logs\n\n\ndef test_load_legacy_state_dict():\n    \"\"\"Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.\"\"\"\n    config = Config(\n        n_embd=32,\n        n_head=4,\n        head_size=8,\n        n_query_groups=4,\n        bias=True,\n    )\n\n    attention_1 = CausalSelfAttention(config=config, block_idx=0)\n\n    # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`\n    # and make them interleaved\n    state_dict = deepcopy(attention_1.state_dict())\n    state_dict[\"attn.linear.weight\"] = make_qkv_interleaved(state_dict.pop(\"qkv.linear.weight\"), config)\n    state_dict[\"attn.linear.bias\"] = make_qkv_interleaved(state_dict.pop(\"qkv.linear.bias\"), config)\n\n    attention_2 = CausalSelfAttention(config=config, block_idx=0)\n    attention_2.load_state_dict(state_dict)\n"
  },
  {
    "path": "tests/test_api.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport re\nimport sys\nfrom collections import OrderedDict\nfrom pathlib import Path\nfrom unittest.mock import MagicMock, patch\n\nimport pytest\nimport torch\nfrom lightning.fabric.accelerators import CUDAAccelerator\n\nfrom litgpt.api import LLM, benchmark_dict_to_markdown_table, calculate_number_of_devices\nfrom litgpt.scripts.download import download_from_hub\nfrom litgpt.utils import _RunIf\n\nskip_in_ci_on_macos = pytest.mark.skipif(\n    sys.platform == \"darwin\" and os.getenv(\"GITHUB_ACTIONS\") == \"true\",\n    reason=\"Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.\",\n)\n\n\nif sys.platform == \"darwin\" and os.getenv(\"GITHUB_ACTIONS\") == \"true\":\n    USE_MPS = False\nelif torch.backends.mps.is_available():\n    USE_MPS = True\nelse:\n    USE_MPS = False\n\n\n@pytest.fixture\ndef mock_llm():\n    llm = MagicMock(spec=LLM)\n    llm.model = MagicMock()\n    llm.preprocessor = MagicMock()\n    llm.prompt_style = MagicMock()\n    llm.checkpoint_dir = MagicMock()\n    llm.fabric = MagicMock()\n    return llm\n\n\ndef test_load_model(mock_llm):\n    assert isinstance(mock_llm, LLM)\n    assert mock_llm.model is not None\n    assert mock_llm.preprocessor is not None\n    assert mock_llm.prompt_style is not None\n    assert mock_llm.checkpoint_dir is not None\n    assert mock_llm.fabric is not None\n\n\ndef test_generate(mock_llm):\n    prompt = \"What do Llamas eat?\"\n    mock_llm.generate.return_value = prompt + \" Mock output\"\n    output = mock_llm.generate(prompt, max_new_tokens=10, temperature=0.8, top_k=5)\n    assert isinstance(output, str)\n    assert len(output) > len(prompt)\n\n\ndef test_stream_generate(mock_llm):\n    prompt = \"What do Llamas eat?\"\n\n    def iterator():\n        outputs = (prompt + \" Mock output\").split()\n        yield from outputs\n\n    mock_llm.generate.return_value = iterator()\n    output = mock_llm.generate(prompt, max_new_tokens=10, temperature=0.8, top_k=5, stream=True)\n    result = \"\".join([out for out in output])\n    assert len(result) > len(prompt)\n\n\ndef test_generate_token_ids(mock_llm):\n    prompt = \"What do Llamas eat?\"\n    mock_output_ids = MagicMock(spec=torch.Tensor)\n    mock_output_ids.shape = [len(prompt) + 10]\n    mock_llm.generate.return_value = mock_output_ids\n    output_ids = mock_llm.generate(prompt, max_new_tokens=10, return_as_token_ids=True)\n    assert isinstance(output_ids, torch.Tensor)\n    assert output_ids.shape[0] > len(prompt)\n\n\ndef test_calculate_number_of_devices():\n    assert calculate_number_of_devices(1) == 1\n    assert calculate_number_of_devices([0, 1, 2]) == 3\n    assert calculate_number_of_devices(None) == 0\n\n\ndef test_llm_load_random_init(tmp_path):\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n\n    torch.manual_seed(123)\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(model=\"pythia-160m\", init=\"random\", tokenizer_dir=Path(tmp_path / \"EleutherAI/pythia-14m\"))\n\n    input_text = \"some text text\"\n    output_text = llm.generate(input_text, max_new_tokens=15)\n    ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))\n    assert ln <= 15\n\n    # The following below tests that generate works with different prompt lengths\n    # after the kv cache was set\n\n    input_text = \"some text\"\n    output_text = llm.generate(input_text, max_new_tokens=15)\n    ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))\n    assert ln <= 15\n\n    input_text = \"some text text text\"\n    output_text = llm.generate(input_text, max_new_tokens=15)\n    ln = len(llm.preprocessor.tokenizer.encode(output_text)) - len(llm.preprocessor.tokenizer.encode(input_text))\n    assert ln <= 15\n\n\ndef test_llm_load_hub_init(tmp_path):\n    torch.manual_seed(123)\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(model=\"EleutherAI/pythia-14m\", init=\"pretrained\")\n\n    text_1 = llm.generate(\"text\", max_new_tokens=10, top_k=1)\n    assert len(text_1) > 0\n\n    text_2 = llm.generate(\"text\", max_new_tokens=10, top_k=1, stream=True)\n    text_2 = \"\".join(list(text_2))\n    assert text_1 == text_2, (text_1, text_2)\n\n\ndef test_model_not_initialized(tmp_path):\n    llm = LLM.load(model=\"EleutherAI/pythia-14m\", init=\"pretrained\", distribute=None)\n    s = \"The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model.\"\n    with pytest.raises(AttributeError, match=re.escape(s)):\n        llm.generate(\"text\")\n\n    llm = LLM.load(model=\"EleutherAI/pythia-14m\", tokenizer_dir=\"EleutherAI/pythia-14m\", init=\"random\", distribute=None)\n    s = \"The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model.\"\n    with pytest.raises(AttributeError, match=re.escape(s)):\n        llm.generate(\"text\")\n\n\n@_RunIf(min_cuda_gpus=2)\ndef test_more_than_1_device_for_sequential_gpu(tmp_path):\n    device_count = CUDAAccelerator.auto_device_count()\n\n    if device_count <= 2:\n        model_name = \"EleutherAI/pythia-14m\"\n    else:\n        model_name = \"EleutherAI/pythia-160m\"\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=model_name,\n        )\n\n    with pytest.raises(\n        NotImplementedError,\n        match=\"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'.\",\n    ):\n        llm.distribute(devices=2)\n\n    llm.distribute(devices=2, generate_strategy=\"sequential\")\n    assert isinstance(llm.generate(\"What do llamas eat?\"), str)\n    assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == \"cuda:0\"\n    last_layer_idx = len(llm.model.transformer.h) - 1\n    assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == \"cuda:1\"\n\n    # Also check with default (devices=\"auto\") setting\n    llm.distribute(generate_strategy=\"sequential\")\n    assert isinstance(llm.generate(\"What do llamas eat?\"), str)\n    assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == \"cuda:0\"\n    assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == f\"cuda:{device_count - 1}\"\n\n\n@_RunIf(min_cuda_gpus=2)\n@pytest.mark.skipif(bool(os.getenv(\"SKIP_WITH_CI\")), reason=\"Skip this test in CI due to ...\")\ndef test_more_than_1_device_for_tensor_parallel_gpu(tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(model=\"EleutherAI/pythia-14m\")\n\n    # this crashes the CI, maybe because of process forking; works fine locally though\n    llm.distribute(devices=2, generate_strategy=\"tensor_parallel\")\n    assert isinstance(llm.generate(\"What do llamas eat?\"), str)\n\n\n@_RunIf(min_cuda_gpus=1)\n@pytest.mark.parametrize(\"strategy\", (\"sequential\", \"tensor_parallel\"))\n@pytest.mark.xfail(\n    NotADirectoryError, reason=\"This test is expected to fail due to a NotADirectoryError.\", strict=False\n)\ndef test_sequential_tp_incompatibility_with_random_weights(strategy, tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(model=\"EleutherAI/pythia-14m\", tokenizer_dir=\"EleutherAI/pythia-14m\", init=\"random\")\n    with pytest.raises(\n        NotImplementedError,\n        match=re.escape(\n            \"The LLM was initialized with init='random' but .distribute() currently only supports pretrained weights.\"\n        ),\n    ):\n        llm.distribute(devices=1, generate_strategy=strategy)\n\n\n@pytest.mark.parametrize(\"strategy\", (\"sequential\", \"tensor_parallel\"))\ndef test_sequential_tp_cpu(strategy, tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=\"EleutherAI/pythia-14m\",\n            distribute=None,\n        )\n    with pytest.raises(\n        NotImplementedError, match=f\"generate_strategy='{strategy}' is only supported for accelerator='cuda'|'gpu'.\"\n    ):\n        llm.distribute(devices=1, accelerator=\"cpu\", generate_strategy=strategy)\n\n\ndef test_initialization_for_trainer(tmp_path):\n    llm = LLM.load(model=\"EleutherAI/pythia-14m\", distribute=None)\n    s = \"The model is not initialized yet; use the .distribute() or .trainer_setup() method to initialize the model.\"\n    with pytest.raises(AttributeError, match=re.escape(s)):\n        llm.generate(\"hello world\")\n\n    llm.trainer_setup()\n    llm.model.to(llm.preprocessor.device)\n    assert isinstance(llm.generate(\"hello world\"), str)\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_quantization_is_applied(tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=\"EleutherAI/pythia-14m\",\n        )\n    llm.distribute(devices=1, quantize=\"bnb.nf4\", precision=\"bf16-true\")\n    strtype = str(type(llm.model.lm_head))\n    assert \"NF4Linear\" in strtype, strtype\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_fixed_kv_cache(tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=\"EleutherAI/pythia-14m\",\n        )\n    llm.distribute(devices=1, fixed_kv_cache_size=100)\n\n    # Request too many tokens\n    with pytest.raises(NotImplementedError, match=\"max_seq_length 512 needs to be >= 9223372036854775809\"):\n        _ = llm.generate(\"hello world\", max_new_tokens=2**63)\n\n\ndef test_invalid_accelerator(tmp_path):\n    llm = LLM.load(model=\"EleutherAI/pythia-14m\", distribute=None)\n    with pytest.raises(ValueError, match=\"Invalid accelerator\"):\n        llm.distribute(accelerator=\"invalid\")\n\n\ndef test_returned_benchmark_dir(tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=\"EleutherAI/pythia-14m\",\n        )\n\n    text, bench_d = llm.benchmark(prompt=\"hello world\")\n    assert isinstance(bench_d[\"Inference speed in tokens/sec\"], list)\n    assert len(bench_d[\"Inference speed in tokens/sec\"]) == 1\n    assert isinstance(bench_d[\"Inference speed in tokens/sec\"][0], float)\n\n    text, bench_d = llm.benchmark(prompt=\"hello world\", stream=True)\n    assert isinstance(bench_d[\"Inference speed in tokens/sec\"], list)\n    assert len(bench_d[\"Inference speed in tokens/sec\"]) == 1\n    assert isinstance(bench_d[\"Inference speed in tokens/sec\"][0], float)\n\n    text, bench_d = llm.benchmark(num_iterations=10, prompt=\"hello world\", stream=True)\n    assert isinstance(bench_d[\"Inference speed in tokens/sec\"], list)\n    assert len(bench_d[\"Inference speed in tokens/sec\"]) == 10\n    assert isinstance(bench_d[\"Inference speed in tokens/sec\"][0], float)\n\n\ndef test_benchmark_dict_to_markdown_table_single_values():\n    bench_d = {\n        \"Inference speed in tokens/sec\": [17.617540650112936],\n        \"Seconds to first token\": [0.6533610639999097],\n        \"Seconds total\": [1.4758019020000575],\n        \"Tokens generated\": [26],\n        \"Total GPU memory allocated in GB\": [5.923729408],\n    }\n\n    expected_output = (\n        \"| Metric                              | Mean                        | Std Dev                     |\\n\"\n        \"|-------------------------------------|-----------------------------|-----------------------------|\\n\"\n        \"| Inference speed in tokens/sec       | 17.62                       | nan                         |\\n\"\n        \"| Seconds to first token              | 0.65                        | nan                         |\\n\"\n        \"| Seconds total                       | 1.48                        | nan                         |\\n\"\n        \"| Tokens generated                    | 26.00                       | nan                         |\\n\"\n        \"| Total GPU memory allocated in GB    | 5.92                        | nan                         |\\n\"\n    )\n\n    assert benchmark_dict_to_markdown_table(bench_d) == expected_output\n\n\ndef test_benchmark_dict_to_markdown_table_multiple_values():\n    bench_d_list = {\n        \"Inference speed in tokens/sec\": [\n            17.034547562152305,\n            32.8974175404589,\n            33.04784205046782,\n            32.445697744648584,\n            33.204480197756396,\n            32.64187570945661,\n            33.21232058140845,\n            32.69377798373551,\n            32.92351459309756,\n            32.48909032591177,\n        ],\n        \"Seconds to first token\": [\n            0.7403525039999295,\n            0.022901020000063,\n            0.02335712100011733,\n            0.022969672000272112,\n            0.022788318000039,\n            0.02365505999978268,\n            0.02320190000000366,\n            0.022791139999753796,\n            0.022871761999795126,\n            0.023060415999680117,\n        ],\n        \"Seconds total\": [\n            1.5263099829999192,\n            0.7903355929997815,\n            0.7867382069998712,\n            0.8013389080001616,\n            0.7830268640000213,\n            0.7965228539997042,\n            0.7828420160003589,\n            0.7952583520000189,\n            0.7897091279996857,\n            0.8002686360000553,\n        ],\n        \"Tokens generated\": [26, 26, 26, 26, 26, 26, 26, 26, 26, 26],\n        \"Total GPU memory allocated in GB\": [\n            5.923729408,\n            5.923729408,\n            5.923729408,\n            5.923729408,\n            5.923729408,\n            5.923729408,\n            5.923729408,\n            5.923729408,\n            5.923729408,\n            5.923729408,\n        ],\n    }\n\n    expected_output = (\n        \"| Metric                              | Mean                        | Std Dev                     |\\n\"\n        \"|-------------------------------------|-----------------------------|-----------------------------|\\n\"\n        \"| Inference speed in tokens/sec       | 31.26                       | 5.01                        |\\n\"\n        \"| Seconds to first token              | 0.09                        | 0.23                        |\\n\"\n        \"| Seconds total                       | 0.87                        | 0.23                        |\\n\"\n        \"| Tokens generated                    | 26.00                       | 0.00                        |\\n\"\n        \"| Total GPU memory allocated in GB    | 5.92                        | 0.00                        |\\n\"\n    )\n\n    assert benchmark_dict_to_markdown_table(bench_d_list) == expected_output\n\n\ndef test_state_dict(tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=\"EleutherAI/pythia-14m\",\n        )\n    assert isinstance(llm.state_dict(), OrderedDict)\n    assert llm.state_dict()[\"lm_head.weight\"].shape == torch.Size([50304, 128])\n\n\ndef test_save_method(tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=\"EleutherAI/pythia-14m\",\n        )\n\n    target_dir = \"saved_model\"\n    llm.save(target_dir)\n\n    expected_files = [\n        \"config.json\",\n        \"generation_config.json\",\n        \"lit_model.pth\",\n        \"model_config.yaml\",\n        \"prompt_style.yaml\",\n        \"tokenizer_config.json\",\n        \"tokenizer.json\",\n    ]\n\n    files_in_directory = os.listdir(target_dir)\n    for file_name in expected_files:\n        assert file_name in files_in_directory, f\"{file_name} is missing from {target_dir}\"\n\n\ndef test_forward_method(tmp_path):\n    with patch(\"torch.backends.mps.is_available\", return_value=USE_MPS):\n        llm = LLM.load(\n            model=\"EleutherAI/pythia-14m\",\n        )\n    inputs = torch.ones(6, 128, dtype=torch.int64).to(next(llm.model.parameters()).device)\n\n    assert llm(inputs).shape == torch.Size([6, 128, 50304])\n    logits, loss = llm(inputs, target_ids=inputs)\n    assert logits.shape == torch.Size([6, 128, 50304])\n    assert isinstance(loss.item(), float)\n\n\n@skip_in_ci_on_macos  # The macOS CI machine segfaults here (it works fine locally though)\ndef test_precision_selection(tmp_path):\n    llm = LLM.load(model=\"EleutherAI/pythia-14m\", init=\"pretrained\")\n\n    llm.distribute(precision=\"16-true\")\n    assert llm.model._forward_module.lm_head.weight.dtype == torch.float16, (\n        f\"Expected float16, but got {llm.model._forward_module.lm_head.weight.dtype}\"\n    )\n"
  },
  {
    "path": "tests/test_args.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport pytest\n\nfrom litgpt.args import TrainArgs\n\n\ndef test_compute_warmup_iters():\n    # warmup disabled\n    train = TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=0)\n    assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 0\n\n    # lr_warmup_steps and lr_warmup_fraction both are not allowed\n    with pytest.raises(ValueError, match=\"Can't provide both `--train.lr_warmup_fraction`\"):\n        TrainArgs(lr_warmup_steps=1, lr_warmup_fraction=0.2)\n\n    # lr_warmup_fraction invalid range\n    with pytest.raises(ValueError, match=\" must be between 0 and 1\"):\n        TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=1.1)\n\n    # lr_warmup_steps\n    train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=100, lr_warmup_fraction=0)\n    assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 100\n    # lr_warmup_steps multiplied by accumulation factor\n    train.global_batch_size = 4\n    assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 400\n    assert train.warmup_iters(devices=2, num_nodes=1, max_iters=1000, train_dataloader=range(10)) == 200\n    # lr_warmup_steps truncated by max iters\n    assert train.warmup_iters(devices=1, num_nodes=1, max_iters=120, train_dataloader=range(10)) == 120\n\n    # lr_warmup_fraction\n    train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=0, lr_warmup_fraction=0.3)\n    assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(100)) == 30\n    # lr_warmup_fraction truncated by max iters\n    assert train.warmup_iters(devices=1, num_nodes=1, max_iters=20, train_dataloader=range(100)) == 20\n    # lr_warmup_fraction rounds up\n    assert train.warmup_iters(devices=1, num_nodes=1, max_iters=1000, train_dataloader=range(5)) == 2\n"
  },
  {
    "path": "tests/test_batch.py",
    "content": "import warnings\nfrom pathlib import Path\n\nimport lightning as L\nimport pytest\nimport torch\n\nimport litgpt\nfrom litgpt.api import GPT, LLM\nfrom litgpt.generate.base import (\n    batched_generate_fn,\n    batched_next_token,\n    generate_fn,\n    next_token,\n)\nfrom litgpt.scripts.download import download_from_hub\nfrom litgpt.utils import _RunIf\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]:\n    L.seed_everything(42)\n\n    model_name = \"microsoft/phi-2\"\n    download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path)\n\n    llm: LLM = LLM.load(\n        model_name,\n        tokenizer_dir=Path(tmp_path / model_name),\n        init=\"random\",\n    )\n    model: GPT = llm.model\n    model.set_kv_cache(batch_size=batch_size, max_seq_length=max_seq_length, device=device)\n\n    return llm, model\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"Test requires a GPU.\")\ndef test_batched_equivalence(tmp_path):\n    model_name = \"microsoft/phi-2\"\n    download_from_hub(repo_id=model_name, tokenizer_only=True, checkpoint_dir=tmp_path)\n\n    device = \"cuda:0\"\n    batch_size = 3\n    sample_kwargs = {\"top_k\": 1}\n\n    llm: LLM = LLM.load(\n        model_name,\n        tokenizer_dir=Path(tmp_path / model_name),\n        init=\"random\",\n    )\n    model: GPT = llm.model\n    model.set_kv_cache(batch_size=1, max_seq_length=50, device=device)\n\n    input_pos_1 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device=device)\n    input_pos_2 = torch.tensor([10], dtype=torch.int64, device=device)\n\n    x = torch.tensor(\n        [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410],\n        device=device,\n        dtype=torch.int64,\n    )\n\n    batch_x1 = torch.stack([x] * batch_size, dim=0)\n\n    # Single token generation baseline\n    tok_1 = next_token(model, input_pos_1, x.unsqueeze(0), **sample_kwargs)\n    tok_2 = next_token(model, input_pos_2, tok_1.unsqueeze(0), **sample_kwargs)\n\n    assert tok_1.ndim == 1\n    assert tok_2.ndim == 1\n    assert tok_1.size(0) == 1\n    assert tok_2.size(0) == 1\n\n    # Switch to batched generation\n    model.clear_kv_cache()\n    model.set_kv_cache(batch_size=batch_size, max_seq_length=50, device=\"cuda:0\")\n\n    toks_1: torch.Tensor = batched_next_token(model, input_pos_1, batch_x1, sample_kwargs)\n    toks_2: torch.Tensor = batched_next_token(model, input_pos_2, toks_1, sample_kwargs)\n\n    assert toks_1.ndim == 2\n    assert toks_2.ndim == 2\n    assert toks_1.size(0) == batch_size\n    assert toks_2.size(0) == batch_size\n\n    # Assert that single and batched next token generation are equivalent\n    assert all(t == tok_1 for t in toks_1), f\"{tok_1} != {toks_1}\"\n    assert all(t == tok_2 for t in toks_2), f\"{tok_2} != {toks_2}\"\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_simple_batch():\n    old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32\n    torch.backends.cuda.matmul.allow_tf32 = False\n    config = litgpt.Config.from_name(\"microsoft/phi-2\", padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=256)\n    with torch.device(\"cuda\"):\n        m = litgpt.GPT(config).requires_grad_(False).eval()\n        x0 = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 7]])\n        input_pos0 = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 2]])\n        x1 = torch.tensor([[1], [2]])\n        input_pos1 = torch.tensor([[4], [3]])\n\n    with torch.device(\"cuda\"):\n        m.set_kv_cache(2)\n    outs0 = m(x0, input_pos0)\n    outs1 = m(x1, input_pos1)\n\n    with torch.device(\"cuda\"):\n        m.set_kv_cache(1)\n\n    outs0_ref0 = m(x0[:1], input_pos0[0])\n    outs1_ref0 = m(x1[:1], input_pos1[0])\n\n    with torch.device(\"cuda\"):\n        m.set_kv_cache(1)\n\n    outs0_ref1 = m(x0[1:], input_pos0[1])\n    outs1_ref1 = m(x1[1:], input_pos1[1])\n\n    outs0_ref = torch.cat([outs0_ref0, outs0_ref1])\n    outs1_ref = torch.cat([outs1_ref0, outs1_ref1])\n\n    print(outs0_ref - outs0)\n    print(outs0.shape)\n    torch.testing.assert_close(outs0, outs0_ref)\n    torch.testing.assert_close(outs1, outs1_ref)\n    torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_batch_generate(tmp_path):\n    torch.use_deterministic_algorithms(True)\n\n    device = \"cuda:0\"\n    batch_size = 3\n    sample_kwargs = {\"top_k\": 1}\n    llm, model = create_llm(tmp_path, batch_size, 50, device)\n\n    batch_x = torch.tensor(\n        [\n            [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410],\n            [25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596],\n            [1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596],\n        ],\n        device=device,\n        dtype=torch.int64,\n    )\n\n    # Generate tokens\n    tokens = []\n    for l in batched_generate_fn(\n        model,\n        prompts=batch_x,\n        max_returned_tokens=50,\n        sample_args=sample_kwargs,\n        include_prompt=True,\n        include_eos=False,\n    ):\n        tokens.append([t.item() if t is not None else None for t in l])\n\n    def find_unique_stop(triplets):\n        # Initialize a dictionary to count all number occurrences\n        number_count = {}\n\n        # Count occurrences of each number across all positions\n        for triplet in triplets:\n            for num in triplet:\n                number_count[num] = number_count.get(num, 0) + 1\n\n        # Initialize lists to store unique numbers for each position\n        unique_first = []\n        unique_second = []\n        unique_third = []\n\n        # Check each triplet\n        for a, b, c in triplets:\n            if number_count[a] == 1:\n                unique_first.append(a)\n            if number_count[b] == 1:\n                unique_second.append(b)\n            if number_count[c] == 1:\n                unique_third.append(c)\n\n        import random  # Seeded earlier\n\n        random.shuffle(unique_first)\n        random.shuffle(unique_second)\n        random.shuffle(unique_third)\n        return [unique_first[0], unique_second[0], unique_third[0]]\n\n    # Now that we know the randomly generated tokens, sample some tokens to stop each stream at.\n    stops = find_unique_stop(tokens[batch_x.size(1) :])\n    first_stream = [t[0] for t in tokens if t[0] is not None]\n    second_stream = [t[1] for t in tokens if t[1] is not None]\n    third_stream = [t[2] for t in tokens if t[2] is not None]\n\n    # Let's slice the streams at the stop tokens.\n    stop_idxes = [\n        first_stream.index(stops[0]),\n        second_stream.index(stops[1]),\n        third_stream.index(stops[2]),\n    ]\n\n    # While we're at it, grab the last token that would be generated before stopping.\n    last_tokens = [\n        first_stream[stop_idxes[0] - 1],\n        second_stream[stop_idxes[1] - 1],\n        third_stream[stop_idxes[2] - 1],\n    ]\n\n    for t in tokens:\n        print(t)\n\n    # Now we generate again, stopping early at the stop tokens.\n    tokens = []\n    for l in batched_generate_fn(\n        model,\n        prompts=batch_x,\n        max_returned_tokens=50,\n        stop_tokens=[(s,) for s in stops],\n        sample_args=sample_kwargs,\n        include_prompt=True,\n        include_eos=False,\n    ):\n        tokens.append([t.item() if t is not None else None for t in l])\n\n    # Finally, assert that the streams are correct.\n\n    first_stream = [t[0] for t in tokens if t[0] is not None]\n    print(first_stream)\n    print(len(first_stream), stop_idxes[0])\n    assert len(first_stream) == stop_idxes[0]\n    assert first_stream[-1] == last_tokens[0]\n\n    second_stream = [t[1] for t in tokens if t[1] is not None]\n    print(second_stream)\n    print(len(second_stream), stop_idxes[1])\n    assert len(second_stream) == stop_idxes[1]\n    assert second_stream[-1] == last_tokens[1]\n\n    third_stream = [t[2] for t in tokens if t[2] is not None]\n    print(third_stream)\n    print(len(third_stream), stop_idxes[2])\n    assert len(third_stream) == stop_idxes[2]\n    assert third_stream[-1] == last_tokens[2]\n\n    torch.use_deterministic_algorithms(False)\n\n    # for t in llm.tokenizer.decode_stream([torch.tensor(i) for i in first_stream]):\n    #    print(t, end=\"\", flush=True)\n    # print()\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_batch_generate_equivalence(tmp_path):\n    torch.use_deterministic_algorithms(True)\n\n    device = \"cuda:0\"\n    batch_size = 3\n    sample_kwargs = {\"top_k\": 1}\n    llm, model = create_llm(tmp_path, batch_size, 50, device)\n\n    batch_x = torch.tensor(\n        [\n            [43993, 25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410],\n            [25, 1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596],\n            [1867, 466, 32660, 17485, 4483, 30, 198, 26410, 7596, 7596],\n        ],\n        device=device,\n        dtype=torch.int64,\n    )\n\n    # The other test tests the stop_tokens functionality much more exhaustively, we'll just generate and compare 50 tokens here.\n\n    batch_tokens = []\n    for l in batched_generate_fn(\n        model,\n        prompts=batch_x,\n        max_returned_tokens=50,\n        sample_args=sample_kwargs,\n        include_prompt=False,\n        include_eos=False,\n    ):\n        batch_tokens.append([t.item() if t is not None else None for t in l])\n\n    first_stream = [t[0] for t in batch_tokens if t[0] is not None]\n\n    batch_size = 1\n    llm, model = create_llm(tmp_path, batch_size, 50, device)\n\n    tokens = []\n    for t in generate_fn(\n        model,\n        prompt=batch_x[0],\n        max_returned_tokens=50,\n        include_prompt=False,\n        include_eos=False,\n        **sample_kwargs,\n    ):\n        if t.size(0) == 1:\n            tokens.append(t.item())\n        else:\n            tokens.extend(t.tolist())\n\n    torch.use_deterministic_algorithms(False)\n\n    # TODO: (apaz-cli) This consistency test doesn't actually work at the moment. It's inconsistent.\n    # The output is really close... Something is going on here. For the moment, maybe this is close enough?\n    # Enough at least that we can start prototyping.\n\n    print(first_stream)\n    print(tokens)\n    # assert first_stream == tokens\n"
  },
  {
    "path": "tests/test_chat.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nimport re\nimport subprocess\nimport sys\nfrom contextlib import redirect_stderr, redirect_stdout\nfrom io import StringIO\nfrom itertools import repeat\nfrom pathlib import Path\nfrom typing import Iterable, Iterator\nfrom unittest.mock import MagicMock, Mock, patch\n\nimport pytest\nimport torch\nimport yaml\n\nimport litgpt.chat.base as chat\nimport litgpt.generate.base as generate\nfrom litgpt import Config, Tokenizer\nfrom litgpt.utils import auto_download_checkpoint, save_config\n\nskip_in_ci_on_macos = pytest.mark.skipif(\n    sys.platform == \"darwin\" and os.getenv(\"GITHUB_ACTIONS\") == \"true\",\n    reason=\"Skipped on macOS in CI environment because CI machine does not have enough memory to run this test.\",\n)\n\n\n@pytest.mark.parametrize(\n    (\"generated\", \"stop_tokens\", \"expected\"),\n    [\n        (repeat(1), (), [1] * 8),\n        ([1, 2, 3, 0], ([0],), [1, 2, 3]),\n        ([1, 2, 3, 0], ([9], [2, 4], [1, 2, 3, 0]), []),\n        ([1, 2, 3, 0, 0], ([0, 0, 0], [0, 0]), [1, 2, 3]),\n        ([3, 1, 2], ([1, 2], [3]), []),\n        ([1, 2, 3, 0, 3, 2, 1, 0], ([4, 3, 2, 1], [2, 4]), [1, 2, 3, 0, 3, 2, 1, 0]),\n    ],\n)\ndef test_generate(monkeypatch, generated, stop_tokens, expected):\n    import lightning as L\n\n    L.seed_everything(1234)\n\n    input_idx = torch.tensor([5, 3])\n    max_returned_tokens = len(input_idx) + 8\n    model = MagicMock()\n    model.config.block_size = 100\n    model.max_seq_length = 100\n    it = iter(generated)\n\n    def multinomial(*_, **__):\n        out = next(it)\n        return torch.tensor([out])\n\n    monkeypatch.setattr(generate, \"multinomial_num_samples_1\", multinomial)\n    actual = chat.generate(model, input_idx, max_returned_tokens, stop_tokens=stop_tokens)\n    actual = list(actual)\n\n    assert len(actual) == len(expected), (actual, expected)\n    if not actual:\n        assert actual == expected, (actual, expected)\n    else:\n        for t in actual:\n            assert t.dtype == torch.long, t.dtype\n        actual_list = torch.cat(actual).tolist()\n        assert actual_list == expected, (actual_list, expected)\n\n\ndef test_decode():\n    checkpoint_dir = auto_download_checkpoint(\"EleutherAI/pythia-14m\")\n    tokenizer = Tokenizer(checkpoint_dir)\n\n    text = (\n        \"Hello World! This a bunch of text. Lorem ipsum dolor sit amet, \"\n        \"consectetur adipiscing elit, sed do eiusmod tempor incididunt \"\n        \"ut labore et dolore magna aliqua.\"\n    )\n\n    encoded: torch.Tensor = tokenizer.encode(text)\n    encoded_stream: Iterable[torch.Tensor] = torch.tensor_split(encoded, encoded.shape[0], dim=0)\n\n    decoded_stream: Iterator[str] = tokenizer.decode_stream(encoded_stream)\n    decoded: str = \"\".join(decoded_stream)\n\n    # Note that encoded and decoded text will not always be character for character identical.abs\n    # Indeed, sometimes it is not. But that tends to be because of special cases, and this is not\n    # one of those.\n    assert text == decoded, (text, decoded)\n\n\n@skip_in_ci_on_macos\n@patch(\"litgpt.chat.base.input\")\n@pytest.mark.parametrize(\"stop_iteration\", [KeyboardInterrupt, \"\"])\ndef test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, tensor_like):\n    # these values will be iteratively provided for each `input()` call\n    mocked_input.side_effect = [\"Hello\", stop_iteration]\n\n    config_path = fake_checkpoint_dir / \"model_config.yaml\"\n    config = {\n        \"name\": \"Llama 3\",\n        \"block_size\": 128,\n        \"vocab_size\": 50,\n        \"n_layer\": 2,\n        \"n_head\": 4,\n        \"n_embd\": 8,\n        \"rotary_percentage\": 1,\n    }\n    config_path.write_text(yaml.dump(config))\n\n    load_mock = Mock()\n    load_mock.return_value = load_mock\n    monkeypatch.setattr(chat, \"load_checkpoint\", load_mock)\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value.backend = \"sentencepiece\"\n    tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])\n    tokenizer_mock.return_value.decode_stream.return_value = \"foo bar baz\"\n    monkeypatch.setattr(chat, \"Tokenizer\", tokenizer_mock)\n    generate_mock = MagicMock()\n    generate_mock.__iter__.return_value = [torch.tensor([3, 2, 1])]\n    monkeypatch.setattr(chat, \"generate\", generate_mock)\n\n    out, err = StringIO(), StringIO()\n    with redirect_stdout(out), redirect_stderr(err):\n        chat.main(temperature=2.0, max_new_tokens=10, top_k=2, top_p=0.9, checkpoint_dir=fake_checkpoint_dir)\n\n    # decoding is done per each generated item\n    assert len(tokenizer_mock.return_value.decode_stream.mock_calls) == 1\n    assert tokenizer_mock.return_value.decode_stream.call_args[0][0] is generate_mock.return_value  # Now a Mock\n\n    # Assert that the generated result is printed to stdout\n    assert re.match(r\".*Now chatting with Llama 3.*>> .*Reply: foo bar baz\", out.getvalue(), re.DOTALL), out.getvalue()\n\n\ndef test_cli():\n    args = [\"litgpt\", \"chat\", \"-h\"]\n    output = subprocess.check_output(args)\n    output = str(output.decode())\n    assert \"Chat with a model\" in output\n\n\n@skip_in_ci_on_macos\n@patch(\"litgpt.chat.base.input\")\n@patch(\"litgpt.chat.base.merge_lora\")\ndef test_merge_lora_if_needed(mocked_merge_lora, mocked_input, fake_checkpoint_dir, monkeypatch, tensor_like):\n    # these values will be iteratively provided for each `input()` call\n    mocked_input.side_effect = [\"\"]\n\n    # pretend there is an unmerged LORA checkpoint\n    os.rename(fake_checkpoint_dir / \"lit_model.pth\", fake_checkpoint_dir / \"lit_model.pth.lora\")\n    mocked_merge_lora.side_effect = lambda _: Path(fake_checkpoint_dir / \"lit_model.pth\").touch()\n\n    config = Config.from_name(\"pythia-14m\")\n    save_config(config, fake_checkpoint_dir)\n    monkeypatch.setattr(chat, \"load_checkpoint\", Mock())\n    monkeypatch.setattr(chat, \"Tokenizer\", Mock())\n\n    out, err = StringIO(), StringIO()\n    with redirect_stdout(out), redirect_stderr(err):\n        chat.main(checkpoint_dir=fake_checkpoint_dir)\n\n    assert re.match(r\".*Merging LoRA weights with the base model\\..*\", out.getvalue(), re.DOTALL)\n    mocked_merge_lora.assert_called_once()\n\n\n@skip_in_ci_on_macos\ndef test_litgpt_chat_endtoend():\n    from litgpt.chat.base import main\n\n    checkpoint_dir = auto_download_checkpoint(\"EleutherAI/pythia-14m\")\n\n    # Patch input() and redirect stdout. Raise to exit the repl.\n    simulated_input = Mock(side_effect=[\"input\", KeyboardInterrupt])\n    captured_output = StringIO()\n    with patch(\"builtins.input\", simulated_input):\n        with redirect_stdout(captured_output):\n            try:\n                main(checkpoint_dir=checkpoint_dir, max_new_tokens=256, top_k=1)\n            except KeyboardInterrupt:\n                pass\n\n    # pythia-14m is not instruct-tuned, so it does not give an \"answer\" per se, but a continuation.\n    output = captured_output.getvalue()\n    assert \">> Reply: \" in output, f\"Expected reply not found. Got:\\n{output}\"\n    # Verify the model actually generated some text after the reply prompt\n    reply_start = output.index(\">> Reply: \") + len(\">> Reply: \")\n    assert len(output[reply_start:].strip()) > 0, f\"Expected non-empty reply. Got:\\n{output}\"\n    assert simulated_input.call_count == 2\n\n\n@skip_in_ci_on_macos\ndef test_litgpt_generate_endtoend():\n    from litgpt.generate.base import main\n\n    checkpoint_dir = auto_download_checkpoint(\"EleutherAI/pythia-14m\")\n\n    captured_output = StringIO()\n    with redirect_stdout(captured_output):\n        try:\n            main(checkpoint_dir=checkpoint_dir, prompt=\"Hello World\", max_new_tokens=256, top_k=1)\n        except KeyboardInterrupt:\n            pass\n\n    # pythia-14m is not instruct-tuned, so it does not give an \"answer\" per se, but a continuation.\n    assert \"Hello World!\" in captured_output.getvalue(), (\n        f\"Expected output not found. Got:\\n{captured_output.getvalue()}\"\n    )\n"
  },
  {
    "path": "tests/test_ci.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE\n\nfrom litgpt.utils import _RunIf\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_gpu_ci_installs_bitsandbytes():\n    assert _BITSANDBYTES_AVAILABLE, str(_BITSANDBYTES_AVAILABLE)\n"
  },
  {
    "path": "tests/test_cli.py",
    "content": "import sys\nfrom contextlib import redirect_stdout\nfrom io import StringIO\nfrom unittest import mock\n\nimport pytest\nfrom packaging.version import Version\n\nfrom litgpt.__main__ import main\n\n\ndef test_cli():\n    out = StringIO()\n    with pytest.raises(SystemExit), redirect_stdout(out), mock.patch(\"sys.argv\", [\"litgpt\", \"-h\"]):\n        main()\n    out = out.getvalue()\n    assert \"usage: litgpt\" in out\n    assert (\n        \"{download,chat,finetune,finetune_lora,finetune_full,finetune_adapter,finetune_adapter_v2,\"\n        \"pretrain,generate,generate_full,generate_adapter,generate_adapter_v2,generate_sequentially,\"\n        \"generate_speculatively,generate_tp,convert_to_litgpt,convert_from_litgpt,convert_pretrained_checkpoint,\"\n        \"merge_lora,evaluate,serve}\" in out\n    )\n    assert (\n        \"\"\"Available subcommands:\n    download            Download weights or tokenizer data from the Hugging\n                        Face Hub.\n    chat                Chat with a model.\"\"\"\n        in out\n    )\n    assert \"\"\"evaluate            Evaluate a model with the LM Evaluation Harness.\"\"\" in out\n    assert \"\"\"serve               Serve a LitGPT model using LitServe.\"\"\" in out\n    out = StringIO()\n    with pytest.raises(SystemExit), redirect_stdout(out), mock.patch(\"sys.argv\", [\"litgpt\", \"finetune_lora\", \"-h\"]):\n        main()\n    out = out.getvalue()\n    assert (\n        \"\"\"--lora_alpha LORA_ALPHA\n                        The LoRA alpha. (type: int, default: 16)\"\"\"\n        in out\n    )\n\n    if Version(f\"{sys.version_info.major}.{sys.version_info.minor}\") < Version(\"3.9\"):\n        # python 3.8 prints `Union[int, null]` instead of `Optional[int]`\n        return\n\n    out = StringIO()\n    with pytest.raises(SystemExit), redirect_stdout(out), mock.patch(\"sys.argv\", [\"litgpt\", \"pretrain\", \"-h\"]):\n        main()\n    out = out.getvalue()\n    print(out)\n    assert (\n        \"\"\"--train.max_tokens MAX_TOKENS\n                        Total number of tokens to train on (type:\n                        Optional[int], default: 3000000000000)\"\"\"\n        in out\n    )\n\n\ndef test_pretrain_allows_max_steps():\n    # Ensure --train.max_steps is accepted by the CLI for pretrain\n    # and only emits a warning instead of raising a validation error.\n    args = [\n        \"litgpt\",\n        \"pretrain\",\n        \"pythia-14m\",\n        \"--train.max_steps=1\",\n        \"--out_dir=out/test-cli\",\n    ]\n\n    with pytest.warns(UserWarning, match=\"max_steps\"):\n        try:\n            with mock.patch(\"sys.argv\", args):\n                main()\n        except Exception:\n            pass\n\n\ndef test_rewrite_finetune_command():\n    out1 = StringIO()\n    with pytest.raises(SystemExit), redirect_stdout(out1), mock.patch(\"sys.argv\", [\"litgpt\", \"fineune\", \"-h\"]):\n        main()\n    out2 = StringIO()\n    with pytest.raises(SystemExit), redirect_stdout(out2), mock.patch(\"sys.argv\", [\"litgpt\", \"fineune_lora\", \"-h\"]):\n        main()\n    assert out1.getvalue() == out2.getvalue()\n"
  },
  {
    "path": "tests/test_config.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport pytest\nimport yaml\n\nimport litgpt.config as config_module\nfrom litgpt import Config\nfrom litgpt.config import find_multiple\n\n\ndef test_config():\n    config = Config()\n    assert config.name == \"\"\n    assert config.block_size == 4096\n\n    config = Config(block_size=2048)\n    assert config.block_size == 2048\n\n    config = Config.from_name(\"pythia-14m\")\n    assert config.block_size == 512\n\n    config = Config.from_name(\"pythia-14m\", block_size=4096)\n    assert config.block_size == 4096\n\n    config = Config(hf_config={\"name\": \"pythia-14m\"})\n    assert config.name == \"pythia-14m\"\n\n\ndef test_from_hf_name():\n    # by short-hand name\n    config0 = Config.from_name(\"tiny-llama-1.1b\")\n    # or by huggingface hub repo name\n    config1 = Config.from_name(\"TinyLlama-1.1B-intermediate-step-1431k-3T\")\n    assert config0 is not None\n    assert config1 is not None\n    assert config0 == config1\n\n\ndef test_nonexisting_name():\n    with pytest.raises(ValueError, match=\"'invalid-model-name' is not a supported config name\"):\n        Config.from_name(\"invalid-model-name\")\n\n\n@pytest.mark.parametrize(\"config\", config_module.configs, ids=[c[\"name\"] for c in config_module.configs])\ndef test_short_and_hf_names_are_equal_unless_on_purpose(config):\n    # by short-hand name\n    config0 = Config.from_name(config[\"name\"])\n    # or by huggingface hub repo name\n    config1 = Config.from_name(config[\"hf_config\"][\"name\"])\n    assert config0.name == config1.name\n\n\ndef test_from_hf_name_with_org_string():\n    # Test case 1: valid input\n    config0 = Config.from_name(\"tiny-llama-1.1b\")\n    config1 = Config.from_name(\"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\")\n    assert config0 is not None\n    assert config1 is not None\n    assert config0 == config1\n\n    # Test case 2: invalid input - org not found\n    with pytest.raises(\n        ValueError, match=\"'UnknownOrg/TinyLlama-1.1B-intermediate-step-1431k-3T' is not a supported config name\"\n    ):\n        Config.from_name(\"UnknownOrg/TinyLlama-1.1B-intermediate-step-1431k-3T\")\n\n    # Test case 3: invalid input - name not found\n    with pytest.raises(ValueError, match=\"'TinyLlama/TinyLlama-XYZ' is not a supported config name\"):\n        Config.from_name(\"TinyLlama/TinyLlama-XYZ\")\n\n\ndef test_from_checkpoint(tmp_path):\n    # 1. Neither `lit_config.py` nor matching config exists.\n    with pytest.raises(FileNotFoundError, match=\"neither 'model_config.yaml' nor matching config exists\"):\n        Config.from_checkpoint(tmp_path / \"non_existing_checkpoint\")\n\n    # 2. If `lit_config.py` doesn't exists, but there is a matching config in `litgpt/config.py`.\n    config = Config.from_checkpoint(tmp_path / \"pythia-14m\")\n    assert config.name == \"pythia-14m\"\n    assert config.block_size == 512\n    assert config.n_layer == 6\n\n    # 3. If only `lit_config.py` exists.\n    config_data = {\"name\": \"pythia-14m\", \"block_size\": 24, \"n_layer\": 2}\n    with open(tmp_path / \"model_config.yaml\", \"w\", encoding=\"utf-8\") as file:\n        yaml.dump(config_data, file)\n    config = Config.from_checkpoint(tmp_path)\n    assert config.name == \"pythia-14m\"\n    assert config.block_size == 24\n    assert config.n_layer == 2\n\n    # 4. Both `lit_config.py` and a matching config exist, but `lit_config.py` supersedes matching config\n    (tmp_path / \"pythia-14m\").mkdir()\n    with open(tmp_path / \"pythia-14m/model_config.yaml\", \"w\", encoding=\"utf-8\") as file:\n        yaml.dump(config_data, file)\n    config = Config.from_checkpoint(tmp_path / \"pythia-14m\")\n    assert config.name == \"pythia-14m\"\n    assert config.block_size == 24\n    assert config.n_layer == 2\n\n\n@pytest.mark.parametrize(\"head_size\", [None, 128])\ndef test_head_size(head_size):\n    config = Config(head_size)\n\n    assert config.head_size == head_size or config.n_embd // config.n_head\n\n\ndef test_find_multiple():\n    assert find_multiple(17, 5) == 20\n    assert find_multiple(30, 7) == 35\n    assert find_multiple(10, 2) == 10\n    assert find_multiple(5, 10) == 10\n    assert find_multiple(50254, 128) == 50304\n    assert find_multiple(50254, 256) == 50432\n    assert find_multiple(50254, 512) == 50688\n"
  },
  {
    "path": "tests/test_config_hub.py",
    "content": "import importlib\nimport importlib.util\nfrom pathlib import Path\nfrom unittest import mock\nfrom unittest.mock import Mock\n\nimport pytest\nfrom lightning.fabric.plugins import Precision\n\nfrom litgpt import Config\nfrom litgpt.utils import CLI\n\nfixed_pairs = [\n    (\"litgpt/pretrain.py\", \"pretrain/debug.yaml\"),\n    (\"litgpt/pretrain.py\", \"pretrain/tinyllama.yaml\"),\n    (\"litgpt/pretrain.py\", \"pretrain/tinystories.yaml\"),\n    (\n        \"litgpt/pretrain.py\",\n        \"https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/tinystories.yaml\",\n    ),\n]\n\nconfig_hub_path = Path(__file__).parent.parent / \"config_hub\" / \"finetune\"\nmodel_pairs = []\n\nfor model_dir in config_hub_path.iterdir():\n    if model_dir.is_dir():\n        model_name = model_dir.name\n        for yaml_file in model_dir.glob(\"*.yaml\"):\n            config_name = yaml_file.stem\n            python_file = \"litgpt/finetune/full.py\" if config_name == \"full\" else \"litgpt/finetune/lora.py\"\n            relative_yaml_path = yaml_file.relative_to(config_hub_path.parent)\n            model_pairs.append((python_file, str(relative_yaml_path)))\n\nall_pairs = fixed_pairs + model_pairs\n\n\n@pytest.mark.parametrize((\"script_file\", \"config_file\"), all_pairs)\ndef test_config_help(script_file, config_file, monkeypatch):\n    \"\"\"Test that configs validate against the signature in the scripts.\"\"\"\n    script_file = Path(__file__).parent.parent / script_file\n    assert script_file.is_file()\n    if \"http\" not in str(config_file):\n        config_file = Path(__file__).parent.parent / \"config_hub\" / config_file\n        assert config_file.is_file()\n\n    spec = importlib.util.spec_from_file_location(str(script_file.parent.name), script_file)\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)\n\n    monkeypatch.setattr(module, \"main\", Mock())\n    monkeypatch.setattr(module, \"Tokenizer\", Mock())\n    monkeypatch.setattr(module, \"BitsandbytesPrecision\", Mock(return_value=Precision()), raising=False)\n    monkeypatch.setattr(module, \"Config\", Mock(return_value=Config.from_name(\"pythia-14m\")))\n    monkeypatch.setattr(module, \"check_valid_checkpoint_dir\", Mock(), raising=False)\n\n    try:\n        with mock.patch(\"sys.argv\", [script_file.name, \"--config\", str(config_file), \"--devices\", \"1\"]):\n            CLI(module.setup)\n            module.main.assert_called_once()\n    except FileNotFoundError:\n        pass\n        # FileNotFound occurs here because we have not downloaded the model weights referenced in the config files\n        # which is ok because here we just want to validate the config file itself.\n"
  },
  {
    "path": "tests/test_deepseek_moe.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport pytest\nimport torch\nfrom transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM\n\nfrom litgpt import Config\nfrom litgpt.model import GPT, LLaMAMLP\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"batch_size\", (1, 2))\n@pytest.mark.parametrize(\"seq_len\", (8, 16))\n@pytest.mark.parametrize(\"device\", [torch.device(\"cpu\")])\ndef test_deepseek_moe_litgpt_vs_hf(batch_size, seq_len, device):\n    \"\"\"Test MOE litgpt vs hf\"\"\"\n    config_litgpt = Config(\n        padded_vocab_size=10000,\n        n_layer=2,\n        vocab_size=10000,\n        n_embd=64,\n        n_head=4,\n        n_query_groups=4,\n        head_size=16,\n        norm_eps=1e-6,\n        bias=False,\n        latent_attention={\n            \"q_lora_rank\": 32,\n            \"kv_lora_rank\": 16,\n            \"qk_rope_head_dim\": 8,\n            \"qk_nope_head_dim\": 8,\n            \"v_head_dim\": 16,\n        },\n        n_expert=16,\n        n_shared_expert=1,\n        n_expert_per_token=2,\n        n_expert_groups=4,\n        n_topk_groups=2,\n        n_topk_scores_per_group=2,  # Note: Deepseek hardcodes this to `2`\n        first_k_dense_replace=1,\n        routed_scaling_factor=2.5,\n        norm_topk_prob=True,\n        moe_intermediate_size=20,\n        mlp_class_name=\"LLaMAMoE\",\n    )\n\n    config_hf = DeepseekV3Config(\n        padded_vocab_size=10000,\n        num_hidden_layers=2,\n        vocab_size=10000,\n        hidden_size=64,\n        num_attention_heads=4,\n        num_key_value_heads=4,\n        q_lora_rank=32,\n        kv_lora_rank=16,\n        qk_rope_head_dim=8,\n        qk_nope_head_dim=8,\n        v_head_dim=16,\n        rope_interleave=False,\n        first_k_dense_replace=1,\n        routed_scaling_factor=2.5,\n        norm_topk_prob=True,\n        n_routed_experts=config_litgpt.n_expert,\n        n_shared_experts=config_litgpt.n_shared_expert,\n        num_experts_per_tok=config_litgpt.n_expert_per_token,\n        n_group=config_litgpt.n_expert_groups,\n        topk_group=config_litgpt.n_topk_groups,\n        moe_intermediate_size=config_litgpt.moe_intermediate_size,\n    )\n\n    model_litgpt = GPT(config_litgpt).to(device)\n    model_litgpt.apply(model_litgpt._init_weights)\n\n    mlp_litgpt = model_litgpt.transformer.h[0].mlp\n    assert isinstance(mlp_litgpt, LLaMAMLP)  # Test first_k_dense_replace (k=1)\n\n    moe_litgpt = model_litgpt.transformer.h[1].mlp\n    model_hf = DeepseekV3ForCausalLM(config_hf).to(device)\n    moe_hf = model_hf.model.layers[1].mlp\n\n    moe_litgpt.eval()\n    moe_hf.eval()\n\n    sync_weights(moe_litgpt, moe_hf)\n\n    hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)\n\n    output_litgpt = moe_litgpt(hidden_states)\n    output_hf = moe_hf(hidden_states)\n\n    assert torch.allclose(output_litgpt, output_hf, atol=1e-5)\n\n\ndef sync_weights(litgpt_model, hf_model):\n    print(\"Synchronizing MoE weights...\")\n\n    with torch.no_grad():\n        if hasattr(litgpt_model, \"gate\"):\n            if hasattr(litgpt_model.gate, \"weight\"):\n                hf_model.gate.weight.copy_(litgpt_model.gate.weight)\n            if hasattr(litgpt_model.gate, \"e_score_correction_bias\"):\n                hf_model.gate.e_score_correction_bias.copy_(litgpt_model.gate.e_score_correction_bias)\n\n        for i, (litgpt_expert, hf_expert) in enumerate(zip(litgpt_model.experts, hf_model.experts)):\n            hf_expert.gate_proj.weight.copy_(litgpt_expert.fc_1.weight)\n            hf_expert.up_proj.weight.copy_(litgpt_expert.fc_2.weight)\n            hf_expert.down_proj.weight.copy_(litgpt_expert.proj.weight)\n\n        if hasattr(litgpt_model, \"shared_experts\") and hasattr(hf_model, \"shared_experts\"):\n            hf_model.shared_experts.gate_proj.weight.copy_(litgpt_model.shared_experts.fc_1.weight)\n            hf_model.shared_experts.up_proj.weight.copy_(litgpt_model.shared_experts.fc_2.weight)\n            hf_model.shared_experts.down_proj.weight.copy_(litgpt_model.shared_experts.proj.weight)\n\n    print(\"MoE weight synchronization complete.\")\n"
  },
  {
    "path": "tests/test_distributed.py",
    "content": "import pytest\nimport torch\nfrom lightning import Fabric\n\nfrom litgpt.utils import _RunIf\n\n\n@_RunIf(min_cuda_gpus=2, standalone=True)\n@pytest.mark.parametrize(\"strategy\", [\"ddp\", \"fsdp\"])\ndef test_no_backward_sync(strategy):\n    fabric = Fabric(devices=2, accelerator=\"cuda\", strategy=strategy)\n    fabric.launch()\n\n    # account for sharding in the case of FSDP\n    out_features = 1 if \"ddp\" in strategy else fabric.world_size\n\n    model = torch.nn.Linear(1, out_features, bias=False, device=fabric.device)\n    x = torch.randn(1, 1, device=fabric.device)\n    model = fabric.setup(model)\n\n    # 6 iters, 3 grad accumulation iters\n    for i, enabled in enumerate((True, True, False, True, True, False), 1):\n        x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32)\n\n        with fabric.no_backward_sync(model, enabled):\n            y = model(x)\n            fabric.backward(y.sum())\n        if not enabled:\n            # Math for the first 3 iters\n            #\n            # DistributedDataParallel\n            # (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2       = 9\n            #  ^^^^^^^^^^^   ^^^^^^^^^^^  ^^^\n            #  rank0         rank1        allreduce\n            #\n            # thunder.distributed.ddp\n            # ((1*1+2*1) + (1*2+2*2)) / 2        + (3*1 + 3*2)  / 2        = 9\n            #   ^^^^^^^     ^^^^^^^   ^^^           ^^^   ^^^   ^^^\n            #   rank0       rank1     allreduce1    rank0 rank1 allreduce2\n            assert model.weight.grad.shape.numel() == 1, model.weight.grad.shape\n            assert model.weight.grad.item() == (9.0 if i == 3 else 22.5)\n            model.weight.grad = None\n"
  },
  {
    "path": "tests/test_evaluate.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport subprocess\nfrom contextlib import redirect_stdout\nfrom dataclasses import asdict\nfrom io import StringIO\nfrom unittest import mock\n\nimport pytest\nimport torch\nimport yaml\n\nimport litgpt.eval.evaluate as module\nfrom litgpt import GPT, Config\nfrom litgpt.scripts.download import download_from_hub\n\n\n@pytest.mark.flaky(reruns=3)\ndef test_evaluate_script(tmp_path):\n    ours_config = Config.from_name(\"pythia-14m\")\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    checkpoint_dir = tmp_path / \"EleutherAI\" / \"pythia-14m\"\n    ours_model = GPT(ours_config)\n    torch.save(ours_model.state_dict(), checkpoint_dir / \"lit_model.pth\")\n    with open(checkpoint_dir / \"model_config.yaml\", \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"eval/evaluate.py\"]):\n        with pytest.raises(ValueError) as excinfo:\n            module.convert_and_evaluate(\n                checkpoint_dir,\n                out_dir=tmp_path / \"out_dir\",\n                device=None,\n                dtype=torch.float32,\n                limit=5,\n                tasks=\"logiqa\",\n                batch_size=0,  # Test for non-positive integer\n            )\n        assert \"batch_size must be a positive integer, 'auto', or in the format 'auto:N'.\" in str(excinfo.value)\n\n        with pytest.raises(ValueError) as excinfo:\n            module.convert_and_evaluate(\n                checkpoint_dir,\n                out_dir=tmp_path / \"out_dir\",\n                device=None,\n                dtype=torch.float32,\n                limit=5,\n                tasks=\"logiqa\",\n                batch_size=\"invalid\",  # Test for invalid string\n            )\n        assert \"batch_size must be a positive integer, 'auto', or in the format 'auto:N'.\" in str(excinfo.value)\n\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"eval/evaluate.py\"]):\n        module.convert_and_evaluate(\n            checkpoint_dir,\n            out_dir=tmp_path / \"out_dir\",\n            device=None,\n            dtype=torch.float32,\n            limit=5,\n            tasks=\"logiqa\",\n            batch_size=1,  # Valid case\n        )\n    stdout = stdout.getvalue()\n    assert (tmp_path / \"out_dir\" / \"results.json\").is_file()\n    assert \"logiqa\" in stdout\n    assert \"Metric\" in stdout\n    assert \"Loading checkpoint shards\" not in stdout\n\n\ndef test_cli():\n    args = [\"litgpt\", \"evaluate\", \"-h\"]\n    output = subprocess.check_output(args)\n    output = str(output.decode())\n    assert \"Evaluate a model with the LM Evaluation Harness\" in output\n"
  },
  {
    "path": "tests/test_full.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nfrom contextlib import redirect_stdout\nfrom io import StringIO\nfrom unittest import mock\nfrom unittest.mock import Mock\n\nimport torch\nimport yaml\n\nimport litgpt.finetune.full as module\nfrom litgpt.args import EvalArgs, TrainArgs\nfrom litgpt.data import Alpaca\n\n\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\ndef test_full_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):\n    model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)\n    (fake_checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(model_config))\n    monkeypatch.setattr(module, \"load_checkpoint\", Mock())\n\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value = tokenizer_mock\n    tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])\n    monkeypatch.setattr(module, \"Tokenizer\", tokenizer_mock)\n\n    out_dir = tmp_path / \"out\"\n    setup_args = (fake_checkpoint_dir,)\n    setup_kwargs = dict(\n        data=Alpaca(download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0),\n        out_dir=out_dir,\n        precision=\"32-true\",\n        train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),\n        eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),\n    )\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"full.py\", str(fake_checkpoint_dir)]):\n        module.setup(*setup_args, **setup_kwargs)\n\n    out_dir_contents = set(os.listdir(out_dir))\n    checkpoint_dirs = {\"step-000002\", \"step-000004\", \"step-000006\", \"final\"}\n    assert checkpoint_dirs.issubset(out_dir_contents)\n    assert all((out_dir / p).is_dir() for p in checkpoint_dirs)\n    for checkpoint_dir in checkpoint_dirs:\n        assert set(os.listdir(out_dir / checkpoint_dir)) == {\n            \"lit_model.pth\",\n            \"model_config.yaml\",\n            \"tokenizer_config.json\",\n            \"tokenizer.json\",\n            \"hyperparameters.yaml\",\n            \"prompt_style.yaml\",\n        }\n    assert (out_dir / \"logs\" / \"csv\" / \"version_0\" / \"metrics.csv\").is_file()\n\n    logs = stdout.getvalue()\n    assert logs.count(\"(step)\") == 6\n    assert logs.count(\"val loss\") == 4  # 3 validations + 1 final validation\n    assert logs.count(\"Final evaluation\") == 1\n    assert \"of trainable parameters: 1,888\" in logs\n\n    # Resume training and do 2 steps more\n    setup_kwargs[\"train\"].max_steps = 8\n    setup_kwargs[\"resume\"] = True\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"full.py\", str(fake_checkpoint_dir)]):\n        module.setup(*setup_args, **setup_kwargs)\n    logs = stdout.getvalue()\n    assert f\"Resuming training from {out_dir / 'step-000006' / 'lit_model.pth'}\" in logs\n    assert logs.count(\"(step)\") == 2\n    assert out_dir / \"step-000008\" in set(out_dir.iterdir())\n"
  },
  {
    "path": "tests/test_generate_speculatively.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport re\nimport subprocess\nfrom contextlib import redirect_stderr, redirect_stdout\nfrom io import StringIO\nfrom unittest.mock import ANY, Mock, call\n\nimport pytest\nimport torch\nimport yaml\nfrom torch import nn\n\nimport litgpt.generate.speculative_decoding as generate\nfrom litgpt import GPT, Config\nfrom litgpt.utils import _RunIf\n\n\ndef test_speculative_decoding_target_never_accepts_draft_tokens():\n    class DraftModel(nn.Module):\n        def forward(self, **kwargs):\n            return torch.tensor([1, 2, 3, 4, 5, 0, 0, 0, 0, 0], dtype=torch.float)[None, None, ...]  # (B, T, C)\n\n    class TargetModel(nn.Module):\n        def forward(self, idx, **kwargs):\n            _, T = idx.shape\n            return torch.tensor([[0, 0, 0, 0, 0, 6, 7, 8, 9, 10]] * T, dtype=torch.float)[None, ...]  # (B, T, C)\n\n    draft_model = DraftModel()\n    target_model = TargetModel()\n\n    token = torch.tensor([-1])\n    input_pos = torch.tensor([0])\n    sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0)  # to make sampling consistent\n    output = generate.speculative_decoding(\n        draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs\n    )\n\n    # target model never accepts draft model's output, thus the output of the `speculative_decoding`\n    # is a single token sampled from the target model\n    assert len(output) == 1\n    assert output > 5\n\n\ndef test_speculative_decoding_target_always_accepts_draft_tokens():\n    class DraftModel(nn.Module):\n        def forward(self, **kwargs):\n            return torch.tensor([0, 0, 3, 4, 5, 6, 7, 8, 0, 0], dtype=torch.float)[None, None, ...]  # (B, T, C)\n\n    class TargetModel(nn.Module):\n        def forward(self, idx, **kwargs):\n            _, T = idx.shape\n            return torch.tensor([[0, 0, 3, 4, 5, 6, 7, 8, 0, 0]] * T, dtype=torch.float)[None, ...]  # (B, T, C)\n\n    draft_model = DraftModel()\n    target_model = TargetModel()\n\n    token = torch.tensor([-1])\n    input_pos = torch.tensor([0])\n    sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0)  # to make sampling consistent\n    output = generate.speculative_decoding(\n        draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs\n    )\n\n    # target model always accepts draft model's output, thus the output of the `speculative_decoding`\n    # is 4 tokens (3 accepted draft tokens + 1 sampled from target model's output)\n    assert len(output) == 4\n    assert torch.all((output >= 3) & (output <= 8))\n\n\ndef test_speculative_decoding_target_sometimes_accepts_draft_tokens():\n    class DraftModel(nn.Module):\n        def forward(self, **kwargs):\n            return torch.tensor([0, 0, 3, 4, 10, 9, 7, 8, 0, 0], dtype=torch.float)[None, None, ...]  # (B, T, C)\n\n    class TargetModel(nn.Module):\n        def forward(self, idx, **kwargs):\n            return torch.tensor(\n                [\n                    [0, 0, 0, 0, 10, 9, 0, 0, 0, 0],\n                    [0, 0, 0, 0, 10, 9, 0, 0, 0, 0],\n                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 10],\n                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 10],\n                ],\n                dtype=torch.float,\n            )[None, ...]  # (B, T, C)\n\n    draft_model = DraftModel()\n    target_model = TargetModel()\n\n    token = torch.tensor([-1])\n    input_pos = torch.tensor([0])\n    sample_kwargs = dict(top_k=None, top_p=0.0, temperature=0.0)  # to make sampling consistent\n    output = generate.speculative_decoding(\n        draft_model, target_model, token, input_pos, input_pos, speculative_k=3, **sample_kwargs\n    )\n\n    # target model accepts only 2 out of 3 draft model's output, thus the output of the `speculative_decoding`\n    # is 3 tokens (2 accepted draft tokens + 1 sampled from adjusted distribution)\n    assert len(output) == 3\n    assert torch.equal(output, torch.tensor([4, 4, 9]))\n\n\n@pytest.mark.parametrize(\"max_seq_length\", (10, 15, 20, 25))\n@pytest.mark.parametrize(\"speculative_k\", (1, 2, 3))\ndef test_generate(max_seq_length, speculative_k):\n    # create a prompt\n    T = 5\n    input_idx = torch.arange(0, T)\n    max_new_tokens = max_seq_length - T\n\n    # prepare models\n    draft_model = GPT(Config(vocab_size=16, block_size=64, n_layer=1, n_head=4, n_embd=8))\n    target_model = GPT(Config(vocab_size=16, block_size=128, n_layer=2, n_head=8, n_embd=16))\n    for model in (draft_model, target_model):\n        model.max_seq_length = max_seq_length\n        model.set_kv_cache(batch_size=1)\n\n    # generate tokens\n    out, acceptance_rate = generate.generate(\n        draft_model, target_model, input_idx, T + max_new_tokens, top_k=1, speculative_k=speculative_k\n    )\n\n    # validate\n    assert out.size(0) == T + max_new_tokens - 1, (out.size(0), T + max_new_tokens - 1)\n    assert 0.0 <= acceptance_rate <= 1.0\n\n\n@_RunIf(min_cuda_gpus=1)  # speculative decoding makes sense only on a GPU\ndef test_main(fake_checkpoint_dir, monkeypatch, tensor_like):\n    # prepare configs for draft and target models\n    draft_model_dir = fake_checkpoint_dir / \"draft_model\"\n    draft_model_dir.mkdir()\n    target_model_dir = fake_checkpoint_dir / \"target_model\"\n    target_model_dir.mkdir()\n\n    draft_model_config = dict(vocab_size=16, block_size=64, n_layer=1, n_head=4, n_embd=8)\n    target_model_config = dict(vocab_size=16, block_size=128, n_layer=2, n_head=8, n_embd=16)\n\n    (draft_model_dir / \"model_config.yaml\").write_text(yaml.dump(draft_model_config))\n    (target_model_dir / \"model_config.yaml\").write_text(yaml.dump(target_model_config))\n\n    # create empty files required for validation\n    for model_dir in (draft_model_dir, target_model_dir):\n        (model_dir / \"tokenizer.json\").touch()\n        (model_dir / \"tokenizer_config.json\").touch()\n        (model_dir / \"lit_model.pth\").touch()\n\n    # moke functions\n    module_mock = Mock()\n    module_mock.config.block_size = 128\n    load_mock = Mock()\n    load_mock.return_value = load_mock\n    monkeypatch.setattr(generate, \"load_checkpoint\", load_mock)\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])\n    tokenizer_mock.return_value.decode.return_value = \"foo bar baz\"\n    monkeypatch.setattr(generate, \"Tokenizer\", tokenizer_mock)\n    generate_mock = Mock()\n    generated_tokens = torch.tensor([3, 2, 1])\n    acceptance_rate = 0.0\n    generate_mock.return_value = (generated_tokens, acceptance_rate)\n    monkeypatch.setattr(generate, \"generate\", generate_mock)\n\n    # do the sampling\n    num_samples = 2\n    out, err = StringIO(), StringIO()\n    with redirect_stdout(out), redirect_stderr(err):\n        generate.main(\n            draft_model_checkpoint_dir=draft_model_dir,\n            target_model_checkpoint_dir=target_model_dir,\n            temperature=2.0,\n            top_k=2,\n            top_p=0.9,\n            num_samples=num_samples,\n        )\n\n    assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples\n    assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value[0])\n    assert (\n        generate_mock.mock_calls\n        == [\n            call(\n                ANY,\n                ANY,\n                tensor_like,\n                53,\n                temperature=2.0,\n                top_k=2,\n                top_p=0.9,\n                stop_tokens=[tokenizer_mock.return_value.eos_id],\n                speculative_k=3,\n            )\n        ]\n        * num_samples\n    )\n    expected_output = \"foo bar baz\\nAcceptance rate: 0.00%\\n\" * num_samples\n    # Allow for the config to be printed before the expected repeated strings.\n    pattern = rf\".*^{re.escape(expected_output.strip())}$.*\"\n    assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)\n\n    err_value = err.getvalue()\n    expected_parts = [\n        \"'padded_vocab_size': 512\",\n        \"'n_layer': 2\",\n        \"'n_head': 4\",\n    ]\n    assert all(part in err_value for part in expected_parts)\n\n\ndef test_cli():\n    args = [\"litgpt\", \"generate_speculatively\", \"-h\"]\n    output = subprocess.check_output(args)\n    output = str(output.decode())\n    assert \"Default generation option\" in output\n"
  },
  {
    "path": "tests/test_lora.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nfrom contextlib import redirect_stdout\nfrom copy import deepcopy\nfrom io import StringIO\nfrom itertools import product\nfrom unittest import mock\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\nimport yaml\nfrom lightning import Fabric\nfrom lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision\nfrom lightning.fabric.wrappers import _FabricOptimizer\nfrom torch._dynamo.backends import debugging\nfrom torch.distributed.device_mesh import init_device_mesh\nfrom torch.nn import functional as F\nfrom transformers.models.gemma import GemmaConfig, GemmaForCausalLM\nfrom transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM\nfrom transformers.models.gemma3 import Gemma3ForCausalLM, Gemma3TextConfig\nfrom transformers.models.mixtral import MixtralConfig, MixtralForCausalLM\n\nimport litgpt.config as config_module\nimport litgpt.finetune.lora as module\nfrom litgpt.args import EvalArgs, TrainArgs\nfrom litgpt.data import Alpaca\nfrom litgpt.lora import GPT as LoRAGPT\nfrom litgpt.lora import (\n    CausalSelfAttention,\n    Config,\n    LoRALinear,\n    LoRAQKVLinear,\n    lora_filter,\n    mark_only_lora_as_trainable,\n    merge_lora_weights,\n)\nfrom litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention\nfrom litgpt.model import GPT as BaseGPT\nfrom litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_gemma_3, copy_weights_hf_llama\nfrom litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved\nfrom litgpt.utils import _RunIf\n\n\ndef test_lora_layer_replacement():\n    config = Config(n_layer=2, n_head=4, n_embd=8, block_size=8, vocab_size=8, lora_r=8, lora_alpha=8, lora_dropout=0.1)\n    model = LoRAGPT(config)\n\n    assert isinstance(model.transformer.h[0].attn, LoRACausalSelfAttention)\n    assert isinstance(model.transformer.h[1].attn, LoRACausalSelfAttention)\n    assert isinstance(model.lm_head, LoRALinear)\n    assert isinstance(model.transformer.h[0].mlp.proj, LoRALinear)\n\n\ndef test_lora_merge():\n    config = Config(\n        n_layer=1,\n        n_head=2,\n        n_embd=8,\n        block_size=8,\n        vocab_size=8,\n        lora_r=8,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_value=True,\n        lora_projection=True,\n    )\n    model = LoRAGPT(config)\n    model.train()\n    attn_proj = model.transformer.h[0].attn.proj\n\n    initial_weight = attn_proj.linear.weight.clone()\n    assert torch.equal(attn_proj.linear.weight, initial_weight)\n\n    # perform an update to the LoRA weights\n    mark_only_lora_as_trainable(model)\n    optimizer = torch.optim.SGD(model.parameters(), lr=1.0)\n    y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64))\n    y.sum().backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    # the weight remains unchanged (only lora A and B change)\n    assert torch.equal(attn_proj.linear.weight, initial_weight)\n\n    # calling merge() multiple times in a row should not merge multiple times\n    merge_lora_weights(model)\n    assert attn_proj.merged\n    weight_after = attn_proj.linear.weight.clone()\n    merge_lora_weights(model)\n    merge_lora_weights(model)\n    assert torch.equal(attn_proj.linear.weight, weight_after)\n\n    # check that `W_after = W_initial + (A x B)`\n    delta_w = attn_proj.get_lora_AB()\n    torch.testing.assert_close(weight_after, initial_weight + delta_w)\n\n\ndef test_lora_mqa_gqa():\n    # MHA\n    config = Config(\n        n_layer=1,\n        n_head=4,\n        n_embd=8,\n        block_size=1,\n        vocab_size=1,\n        lora_r=2,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_value=True,\n    )\n    assert config.n_query_groups == config.n_head\n    model = LoRAGPT(config)\n    attn = model.transformer.h[0].attn.qkv\n    for p in attn.linear.parameters():\n        torch.nn.init.zeros_(p)\n    torch.nn.init.ones_(attn.lora_B)\n    lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23]\n    assert attn.linear.weight.shape == (24, 8)\n    assert attn.lora_A.shape == (4, 8)\n    assert attn.lora_B.shape == (16, 2)\n    assert torch.equal(attn.lora_ind, torch.tensor(lora_ind))\n    x = torch.randint(0, 8, size=(3, 5, 16), dtype=torch.int64)\n    assert attn.zero_pad(x).shape == (3, 5, 24)\n    bsz, ctx_len, in_dim = 2, 30, 8\n    x_in = torch.randn(bsz, ctx_len, in_dim)\n    out = attn(x_in)\n    non_lora_ind = list(set(range(24)).difference(lora_ind))\n    assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)\n    assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0\n\n    # MQA\n    config.n_query_groups = 1\n    model = LoRAGPT(config)\n    attn = model.transformer.h[0].attn.qkv\n    for p in attn.linear.parameters():\n        torch.nn.init.zeros_(p)\n    torch.nn.init.ones_(attn.lora_B)\n    lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 10, 11]\n    assert attn.linear.weight.shape == (12, 8)\n    assert attn.lora_A.shape == (4, 8)\n    assert attn.lora_B.shape == (10, 2)\n    assert torch.equal(attn.lora_ind, torch.tensor(lora_ind))\n    x = torch.randint(0, 8, size=(3, 5, 10), dtype=torch.int64)\n    assert attn.zero_pad(x).shape == (3, 5, 12)\n    bsz, ctx_len, in_dim = 2, 30, 8\n    x_in = torch.randn(bsz, ctx_len, in_dim)\n    out = attn(x_in)\n    non_lora_ind = list(set(range(12)).difference(lora_ind))\n    assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)\n    assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0\n\n    # GQA\n    config.n_query_groups = 2\n    model = LoRAGPT(config)\n    attn = model.transformer.h[0].attn.qkv\n    for p in attn.linear.parameters():\n        torch.nn.init.zeros_(p)\n    torch.nn.init.ones_(attn.lora_B)\n    lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]\n    assert attn.linear.weight.shape == (16, 8)\n    assert attn.lora_A.shape == (4, 8)\n    assert attn.lora_B.shape == (12, 2)\n    assert torch.equal(attn.lora_ind, torch.tensor(lora_ind))\n    x = torch.randint(0, 8, size=(3, 5, 12), dtype=torch.int64)\n    assert attn.zero_pad(x).shape == (3, 5, 16)\n    bsz, ctx_len, in_dim = 2, 30, 8\n    x_in = torch.randn(bsz, ctx_len, in_dim)\n    out = attn(x_in)\n    non_lora_ind = list(set(range(16)).difference(lora_ind))\n    assert torch.count_nonzero(out[:, :, lora_ind]) == bsz * ctx_len * len(lora_ind)\n    assert torch.count_nonzero(out[:, :, non_lora_ind]) == 0\n\n\n@pytest.mark.parametrize(\n    \"n_head, n_query_groups, enable_lora\",\n    [\n        (4, 2, (True, False, True)),  # GQA: Q+V only\n        (4, 1, (False, True, True)),  # MQA: K+V only\n        (4, 2, (True, True, False)),  # GQA: Q+K only\n        (8, 2, (True, True, True)),  # GQA: all enabled, different ratio\n        (4, 4, (False, False, True)),  # MHA: V only\n    ],\n)\ndef test_lora_ind_correctness(n_head, n_query_groups, enable_lora):\n    \"\"\"Verify lora_ind correctly partitions Q, K, V regions using head_size-based sizes.\"\"\"\n    n_embd = 16\n    config = Config(\n        n_layer=1,\n        n_head=n_head,\n        n_embd=n_embd,\n        block_size=1,\n        vocab_size=1,\n        lora_r=2,\n        lora_alpha=8,\n        lora_dropout=0.0,\n        lora_query=enable_lora[0],\n        lora_key=enable_lora[1],\n        lora_value=enable_lora[2],\n        n_query_groups=n_query_groups,\n    )\n    model = LoRAGPT(config)\n    attn = model.transformer.h[0].attn.qkv\n\n    head_size = n_embd // n_head\n    q_size = head_size * n_head\n    kv_size = head_size * n_query_groups\n\n    expected_ind = []\n    if enable_lora[0]:\n        expected_ind.extend(range(0, q_size))\n    if enable_lora[1]:\n        expected_ind.extend(range(q_size, q_size + kv_size))\n    if enable_lora[2]:\n        expected_ind.extend(range(q_size + kv_size, q_size + 2 * kv_size))\n\n    assert torch.equal(attn.lora_ind, torch.tensor(expected_ind))\n\n    # Verify zero_pad output dimension matches full QKV size\n    total_qkv = q_size + 2 * kv_size\n    lora_out_dim = sum(attn.qkv_shapes)\n    x = torch.randn(1, 1, lora_out_dim)\n    assert attn.zero_pad(x).shape[-1] == total_qkv\n\n\ndef test_lora_filter(tmp_path):\n    fabric = Fabric(devices=1)\n    model = LoRAGPT.from_name(\"pythia-14m\", n_layer=3, lora_r=1, lora_query=True, lora_value=True)\n    save_path = tmp_path / \"model.pth\"\n    fabric.save(save_path, {\"model\": model}, filter={\"model\": lora_filter})\n    saved = torch.load(save_path)[\"model\"]\n\n    expected = {\n        \"transformer.h.1.attn.qkv.lora_B\",\n        \"transformer.h.2.attn.qkv.lora_B\",\n        \"transformer.h.2.attn.qkv.lora_A\",\n        \"transformer.h.1.attn.qkv.lora_A\",\n        \"transformer.h.0.attn.qkv.lora_A\",\n        \"transformer.h.0.attn.qkv.lora_B\",\n    }\n    assert set(saved) == expected\n\n\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\ndef test_lora_script(tmp_path, fake_checkpoint_dir, monkeypatch, alpaca_path):\n    model_config = dict(block_size=128, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)\n    (fake_checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(model_config))\n    monkeypatch.setattr(module, \"load_checkpoint\", Mock())\n    monkeypatch.setattr(module, \"merge_lora\", Mock())\n\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value = tokenizer_mock\n    tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])\n    monkeypatch.setattr(module, \"Tokenizer\", tokenizer_mock)\n\n    out_dir = tmp_path / \"out\"\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"lora.py\", str(fake_checkpoint_dir)]):\n        module.setup(\n            fake_checkpoint_dir,\n            data=Alpaca(\n                download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0\n            ),\n            out_dir=out_dir,\n            precision=\"32-true\",\n            train=TrainArgs(global_batch_size=1, save_interval=2, epochs=1, max_steps=6, micro_batch_size=1),\n            eval=EvalArgs(interval=2, max_iters=2, max_new_tokens=1),\n        )\n\n    out_dir_contents = set(os.listdir(out_dir))\n    checkpoint_dirs = {\"step-000002\", \"step-000004\", \"step-000006\", \"final\"}\n    assert checkpoint_dirs.issubset(out_dir_contents)\n    assert all((out_dir / p).is_dir() for p in checkpoint_dirs)\n    for checkpoint_dir in checkpoint_dirs:\n        assert {p.name for p in (out_dir / checkpoint_dir).iterdir()} == {\n            \"lit_model.pth.lora\",\n            \"model_config.yaml\",\n            \"tokenizer_config.json\",\n            \"tokenizer.json\",\n            \"hyperparameters.yaml\",\n            \"prompt_style.yaml\",\n        }\n    assert (out_dir / \"logs\" / \"csv\" / \"version_0\" / \"metrics.csv\").is_file()\n\n    logs = stdout.getvalue()\n    assert logs.count(\"(step)\") == 6\n    assert logs.count(\"val loss\") == 4  # 3 validations + 1 final validation\n    assert logs.count(\"Final evaluation\") == 1\n    assert \"of trainable parameters: 512\" in logs\n\n\ndef test_lora_init_when_linear_overridden():\n    class MyLinear(torch.nn.Linear):\n        def __init__(self, *args, **kwargs):\n            # this needs to be implemented to demonstrate the failure\n            super().__init__(*args, **kwargs)\n\n    original_linear = torch.nn.Linear\n    # Our bnb does this sort of monkey patching\n    torch.nn.Linear = MyLinear\n    layer = LoRAQKVLinear(1, 1, 1, 1, 1)\n    assert isinstance(layer.linear, original_linear)\n    torch.nn.Linear = original_linear\n\n\n@pytest.mark.parametrize(\n    (\"apply_to\", \"target_layer_names\", \"mlp_class_name\"),\n    (\n        (\"lora_projection\", \"transformer.h.0.attn.proj\", \"GptNeoxMLP\"),\n        (\"lora_mlp\", {\"transformer.h.0.mlp.fc\", \"transformer.h.0.mlp.proj\"}, \"GptNeoxMLP\"),\n        (\"lora_head\", \"lm_head\", \"GptNeoxMLP\"),\n        (\"lora_projection\", \"transformer.h.0.attn.proj\", \"LLaMAMLP\"),\n        (\"lora_mlp\", {\"transformer.h.0.mlp.fc_1\", \"transformer.h.0.mlp.fc_2\", \"transformer.h.0.mlp.proj\"}, \"LLaMAMLP\"),\n        (\"lora_head\", \"lm_head\", \"LLaMAMLP\"),\n    ),\n)\ndef test_lora_linear_utilization(apply_to, target_layer_names, mlp_class_name):\n    config = Config(\n        n_layer=1,\n        n_head=4,\n        n_embd=8,\n        block_size=1,\n        vocab_size=1,\n        lora_r=2,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        mlp_class_name=mlp_class_name,\n        intermediate_size=8 * 3,\n        **{apply_to: True},\n    )\n    model = LoRAGPT(config)\n    state_dict = model.state_dict()\n\n    if isinstance(target_layer_names, str):\n        target_layer_names = {target_layer_names}\n    lora_sublayers = (\".lora_A\", \".lora_B\")\n\n    # check that all the target layers have LoRA weights\n    for layer_name in target_layer_names:\n        for lora_sublayer in lora_sublayers:\n            assert layer_name + lora_sublayer in state_dict\n\n    # check that only target layers have LoRA weights\n    lora_params = [k for k in state_dict if k.endswith(lora_sublayers)]\n    lora_params = {k[:-7] for k in lora_params}\n    assert lora_params == target_layer_names\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"apply_to\", (None, \"lora_query\", \"lora_key\", \"lora_value\", \"lora_projection\", \"lora_mlp\", \"lora_head\")\n)\ndef test_lora_gpt_apply_lora_forward_no_exception(apply_to):\n    config = Config(n_layer=1, n_head=4, n_embd=8, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_dropout=0.1)\n    if apply_to:\n        setattr(config, apply_to, True)\n    input_ids = torch.tensor([[1]])\n    model = LoRAGPT(config)\n    model.eval()\n\n    model(input_ids)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"n_query_groups\", (1, 2, 3, 6))\n@pytest.mark.parametrize(\"apply_to\", product((False, True), repeat=3))\ndef test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, apply_to):\n    keys = (\"lora_query\", \"lora_key\", \"lora_value\")\n    values = apply_to\n    apply_to = dict(zip(keys, values))\n\n    config = Config(\n        n_layer=1,\n        n_head=6,\n        n_embd=12,\n        block_size=1,\n        vocab_size=1,\n        lora_r=2,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        n_query_groups=n_query_groups,\n        **apply_to,\n    )\n    model = LoRAGPT(config)\n    merge_lora_weights(model)\n    input_ids = torch.tensor([[1]])\n    model(input_ids)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"head_size\", (1, 2, 4))\n@pytest.mark.parametrize(\"n_head\", (1, 2, 3, 6, 12))\n@pytest.mark.parametrize(\n    \"enable_lora\",\n    [\n        (False, False, True),\n        (False, True, False),\n        (False, True, True),\n        (True, False, False),\n        (True, False, True),\n        (True, True, False),\n        (True, True, True),\n    ],\n)\ndef test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora):\n    C = 12\n    layer = LoRAQKVLinear(\n        C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora\n    )\n    x = torch.randn((1, 1, C))\n    a = F.linear(x, layer.lora_A).transpose(-2, -1)  # after_A\n    b = layer.lora_B.data.unsqueeze(-1)\n\n    # original PyTorch conv1d function output\n    conv1d_pytorch = F.conv1d(a, b, groups=sum(layer.enable_lora))\n\n    # custom conv1d\n    conv1d_custom = layer.conv1d(a, b)\n\n    # custom conv1d forced to split, apply and concat tensors\n    layer.n_head = layer.n_query_groups + 1\n    conv1d_custom_forced = layer.conv1d(a, b)\n\n    assert torch.allclose(conv1d_pytorch, conv1d_custom)\n    assert torch.allclose(conv1d_pytorch, conv1d_custom_forced)\n\n\n@pytest.mark.parametrize((\"rank\", \"expected_merged\"), ((0, False), (1, True)))\ndef test_lora_linear_weights_merged_status(rank, expected_merged):\n    layer = LoRALinear(10, 10, r=rank)\n    assert not layer.merged\n    layer.merge()\n    assert layer.merged == expected_merged\n\n\n@pytest.mark.parametrize(\n    (\"rank\", \"enable_lora\", \"expected_merged\"),\n    ((0, True, False), (1, True, True), (0, False, False), (1, False, False)),\n)\ndef test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged):\n    C = 10\n    layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)\n    assert not layer.merged\n    layer.merge()\n    assert layer.merged == expected_merged\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_lora_merge_with_bitsandbytes():\n    if not _BITSANDBYTES_AVAILABLE:\n        pytest.skip(\"BNB not available\")\n    import bitsandbytes as bnb\n\n    config = Config(\n        n_layer=1,\n        n_head=2,\n        n_embd=8,\n        block_size=8,\n        vocab_size=8,\n        lora_r=8,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_value=True,\n        lora_projection=True,\n    )\n    fabric = Fabric(devices=1, plugins=BitsandbytesPrecision(\"nf4\", dtype=torch.bfloat16, ignore_modules={\"lm_head\"}))\n    model = LoRAGPT(config)\n    mark_only_lora_as_trainable(model)\n\n    from bitsandbytes.optim import PagedAdamW\n\n    optimizer = PagedAdamW(model.parameters(), lr=1.0)\n    model, optimizer = fabric.setup(model, optimizer)\n\n    model.train()\n\n    attn_proj = model.transformer.h[0].attn.proj\n    initial_weight = attn_proj.linear.weight.clone()\n    initial_weight_kwargs = attn_proj.linear.weight.__dict__\n\n    # this was skipped\n    assert model.lm_head.linear.weight.dtype is torch.float32\n    assert attn_proj.linear.weight.dtype is torch.uint8\n\n    # perform an update to the LoRA weights\n    y = model(torch.randint(0, 8, size=(2, 4), dtype=torch.int64, device=fabric.device))\n    loss = y.sum()\n    fabric.backward(loss)\n    optimizer.step()\n    optimizer.zero_grad()\n    # the weight remains unchanged (only lora A and B change)\n    assert torch.equal(attn_proj.linear.weight, initial_weight)\n\n    # calling merge() multiple times in a row should not merge multiple times\n    merge_lora_weights(model)\n    assert attn_proj.merged\n    weight_after = attn_proj.linear.weight.clone()\n    merge_lora_weights(model)\n    merge_lora_weights(model)\n    assert torch.equal(attn_proj.linear.weight, weight_after)\n\n    # check that `W_after = W_initial + (A x B)`\n    delta_w = attn_proj.get_lora_AB()\n    # dequantize initial weight and sum with delta_w\n    initial_weight_data = (\n        bnb.functional.dequantize_4bit(initial_weight.data, initial_weight_kwargs[\"quant_state\"]) + delta_w\n    )\n    # quantize again\n    initial_weight_data = bnb.nn.Params4bit(\n        initial_weight_data.to(\"cpu\"), requires_grad=False, **initial_weight_kwargs\n    ).to(initial_weight.device)\n    torch.testing.assert_close(weight_after, initial_weight_data)\n\n\ndef test_lora_gpt_init_weights():\n    config = Config(n_layer=1, n_head=6, n_embd=12, block_size=1, vocab_size=1, lora_r=2, lora_alpha=8, lora_head=True)\n    model = LoRAGPT(config)\n    param = model.lm_head.lora_B.data\n\n    assert (param == 0).all()\n    torch.nn.init.constant_(param, 1.23)\n    assert (param != 0).any()\n    model.apply(model._init_weights)\n    assert (param == 0).all()\n\n\n@pytest.mark.parametrize(\"name\", [c[\"name\"] for c in config_module.configs])\ndef test_base_model_can_be_lora_loaded(name):\n    kwargs = {\"n_layer\": 2, \"n_head\": 8, \"n_query_groups\": 4, \"n_embd\": 16, \"padded_vocab_size\": 32}\n    base_model = BaseGPT.from_name(name, **kwargs)\n    base_model_state_dict = base_model.state_dict()\n    lora_model = LoRAGPT.from_name(\n        name,\n        **kwargs,\n        lora_r=1,\n        lora_query=True,\n        lora_key=True,\n        lora_value=True,\n        lora_projection=True,\n        lora_mlp=True,\n        lora_head=True,\n    )\n    keys = lora_model.load_state_dict(base_model_state_dict, strict=False)\n    assert not keys.unexpected_keys\n    for k in keys.missing_keys:\n        assert lora_filter(k, None)\n\n\n@_RunIf(dynamo=True)\n@torch.inference_mode()\ndef test_lora_compile():\n    model = LoRAGPT.from_name(\n        \"pythia-14m\",\n        n_layer=3,\n        lora_r=8,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_key=True,\n        lora_value=True,\n        lora_projection=True,\n        lora_mlp=True,\n        lora_head=True,\n    )\n    x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)\n\n    explanation = torch._dynamo.explain(model)(x)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n    model = LoRAGPT(model.config)\n    model.set_kv_cache(2)\n    input_pos = torch.arange(model.config.block_size)\n    explanation = torch._dynamo.explain(model)(x, input_pos)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n\n@torch.inference_mode()\ndef test_against_hf_mixtral():\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    ours_config = Config.from_name(\n        \"Mixtral-8x7B-Instruct-v0.1\",\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_embd=32,\n        n_head=8,\n        n_query_groups=2,\n        intermediate_size=86,\n        n_expert=4,\n        lora_r=1,\n        lora_key=True,\n        lora_value=True,\n    )\n    T = 5\n    theirs_config = MixtralConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        num_local_experts=ours_config.n_expert,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = MixtralForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = LoRAGPT(ours_config).to(device)\n    keys = ours_model.load_state_dict(state_dict, strict=False)\n    assert not keys.unexpected_keys\n    for k in keys.missing_keys:\n        assert lora_filter(k, None)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", [\"gemma-2b\", \"gemma-7b\"])\ndef test_against_hf_gemma(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    T = 5\n    ours_config = Config.from_name(\n        model_name,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        head_size=4,\n        intermediate_size=86,\n        lora_r=1,\n        lora_query=True,\n        lora_key=True,\n        lora_value=True,\n    )\n    theirs_config = GemmaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = GemmaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = LoRAGPT(ours_config).to(device)\n    keys = ours_model.load_state_dict(state_dict, strict=False)\n    assert not keys.unexpected_keys\n    for k in keys.missing_keys:\n        assert lora_filter(k, None)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-2-9b\", \"gemma-2-27b\"))\ndef test_against_original_gemma_2(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = Gemma2ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_gemma_2({}, state_dict, theirs_state_dict)\n    ours_model = LoRAGPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, atol=1e-4, rtol=1e-5)\n\n\n@torch.inference_mode()\n@pytest.mark.flaky(reruns=3)\n@pytest.mark.parametrize(\"model_name\", (\"gemma-3-1b-it\", \"gemma-3-4b-it\", \"gemma-3-12b-it\", \"gemma-3-27b-it\"))\ndef test_against_original_gemma_3(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma3TextConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = Gemma3ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_gemma_3({}, state_dict, theirs_state_dict)\n    ours_model = LoRAGPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path):\n    if not _BITSANDBYTES_AVAILABLE:\n        pytest.skip(\"BNB not available\")\n\n    from bitsandbytes.optim import PagedAdamW\n\n    model_config = dict(\n        block_size=128,\n        n_layer=2,\n        n_embd=8,\n        n_head=4,\n        padded_vocab_size=8,\n        bias=True,\n        lora_r=8,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_value=True,\n        lora_projection=True,\n    )\n    (fake_checkpoint_dir / \"model_config.yaml\").write_text(yaml.dump(model_config))\n\n    tokenizer_mock = Mock()\n    tokenizer_mock.return_value = tokenizer_mock\n    tokenizer_mock.encode = lambda *_, **__: torch.tensor([3, 2, 1])\n    monkeypatch.setattr(module, \"Tokenizer\", tokenizer_mock)\n\n    monkeypatch.setattr(module, \"load_checkpoint\", Mock())\n    monkeypatch.setattr(module, \"merge_lora\", Mock())\n    train_mock = Mock()\n    train_mock.return_value = {\n        \"raw_tokens\": 1000,\n        \"raw_tokens_plus_prompt_template\": 1100,\n        \"raw_tokens_plus_prompt_template_and_padding\": 1200,\n    }\n    monkeypatch.setattr(module, \"fit\", train_mock)\n\n    stdout = StringIO()\n    with redirect_stdout(stdout), mock.patch(\"sys.argv\", [\"full.py\", str(fake_checkpoint_dir)]):\n        module.setup(\n            fake_checkpoint_dir,\n            data=Alpaca(\n                download_dir=alpaca_path.parent, file_name=alpaca_path.name, val_split_fraction=0.5, num_workers=0\n            ),\n            out_dir=tmp_path,\n            precision=\"16-true\",\n            quantize=\"bnb.nf4-dq\",\n        )\n\n    _, kwargs = train_mock.call_args\n    fabric = kwargs[\"fabric\"]\n    model = kwargs[\"model\"]\n    optimizer = kwargs[\"optimizer\"]\n    model.transformer.wte = model.transformer.wte.half()\n    assert isinstance(fabric.strategy.precision, BitsandbytesPrecision)\n    assert isinstance(optimizer, _FabricOptimizer)\n    assert isinstance(optimizer._optimizer, PagedAdamW)\n\n    dtype_to_name = {\"torch.uint8\": set(), \"torch.float16\": set()}\n    for name, layer in model.named_parameters():\n        name = name[len(\"_forward_module.\") :]\n        dtype_to_name[str(layer.dtype)].add(name)\n    assert dtype_to_name == {\n        \"torch.uint8\": {\n            \"transformer.h.0.attn.qkv.linear.weight\",\n            \"transformer.h.0.attn.proj.linear.weight\",\n            \"transformer.h.0.mlp.fc.linear.weight\",\n            \"transformer.h.1.mlp.proj.linear.weight\",\n            \"transformer.h.0.mlp.proj.linear.weight\",\n            \"transformer.h.1.attn.qkv.linear.weight\",\n            \"lm_head.linear.weight\",\n            \"transformer.h.1.attn.proj.linear.weight\",\n            \"transformer.h.1.mlp.fc.linear.weight\",\n        },\n        \"torch.float16\": {\n            \"transformer.h.0.attn.qkv.lora_B\",\n            \"transformer.h.0.norm_2.weight\",\n            \"transformer.wte.weight\",\n            \"transformer.wte.norm.weight\",\n            \"transformer.wte.norm.bias\",\n            \"transformer.h.1.mlp.fc.linear.bias\",\n            \"transformer.ln_f.bias\",\n            \"transformer.h.1.attn.qkv.lora_B\",\n            \"transformer.h.1.attn.proj.linear.bias\",\n            \"transformer.h.1.norm_1.weight\",\n            \"transformer.h.1.attn.qkv.linear.bias\",\n            \"transformer.h.1.attn.qkv.lora_A\",\n            \"transformer.h.1.norm_1.bias\",\n            \"transformer.h.1.norm_2.bias\",\n            \"transformer.h.0.attn.proj.linear.bias\",\n            \"transformer.h.0.norm_1.bias\",\n            \"transformer.h.0.mlp.proj.linear.bias\",\n            \"transformer.h.0.mlp.fc.linear.bias\",\n            \"transformer.h.0.norm_2.bias\",\n            \"transformer.ln_f.weight\",\n            \"transformer.h.0.attn.qkv.lora_A\",\n            \"transformer.h.1.norm_2.weight\",\n            \"transformer.h.1.mlp.proj.linear.bias\",\n            \"transformer.h.0.norm_1.weight\",\n            \"transformer.h.0.attn.qkv.linear.bias\",\n        },\n    }\n\n    assert {p.name for p in tmp_path.rglob(\"*.lora\")} == {\"lit_model.pth.lora\"}\n    state_dict = torch.load(tmp_path / \"final\" / \"lit_model.pth.lora\")\n    assert len(state_dict) == 1\n    dtype_to_name = {\"torch.float16\": set()}\n    for name, layer in state_dict[\"model\"].items():\n        dtype_to_name[str(layer.dtype)].add(name)\n    assert dtype_to_name == {\n        \"torch.float16\": {\n            \"transformer.h.1.attn.qkv.lora_A\",\n            \"transformer.h.0.attn.qkv.lora_A\",\n            \"transformer.h.0.attn.qkv.lora_B\",\n            \"transformer.h.1.attn.qkv.lora_B\",\n        }\n    }\n\n    logs = stdout.getvalue()\n    assert \"of trainable parameters: 512\" in logs\n    assert \"of non-trainable parameters: 1,888\" in logs\n\n\n@_RunIf(standalone=True, min_cuda_gpus=2)\ndef test_lora_model_fsdp_init():\n    config = Config(\n        n_layer=1,\n        n_head=2,\n        n_embd=8,\n        block_size=8,\n        vocab_size=8,\n        lora_r=8,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_value=False,\n        lora_projection=True,\n    )\n    fabric = Fabric(devices=2, strategy=\"fsdp\", precision=\"16-true\")\n    fabric.launch()\n    with fabric.init_module(empty_init=True):\n        model = LoRAGPT(config)\n    x = torch.randint(0, config.padded_vocab_size, size=(2, config.block_size), dtype=torch.int64, device=fabric.device)\n    model = fabric.setup(model)\n    y = model(x)\n    assert y.shape == torch.Size([2, 8, 512])\n\n    # verify that all the parameters, buffers and other attributes aren't on `meta` device\n    for m in model.modules():\n        for p_name, parameter in m.named_parameters():\n            assert not parameter.is_meta, f\"Parameter `{p_name}` isn't materialized.\"\n        for b_name, buffer in m._buffers.items():\n            assert not buffer.is_meta, f\"Buffer `{b_name}` isn't materialized.\"\n        for attr_name, attr_value in m.__dict__.items():\n            if isinstance(attr_value, torch.Tensor):\n                assert not attr_value.is_meta, f\"Attribute `{attr_name}` isn't materialized.\"\n\n\ndef test_zero_pad_cpu_and_mocked_mps():\n    head_size = 64\n    n_head = 12\n    n_query_groups = 3\n    in_features = 128\n    q_size = head_size * n_head\n    kv_size = head_size * n_query_groups\n    out_features = q_size + 2 * kv_size\n    enable_lora = [True, False, True]\n    r = 4\n\n    model = LoRAQKVLinear(\n        in_features=in_features,\n        out_features=out_features,\n        head_size=head_size,\n        n_head=n_head,\n        n_query_groups=n_query_groups,\n        r=r,\n        enable_lora=enable_lora,\n    )\n\n    batch_size = 64\n    seq_len = 64\n    # embed_dim = sum of enabled qkv shapes: Q (q_size) + V (kv_size)\n    embed_dim = q_size + kv_size\n    x = torch.randn(batch_size, seq_len, embed_dim)\n\n    result_cpu = model.zero_pad(x)\n\n    with mock.patch(\"torch.backends.mps.is_available\", return_value=True):\n        with mock.patch(\"torch.Tensor.device\", new_callable=mock.PropertyMock) as mock_device:\n            mock_device.return_value = torch.device(\"mps\")\n\n            result_mps = model.zero_pad(x)\n\n            assert result_cpu.shape == result_mps.shape, \"Shape mismatch between CPU and MPS\"\n            assert torch.allclose(result_cpu, result_mps), \"Tensor values mismatch between CPU and MPS\"\n\n\ndef test_load_legacy_state_dict():\n    \"\"\"Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.\"\"\"\n    config = Config(\n        n_embd=32, n_head=4, head_size=8, n_query_groups=4, bias=True, lora_r=8, lora_alpha=16, lora_dropout=0.1\n    )\n\n    attention_1 = CausalSelfAttention(config=config, block_idx=0)\n\n    # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`\n    # and make them interleaved\n    state_dict = deepcopy(attention_1.state_dict())\n    state_dict[\"attn.linear.weight\"] = make_qkv_interleaved(state_dict.pop(\"qkv.linear.weight\"), config)\n    state_dict[\"attn.linear.bias\"] = make_qkv_interleaved(state_dict.pop(\"qkv.linear.bias\"), config)\n\n    attention_2 = CausalSelfAttention(config=config, block_idx=0)\n    attention_2.load_state_dict(state_dict)\n\n\n@_RunIf(standalone=True, min_cuda_gpus=2)\ndef test_parallelize_fn():\n    from litgpt.finetune.lora import parallelize_fn\n\n    config = Config(\n        n_layer=2,\n        n_head=4,\n        n_embd=32,\n        block_size=8,\n        vocab_size=8,\n        lora_r=4,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_value=True,\n        lora_projection=True,\n    )\n\n    fabric = Fabric(devices=2, strategy=\"fsdp\", precision=\"16-true\")\n    fabric.launch()\n\n    model = LoRAGPT(config)\n    mark_only_lora_as_trainable(model)\n\n    # create device mesh for data parallel\n    device_mesh = init_device_mesh(\n        device_type=fabric.device.type,\n        mesh_shape=(2, 1),\n        mesh_dim_names=(\"data_parallel\", \"tensor_parallel\"),\n    )\n\n    # test with activation checkpointing enabled (default)\n    parallelized_model = parallelize_fn(model, device_mesh, activation_checkpointing=True)\n\n    # verify the model is still functional\n    assert parallelized_model is not None\n    assert isinstance(parallelized_model, LoRAGPT)\n\n    parallelized_model = parallelized_model.to(fabric.device)\n\n    # test forward pass to ensure the parallelized model works\n    x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device)\n\n    # verify forward pass works\n    with torch.no_grad():\n        output = parallelized_model(x)\n        assert output.shape == (1, config.block_size, config.padded_vocab_size)\n\n    # test with activation checkpointing disabled\n    model_no_checkpoint = LoRAGPT(config)\n    mark_only_lora_as_trainable(model_no_checkpoint)\n\n    parallelized_model_no_checkpoint = parallelize_fn(model_no_checkpoint, device_mesh, activation_checkpointing=False)\n\n    # verify the model is still functional\n    assert parallelized_model_no_checkpoint is not None\n    assert isinstance(parallelized_model_no_checkpoint, LoRAGPT)\n\n    # test forward pass to ensure the parallelized model works\n    parallelized_model_no_checkpoint = parallelized_model_no_checkpoint.to(fabric.device)\n\n    with torch.no_grad():\n        output = parallelized_model_no_checkpoint(x)\n        assert output.shape == (1, config.block_size, config.padded_vocab_size)\n\n    # verify that all parameters are properly distributed (not on meta device)\n    for mod in parallelized_model.modules():\n        for param_name, param in mod.named_parameters():\n            if param.requires_grad:  # Only check trainable parameters (LoRA parameters)\n                assert not param.is_meta, f\"Parameter `{param_name}` should not be on meta device\"\n                assert param.device.type == \"cuda\", f\"Parameter `{param_name}` should be on CUDA device\"\n\n\n@_RunIf(standalone=True, min_cuda_gpus=2)\ndef test_load_from_full_model_state_dict():\n    from litgpt.finetune.lora import parallelize_fn\n    from litgpt.utils import load_from_full_model_state_dict\n\n    config = Config(\n        n_layer=2,\n        n_head=4,\n        n_embd=32,\n        block_size=8,\n        vocab_size=8,\n        lora_r=4,\n        lora_alpha=8,\n        lora_dropout=0.1,\n        lora_query=True,\n        lora_value=True,\n        lora_projection=True,\n        lora_mlp=True,\n        lora_head=True,\n    )\n\n    # set up distributed environment with FSDP\n    fabric = Fabric(devices=2, strategy=\"fsdp\", precision=\"16-true\")\n    fabric.launch()\n\n    # create a reference model to get the full state dict\n    reference_model = LoRAGPT(config)\n    mark_only_lora_as_trainable(reference_model)\n\n    # initialize the reference model with some values\n    with torch.no_grad():\n        for param in reference_model.parameters():\n            if param.requires_grad:\n                param.fill_(0.1)\n\n    # get the full state dict (simulating a checkpoint)\n    full_state_dict = {}\n    for name, param in reference_model.named_parameters():\n        # Convert parameters to checkpoint format (what load_from_full_model_state_dict expects)\n        if \"norm\" not in name and \"wte\" not in name and \"ln_f\" not in name:\n            # For linear layers, remove .linear from the name to simulate checkpoint format\n            checkpoint_name = name.replace(\".linear.weight\", \".weight\").replace(\".linear.bias\", \".bias\")\n        else:\n            # For norm, embedding, and layer norm layers, keep the original name\n            checkpoint_name = name\n        full_state_dict[checkpoint_name] = param.detach().clone()\n\n    # create distributed model\n    model = LoRAGPT(config)\n    mark_only_lora_as_trainable(model)\n\n    # set up device mesh for distributed model\n    device_mesh = init_device_mesh(\n        device_type=fabric.device.type,\n        mesh_shape=(2, 1),\n        mesh_dim_names=(\"data_parallel\", \"tensor_parallel\"),\n    )\n    model = parallelize_fn(model, device_mesh, activation_checkpointing=False)\n    model = model.to(fabric.device)\n\n    # test with default parameters (strict=False, cpu_offload=False)\n    result = load_from_full_model_state_dict(\n        model=model,\n        full_sd=full_state_dict,\n        device=fabric.device,\n        strict=False,\n        cpu_offload=False,\n    )\n\n    # verify that the function returns the missing/unexpected keys\n    assert hasattr(result, \"missing_keys\")\n    assert hasattr(result, \"unexpected_keys\")\n\n    # verify that parameters are loaded correctly\n    for name, param in model.named_parameters():\n        if param.requires_grad:\n            # Check that parameter is not on meta device\n            assert not param.is_meta, f\"Parameter {name} should not be on meta device\"\n            # Check that parameter is on the correct device\n            assert param.device.type == \"cuda\", f\"Parameter {name} should be on CUDA device\"\n\n    # test with cpu_offload=True\n    model_cpu_offload = LoRAGPT(config)\n    mark_only_lora_as_trainable(model_cpu_offload)\n    model_cpu_offload = parallelize_fn(model_cpu_offload, device_mesh, activation_checkpointing=False)\n    model_cpu_offload = model_cpu_offload.to(fabric.device)\n\n    result_cpu_offload = load_from_full_model_state_dict(\n        model=model_cpu_offload,\n        full_sd=full_state_dict,\n        device=fabric.device,\n        strict=False,\n        cpu_offload=True,\n    )\n\n    # verify that parameters are loaded correctly with CPU offload\n    for name, param in model_cpu_offload.named_parameters():\n        if param.requires_grad:\n            # Check that parameter is not on meta device\n            assert not param.is_meta, f\"Parameter {name} should not be on meta device\"\n            # With cpu_offload, parameters might be on CPU\n            assert param.device.type in [\"cpu\", \"cuda\"], f\"Parameter {name} should be on CPU or CUDA device\"\n\n    # test with strict=True\n    model_strict = LoRAGPT(config)\n    mark_only_lora_as_trainable(model_strict)\n    model_strict = parallelize_fn(model_strict, device_mesh, activation_checkpointing=False)\n    model_strict = model_strict.to(fabric.device)\n\n    try:\n        result_strict = load_from_full_model_state_dict(\n            model=model_strict,\n            full_sd=full_state_dict,\n            device=fabric.device,\n            strict=True,\n            cpu_offload=False,\n        )\n        # If strict loading succeeds, verify parameters\n        for name, param in model_strict.named_parameters():\n            if param.requires_grad:\n                assert not param.is_meta, f\"Parameter {name} should not be on meta device\"\n                assert param.device.type == \"cuda\", f\"Parameter {name} should be on CUDA device\"\n    except RuntimeError as e:\n        # strict=True might fail if there are missing keys, which is expected behavior\n        assert \"Missing key(s)\" in str(e) or \"Unexpected key(s)\" in str(e)\n\n    # test forward pass to ensure model still works after loading\n    x = torch.randint(0, config.padded_vocab_size, size=(1, config.block_size), dtype=torch.int64, device=fabric.device)\n\n    with torch.no_grad():\n        output = model(x)\n        assert output.shape == (1, config.block_size, config.padded_vocab_size)\n\n        output_cpu_offload = model_cpu_offload(x)\n        assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size)\n"
  },
  {
    "path": "tests/test_merge_lora.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport shutil\nfrom contextlib import redirect_stdout\nfrom io import StringIO\nfrom pathlib import Path\nfrom unittest import mock\n\nimport pytest\nimport torch\nimport yaml\n\nfrom litgpt.lora import GPT as LoRAGPT\nfrom litgpt.lora import lora_filter\nfrom litgpt.model import GPT\nfrom litgpt.scripts.merge_lora import load_lora_metadata, merge_lora\n\n\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\n@pytest.mark.parametrize(\n    (\"pretrained_dtype\", \"lora_dtype\"), [(None, None), (torch.float16, torch.float32), (torch.float16, torch.bfloat16)]\n)\ndef test_merge_lora(tmp_path, fake_checkpoint_dir, pretrained_dtype, lora_dtype):\n    pretrained_checkpoint_dir = tmp_path / \"pretrained\"\n    lora_checkpoint_dir = tmp_path / \"lora\"\n    shutil.copytree(fake_checkpoint_dir, pretrained_checkpoint_dir)\n    shutil.copytree(fake_checkpoint_dir, lora_checkpoint_dir)\n    (lora_checkpoint_dir / \"lit_model.pth\").unlink()  # should not already exist\n    shutil.rmtree(tmp_path / \"checkpoints\")\n\n    # Create a fake pretrained checkpoint\n    config = dict(block_size=128, padded_vocab_size=256, n_layer=3, n_head=8, n_embd=16)\n    with open(pretrained_checkpoint_dir / \"model_config.yaml\", \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(config, fp)\n    base_model = GPT.from_name(\"pythia-14m\", **config).to(dtype=pretrained_dtype)\n    state_dict = base_model.state_dict()\n    assert len(state_dict) == 40\n    torch.save(state_dict, pretrained_checkpoint_dir / \"lit_model.pth\")\n\n    # Create a fake LoRA checkpoint\n    lora_kwargs = dict(lora_r=8, lora_alpha=16, lora_dropout=0.05, lora_query=True, lora_value=True)\n    lora_model = LoRAGPT.from_name(\"pythia-14m\", **config, **lora_kwargs).to(dtype=lora_dtype)\n    state_dict = {k: v for k, v in lora_model.state_dict().items() if lora_filter(k, v)}\n    assert len(state_dict) == 6\n    torch.save(state_dict, lora_checkpoint_dir / \"lit_model.pth.lora\")\n    hparams = dict(checkpoint_dir=str(pretrained_checkpoint_dir), **lora_kwargs)\n    with open(lora_checkpoint_dir / \"hyperparameters.yaml\", \"w\", encoding=\"utf-8\") as file:\n        yaml.dump(hparams, file)\n    shutil.copyfile(pretrained_checkpoint_dir / \"model_config.yaml\", lora_checkpoint_dir / \"model_config.yaml\")\n\n    assert set(os.listdir(tmp_path)) == {\"lora\", \"pretrained\"}\n    merge_lora(lora_checkpoint_dir)\n    assert set(os.listdir(tmp_path)) == {\"lora\", \"pretrained\"}\n    assert set(os.listdir(lora_checkpoint_dir)) == {\n        \"model_config.yaml\",\n        \"lit_model.pth\",\n        \"lit_model.pth.lora\",\n        \"tokenizer.json\",\n        \"tokenizer_config.json\",\n        \"hyperparameters.yaml\",\n    }\n\n    # Assert that the merged weights can be loaded back into the base model\n    merged = torch.load(lora_checkpoint_dir / \"lit_model.pth\")\n    keys = base_model.load_state_dict(merged, strict=True)\n    assert not keys.missing_keys\n    assert not keys.unexpected_keys\n\n    # Attempt to merge again\n    stdout = StringIO()\n    with redirect_stdout(stdout):\n        merge_lora(lora_checkpoint_dir)\n    assert \"LoRA weights have already been merged\" in stdout.getvalue()\n\n\ndef test_load_lora_metadata(fake_checkpoint_dir):\n    assert not (fake_checkpoint_dir / \"hyperparameters.yaml\").is_file()\n    with pytest.raises(FileNotFoundError, match=\"missing a `hyperparameters.yaml` file\"):\n        load_lora_metadata(fake_checkpoint_dir)\n\n    hparams = dict(precision=\"bf16-mixed\", checkpoint_dir=\"checkpoints/meta-llama/Llama-2-7b\", lora_r=8, lora_alpha=16)\n    with open(fake_checkpoint_dir / \"hyperparameters.yaml\", \"w\", encoding=\"utf-8\") as file:\n        yaml.dump(hparams, file)\n\n    lora_args, pretrained_dir, precision = load_lora_metadata(fake_checkpoint_dir)\n    assert lora_args == dict(lora_r=8, lora_alpha=16)\n    assert pretrained_dir == Path(\"checkpoints/meta-llama/Llama-2-7b\")\n    assert precision == \"bf16-mixed\"\n"
  },
  {
    "path": "tests/test_model.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom copy import deepcopy\nfrom functools import partial\nfrom unittest import mock\n\nimport pytest\nimport torch\nfrom lightning import Fabric\nfrom lightning.fabric.utilities.imports import _IS_WINDOWS\nfrom lightning.fabric.utilities.init import _materialize_meta_tensors\nfrom torch._dynamo.backends import debugging\nfrom torch.backends.cuda import (\n    SDPAParams,\n    SDPBackend,\n    can_use_efficient_attention,\n    can_use_flash_attention,\n    flash_sdp_enabled,\n    math_sdp_enabled,\n    mem_efficient_sdp_enabled,\n)\nfrom transformers import AutoConfig, AutoModelForCausalLM\nfrom transformers.models.falcon import FalconConfig, FalconForCausalLM\nfrom transformers.models.gemma import GemmaConfig, GemmaForCausalLM\nfrom transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM\nfrom transformers.models.gemma3 import Gemma3Config, Gemma3ForCausalLM, Gemma3ForConditionalGeneration, Gemma3TextConfig\nfrom transformers.models.gpt_neox import GPTNeoXConfig, GPTNeoXForCausalLM\nfrom transformers.models.llama import LlamaConfig, LlamaForCausalLM\nfrom transformers.models.mistral import MistralConfig, MistralForCausalLM\nfrom transformers.models.mixtral import MixtralConfig, MixtralForCausalLM\nfrom transformers.models.olmo import OlmoConfig, OlmoForCausalLM\nfrom transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM\nfrom transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM\nfrom transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM\nfrom transformers.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM\n\nimport litgpt.config as config_module\nfrom litgpt import GPT, Config\nfrom litgpt.model import CausalSelfAttention, batched_index_copy_\nfrom litgpt.scripts.convert_hf_checkpoint import (\n    copy_weights_falcon,\n    copy_weights_gemma_2,\n    copy_weights_gemma_3,\n    copy_weights_gpt_neox,\n    copy_weights_hf_llama,\n    copy_weights_olmo2,\n    copy_weights_phi,\n    copy_weights_qwen_2_5,\n    copy_weights_qwen_3,\n)\nfrom litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved\nfrom litgpt.utils import _RunIf\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"rotary_pct\", (0.25, 1))\n@pytest.mark.parametrize(\"batch_size\", (1, 3))\n@pytest.mark.parametrize(\"n_embd\", (16, 32))\n@pytest.mark.parametrize(\"parallel_residual\", (False, True))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, device, dtype) -> None:\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config(\n        block_size=64,\n        vocab_size=100,\n        n_layer=4,\n        n_head=8,\n        n_embd=n_embd,\n        rotary_percentage=rotary_pct,\n        parallel_residual=parallel_residual,\n    )\n    assert ours_config.padded_vocab_size == 512\n    theirs_config = GPTNeoXConfig(\n        hidden_act=\"gelu\",\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        initializer_range=0.02,\n        intermediate_size=ours_config.intermediate_size,\n        layer_norm_eps=ours_config.norm_eps,\n        max_position_embeddings=ours_config.block_size,\n        rotary_emb_base=10000,\n        rotary_pct=ours_config.rotary_percentage,\n        vocab_size=ours_config.padded_vocab_size,\n        use_parallel_residual=ours_config.parallel_residual,\n        attn_implementation=\"eager\",\n    )\n\n    state_dict = {}\n    theirs_model = GPTNeoXForCausalLM(theirs_config).to(device)\n    # load the hf initialization into our model\n    copy_weights_gpt_neox(ours_config, state_dict, theirs_model.state_dict())\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    token_sample = torch.randint(\n        0, ours_config.padded_vocab_size, size=(batch_size, ours_config.block_size), dtype=torch.int64, device=device\n    )\n\n    theirs = theirs_model(token_sample)[\"logits\"]\n    ours = ours_model(token_sample)\n    torch.testing.assert_close(ours, theirs, rtol=1e-2, atol=1e-2)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"kwargs\",\n    [\n        dict(name=\"falcon-180B\", n_layer=2, n_head=8, n_query_groups=4, n_embd=32),\n        dict(name=\"falcon-40b\", n_layer=2, n_head=8, n_query_groups=4, n_embd=32),\n    ],\n)\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_hf_falcon(kwargs, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(**kwargs)\n    theirs_config = FalconConfig(\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_kv_heads=ours_config.n_query_groups,\n        num_hidden_layers=ours_config.n_layer,\n        parallel_attn=ours_config.parallel_residual,\n        vocab_size=ours_config.padded_vocab_size,\n        bias=ours_config.bias,\n        new_decoder_architecture=True,\n    )\n\n    theirs_model = FalconForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_falcon(ours_config, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"]\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_open_llama_3b(device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\"open_llama_3b\", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)\n    T = 5\n    theirs_config = LlamaConfig(\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = LlamaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"ours_kwargs\",\n    [\n        {\"name\": \"Llama-2-7b-hf\"},\n        {\"name\": \"CodeLlama-7b-hf\"},\n        {\"name\": \"Llama-2-70b-chat-hf\", \"n_query_groups\": 1},\n        {\"name\": \"Llama-3-8B\"},\n        {\"name\": \"Llama-3-8B-Instruct\"},\n        {\"name\": \"Llama-3.1-405B\", \"n_query_groups\": 4},\n        {\"name\": \"Llama-3.1-8B\"},\n        {\"name\": \"Llama-3.1-8B-Instruct\"},\n        {\"name\": \"Llama-3.2-1B\"},\n        {\"name\": \"Llama-3.2-3B\"},\n        {\"name\": \"Llama-3.3-70B-Instruct\"},\n        {\"name\": \"R1-Distill-Llama-8B\"},\n        {\"name\": \"R1-Distill-Llama-70B\"},\n    ],\n)\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_hf_llama_2_and_3(ours_kwargs, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs\n    )\n    T = 5\n    theirs_config = LlamaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = LlamaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"phi-1_5\", \"phi-2\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1)],\n        ),\n    ],\n)\ndef test_against_hf_phi(model_name, device, dtype):\n    from transformers.models.phi.configuration_phi import PhiConfig\n    from transformers.models.phi.modeling_phi import PhiForCausalLM\n\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5\n    )\n    T = 5\n    theirs_config = PhiConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        max_position_embeddings=ours_config.block_size,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        partial_rotary_factor=ours_config.rotary_percentage,\n        torch_dtype=dtype,\n    )\n\n    theirs_model = PhiForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"model_name\",\n    (\n        \"Phi-3-mini-4k-instruct\",\n        \"Phi-3-mini-128k-instruct\",\n        \"Phi-3.5-mini-instruct\",\n        \"phi-4\",\n        \"Phi-4-mini-instruct\",\n        \"Phi-4-reasoning\",\n        \"Phi-4-mini-reasoning\",\n    ),\n)\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[pytest.mark.xfail(raises=AssertionError, strict=False), _RunIf(min_cuda_gpus=1)],\n        ),\n    ],\n)\ndef test_against_hf_phi_3(model_name, device, dtype):\n    from transformers.models.phi3.configuration_phi3 import Phi3Config\n    from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM\n\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_head=4,\n        n_query_groups=4,\n        n_embd=256,\n    )\n    T = 5\n    theirs_config = Phi3Config(\n        attention_bias=ours_config.bias,\n        head_dim=ours_config.head_size,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        num_key_value_heads=ours_config.n_query_groups,\n        pad_token_id=ours_config.padded_vocab_size - 1,\n        partial_rotary_factor=ours_config.rotary_percentage,\n        rms_norm_eps=ours_config.norm_eps,\n        rope_theta=ours_config.rope_base,\n        torch_dtype=dtype,\n        vocab_size=ours_config.padded_vocab_size,\n    )\n\n    theirs_model = Phi3ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_phi(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\n@pytest.mark.parametrize(\"model_name\", [\"Mistral-7B-Instruct-v0.1\", \"Mistral-7B-v0.1\"])\ndef test_against_mistral_hf_models(device, dtype, model_name):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_embd=32,\n        n_head=8,\n        n_query_groups=2,\n        intermediate_size=86,\n    )\n\n    theirs_config = MistralConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attn_implementation=\"eager\",\n        sliding_window=ours_config.sliding_window_size,\n    )\n\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = MistralForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_mathstral_hf_models(device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        \"Mathstral-7B-v0.1\",\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_embd=32,\n        n_head=8,\n        n_query_groups=2,\n        intermediate_size=86,\n    )\n\n    T = 5\n    theirs_config = MistralConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n    )\n\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = MistralForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"Mixtral-8x7B-Instruct-v0.1\", \"Mixtral-8x22B-Instruct-v0.1\"))\ndef test_against_hf_mixtral(model_name):\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_embd=32,\n        n_head=8,\n        n_query_groups=2,\n        intermediate_size=86,\n        n_expert=4,\n    )\n    T = 5\n    theirs_config = MixtralConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        num_local_experts=ours_config.n_expert,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = MixtralForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"OLMo-1B-hf\", \"OLMo-7B-hf\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_olmo(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_head=8,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    T = 5\n    theirs_config = OlmoConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        num_hidden_layers=ours_config.n_layer,\n        num_attention_heads=ours_config.n_head,\n        num_key_value_heads=ours_config.n_query_groups,\n        max_positional_embeddings=T,\n        attention_bias=ours_config.bias,\n        rope_theta=ours_config.rope_base,\n        tie_word_embeddings=(model_name == \"OLMo-1B-hf\"),\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = OlmoForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"OLMo-2-1124-7B\", \"OLMo-2-1124-13B\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_olmo2(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_head=8,\n        n_embd=32,\n        n_query_groups=2,\n        intermediate_size=86,\n    )\n    T = 5\n    theirs_config = Olmo2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        num_hidden_layers=ours_config.n_layer,\n        num_attention_heads=ours_config.n_head,\n        num_key_value_heads=ours_config.n_query_groups,\n        max_positional_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        attention_bias=ours_config.bias,\n        rope_theta=ours_config.rope_base,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = Olmo2ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_olmo2(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_stablelm_zephyr_3b(device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 5\n    ours_config = Config.from_name(\"stablelm-zephyr-3b\", n_layer=2, n_head=16, n_embd=32, intermediate_size=86)\n    theirs_config = AutoConfig.from_pretrained(\n        \"stabilityai/stablelm-zephyr-3b\",\n        trust_remote_code=True,\n        num_hidden_layers=ours_config.n_layer,\n        num_attention_heads=ours_config.n_head,\n        num_key_value_heads=ours_config.n_head,\n        hidden_size=ours_config.n_embd,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        torch_dtype=dtype,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = AutoModelForCausalLM.from_config(theirs_config, trust_remote_code=True).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", [\"gemma-2b\", \"gemma-7b\"])\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 5\n    ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86)\n    theirs_config = GemmaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = GemmaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"gemma-2-9b\", \"gemma-2-27b\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma_2(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Gemma2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_logit_softcapping=ours_config.attention_logit_softcapping,\n        final_logit_softcapping=ours_config.final_logit_softcapping,\n        initializer_range=1.0,  # to make the affect of attention_logit_softcapping more prominent\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n    )\n\n    theirs_model = Gemma2ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_gemma_2({}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", [\"gemma-3-1b-it\", \"gemma-3-4b-it\", \"gemma-3-12b-it\", \"gemma-3-27b-it\"])\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_gemma_3(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n\n    theirs_config = Gemma3TextConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        sliding_window=ours_config.sliding_window_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n        tie_word_embeddings=True,\n        hidden_act=\"gelu_pytorch_tanh\",\n        attn_implementation=\"eager\",\n        query_pre_attn_scalar=ours_config.attention_scores_scalar,\n        rope_scaling={\"factor\": 8.0, \"rope_type\": \"linear\"},\n        rope_local_base_freq=ours_config.rope_local_base_freq,\n    )\n\n    theirs_model = Gemma3ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n\n    copy_weights_gemma_3({}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", [\"gemma-3-4b-it\", \"gemma-3-12b-it\", \"gemma-3-27b-it\"])\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_multimodal_gemma_3(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        sliding_window_size=T // 2,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n\n    theirs_config = Gemma3Config(\n        Gemma3TextConfig(\n            vocab_size=ours_config.padded_vocab_size,\n            hidden_size=ours_config.n_embd,\n            head_dim=ours_config.head_size,\n            num_attention_heads=ours_config.n_head,\n            num_hidden_layers=ours_config.n_layer,\n            intermediate_size=ours_config.intermediate_size,\n            max_position_embeddings=ours_config.block_size,\n            sliding_window=ours_config.sliding_window_size,\n            rms_norm_eps=ours_config.norm_eps,\n            num_key_value_heads=ours_config.n_query_groups,\n            rope_theta=ours_config.rope_base,\n            attention_bias=ours_config.bias,\n            tie_word_embeddings=True,\n            hidden_act=\"gelu_pytorch_tanh\",\n            attn_implementation=\"eager\",\n            query_pre_attn_scalar=ours_config.attention_scores_scalar,\n            rope_scaling={\"factor\": 8.0, \"rope_type\": \"linear\"},\n            rope_local_base_freq=ours_config.rope_local_base_freq,\n        )\n    )\n\n    theirs_model = Gemma3ForConditionalGeneration(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n\n    state_dict = {}\n\n    copy_weights_gemma_3({}, state_dict, theirs_state_dict, config=ours_config)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y, rtol=3e-5, atol=3e-5)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"model_name\", [\"Qwen2.5-1.5B\", \"Qwen2.5-Coder-1.5B\", \"Qwen2.5-Math-1.5B\", \"QwQ-32B-Preview\", \"QwQ-32B\"]\n)\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_qwen_2_5(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Qwen2Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.attn_bias,\n        tie_word_embeddings=True,\n    )\n\n    theirs_model = Qwen2ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    # Gemma weights are shipped without `lm_head.weight`\n    theirs_state_dict.pop(\"lm_head.weight\")\n    state_dict = {}\n    copy_weights_qwen_2_5(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"Qwen3-0.6B\",\n        \"Qwen3-8B\",\n        \"Qwen3-4B-Base\",\n        \"Qwen3-14B-Base\",\n        \"Qwen3-32B\",\n        \"Qwen3-4B-Thinking-2507\",\n        \"Qwen3-4B-Instruct-2507\",\n    ],\n)\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_qwen_3(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n    )\n    theirs_config = Qwen3Config(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        tie_word_embeddings=False,\n    )\n\n    theirs_model = Qwen3ForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"model_name\", [\"Qwen3-30B-A3B\", \"Qwen3-235B-A22B\", \"Qwen3-235B-A22B-Thinking-2507\", \"Qwen3-235B-A22B-Instruct-2507\"]\n)\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_qwen_3_moe(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    T = 20\n    ours_config = Config.from_name(\n        model_name,\n        block_size=T,\n        n_layer=2,\n        n_head=16,\n        n_embd=32,\n        intermediate_size=86,\n        moe_intermediate_size=20,\n        n_expert=4,\n        n_expert_per_token=2,\n    )\n    theirs_config = Qwen3MoeConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        head_dim=ours_config.head_size,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        moe_intermediate_size=ours_config.moe_intermediate_size,\n        max_position_embeddings=ours_config.block_size,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        tie_word_embeddings=False,\n        num_experts=ours_config.n_expert,\n        num_experts_per_tok=ours_config.n_expert_per_token,\n        norm_topk_prob=True,\n    )\n\n    theirs_model = Qwen3MoeForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_qwen_3(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"salamandra-2b\", \"salamandra-7b\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_salamandra(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_head=8,\n        n_embd=32,\n        n_query_groups=2,\n        intermediate_size=86,\n    )\n    T = 5\n    theirs_config = LlamaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = LlamaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"SmolLM2-135M\", \"SmolLM2-360M\", \"SmolLM2-1.7B\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_original_smollm2(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_head=8,\n        n_embd=32,\n        n_query_groups=2,\n        intermediate_size=86,\n    )\n    T = 5\n    theirs_config = LlamaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = LlamaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"model_name\", (\"Falcon3-1B-Base\", \"Falcon3-7B-Base\"))\n@pytest.mark.parametrize(\n    (\"device\", \"dtype\"),\n    [\n        (torch.device(\"cpu\"), torch.float32),\n        pytest.param(\n            torch.device(\"cuda\"),\n            torch.float16,\n            marks=[\n                # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input\n                # is slightly different\n                pytest.mark.xfail(raises=AssertionError, strict=False),\n                _RunIf(min_cuda_gpus=1),\n            ],\n        ),\n    ],\n)\ndef test_against_hf_falcon3(model_name, device, dtype):\n    torch.set_default_dtype(dtype)\n\n    ours_config = Config.from_name(\n        model_name,\n        padded_vocab_size=10000,\n        n_layer=2,\n        n_head=8,\n        n_embd=32,\n        n_query_groups=2,\n        intermediate_size=86,\n    )\n    T = 5\n    theirs_config = LlamaConfig(\n        vocab_size=ours_config.padded_vocab_size,\n        hidden_size=ours_config.n_embd,\n        num_attention_heads=ours_config.n_head,\n        num_hidden_layers=ours_config.n_layer,\n        intermediate_size=ours_config.intermediate_size,\n        max_position_embeddings=T,\n        rms_norm_eps=ours_config.norm_eps,\n        num_key_value_heads=ours_config.n_query_groups,\n        rope_theta=ours_config.rope_base,\n        attention_bias=ours_config.bias,\n    )\n    assert ours_config.intermediate_size == theirs_config.intermediate_size\n\n    theirs_model = LlamaForCausalLM(theirs_config).to(device)\n    theirs_state_dict = theirs_model.state_dict()\n    state_dict = {}\n    copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)\n    ours_model = GPT(ours_config).to(device)\n    ours_model.load_state_dict(state_dict)\n\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n\n\n@_RunIf(dynamo=True)\n@torch.inference_mode()\ndef test_model_compile():\n    model = GPT.from_name(\"pythia-14m\", n_layer=3)\n    x = torch.randint(model.config.vocab_size, size=(2, model.config.block_size), dtype=torch.int64)\n\n    explanation = torch._dynamo.explain(model)(x)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n    model = GPT(model.config)\n    model.set_kv_cache(2)\n    input_pos = torch.arange(model.config.block_size)\n    explanation = torch._dynamo.explain(model)(x, input_pos)\n    assert isinstance(explanation, debugging.ExplainOutput)\n    assert explanation.graph_count == 1\n    assert explanation.graph_break_count == 0\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\n    \"max_seq_length\", (25, pytest.param(23, marks=pytest.mark.xfail(raises=IndexError, strict=True)))\n)\n@pytest.mark.flaky(reruns=5)\ndef test_kv_cache(max_seq_length):\n    config = Config(block_size=25, padded_vocab_size=5, n_layer=2, n_head=2, n_embd=8)\n    model = GPT(config)\n    idx = torch.randint(0, model.config.padded_vocab_size, (1, 5))\n    max_new_tokens = 20\n    model.max_seq_length = max_seq_length\n    model.set_kv_cache(1)\n\n    def generate(logits):\n        logits = logits[:, -1:]\n        probs = torch.nn.functional.softmax(logits, dim=-1)\n        return torch.argmax(probs).unsqueeze(0).unsqueeze(0)\n\n    x_no_cache = idx\n    x_cache = idx\n    input_pos = torch.arange(0, 5)\n    for _ in range(max_new_tokens):\n        logits_no_cache = model(x_no_cache[:, -max_seq_length:])\n        out_no_cache = generate(logits_no_cache)\n\n        logits_cache = model(x_cache, input_pos)\n        out_cache = generate(logits_cache)\n\n        torch.testing.assert_close(out_no_cache, out_cache, rtol=0, atol=0)\n\n        x_no_cache = torch.cat((x_no_cache, out_no_cache), dim=1)\n        x_cache = out_cache\n        input_pos = input_pos[-1:] + 1\n\n\n@torch.inference_mode()\ndef test_model_kv_cache_amp():\n    config = Config.from_name(\"pythia-14m\", n_layer=2)\n    model = GPT(config)\n    encoded = torch.arange(45)\n    model.set_kv_cache(batch_size=1)\n    with torch.autocast(\"cpu\", torch.bfloat16):\n        output = model(encoded.unsqueeze(0), encoded)\n    assert output.dtype is torch.bfloat16\n\n\n@pytest.mark.parametrize(\"model_name\", [\"pythia-14m\", \"gemma-3-1b-it\"])\ndef test_rope_cache_length(model_name):\n    config = Config.from_name(model_name, n_layer=2)\n    model = GPT(config)\n    model.max_seq_length = 128\n\n    rope_len = model.rope_cache_length()\n    assert rope_len == config.rope_n_elem\n\n    # Verify it works with set_kv_cache\n    model.set_kv_cache(batch_size=1)\n    assert model.transformer.h[0].attn.kv_cache is not None\n\n\n# https://github.com/pytorch/pytorch/blob/ad3572a5d/torch/testing/_internal/common_cuda.py#L31-L34\nSUPPORTS_FLASH_ATTENTION = (\n    torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) and not _IS_WINDOWS\n)\n\n\n@_RunIf(min_cuda_gpus=1)\n@pytest.mark.parametrize(\"config\", deepcopy(config_module.configs), ids=[c[\"name\"] for c in config_module.configs])\n@torch.inference_mode()\ndef test_sdpa_choice(config):\n    if config[\"name\"].startswith(\"Gemma-2-\"):\n        pytest.skip(\"Gemma 2 doesn't support SDPA\")\n\n    torch.set_default_dtype(torch.float16)\n\n    def assert_sdpa_backend(original_fn, q, k, v, mask):\n        # SDPAParams gained an additional argument in PyTorch 2.5\n        args = []\n        if hasattr(SDPAParams, \"enable_gqa\"):\n            args.append(False)\n        params = SDPAParams(q, k, v, mask, 0.0, True, *args)\n        if expected is SDPBackend.FLASH_ATTENTION:\n            assert flash_sdp_enabled(), \"flash_sdp_enabled() is False\"\n            if config.sliding_window_size is None:\n                assert can_use_flash_attention(params, True), \"can_use_flash_attention(params, True) is False\"\n        elif expected is SDPBackend.EFFICIENT_ATTENTION:\n            assert mem_efficient_sdp_enabled(), \"mem_efficient_sdp_enabled() is False\"\n            assert can_use_efficient_attention(params, True), \"can_use_efficient_attention(params, True) is False\"\n        elif expected is SDPBackend.MATH:\n            assert math_sdp_enabled(), \"math_sdp_enabled() is False\"\n        else:\n            raise NotImplementedError\n        return original_fn(q, k, v, mask)\n\n    config[\"n_layer\"] = 1\n    config = config_module.Config(**config)\n\n    try:\n        with torch.device(\"cuda\"):\n            model = GPT(config)\n            x = torch.randint(0, 10, (2, 16), dtype=torch.int32)\n    except torch.cuda.OutOfMemoryError:\n        # best effort, if the GPU can load it\n        pytest.xfail()\n\n    for h in model.transformer.h:\n        h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention)\n\n    if SUPPORTS_FLASH_ATTENTION:\n        expected = SDPBackend.FLASH_ATTENTION\n        with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):\n            model(x)\n\n    expected = SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 else SDPBackend.MATH\n    with torch.backends.cuda.sdp_kernel(enable_flash=False):\n        model(x)\n\n\n@_RunIf(min_cuda_gpus=1)\n@pytest.mark.parametrize(\"config\", deepcopy(config_module.configs), ids=[c[\"name\"] for c in config_module.configs])\n@torch.inference_mode()\ndef test_sdpa_choice_kv_cache(config):\n    torch.set_default_dtype(torch.float16)\n\n    def assert_sdpa_backend(original_fn, q, k, v, mask):\n        # SDPAParams gained an additional argument in PyTorch 2.5\n        args = []\n        if hasattr(SDPAParams, \"enable_gqa\"):\n            args.append(False)\n        params = SDPAParams(q, k, v, mask, 0.0, True, *args)\n        if expected is SDPBackend.FLASH_ATTENTION:\n            assert flash_sdp_enabled()\n            assert can_use_flash_attention(params, True)\n        elif expected is SDPBackend.EFFICIENT_ATTENTION:\n            assert mem_efficient_sdp_enabled()\n            assert can_use_efficient_attention(params, True)\n        elif expected is SDPBackend.MATH:\n            assert math_sdp_enabled()\n        else:\n            raise NotImplementedError\n        return original_fn(q, k, v, mask)\n\n    config[\"n_layer\"] = 1\n    config = config_module.Config(**config)\n\n    try:\n        with torch.device(\"cuda\"):\n            model = GPT(config)\n            model.max_seq_length = 1\n            model.set_kv_cache(2)\n            x = torch.randint(0, 10, (2, 1), dtype=torch.int32)\n            input_pos = torch.tensor([0], dtype=torch.long)\n    except torch.cuda.OutOfMemoryError:\n        # best effort, if the GPU can load it\n        pytest.xfail()\n\n    for h in model.transformer.h:\n        h.attn.scaled_dot_product_attention = partial(assert_sdpa_backend, h.attn.scaled_dot_product_attention)\n\n    if SUPPORTS_FLASH_ATTENTION:\n        # flash attention does not support an attention mask\n        expected = SDPBackend.MATH\n        with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):\n            model(x, input_pos)\n\n    expected = (\n        SDPBackend.EFFICIENT_ATTENTION if config.head_size % 8 == 0 and config.n_query_groups != 1 else SDPBackend.MATH\n    )\n    with torch.backends.cuda.sdp_kernel(enable_flash=False):\n        model(x, input_pos)\n\n\n@_RunIf(min_cuda_gpus=2, standalone=True)\ndef test_rope_init_under_fsdp():\n    \"\"\"Check that the rope cache is properly initialized\"\"\"\n    fabric = Fabric(devices=2, strategy=\"fsdp\", accelerator=\"cuda\")\n    fabric.launch()\n\n    with fabric.init_module(empty_init=True):\n        model = GPT.from_name(\"pythia-14m\", n_layer=1)\n    assert model.cos.device.type == \"meta\"\n    assert model.sin.device.type == \"meta\"\n\n    model = fabric.setup(model)\n    assert model.cos.device.type == \"cuda\"\n    assert model.sin.device.type == \"cuda\"\n    cos, sin = model.rope_cache(device=fabric.device)\n    torch.testing.assert_close(model.cos, cos)\n    torch.testing.assert_close(model.sin, sin)\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_reset_parameters_device():\n    with torch.device(\"meta\"):\n        model = GPT.from_name(\"pythia-14m\", n_layer=1)\n    _materialize_meta_tensors(model, torch.device(\"cuda\"))\n    model.reset_parameters()\n    assert model.cos.device.type == \"cuda\"\n\n\ndef test_batched_index_copy_modes():\n    # Mock the torch.backends.mps.is_available() function to simulate MPS availability\n    with mock.patch(\"torch.backends.mps.is_available\", return_value=True):\n        # Mock the device type to simulate the \"mps\" device\n        with mock.patch(\"torch.Tensor.device\", new_callable=mock.PropertyMock) as mock_device:\n            mock_device.return_value = torch.device(\"mps\")\n\n            # Test case when idx.dim() == 1\n            t_original_1 = torch.randn(3, 5)\n            dim_1 = 0\n            idx_1 = torch.tensor([0, 2])\n            val_1 = torch.randn(2, 5)\n\n            t1_cpu = t_original_1.clone()\n            t1_mps = t_original_1.clone()\n\n            # Perform the index copy on CPU\n            batched_index_copy_(t1_cpu, dim_1, idx_1, val_1)\n\n            # Simulate the MPS index copy\n            idx_1_mps = idx_1\n            val_1_mps = val_1\n            batched_index_copy_(t1_mps, dim_1, idx_1_mps, val_1_mps)\n            assert torch.allclose(t1_cpu, t1_mps), \"Mismatch with idx.dim() == 1 on mocked MPS\"\n\n            # Test case when idx.dim() == 2\n            t_original_2 = torch.randn(2, 5, 4)\n            dim_2 = 1\n            idx_2 = torch.tensor([[0, 2], [1, 3]])\n            val_2 = torch.randn(2, 2, 4)\n\n            t2_cpu = t_original_2.clone()\n            t2_mps = t_original_2.clone()\n\n            # Perform the index copy on CPU\n            batched_index_copy_(t2_cpu, dim_2, idx_2, val_2)\n\n            # Simulate the MPS index copy\n            idx_2_mps = idx_2\n            val_2_mps = val_2\n            batched_index_copy_(t2_mps, dim_2, idx_2_mps, val_2_mps)\n            assert torch.allclose(t2_cpu, t2_mps), \"Mismatch with idx.dim() == 2 on mocked MPS\"\n\n            # Additional test with negative dimension\n            t_original_3 = torch.randn(2, 3, 4)\n            dim_3 = -2\n            idx_3 = torch.tensor([[0, 1], [1, 2]])\n            val_3 = torch.randn(2, 2, 4)\n\n            t3_cpu = t_original_3.clone()\n            t3_mps = t_original_3.clone()\n\n            # Perform the index copy on CPU\n            batched_index_copy_(t3_cpu, dim_3, idx_3, val_3)\n\n            # Simulate the MPS index copy\n            idx_3_mps = idx_3\n            val_3_mps = val_3\n            batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps)\n            assert torch.allclose(t3_cpu, t3_mps), \"Mismatch with negative dimension on mocked MPS\"\n\n\ndef test_load_legacy_state_dict():\n    \"\"\"Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.\"\"\"\n    config = Config(\n        n_embd=32,\n        n_head=4,\n        head_size=8,\n        n_query_groups=4,\n        bias=True,\n    )\n\n    attention_1 = CausalSelfAttention(config=config, block_idx=0)\n\n    # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight`\n    # and make them interleaved\n    state_dict = deepcopy(attention_1.state_dict())\n    state_dict[\"attn.weight\"] = make_qkv_interleaved(state_dict.pop(\"qkv.weight\"), config)\n    state_dict[\"attn.bias\"] = make_qkv_interleaved(state_dict.pop(\"qkv.bias\"), config)\n\n    attention_2 = CausalSelfAttention(config=config, block_idx=0)\n    attention_2.load_state_dict(state_dict)\n\n\n@pytest.mark.parametrize(\"n_query_groups\", (1, 2, 4, 8))\n@torch.inference_mode()\ndef test_kv_cache_buffer_shape(n_query_groups):\n    batch_size = 3\n    max_seq_length = 23\n    config = Config(\n        block_size=25,\n        padded_vocab_size=5,\n        n_layer=2,\n        n_head=8,\n        n_embd=16,\n        n_query_groups=n_query_groups,\n    )\n    model = GPT(config)\n    model.max_seq_length = max_seq_length\n    model.set_kv_cache(batch_size)\n    required_shape = (batch_size, n_query_groups, max_seq_length, config.head_size)\n    for block in model.transformer.h:\n        kv_cache = block.attn.kv_cache\n        assert kv_cache is not None\n        assert kv_cache.k.shape == required_shape\n        assert kv_cache.v.shape == required_shape\n\n\n@pytest.mark.parametrize((\"rotary_percentage\", \"final_dim\"), ((0.75, 3), (0.25, 2)))\n@torch.inference_mode()\ndef test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(rotary_percentage, final_dim):\n    batch_size = 3\n    config = Config(\n        block_size=25,\n        padded_vocab_size=5,\n        n_layer=2,\n        n_head=4,\n        n_embd=16,\n        rotary_percentage=rotary_percentage,\n    )\n    model = GPT(config)\n    required_shape = (config.block_size, final_dim)\n    assert model.cos.shape == required_shape\n    assert model.sin.shape == required_shape\n\n\ndef test_forward_with_without_input_pos_maxp1():\n    batch_size = 3\n    config = Config(\n        block_size=25,\n        padded_vocab_size=5,\n        n_layer=2,\n        n_head=8,\n        n_embd=16,\n    )\n    model = GPT(config)\n    model.set_kv_cache(batch_size)\n    idx = torch.randint(0, config.padded_vocab_size, (1, 10))\n    input_pos = torch.arange(1, 11)\n    input_pos_maxp1 = 11\n    logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1)\n    logits_no_maxp1 = model(idx, input_pos)\n    torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)\n"
  },
  {
    "path": "tests/test_multihead_latent_attention.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport pytest\nimport torch\nfrom transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM\n\nfrom litgpt import Config\nfrom litgpt.model import MultiheadLatentAttention\n\n\n@torch.inference_mode()\ndef test_multihead_latent_attention_kv_cache():\n    \"\"\"Test KV cache functionality\"\"\"\n    config = Config(\n        block_size=32,\n        n_embd=64,\n        n_head=4,\n        n_query_groups=4,\n        head_size=16,\n        latent_attention={\n            \"q_lora_rank\": 32,\n            \"kv_lora_rank\": 16,\n            \"qk_rope_head_dim\": 8,\n            \"qk_nope_head_dim\": 8,\n            \"v_head_dim\": 16,\n        },\n    )\n\n    mla = MultiheadLatentAttention(config, block_idx=0)\n\n    # Build KV cache\n    kv_cache = mla.build_kv_cache(batch_size=2, max_seq_length=32, device=torch.device(\"cpu\"), dtype=torch.float32)\n\n    # Check cache shapes\n    assert kv_cache.k.shape == (2, config.n_head, 32, config.qk_head_dim)\n    assert kv_cache.v.shape == (2, config.n_head, 32, config.v_head_dim)\n\n\n@torch.inference_mode()\ndef test_multihead_latent_attention_with_mask():\n    \"\"\"Test attention with causal mask\"\"\"\n    config = Config(\n        n_embd=64,\n        n_head=4,\n        n_query_groups=4,\n        head_size=16,\n        latent_attention={\n            \"q_lora_rank\": 32,\n            \"kv_lora_rank\": 16,\n            \"qk_rope_head_dim\": 8,\n            \"qk_nope_head_dim\": 8,\n            \"v_head_dim\": 16,\n        },\n    )\n\n    mla = MultiheadLatentAttention(config, block_idx=0)\n\n    batch_size, seq_len = 1, 8\n    x = torch.randn(batch_size, seq_len, config.n_embd)\n    cos = torch.randn(1, seq_len, config.qk_rope_head_dim)\n    sin = torch.randn(1, seq_len, config.qk_rope_head_dim)\n\n    # Create causal mask\n    mask = torch.ones(seq_len, seq_len, dtype=x.dtype).triu(diagonal=1)\n    mask.masked_fill_(mask.bool(), float(\"-inf\"))\n    mask = mask.view(1, 1, seq_len, seq_len)\n\n    # Forward pass with mask\n    output = mla(x, cos, sin, mask=mask)\n\n    assert output.shape == (batch_size, seq_len, config.n_embd)\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"batch_size\", (1, 2))\n@pytest.mark.parametrize(\"seq_len\", (8, 16))\n@pytest.mark.parametrize(\"device\", [torch.device(\"cpu\")])\ndef test_multihead_latent_attention_litgpt_vs_hf(batch_size, seq_len, device):\n    \"\"\"Test MLA litgpt vs hf\"\"\"\n    config_litgpt = Config(\n        n_embd=64,\n        n_head=4,\n        n_query_groups=4,\n        head_size=16,\n        norm_eps=1e-6,\n        bias=False,\n        latent_attention={\n            \"q_lora_rank\": 32,\n            \"kv_lora_rank\": 16,\n            \"qk_rope_head_dim\": 8,\n            \"qk_nope_head_dim\": 8,\n            \"v_head_dim\": 16,\n        },\n    )\n\n    config_hf = DeepseekV3Config(\n        padded_vocab_size=10000,\n        num_hidden_layers=1,\n        vocab_size=10000,\n        hidden_size=64,\n        num_attention_heads=4,\n        num_key_value_heads=4,\n        q_lora_rank=32,\n        kv_lora_rank=16,\n        qk_rope_head_dim=8,\n        qk_nope_head_dim=8,\n        v_head_dim=16,\n        rope_interleave=False,\n    )\n\n    mla_litgpt = MultiheadLatentAttention(config_litgpt, block_idx=0).to(device)\n    model_hf = DeepseekV3ForCausalLM(config_hf).to(device)\n    mla_hf = model_hf.model.layers[0].self_attn\n\n    mla_litgpt.eval()\n    mla_hf.eval()\n\n    sync_weights(mla_litgpt, mla_hf)\n\n    hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)\n\n    # Prepare RoPE sin/cos tables\n    rope_head_dim = config_litgpt.latent_attention[\"qk_rope_head_dim\"]\n    cos = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype)\n    sin = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype)\n\n    causal_mask = torch.triu(\n        torch.full((seq_len, seq_len), float(\"-inf\"), device=device, dtype=hidden_states.dtype), diagonal=1\n    )\n    attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)\n\n    # Run forward passes\n    output_litgpt = mla_litgpt(hidden_states, cos, sin)\n    output_hf = mla_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask)[0]\n\n    assert torch.allclose(output_litgpt, output_hf, atol=1e-5)\n\n\ndef sync_weights(litgpt_model, hf_model):\n    \"\"\"Copies weights from lit-gpt model to HF model.\"\"\"\n    print(\"Synchronizing weights...\")\n    with torch.no_grad():\n        hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight)\n        hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight)\n        hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight)\n        hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight)\n        hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight)\n        hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight)\n        hf_model.o_proj.weight.copy_(litgpt_model.proj.weight)\n    print(\"Synchronization complete.\")\n"
  },
  {
    "path": "tests/test_pretrain.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nfrom contextlib import redirect_stdout\nfrom io import StringIO\nfrom unittest import mock\nfrom unittest.mock import ANY, Mock\n\nimport pytest\nimport torch\nfrom lightning.fabric.strategies import FSDPStrategy, SingleDeviceStrategy\nfrom torch.utils.data import DataLoader\n\nfrom litgpt import pretrain\nfrom litgpt.args import EvalArgs, TrainArgs\nfrom litgpt.config import Config\nfrom litgpt.pretrain import initialize_weights\nfrom litgpt.utils import _RunIf\n\n\n@_RunIf(min_cuda_gpus=1, standalone=True)\n@mock.patch(\"litgpt.pretrain.save_hyperparameters\")\ndef test_optimizer_args(_, tmp_path):\n    model_config = Config(block_size=2, n_layer=2, n_embd=4, n_head=2, padded_vocab_size=8)\n\n    dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])\n    dataloader = DataLoader(dataset)\n    pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))\n\n    for i in (\"AdamW\", \"SGD\", \"RMSprop\"):\n        pretrain.setup(\n            \"pythia-14m\",\n            devices=1,\n            optimizer=\"RMSprop\",\n            model_config=model_config,\n            out_dir=tmp_path,\n            train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),\n            eval=EvalArgs(interval=1, max_iters=1, final_validation=False),\n        )\n\n\n@_RunIf(min_cuda_gpus=2, standalone=True)\n# If we were to use `save_hyperparameters()`, we would have to patch `sys.argv` or otherwise\n# the CLI would capture pytest args, but unfortunately patching would mess with subprocess\n# launching, so we need to mock `save_hyperparameters()`\n@mock.patch(\"litgpt.pretrain.save_hyperparameters\")\n# todo: it expects exactly 2 GPUs and has strange failing for validated 4 # GPUs, so we temporarily mark it as xfail\n@pytest.mark.xfail(condition=torch.cuda.device_count() != 2, reason=\"This test is flaky, expects exactly 2 GPUs\")\ndef test_pretrain(_, tmp_path):\n    model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)\n\n    dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])\n    dataloader = DataLoader(dataset)\n    pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))\n\n    out_dir = tmp_path / \"out\"\n    stdout = StringIO()\n    with redirect_stdout(stdout):\n        pretrain.setup(\n            \"pythia-14m\",\n            devices=2,\n            model_config=model_config,\n            out_dir=out_dir,\n            train=TrainArgs(global_batch_size=2, max_tokens=16, save_interval=1, micro_batch_size=1, max_norm=1.0),\n            eval=EvalArgs(interval=1, max_iters=1, final_validation=False),\n        )\n\n    if torch.distributed.get_rank() == 0:\n        # tmp_path is not the same across all ranks, run assert only on rank 0\n        out_dir_contents = set(os.listdir(out_dir))\n        checkpoint_dirs = {\"step-00000001\", \"step-00000002\", \"step-00000003\", \"step-00000004\", \"final\"}\n        assert checkpoint_dirs.issubset(out_dir_contents)\n        assert all((out_dir / p).is_dir() for p in checkpoint_dirs)\n        for checkpoint_dir in checkpoint_dirs:\n            # the `tokenizer_dir` is None by default, so only 'lit_model.pth' shows here\n            assert set(os.listdir(out_dir / checkpoint_dir)) == {\"lit_model.pth\", \"model_config.yaml\"}\n\n        assert (out_dir / \"logs\" / \"tensorboard\" / \"version_0\").is_dir()\n\n        # logs only appear on rank 0\n        logs = stdout.getvalue()\n        assert logs.count(\"(step)\") == 4\n        assert logs.count(\"val loss\") == 4\n        assert \"Total parameters: 1,888\" in logs\n\n    torch.distributed.barrier()\n\n\n@_RunIf(min_cuda_gpus=2, standalone=True)\n@mock.patch(\"litgpt.pretrain.L.Fabric.load_raw\")\n# See comment in `test_pretrain` why we need to mock `save_hyperparameters()`\n@mock.patch(\"litgpt.pretrain.save_hyperparameters\")\ndef test_initial_checkpoint_dir(_, load_mock, tmp_path):\n    model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)\n\n    dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])\n    dataloader = DataLoader(dataset)\n    pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))\n    pretrain.fit = Mock()\n\n    pretrain.setup(\n        \"pythia-14m\",\n        initial_checkpoint_dir=tmp_path,\n        devices=torch.cuda.device_count(),\n        model_config=model_config,\n        out_dir=tmp_path,\n    )\n\n    load_mock.assert_called_once_with(tmp_path / \"lit_model.pth\", ANY)\n\n\n@pytest.mark.parametrize((\"strategy\", \"expected\"), [(SingleDeviceStrategy, True), (FSDPStrategy, False)])\ndef test_initialize_weights(strategy, expected):\n    fabric_mock = Mock()\n    fabric_mock.strategy = Mock(spec=strategy)\n\n    class Child(torch.nn.Module):\n        pass\n\n    class Parent(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.child = Child()\n\n    model = Parent()\n    model.reset_parameters = Mock()\n    model.child.reset_parameters = Mock()\n\n    initialize_weights(fabric_mock, model, n_layer=2, n_embd=8)\n    assert model.reset_parameters.call_count == int(expected)\n    assert model.child.reset_parameters.call_count == int(expected)\n"
  },
  {
    "path": "tests/test_prompts.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nfrom typing import Optional\n\nimport pytest\nimport yaml\n\nimport litgpt.config\nfrom litgpt import Config\nfrom litgpt.prompts import (\n    Alpaca,\n    Default,\n    Llama3,\n    Phi3,\n    PromptStyle,\n    has_prompt_style,\n    load_prompt_style,\n    prompt_styles,\n    save_prompt_style,\n)\n\n\ndef test_default_prompt_style(mock_tokenizer):\n    prompt_style = Default()\n    prompt = \"This is a test prompt.\"\n    assert prompt_style.apply(prompt) == prompt\n    assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],)\n\n\n@pytest.mark.parametrize(\"sys_prompt\", [None, \"You are a helpful coding assistant.\"])\ndef test_sys_prompt(mock_tokenizer, sys_prompt: Optional[str]):\n    prompt_style = Phi3()\n    prompt = \"This is a test prompt.\"\n    default_sys_prompt = \"You are a helpful assistant.\"\n    response = f\"<|system|>\\n{sys_prompt or default_sys_prompt}<|end|>\\n<|user|>\\n{prompt}<|end|>\\n<|assistant|>\\n\"\n    assert prompt_style.apply(prompt, sys_prompt=sys_prompt) == response\n    assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],)\n\n\n@pytest.mark.parametrize(\"sys_prompt\", [None, \"You are a helpful coding assistant.\"])\ndef test_sys_prompt_with_kwargs(mock_tokenizer, sys_prompt: Optional[str]):\n    prompt_style = Phi3()\n    prompt = \"This is a test prompt.\"\n    default_sys_prompt = \"You are a helpful assistant.\"\n    response = f\"<|system|>\\n{sys_prompt or default_sys_prompt}<|end|>\\n<|user|>\\n{prompt}<|end|>\\n<|assistant|>\\n\"\n    assert prompt_style.apply(prompt, sys_prompt=sys_prompt, test=1) == response\n    assert prompt_style.stop_tokens(mock_tokenizer) == ([mock_tokenizer.eos_id],)\n\n\ndef test_prompt_style_from_name():\n    for style_name in prompt_styles:\n        assert isinstance(PromptStyle.from_name(style_name), prompt_styles[style_name])\n\n\ndef test_prompt_style_from_config():\n    model_names = [\n        \"stablelm-tuned-alpha-3b\",\n        \"stablelm-tuned-alpha-7b\",\n        \"stablelm-zephyr-3b\",\n        \"stablecode-instruct-alpha-3b\",\n        \"falcon-7b-instruct\",\n        \"falcon-40b-instruct\",\n        \"Llama-2-7b-chat-hf\",\n        \"Llama-2-13b-chat-hf\",\n        \"Llama-2-70b-chat-hf\",\n        \"Llama-3-8B-Instruct\",\n        \"Llama-3-70B-Instruct\",\n        \"Llama-3.1-405B-Instruct\",\n        \"Gemma-2b-it\",\n        \"Gemma-7b-it\",\n        \"FreeWilly2\",\n        \"CodeLlama-7b-Instruct-hf\",\n        \"CodeLlama-13b-Instruct-hf\",\n        \"CodeLlama-34b-Instruct-hf\",\n        \"CodeLlama-70b-Instruct-hf\",\n        \"phi-1_5\",\n        \"phi-2\",\n        \"Phi-3-mini-4k-instruct\",\n        \"Mistral-7B-Instruct-v0.1\",\n        \"Mistral-7B-Instruct-v0.2\",\n        \"tiny-llama-1.1b-chat\",\n        \"Llama-2-7b-chat-hf-function-calling-v2\",\n    ]\n\n    for c in litgpt.config.platypus:\n        model_names.append(c[\"name\"])\n\n    for model_name in model_names:\n        # by asserting the returned style is not the Default, we show that at least one of the regex patterns matched\n        assert not isinstance(PromptStyle.from_config(Config.from_name(model_name)), Default)\n\n\ndef test_apply_prompts():\n    prompt = \"Is a coconut a nut or a fruit?\"\n    inp = \"Optional input\"\n\n    for style in prompt_styles.values():\n        output = style().apply(prompt, input=inp)\n        assert prompt in output\n        if isinstance(style, Alpaca):\n            assert inp in output\n\n\nclass CustomPromptStyle(PromptStyle):\n    def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs) -> str:\n        return prompt\n\n\ndef test_save_load_prompt_style(tmp_path):\n    # Save and load a built-in style\n    checkpoint_dir = tmp_path / \"checkpoint\"\n    checkpoint_dir.mkdir()\n    assert not has_prompt_style(checkpoint_dir)\n    save_prompt_style(\"alpaca\", checkpoint_dir)\n    assert has_prompt_style(checkpoint_dir)\n    with open(checkpoint_dir / \"prompt_style.yaml\", encoding=\"utf-8\") as file:\n        contents = yaml.safe_load(file)\n    assert contents == {\"class_path\": \"litgpt.prompts.Alpaca\"}\n    loaded = load_prompt_style(checkpoint_dir)\n    assert isinstance(loaded, Alpaca)\n\n    # Save a custom style\n    checkpoint_dir = tmp_path / \"custom\"\n    checkpoint_dir.mkdir()\n    save_prompt_style(CustomPromptStyle(), checkpoint_dir)\n    with open(checkpoint_dir / \"prompt_style.yaml\", encoding=\"utf-8\") as file:\n        contents = yaml.safe_load(file)\n    assert contents == {\"class_path\": \"test_prompts.CustomPromptStyle\"}\n    loaded = load_prompt_style(checkpoint_dir)\n    assert isinstance(loaded, CustomPromptStyle)\n\n\ndef test_multiturn_prompt():\n    prompt = \"What is the capital of France?\"\n    msgs = [{\"role\": \"user\", \"content\": prompt}]\n    style = Llama3()\n    simple_output = style.apply(prompt)\n    multiturn_output = style.apply(msgs)\n    assert simple_output == multiturn_output\n\n    # override system prompt\n    msgs = [{\"role\": \"system\", \"content\": \"You are not a helpful assistant.\"}, {\"role\": \"user\", \"content\": prompt}]\n    with_system_multiturn_output = style.apply(msgs)\n    assert \"You are not a helpful assistant.\" in with_system_multiturn_output\n\n    # use default system prompt\n    msgs = [\n        {\"role\": \"user\", \"content\": prompt},\n    ]\n    wo_system_multiturn_output = style.apply(msgs)\n    assert \"You are a helpful assistant.\" in wo_system_multiturn_output\n\n    # Longer turn\n    msgs = [\n        {\"role\": \"system\", \"content\": \"You are a helpful AI assistant for travel tips and recommendations\"},\n        {\"role\": \"user\", \"content\": \"What is France's capital?\"},\n        {\"role\": \"assistant\", \"content\": \"Bonjour! The capital of France is Paris!\"},\n        {\"role\": \"user\", \"content\": \"What can I do there?\"},\n    ]\n    multiturn_output = style.apply(msgs)\n\n    assert (\n        multiturn_output\n        == \"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant for travel tips and recommendations<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nBonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\"\"\"\n    )\n\n    # Longer list without \"system\"\n    msgs = [\n        {\"role\": \"user\", \"content\": \"What is France's capital?\"},\n        {\"role\": \"assistant\", \"content\": \"Bonjour! The capital of France is Paris!\"},\n        {\"role\": \"user\", \"content\": \"What can I do there?\"},\n    ]\n    multiturn_output = style.apply(msgs)\n\n    assert (\n        multiturn_output\n        == \"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nBonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\"\"\"\n    )\n\n    # {random} string format shouldn't lead to key error\n    content = \"this is {random} {system} {user}\"\n    msgs = [{\"role\": \"user\", \"content\": content}]\n    output = style.apply(msgs)\n    simple_output = style.apply(content)\n    assert output == simple_output\n"
  },
  {
    "path": "tests/test_readme.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nimport platform\nimport subprocess\nimport sys\nimport threading\nimport time\nfrom pathlib import Path\nfrom unittest import mock\n\nimport pytest\nimport requests\nfrom urllib3.exceptions import MaxRetryError\n\nfrom litgpt.utils import _RunIf, kill_process_tree\n\nREPO_ID = Path(\"EleutherAI/pythia-14m\")\nCUSTOM_TEXTS_DIR = Path(\"custom_texts\")\n\n\ndef run_command(command):\n    try:\n        result = subprocess.run(command, capture_output=True, text=True, check=True)\n        return result.stdout\n    except subprocess.CalledProcessError as e:\n        error_message = (\n            f\"Command '{' '.join(command)}' failed with exit status {e.returncode}\\n\"\n            f\"Output:\\n{e.stdout}\\n\"\n            f\"Error:\\n{e.stderr}\"\n        )\n        # You can either print the message, log it, or raise an exception with it\n        print(error_message)\n        raise RuntimeError(error_message) from None\n\n\ndef _wait_and_check_response(waiting: int = 30):\n    response_status_code, err = -1, None\n    for _ in range(waiting):\n        try:\n            response = requests.get(\"http://127.0.0.1:8000\", timeout=1)\n            response_status_code = response.status_code\n        except (MaxRetryError, requests.exceptions.ConnectionError) as ex:\n            response_status_code = -1\n            err = str(ex)\n        if response_status_code == 200:\n            break\n        time.sleep(1)\n    assert response_status_code == 200, \"Server did not respond as expected. Error: {err}\"\n\n\n@pytest.mark.dependency()\n@pytest.mark.flaky(reruns=5, reruns_delay=2)\ndef test_download_model():\n    repo_id = str(REPO_ID).replace(\"\\\\\", \"/\")  # fix for Windows CI\n    command = [\"litgpt\", \"download\", str(repo_id)]\n    output = run_command(command)\n\n    s = Path(\"checkpoints\") / repo_id\n    assert f\"Saving converted checkpoint to {str(s)}\" in output\n    assert (\"checkpoints\" / REPO_ID).exists()\n\n    # Also test valid but unsupported repo IDs\n    command = [\"litgpt\", \"download\", \"CohereForAI/aya-23-8B\"]\n    output = run_command(command)\n    assert \"Unsupported `repo_id`\" in output\n\n\n@pytest.mark.dependency()\n@pytest.mark.flaky(reruns=5, reruns_delay=2)\ndef test_download_books():\n    CUSTOM_TEXTS_DIR.mkdir(parents=True, exist_ok=True)\n\n    books = [\n        (\"https://www.gutenberg.org/cache/epub/24440/pg24440.txt\", \"book1.txt\"),\n        (\"https://www.gutenberg.org/cache/epub/26393/pg26393.txt\", \"book2.txt\"),\n    ]\n    for url, filename in books:\n        subprocess.run([\"curl\", url, \"--output\", str(CUSTOM_TEXTS_DIR / filename)], check=True)\n        # Verify each book is downloaded\n        assert (CUSTOM_TEXTS_DIR / filename).exists(), f\"{filename} not downloaded\"\n\n\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\n@pytest.mark.dependency(depends=[\"test_download_model\"])\ndef test_chat_with_model():\n    command = [\"litgpt\", \"generate\", \"checkpoints\" / REPO_ID]\n    prompt = \"What do Llamas eat?\"\n    result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True)\n    assert \"What food do llamas eat?\" in result.stdout\n\n\n@_RunIf(min_cuda_gpus=1)\n@pytest.mark.dependency(depends=[\"test_download_model\"])\ndef test_chat_with_quantized_model():\n    command = [\"litgpt\", \"generate\", \"checkpoints\" / REPO_ID, \"--quantize\", \"bnb.nf4\", \"--precision\", \"bf16-true\"]\n    prompt = \"What do Llamas eat?\"\n    result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True)\n    assert \"What food do llamas eat?\" in result.stdout, result.stdout\n\n\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\n@pytest.mark.dependency(depends=[\"test_download_model\"])\n@pytest.mark.timeout(300)\ndef test_finetune_model(tmp_path):\n    OUT_DIR = tmp_path / \"out\" / \"lora\"\n    DATASET_PATH = tmp_path / \"custom_finetuning_dataset.json\"\n    CHECKPOINT_DIR = \"checkpoints\" / REPO_ID\n\n    download_command = [\n        \"curl\",\n        \"-L\",\n        \"https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json\",\n        \"-o\",\n        str(DATASET_PATH),\n    ]\n    subprocess.run(download_command, check=True)\n\n    assert DATASET_PATH.exists(), \"Dataset file not downloaded\"\n\n    finetune_command = [\n        \"litgpt\",\n        \"finetune_lora\",\n        str(CHECKPOINT_DIR),\n        \"--lora_r\",\n        \"1\",\n        \"--data\",\n        \"JSON\",\n        \"--data.json_path\",\n        str(DATASET_PATH),\n        \"--data.val_split_fraction\",\n        \"0.00001\",  # Keep small because new final validation is expensive\n        \"--train.max_steps\",\n        \"1\",\n        \"--out_dir\",\n        str(OUT_DIR),\n    ]\n    run_command(finetune_command)\n\n    generated_out_dir = OUT_DIR / \"final\"\n    assert generated_out_dir.exists(), f\"Finetuning output directory ({generated_out_dir}) was not created\"\n    model_file = OUT_DIR / \"final\" / \"lit_model.pth\"\n    assert model_file.exists(), f\"Model file ({model_file}) was not created\"\n\n\n@pytest.mark.skipif(\n    sys.platform.startswith(\"win\") or sys.platform == \"darwin\", reason=\"`torch.compile` is not supported on this OS.\"\n)\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\n@pytest.mark.dependency(depends=[\"test_download_model\", \"test_download_books\"])\ndef test_pretrain_model(tmp_path):\n    OUT_DIR = tmp_path / \"out\" / \"custom_pretrained\"\n    pretrain_command = [\n        \"litgpt\",\n        \"pretrain\",\n        \"pythia-14m\",\n        \"--tokenizer_dir\",\n        str(\"checkpoints\" / REPO_ID),\n        \"--data\",\n        \"TextFiles\",\n        \"--data.train_data_path\",\n        str(CUSTOM_TEXTS_DIR),\n        \"--train.max_tokens\",\n        \"100\",  # to accelerate things for CI\n        \"--eval.max_iters\",\n        \"1\",  # to accelerate things for CI\n        \"--out_dir\",\n        str(OUT_DIR),\n    ]\n    output = run_command(pretrain_command)\n\n    assert \"Warning: Preprocessed training data found\" not in output\n    out_dir_path = OUT_DIR / \"final\"\n    assert out_dir_path.exists(), f\"Pretraining output directory ({out_dir_path}) was not created\"\n    out_model_path = OUT_DIR / \"final\" / \"lit_model.pth\"\n    assert out_model_path.exists(), f\"Model file ({out_model_path}) was not created\"\n\n    # Test that warning is displayed when running it a second time\n    output = run_command(pretrain_command)\n    assert \"Warning: Preprocessed training data found\" in output\n\n\n@pytest.mark.skipif(\n    sys.platform.startswith(\"win\") or sys.platform == \"darwin\", reason=\"`torch.compile` is not supported on this OS.\"\n)\n@mock.patch.dict(os.environ, {\"LT_ACCELERATOR\": \"cpu\"})\n@pytest.mark.dependency(depends=[\"test_download_model\", \"test_download_books\"])\ndef test_continue_pretrain_model(tmp_path):\n    OUT_DIR = tmp_path / \"out\" / \"custom_continue_pretrained\"\n    pretrain_command = [\n        \"litgpt\",\n        \"pretrain\",\n        \"pythia-14m\",\n        \"--initial_checkpoint\",\n        str(\"checkpoints\" / REPO_ID),\n        \"--tokenizer_dir\",\n        str(\"checkpoints\" / REPO_ID),\n        \"--data\",\n        \"TextFiles\",\n        \"--data.train_data_path\",\n        str(CUSTOM_TEXTS_DIR),\n        \"--train.max_tokens\",\n        \"100\",  # to accelerate things for CI\n        \"--eval.max_iters\",\n        \"1\",  # to accelerate things for CI\n        \"--out_dir\",\n        str(OUT_DIR),\n    ]\n    run_command(pretrain_command)\n\n    generated_out_dir = OUT_DIR / \"final\"\n    assert generated_out_dir.exists(), f\"Continued pretraining directory ({generated_out_dir}) was not created\"\n    model_file = OUT_DIR / \"final\" / \"lit_model.pth\"\n    assert model_file.exists(), f\"Model file ({model_file}) was not created\"\n\n\n@pytest.mark.dependency(depends=[\"test_download_model\"])\n# todo: try to resolve this issue\n@pytest.mark.xfail(condition=platform.system() == \"Darwin\", reason=\"it passes locally but having some issues on CI\")\ndef test_serve():\n    CHECKPOINT_DIR = str(\"checkpoints\" / REPO_ID)\n    run_command = [\"litgpt\", \"serve\", str(CHECKPOINT_DIR)]\n\n    process = None\n\n    def run_server():\n        nonlocal process\n        try:\n            process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n            stdout, stderr = process.communicate(timeout=60)\n        except subprocess.TimeoutExpired:\n            print(\"Server start-up timeout expired\")\n\n    server_thread = threading.Thread(target=run_server)\n    server_thread.start()\n\n    _wait_and_check_response()\n\n    if process:\n        kill_process_tree(process.pid)\n    server_thread.join()\n"
  },
  {
    "path": "tests/test_rope.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport torch\nfrom transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXConfig, GPTNeoXRotaryEmbedding\nfrom transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo\nfrom transformers.models.llama.configuration_llama import LlamaConfig\nfrom transformers.models.llama.modeling_llama import LlamaRotaryEmbedding\nfrom transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama\n\nfrom litgpt.model import apply_rope, build_rope_cache\n\n\n@torch.inference_mode()\ndef test_rope_gptneox():\n    bs, seq_len, n_head, n_embed = 1, 6, 2, 8\n    head_size = n_embed // n_head  # 4\n    x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()\n    position_ids = torch.arange(seq_len).unsqueeze(0)\n\n    config = GPTNeoXConfig(num_attention_heads=n_head, hidden_size=head_size * n_embed)\n    theirs_rot_emb = GPTNeoXRotaryEmbedding(config)\n    theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids)\n\n    ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device)\n    ours_cos_cached = ours_cos_cached.unsqueeze(0)\n    ours_sin_cached = ours_sin_cached.unsqueeze(0)\n    torch.testing.assert_close(ours_cos_cached, theirs_cos)\n    torch.testing.assert_close(ours_sin_cached, theirs_sin)\n\n    ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached)\n    theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids)\n    torch.testing.assert_close(ours_x_rope, theirs_x_rope)\n\n\n@torch.inference_mode()\ndef test_rope_llama_2():\n    head_dim = 64\n    rope_theta = 10_000\n\n    ##################################\n    # Compare cos and sin\n    ##################################\n    # transformer rope\n    their_rope_config = {\n        \"rope_type\": \"default\",\n    }\n    config = LlamaConfig(head_dim=head_dim, rope_theta=rope_theta, rope_scaling=their_rope_config)\n\n    rot_emb = LlamaRotaryEmbedding(config=config)\n    batch_size, seq_len = 1, 10\n    qk_tensor = torch.randn(batch_size, seq_len, head_dim)\n    position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)\n    theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)\n\n    # our rope\n    ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)\n    ours_cos = ours_cos.unsqueeze(0)\n    ours_sin = ours_sin.unsqueeze(0)\n    torch.testing.assert_close(theirs_cos, ours_cos)\n    torch.testing.assert_close(theirs_sin, ours_sin)\n\n    ##################################\n    # Compare rotated tensors\n    ##################################\n    # Settings\n    num_heads = 4\n\n    # Dummy query and key tensors\n    torch.manual_seed(123)\n    queries = torch.randn(batch_size, num_heads, seq_len, head_dim)\n    keys = torch.randn(batch_size, num_heads, seq_len, head_dim)\n\n    ours_q_rot = apply_rope(queries, ours_cos, ours_sin)\n    ours_k_rot = apply_rope(keys, ours_cos, ours_sin)\n    theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)\n    torch.testing.assert_close(theirs_q_rot, ours_q_rot)\n    torch.testing.assert_close(theirs_k_rot, ours_k_rot)\n\n\n# See https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json for settings\n@torch.inference_mode()\ndef test_rope_llama_3():\n    head_dim = 64\n    rope_theta = 50_000\n\n    ##################################\n    # Compare cos and sin\n    ##################################\n    # transformer rope\n    their_rope_config = {\n        \"rope_type\": \"default\",\n    }\n    config = LlamaConfig(head_dim=head_dim, rope_theta=rope_theta, rope_scaling=their_rope_config)\n\n    rot_emb = LlamaRotaryEmbedding(config=config)\n    batch_size, seq_len = 1, 10\n    qk_tensor = torch.randn(batch_size, seq_len, head_dim)\n    position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)\n    theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)\n\n    # our rope\n    ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)\n    ours_cos = ours_cos.unsqueeze(0)\n    ours_sin = ours_sin.unsqueeze(0)\n    torch.testing.assert_close(theirs_cos, ours_cos)\n    torch.testing.assert_close(theirs_sin, ours_sin)\n\n    ##################################\n    # Compare rotated tensors\n    ##################################\n    # Settings\n    num_heads = 4\n\n    # Dummy query and key tensors\n    torch.manual_seed(123)\n    queries = torch.randn(batch_size, num_heads, seq_len, head_dim)\n    keys = torch.randn(batch_size, num_heads, seq_len, head_dim)\n\n    ours_q_rot = apply_rope(queries, ours_cos, ours_sin)\n    ours_k_rot = apply_rope(keys, ours_cos, ours_sin)\n    theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)\n    torch.testing.assert_close(theirs_q_rot, ours_q_rot)\n    torch.testing.assert_close(theirs_k_rot, ours_k_rot)\n\n\n# See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings\n@torch.inference_mode()\ndef test_rope_llama_3_1():\n    head_dim = 32\n    rope_theta = 50_000\n\n    their_rope_config = {\n        \"factor\": 8.0,\n        \"low_freq_factor\": 1.0,\n        \"high_freq_factor\": 4.0,\n        \"original_max_position_embeddings\": 8192,\n        \"rope_type\": \"llama3\",\n    }\n\n    our_rope_config = {\"factor\": 8.0, \"low_freq_factor\": 1.0, \"high_freq_factor\": 4.0, \"original_max_seq_len\": 8192}\n\n    config = LlamaConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim)\n\n    ##################################\n    # Compare cos and sin\n    ##################################\n    # transformer rope\n    rot_emb = LlamaRotaryEmbedding(config=config)\n    batch_size, seq_len = 1, 131_072\n    qk_tensor = torch.randn(batch_size, seq_len, head_dim)\n    position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)\n    theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)\n\n    # our rope\n    ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)\n    ours_cos = ours_cos.unsqueeze(0)\n    ours_sin = ours_sin.unsqueeze(0)\n    torch.testing.assert_close(theirs_cos, ours_cos)\n    torch.testing.assert_close(theirs_sin, ours_sin)\n\n    ##################################\n    # Compare rotated tensors\n    ##################################\n    # Settings\n    num_heads = 4\n\n    # Dummy query and key tensors\n    torch.manual_seed(123)\n    queries = torch.randn(batch_size, num_heads, seq_len, head_dim)\n    keys = torch.randn(batch_size, num_heads, seq_len, head_dim)\n\n    ours_q_rot = apply_rope(queries, ours_cos, ours_sin)\n    ours_k_rot = apply_rope(keys, ours_cos, ours_sin)\n    theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)\n    torch.testing.assert_close(theirs_q_rot, ours_q_rot)\n    torch.testing.assert_close(theirs_k_rot, ours_k_rot)\n\n\n# See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings\n@torch.inference_mode()\ndef test_rope_llama_3_2():\n    head_dim = 32\n    rope_theta = 50_000\n\n    their_rope_config = {\n        \"factor\": 32.0,\n        \"low_freq_factor\": 1.0,\n        \"high_freq_factor\": 4.0,\n        \"original_max_position_embeddings\": 8192,\n        \"rope_type\": \"llama3\",\n    }\n\n    our_rope_config = {\"factor\": 32.0, \"low_freq_factor\": 1.0, \"high_freq_factor\": 4.0, \"original_max_seq_len\": 8192}\n\n    config = LlamaConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim)\n\n    ##################################\n    # Compare cos and sin\n    ##################################\n    # transformer rope\n    rot_emb = LlamaRotaryEmbedding(config=config)\n    batch_size, seq_len = 1, 131_072\n    qk_tensor = torch.randn(batch_size, seq_len, head_dim)\n    position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)\n    theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)\n\n    # our rope\n    ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)\n    ours_cos = ours_cos.unsqueeze(0)\n    ours_sin = ours_sin.unsqueeze(0)\n    torch.testing.assert_close(theirs_cos, ours_cos)\n    torch.testing.assert_close(theirs_sin, ours_sin)\n\n    ##################################\n    # Compare rotated tensors\n    ##################################\n    # Settings\n    num_heads = 4\n\n    # Dummy query and key tensors\n    torch.manual_seed(123)\n    queries = torch.randn(batch_size, num_heads, seq_len, head_dim)\n    keys = torch.randn(batch_size, num_heads, seq_len, head_dim)\n\n    ours_q_rot = apply_rope(queries, ours_cos, ours_sin)\n    ours_k_rot = apply_rope(keys, ours_cos, ours_sin)\n    theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)\n    torch.testing.assert_close(theirs_q_rot, ours_q_rot)\n    torch.testing.assert_close(theirs_k_rot, ours_k_rot)\n\n\n# See https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json for settings\n@torch.inference_mode()\ndef test_rope_gemma_3():\n    from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig\n    from transformers.models.gemma3.modeling_gemma3 import Gemma3RotaryEmbedding, apply_rotary_pos_emb\n\n    head_dim = 32\n    rope_theta = 50_000\n    their_rope_config = {\n        \"factor\": 8.0,\n        \"rope_type\": \"linear\",\n    }\n\n    our_rope_config = {\"factor\": 8.0}\n\n    ##################################\n    # Compare cos and sin\n    ##################################\n    # transformer rope\n    config = Gemma3TextConfig(rope_theta=rope_theta, rope_scaling=their_rope_config, head_dim=head_dim)\n    rot_emb = Gemma3RotaryEmbedding(config=config)\n    batch_size, seq_len = 1, 10\n    qk_tensor = torch.randn(batch_size, seq_len, head_dim)\n    position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)\n    theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)\n\n    # our rope\n    ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)\n    ours_cos = ours_cos.unsqueeze(0)\n    ours_sin = ours_sin.unsqueeze(0)\n    torch.testing.assert_close(theirs_cos, ours_cos)\n    torch.testing.assert_close(theirs_sin, ours_sin)\n\n    ##################################\n    # Compare rotated tensors\n    ##################################\n    # Settings\n    num_heads = 4\n\n    # Dummy query and key tensors\n    torch.manual_seed(123)\n    queries = torch.randn(batch_size, num_heads, seq_len, head_dim)\n    keys = torch.randn(batch_size, num_heads, seq_len, head_dim)\n\n    ours_q_rot = apply_rope(queries, ours_cos, ours_sin)\n    ours_k_rot = apply_rope(keys, ours_cos, ours_sin)\n    theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb(queries, keys, theirs_cos, theirs_sin)\n    torch.testing.assert_close(theirs_q_rot, ours_q_rot)\n    torch.testing.assert_close(theirs_k_rot, ours_k_rot)\n\n\n@torch.inference_mode()\ndef test_rope_cos_sin_shapes_if_rope_n_elem_is_odd():\n    bs, seq_len, n_head, n_embed = 1, 6, 2, 8\n    head_size = n_embed // n_head  # 4\n    rotary_percentage = 0.75\n    rope_n_elem = int(head_size * rotary_percentage)  # 3\n    ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem)\n    required_shape = (seq_len, rope_n_elem)\n    assert ours_cos.shape == required_shape\n    assert ours_sin.shape == required_shape\n    # Special case: If `rope_n_elem == 1`, the shape is extended. This is to\n    # accommodate a current bug in Hugging Face, ensuring that other unit tests\n    # pass.\n    # https://github.com/huggingface/transformers/issues/35233\n    rotary_percentage = 0.25\n    rope_n_elem = int(head_size * rotary_percentage)  # 1\n    ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem)\n    required_shape = (seq_len, rope_n_elem + 1)\n    assert ours_cos.shape == required_shape\n    assert ours_sin.shape == required_shape\n"
  },
  {
    "path": "tests/test_serve.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport json\nimport platform\nimport shutil\nimport subprocess\nimport threading\nimport time\nfrom dataclasses import asdict\n\nimport pytest\nimport requests\nimport torch\nimport yaml\nfrom lightning.fabric import seed_everything\nfrom urllib3.exceptions import MaxRetryError\n\nfrom litgpt import GPT, Config\nfrom litgpt.scripts.download import download_from_hub\nfrom litgpt.utils import _RunIf, kill_process_tree\n\n\ndef _wait_and_check_response(waiting: int = 30):\n    response_status_code, err = -1, None\n    for _ in range(waiting):\n        try:\n            response = requests.get(\"http://127.0.0.1:8000\", timeout=10)\n            response_status_code = response.status_code\n        except (MaxRetryError, requests.exceptions.ConnectionError) as ex:\n            response_status_code = -1\n            err = str(ex)\n        if response_status_code == 200:\n            break\n        time.sleep(1)\n    assert response_status_code == 200, f\"Server did not respond as expected. Error: {err}\"\n\n\n# todo: try to resolve this issue\n@pytest.mark.flaky(reruns=2, reruns_delay=30)\n@pytest.mark.xfail(condition=platform.system() == \"Darwin\", reason=\"it passes locally but having some issues on CI\")\ndef test_simple(tmp_path):\n    seed_everything(123)\n    ours_config = Config.from_name(\"pythia-14m\")\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer.json\"), str(tmp_path))\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer_config.json\"), str(tmp_path))\n    ours_model = GPT(ours_config)\n    checkpoint_path = tmp_path / \"lit_model.pth\"\n    torch.save(ours_model.state_dict(), checkpoint_path)\n    config_path = tmp_path / \"model_config.yaml\"\n    with open(config_path, \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n\n    run_command = [\"litgpt\", \"serve\", tmp_path]\n\n    process = None\n\n    def run_server():\n        nonlocal process\n        try:\n            process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)\n        except subprocess.TimeoutExpired:\n            print(\"Server start-up timeout expired\")\n\n    server_thread = threading.Thread(target=run_server)\n    server_thread.start()\n\n    _wait_and_check_response(waiting=60)\n\n    if process:\n        kill_process_tree(process.pid)\n    server_thread.join()\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_quantize(tmp_path):\n    seed_everything(123)\n    ours_config = Config.from_name(\"pythia-14m\")\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer.json\"), str(tmp_path))\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer_config.json\"), str(tmp_path))\n    ours_model = GPT(ours_config)\n    checkpoint_path = tmp_path / \"lit_model.pth\"\n    torch.save(ours_model.state_dict(), checkpoint_path)\n    config_path = tmp_path / \"model_config.yaml\"\n    with open(config_path, \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n\n    run_command = [\"litgpt\", \"serve\", tmp_path, \"--quantize\", \"bnb.nf4\"]\n\n    process = None\n\n    def run_server():\n        nonlocal process\n        try:\n            process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)\n        except subprocess.TimeoutExpired:\n            print(\"Server start-up timeout expired\")\n\n    server_thread = threading.Thread(target=run_server)\n    server_thread.start()\n\n    _wait_and_check_response()\n\n    if process:\n        kill_process_tree(process.pid)\n    server_thread.join()\n\n\n@_RunIf(min_cuda_gpus=2)\ndef test_multi_gpu_serve(tmp_path):\n    seed_everything(123)\n    ours_config = Config.from_name(\"pythia-14m\")\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer.json\"), str(tmp_path))\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer_config.json\"), str(tmp_path))\n    ours_model = GPT(ours_config)\n    checkpoint_path = tmp_path / \"lit_model.pth\"\n    torch.save(ours_model.state_dict(), checkpoint_path)\n    config_path = tmp_path / \"model_config.yaml\"\n    with open(config_path, \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n\n    run_command = [\"litgpt\", \"serve\", tmp_path, \"--devices\", \"2\"]\n\n    process = None\n\n    def run_server():\n        nonlocal process\n        try:\n            process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)\n        except subprocess.TimeoutExpired:\n            print(\"Server start-up timeout expired\")\n\n    server_thread = threading.Thread(target=run_server)\n    server_thread.start()\n\n    _wait_and_check_response()\n\n    if process:\n        kill_process_tree(process.pid)\n    server_thread.join()\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_serve_with_openai_spec_missing_chat_template(tmp_path):\n    seed_everything(123)\n    ours_config = Config.from_name(\"pythia-14m\")\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer.json\"), str(tmp_path))\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer_config.json\"), str(tmp_path))\n    ours_model = GPT(ours_config)\n    checkpoint_path = tmp_path / \"lit_model.pth\"\n    torch.save(ours_model.state_dict(), checkpoint_path)\n    config_path = tmp_path / \"model_config.yaml\"\n    with open(config_path, \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n\n    run_command = [\"litgpt\", \"serve\", tmp_path, \"--openai_spec\", \"true\"]\n\n    process = None\n\n    def run_server():\n        nonlocal process\n        try:\n            process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)\n        except subprocess.TimeoutExpired:\n            print(\"Server start-up timeout expired\")\n\n    server_thread = threading.Thread(target=run_server)\n    server_thread.start()\n\n    _wait_and_check_response()\n\n    if process:\n        kill_process_tree(process.pid)\n    server_thread.join()\n\n\n@_RunIf(min_cuda_gpus=1)\ndef test_serve_with_openai_spec(tmp_path):\n    seed_everything(123)\n    ours_config = Config.from_name(\"SmolLM2-135M-Instruct\")\n    download_from_hub(repo_id=\"HuggingFaceTB/SmolLM2-135M-Instruct\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    shutil.move(str(tmp_path / \"HuggingFaceTB\" / \"SmolLM2-135M-Instruct\" / \"tokenizer.json\"), str(tmp_path))\n    shutil.move(str(tmp_path / \"HuggingFaceTB\" / \"SmolLM2-135M-Instruct\" / \"tokenizer_config.json\"), str(tmp_path))\n    ours_model = GPT(ours_config)\n    checkpoint_path = tmp_path / \"lit_model.pth\"\n    torch.save(ours_model.state_dict(), checkpoint_path)\n    config_path = tmp_path / \"model_config.yaml\"\n    with open(config_path, \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n\n    run_command = [\"litgpt\", \"serve\", tmp_path, \"--openai_spec\", \"true\"]\n\n    process = None\n\n    def run_server():\n        nonlocal process\n        try:\n            process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n        except subprocess.TimeoutExpired:\n            print(\"Server start-up timeout expired\")\n\n    server_thread = threading.Thread(target=run_server)\n    server_thread.start()\n\n    _wait_and_check_response()\n\n    try:\n        # Test server health\n        response = requests.get(\"http://127.0.0.1:8000/health\")\n        assert response.status_code == 200, f\"Server health check failed with status code {response.status_code}\"\n        assert response.text == \"ok\", \"Server did not respond as expected.\"\n\n        # Test non-streaming chat completion\n        response = requests.post(\n            \"http://127.0.0.1:8000/v1/chat/completions\",\n            json={\n                \"model\": \"SmolLM2-135M-Instruct\",\n                \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}],\n            },\n        )\n        assert response.status_code == 200, (\n            f\"Non-streaming chat completion failed with status code {response.status_code}\"\n        )\n        response_json = response.json()\n        assert \"choices\" in response_json, \"Response JSON does not contain 'choices'.\"\n        assert \"message\" in response_json[\"choices\"][0], \"Response JSON does not contain 'message' in 'choices'.\"\n        assert \"content\" in response_json[\"choices\"][0][\"message\"], (\n            \"Response JSON does not contain 'content' in 'message'.\"\n        )\n        assert response_json[\"choices\"][0][\"message\"][\"content\"], \"Content is empty in the response.\"\n\n        # Test streaming chat completion\n        stream_response = requests.post(\n            \"http://127.0.0.1:8000/v1/chat/completions\",\n            json={\n                \"model\": \"SmolLM2-135M-Instruct\",\n                \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}],\n                \"stream\": True,\n            },\n        )\n        assert stream_response.status_code == 200, (\n            f\"Streaming chat completion failed with status code {stream_response.status_code}\"\n        )\n        for line in stream_response.iter_lines():\n            decoded = line.decode(\"utf-8\").replace(\"data: \", \"\").replace(\"[DONE]\", \"\").strip()\n            if decoded:\n                data = json.loads(decoded)\n                assert \"choices\" in data, \"Response JSON does not contain 'choices'.\"\n                assert \"delta\" in data[\"choices\"][0], \"Response JSON does not contain 'delta' in 'choices'.\"\n                assert \"content\" in data[\"choices\"][0][\"delta\"], \"Response JSON does not contain 'content' in 'delta'.\"\n    finally:\n        if process:\n            kill_process_tree(process.pid)\n        server_thread.join()\n\n\n@pytest.mark.parametrize(\n    \"generate_strategy\",\n    [\n        pytest.param(\"sequential\", marks=_RunIf(min_cuda_gpus=1)),\n        pytest.param(\"tensor_parallel\", marks=_RunIf(min_cuda_gpus=2)),\n    ],\n)\ndef test_serve_with_generate_strategy(tmp_path, generate_strategy):\n    seed_everything(123)\n    ours_config = Config.from_name(\"pythia-14m\")\n    download_from_hub(repo_id=\"EleutherAI/pythia-14m\", tokenizer_only=True, checkpoint_dir=tmp_path)\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer.json\"), str(tmp_path))\n    shutil.move(str(tmp_path / \"EleutherAI\" / \"pythia-14m\" / \"tokenizer_config.json\"), str(tmp_path))\n    ours_model = GPT(ours_config)\n    checkpoint_path = tmp_path / \"lit_model.pth\"\n    torch.save(ours_model.state_dict(), checkpoint_path)\n    config_path = tmp_path / \"model_config.yaml\"\n    with open(config_path, \"w\", encoding=\"utf-8\") as fp:\n        yaml.dump(asdict(ours_config), fp)\n\n    # Test with generate strategy\n    run_command = [\"litgpt\", \"serve\", tmp_path, \"--generate_strategy\", generate_strategy]\n\n    process = None\n\n    def run_server():\n        nonlocal process\n        try:\n            process = subprocess.Popen(run_command, stdout=None, stderr=None, text=True)\n        except subprocess.TimeoutExpired:\n            print(\"Server start-up timeout expired\")\n\n    server_thread = threading.Thread(target=run_server)\n    server_thread.start()\n\n    _wait_and_check_response()\n\n    if process:\n        kill_process_tree(process.pid)\n    server_thread.join()\n"
  },
  {
    "path": "tests/test_tokenizer.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nimport shutil\nimport warnings\nfrom types import SimpleNamespace\nfrom unittest import mock\n\nimport pytest\nfrom tokenizers import Tokenizer as HFTokenizer\nfrom tokenizers.models import BPE\nfrom transformers import AutoTokenizer\nfrom transformers.utils import cached_file\n\nimport litgpt.config as config_module\nfrom litgpt import PromptStyle, Tokenizer\n\n\n# @pytest.mark.flaky(reruns=3, rerun_except=[\"AssertionError\", \"assert\", \"TypeError\"])\n@pytest.mark.flaky(reruns=3, reruns_delay=120)\n@pytest.mark.parametrize(\"config\", config_module.configs, ids=[c[\"hf_config\"][\"name\"] for c in config_module.configs])\ndef test_tokenizer_against_hf(config, tmp_path):\n    config = config_module.Config(**config)\n\n    repo_id = f\"{config.hf_config['org']}/{config.hf_config['name']}\"\n    theirs = AutoTokenizer.from_pretrained(repo_id, token=os.getenv(\"HF_TOKEN\"))\n\n    # create a checkpoint directory that points to the HF files\n    hf_files = {}\n    for filename in (\"tokenizer.json\", \"generation_config.json\", \"tokenizer.model\", \"tokenizer_config.json\"):\n        try:  # download the HF tokenizer config\n            hf_file = cached_file(path_or_repo_id=repo_id, filename=filename)\n            hf_files[filename] = str(hf_file)\n        except Exception as ex:\n            warnings.warn(str(ex), RuntimeWarning)\n    if \"tokenizer.json\" not in hf_files and \"tokenizer.model\" not in hf_files:\n        raise ConnectionError(\"Unable to download any tokenizer files from HF\")\n\n    # Create a clean, model-specific subdirectory for this test run.\n    # This avoids errors if previous runs or retries left files behind, ensuring the directory is always ready for fresh downloads and comparisons.\n    model_dir = tmp_path / config.hf_config[\"name\"]\n    if model_dir.exists():\n        shutil.rmtree(model_dir)\n    os.makedirs(model_dir, exist_ok=True)\n\n    for filename, hf_file in hf_files.items():\n        shutil.copy(hf_file, model_dir / filename)\n\n    ours = Tokenizer(model_dir)\n\n    assert ours.vocab_size == theirs.vocab_size\n    if config.name == \"Mixtral-8x22B-v0.1\":\n        pytest.xfail(reason=\"Mixtral certainly lists 32000 vocab in its config\")\n    else:\n        assert ours.vocab_size == config.vocab_size\n\n    if config.name.startswith((\"falcon\", \"stablecode\", \"Qwen2.5\", \"QwQ\", \"Qwen3\")):\n        # even though their config defines it, it's set as None in HF\n        assert isinstance(ours.bos_id, int)\n        assert theirs.bos_token_id is None\n    elif config.name.startswith(\"Falcon3\"):\n        if isinstance(ours.bos_id, int):\n            assert theirs.bos_token_id is None\n        else:\n            assert ours.bos_id == theirs.bos_token_id is None\n    else:\n        assert ours.bos_id == theirs.bos_token_id\n\n    if config.name.startswith(\"stablecode\"):\n        # even though their config defines it, it's set as None in HF\n        assert ours.eos_id == 0\n        assert ours.eos_id == theirs.eos_token_id or theirs.eos_token_id is None\n    else:\n        assert ours.eos_id == theirs.eos_token_id\n\n    prompt = \"Hello, readers of this test!\"\n    prompt = PromptStyle.from_config(config).apply(prompt)\n    actual = ours.encode(prompt)\n    expected = theirs.encode(prompt)\n    assert actual.tolist() == expected\n    assert ours.decode(actual) == theirs.decode(expected, skip_special_tokens=True)\n\n    if not config.name.startswith((\"Mistral\", \"Mixtral\")):\n        decoded_output = \"\".join([ours.decode(x) for x in actual])\n        if ours.apply_decoding_fix and decoded_output[0] == \" \":\n            decoded_output = decoded_output[1:]  # the \"hack\" adds an empty space to the beginning\n        assert decoded_output == ours.decode(actual), type(theirs)\n\n\ndef test_tokenizer_input_validation():\n    with pytest.raises(NotADirectoryError, match=\"The checkpoint directory does not exist\"):\n        Tokenizer(\"cocofruit\")\n\n\n@pytest.mark.parametrize(\"use_bos_by_default\", (True, False))\n@pytest.mark.parametrize(\"encode_use_bos\", (None, True, False))\n@pytest.mark.parametrize(\"encode_use_eos\", (True, False))\n@pytest.mark.parametrize(\"processor_returns_bos\", (True, False))\n@pytest.mark.parametrize(\"fake_return_ids\", ([], [34, 8, 17, 2]))\ndef test_tokenizer_bos_eos(\n    tmp_path, use_bos_by_default, encode_use_bos, encode_use_eos, processor_returns_bos, fake_return_ids\n):\n    # let `Tokenizers` create a proper (albeit empty) vocab in json format\n    HFTokenizer(BPE()).save(str(tmp_path / \"tokenizer.json\"))\n\n    tokenizer = Tokenizer(tmp_path)\n    tokenizer.bos_id = 0\n    tokenizer.eos_id = 1\n    tokenizer.use_bos = use_bos_by_default\n\n    if processor_returns_bos:\n        fake_return_ids = [tokenizer.bos_id] + fake_return_ids\n    fake_return_ids = SimpleNamespace(**dict(ids=fake_return_ids))\n\n    with mock.patch.object(tokenizer.processor, \"encode\", return_value=fake_return_ids):\n        tokens = tokenizer.encode(\"Hello world\", bos=encode_use_bos, eos=encode_use_eos).tolist()\n\n    if encode_use_bos or (encode_use_bos is None and use_bos_by_default):\n        assert tokens[0] == tokenizer.bos_id\n    else:\n        assert not tokens or tokens[0] != tokenizer.bos_id\n\n    if encode_use_eos:\n        assert tokens[-1] == tokenizer.eos_id\n    else:\n        assert not tokens or tokens[-1] != tokenizer.eos_id\n\n    # both `bos` and `eos` should either not be found or occur only once at the begging (bos)\n    # or at the end (eos) of the tokens sequence\n    assert max([id for id, token in enumerate(tokens) if token == tokenizer.bos_id], default=0) == 0\n    assert max([id for id, token in enumerate(tokens[::-1]) if token == tokenizer.eos_id], default=0) == 0\n"
  },
  {
    "path": "tests/test_trainer_support.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport os\nfrom pathlib import Path\n\nimport lightning as L\nimport pytest\nimport torch\n\nfrom litgpt.api import LLM\nfrom litgpt.data import Alpaca2k\nfrom litgpt.utils import _RunIf\n\nREPO_ID = Path(\"EleutherAI/pythia-14m\")\n\n\nclass LitLLM(L.LightningModule):\n    def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None):\n        super().__init__()\n\n        self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None)\n        self.trainer_ckpt_path = trainer_ckpt_path\n\n    def setup(self, stage):\n        self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path)\n\n    def training_step(self, batch):\n        logits, loss = self.llm(input_ids=batch[\"input_ids\"], target_ids=batch[\"labels\"])\n        self.log(\"train_loss\", loss, prog_bar=True)\n        return loss\n\n    def validation_step(self, batch):\n        logits, loss = self.llm(input_ids=batch[\"input_ids\"], target_ids=batch[\"labels\"])\n        self.log(\"validation_loss\", loss, prog_bar=True)\n        return loss\n\n    def configure_optimizers(self):\n        warmup_steps = 10\n        optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))\n        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n        return [optimizer], [scheduler]\n\n\n@pytest.mark.dependency()\ndef test_download_model():\n    LLM.load(model=\"EleutherAI/pythia-14m\", distribute=None)\n\n\n@pytest.mark.dependency(depends=[\"test_download_model\"])\n@_RunIf(min_cuda_gpus=1)\ndef test_usecase1_pretraining_from_random_weights(tmp_path):\n    llm = LLM.load(\"EleutherAI/pythia-14m\", tokenizer_dir=\"EleutherAI/pythia-14m\", init=\"random\")\n    llm.save(\"pythia-14m-random-weights\")\n    del llm\n\n    lit_model = LitLLM(checkpoint_dir=\"pythia-14m-random-weights\", tokenizer_dir=\"EleutherAI/pythia-14m\")\n    data = Alpaca2k()\n\n    data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)\n\n    trainer = L.Trainer(\n        max_epochs=1,\n        overfit_batches=2,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    text = lit_model.llm.generate(\"hello world\")\n    assert isinstance(text, str)\n\n\n@pytest.mark.dependency(depends=[\"test_download_model\"])\n@_RunIf(min_cuda_gpus=1)\ndef test_usecase2_continued_pretraining_from_checkpoint(tmp_path):\n    lit_model = LitLLM(checkpoint_dir=\"EleutherAI/pythia-14m\")\n    data = Alpaca2k()\n\n    data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)\n\n    trainer = L.Trainer(\n        accelerator=\"cuda\",\n        max_epochs=1,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    text = lit_model.llm.generate(\"hello world\")\n    assert isinstance(text, str)\n\n\n@pytest.mark.dependency(depends=[\"test_download_model\", \"test_usecase2_continued_pretraining_from_checkpoint\"])\n@_RunIf(min_cuda_gpus=1)\ndef test_usecase3_resume_from_trainer_checkpoint(tmp_path):\n    def find_latest_checkpoint(directory):\n        latest_checkpoint = None\n        latest_time = 0\n\n        for root, _, files in os.walk(directory):\n            for file in files:\n                if file.endswith(\".ckpt\"):\n                    file_path = os.path.join(root, file)\n                    file_time = os.path.getmtime(file_path)\n                    if file_time > latest_time:\n                        latest_time = file_time\n                        latest_checkpoint = file_path\n\n        return latest_checkpoint\n\n    lit_model = LitLLM(\n        checkpoint_dir=\"EleutherAI/pythia-14m\", trainer_ckpt_path=find_latest_checkpoint(\"lightning_logs\")\n    )\n\n    data = Alpaca2k()\n    data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)\n\n    trainer = L.Trainer(\n        accelerator=\"cuda\",\n        max_epochs=1,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    text = lit_model.llm.generate(\"hello world\")\n    assert isinstance(text, str)\n\n\n@pytest.mark.dependency(depends=[\"test_download_model\", \"test_usecase2_continued_pretraining_from_checkpoint\"])\n@_RunIf(min_cuda_gpus=1)\ndef test_usecase4_manually_save_and_resume(tmp_path):\n    lit_model = LitLLM(checkpoint_dir=\"EleutherAI/pythia-14m\")\n    data = Alpaca2k()\n\n    data.connect(lit_model.llm.tokenizer, batch_size=4, max_seq_length=128)\n\n    trainer = L.Trainer(\n        accelerator=\"cuda\",\n        max_epochs=1,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    text = lit_model.llm.generate(\"hello world\")\n    assert isinstance(text, str)\n\n    lit_model.llm.save(\"finetuned_checkpoint\")\n\n    del lit_model\n    lit_model = LitLLM(checkpoint_dir=\"finetuned_checkpoint\")\n\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    text = lit_model.llm.generate(\"hello world\")\n    assert isinstance(text, str)\n"
  },
  {
    "path": "tests/test_types.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nfrom typing import get_args\n\nfrom litgpt.constants import _SUPPORTED_LOGGERS\nfrom litgpt.types import LoggerChoice\n\n\ndef test_logger_types_match_constants():\n    \"\"\"Ensure LoggerChoice and _SUPPORTED_LOGGERS stay synchronized.\"\"\"\n    logger_choice_args = get_args(LoggerChoice)\n    assert logger_choice_args == _SUPPORTED_LOGGERS, (\n        f\"LoggerChoice type args {logger_choice_args} != \"\n        f\"_SUPPORTED_LOGGERS {_SUPPORTED_LOGGERS}. \"\n        f\"These must stay synchronized. Update both litgpt/types.py and \"\n        f\"litgpt/constants.py when adding new loggers.\"\n    )\n"
  },
  {
    "path": "tests/test_utils.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\nimport os\nfrom contextlib import redirect_stderr\nfrom dataclasses import asdict\nfrom io import StringIO\nfrom pathlib import Path\nfrom tempfile import NamedTemporaryFile, TemporaryDirectory\nfrom unittest import mock\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nimport yaml\nfrom lightning import Fabric\nfrom lightning.fabric.loggers import CSVLogger, TensorBoardLogger\nfrom lightning.fabric.plugins import BitsandbytesPrecision\nfrom lightning.pytorch.loggers import LitLogger, MLFlowLogger, WandbLogger\n\nfrom litgpt import GPT\nfrom litgpt.args import TrainArgs\nfrom litgpt.constants import (\n    _LITLOGGER_AVAILABLE,\n    _MLFLOW_AVAILABLE,\n    _MLFLOW_SKINNY_AVAILABLE,\n    _TENSORBOARD_AVAILABLE,\n    _WANDB_AVAILABLE,\n)\nfrom litgpt.parser_config import save_hyperparameters\nfrom litgpt.utils import (\n    CLI,\n    CycleIterator,\n    _RunIf,\n    capture_hparams,\n    check_file_size_on_cpu_and_warn,\n    check_nvlink_connectivity,\n    check_valid_checkpoint_dir,\n    choose_logger,\n    chunked_cross_entropy,\n    copy_config_files,\n    extend_checkpoint_dir,\n    find_resume_path,\n    fix_and_load_json,\n    incremental_save,\n    init_out_dir,\n    instantiate_bnb_optimizer,\n    instantiate_torch_optimizer,\n    num_parameters,\n    parse_devices,\n    select_sft_generate_example,\n)\n\n\n# match fails on windows. why did they have to use backslashes?\n@_RunIf(skip_windows=True)\ndef test_check_valid_checkpoint_dir(tmp_path):\n    os.chdir(tmp_path)\n\n    out = StringIO()\n    with pytest.raises(SystemExit), redirect_stderr(out):\n        check_valid_checkpoint_dir(tmp_path)\n    out = out.getvalue().strip()\n    expected = f\"\"\"\ncheckpoint_dir '{str(tmp_path.absolute())}' is missing the files: ['lit_model.pth', 'model_config.yaml', 'tokenizer.json OR tokenizer.model', 'tokenizer_config.json'].\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n\nSee all download options by running:\n litgpt download\n    \"\"\".strip()\n    assert out == expected\n\n    out = StringIO()\n    checkpoint_dir = tmp_path / \"checkpoints\" / \"stabilityai\" / \"stablelm-base-alpha-3b\"\n    with pytest.raises(SystemExit), redirect_stderr(out):\n        check_valid_checkpoint_dir(checkpoint_dir)\n    out = out.getvalue().strip()\n    expected = f\"\"\"\ncheckpoint_dir '{str(checkpoint_dir.absolute())}' is not a checkpoint directory.\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n\nSee all download options by running:\n litgpt download\n    \"\"\".strip()\n    assert out == expected\n\n    out = StringIO()\n    checkpoint_dir.mkdir(parents=True)\n    foo_checkpoint_dir = tmp_path / \"foo\"\n    with pytest.raises(SystemExit), redirect_stderr(out):\n        check_valid_checkpoint_dir(foo_checkpoint_dir)\n    out = out.getvalue().strip()\n    expected = f\"\"\"\ncheckpoint_dir '{str(foo_checkpoint_dir.absolute())}' is not a checkpoint directory.\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n\nYou have downloaded locally:\n'{str(checkpoint_dir.absolute())}'\n\nSee all download options by running:\n litgpt download\n    \"\"\".strip()\n    assert out == expected\n\n\ndef test_incremental_write(tmp_path):\n    sd = {str(k): torch.randn(5, 10) for k in range(3)}\n    sd[\"0\"].someattr = 1\n    sd_expected = {k: v.clone() for k, v in sd.items()}\n    fn = str(tmp_path / \"test.pt\")\n    with incremental_save(fn) as f:\n        sd[\"0\"] = f.store_early(sd[\"0\"])\n        sd[\"2\"] = f.store_early(sd[\"2\"])\n        f.save(sd)\n    sd_actual = torch.load(fn)\n    assert sd_actual.keys() == sd_expected.keys()\n    assert sd_actual[\"0\"].someattr == 1  # requires PyTorch 2.0+\n    for k, v_expected in sd_expected.items():\n        v_actual = sd_actual[k]\n        torch.testing.assert_close(v_expected, v_actual)\n    sd_actual = torch.load(fn, weights_only=True)\n    assert sd_actual.keys() == sd_expected.keys()\n    assert sd_actual[\"0\"].someattr == 1  # requires PyTorch 2.0+\n    for k, v_expected in sd_expected.items():\n        v_actual = sd_actual[k]\n        torch.testing.assert_close(v_expected, v_actual)\n\n\n@pytest.mark.parametrize(\"B\", (1, 2))\n@pytest.mark.parametrize(\"ignore_index\", (None, -1, -2, -100))\ndef test_chunked_cross_entropy(ignore_index, B):\n    V = 50\n    T = 25\n    regular_logits = torch.randn(B, T, V)\n    targets = torch.randint(0, V, (B, T))\n\n    if ignore_index is not None:\n        targets[:, [1, 4, 10, 19]] = ignore_index\n\n    baseline_loss = F.cross_entropy(\n        regular_logits.reshape(-1, regular_logits.size(-1)),\n        targets.reshape(-1),\n        ignore_index=(ignore_index if ignore_index is not None else -100),\n    )\n\n    ignore_index = ignore_index if ignore_index is not None else -100\n    regular_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=0, ignore_index=ignore_index)\n    assert torch.equal(baseline_loss, regular_loss)\n    assert regular_loss.numel() == 1\n\n    chunked_loss = chunked_cross_entropy(regular_logits, targets, chunk_size=10, ignore_index=ignore_index)\n    torch.testing.assert_close(chunked_loss, regular_loss)\n    torch.testing.assert_close(chunked_loss, baseline_loss)\n\n    logit_chunk_size = 6\n    assert T % logit_chunk_size != 0  # ensure leftover\n    chunked_logits = list(regular_logits.split(logit_chunk_size, dim=1))\n    chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=0, ignore_index=ignore_index)\n    torch.testing.assert_close(chunked_loss, regular_loss)\n    torch.testing.assert_close(chunked_loss, baseline_loss)\n\n    chunked_loss = chunked_cross_entropy(chunked_logits, targets, chunk_size=10, ignore_index=ignore_index)\n    torch.testing.assert_close(chunked_loss, regular_loss)\n    torch.testing.assert_close(chunked_loss, baseline_loss)\n\n\ndef test_num_parameters():\n    model = torch.nn.Linear(2, 2)\n    assert num_parameters(model) == 6\n    assert num_parameters(model, requires_grad=True) == 6\n    assert num_parameters(model, requires_grad=False) == 0\n\n    model = torch.nn.Linear(2, 2)\n    model.bias.requires_grad = False\n    assert num_parameters(model) == 6\n    assert num_parameters(model, requires_grad=True) == 4\n    assert num_parameters(model, requires_grad=False) == 2\n\n\n@_RunIf(min_cuda_gpus=1)\n@pytest.mark.parametrize(\"mode\", [\"nf4\", \"nf4-dq\", \"fp4\", \"fp4-dq\", \"int8\", \"int8-training\"])\ndef test_num_parameters_bitsandbytes(mode):\n    plugin = BitsandbytesPrecision(mode=mode)\n    fabric = Fabric(plugins=plugin, accelerator=\"cuda\", devices=1)\n\n    model = torch.nn.Linear(10, 10)\n    model = fabric.setup(model)\n    assert num_parameters(model) == 110\n\n    with fabric.init_module(empty_init=True):\n        model = GPT.from_name(\"pythia-14m\")\n    assert num_parameters(model) == 14067712\n\n\ndef test_cycle_iterator():\n    iterator = CycleIterator([])\n    with pytest.raises(StopIteration):\n        next(iterator)\n\n    iterator = CycleIterator(range(3))\n    assert iterator.epoch == 0\n    assert next(iterator) == 0\n    assert iterator.epoch == 0\n    assert next(iterator) == 1\n    assert iterator.epoch == 0\n    assert next(iterator) == 2\n    assert iterator.epoch == 0\n    assert next(iterator) == 0\n    assert iterator.epoch == 1\n\n\ndef test_parse_devices():\n    with pytest.raises(ValueError, match=\"must be 'auto' or a positive integer\"):\n        assert parse_devices(0)\n    with pytest.raises(ValueError, match=\"must be 'auto' or a positive integer\"):\n        assert parse_devices(-2)\n\n    with mock.patch(\"litgpt.utils.torch.cuda.device_count\", return_value=0):\n        assert parse_devices(\"auto\") == 1  # CPU\n        assert parse_devices(10) == 10  # leave validation up to Fabric later on\n    with mock.patch(\"litgpt.utils.torch.cuda.device_count\", return_value=1):\n        assert parse_devices(\"auto\") == 1  # CUDA\n    with mock.patch(\"litgpt.utils.torch.cuda.device_count\", return_value=3):\n        assert parse_devices(\"auto\") == 3\n        assert parse_devices(-1) == 3\n\n    assert parse_devices(5) == 5\n\n\ndef test_copy_config_files(fake_checkpoint_dir, tmp_path):\n    copy_config_files(fake_checkpoint_dir, tmp_path)\n    expected = {\"model_config.yaml\", \"tokenizer_config.json\", \"tokenizer.json\"}\n    contents = set(os.listdir(tmp_path))\n    assert expected.issubset(contents)\n\n\ndef test_capture_hparams():\n    integer = 1\n    string = \"string\"\n    boolean = True\n    none = None\n    path = Path(\"/path\")\n    dataclass = TrainArgs()\n    other = torch.nn.Linear(1, 1)\n    hparams = capture_hparams()\n    assert hparams == {\n        \"integer\": integer,\n        \"string\": string,\n        \"boolean\": boolean,\n        \"none\": none,\n        \"path\": path,\n        \"dataclass\": asdict(dataclass),\n        \"other\": str(other),\n    }\n\n\ndef _test_function(out_dir: Path, foo: bool = False, bar: int = 1):\n    save_hyperparameters(_test_function, out_dir)\n\n\ndef test_save_hyperparameters(tmp_path):\n    with mock.patch(\"sys.argv\", [\"any.py\", str(tmp_path), \"--foo\", \"True\"]):\n        CLI(_test_function)\n\n    with open(tmp_path / \"hyperparameters.yaml\", encoding=\"utf-8\") as file:\n        hparams = yaml.full_load(file)\n\n    assert hparams[\"out_dir\"] == str(tmp_path)\n    assert hparams[\"foo\"] is True\n    assert hparams[\"bar\"] == 1\n\n\ndef _test_function2(out_dir: Path, foo: bool = False, bar: int = 1):\n    assert False, \"I only exist as a signature, but I should not run.\"\n\n\n@pytest.mark.parametrize(\n    \"command\",\n    [\n        \"any.py\",\n        \"litgpt finetune\",\n        \"litgpt finetune_full\",\n        \"litgpt finetune_lora\",\n        \"litgpt finetune_adapter\",\n        \"litgpt finetune_adapter_v2\",\n        \"litgpt pretrain\",\n    ],\n)\ndef test_save_hyperparameters_known_commands(command, tmp_path):\n    with mock.patch(\"sys.argv\", [*command.split(\" \"), str(tmp_path), \"--foo\", \"True\"]):\n        save_hyperparameters(_test_function2, tmp_path)\n\n    with open(tmp_path / \"hyperparameters.yaml\", encoding=\"utf-8\") as file:\n        hparams = yaml.full_load(file)\n\n    assert hparams[\"out_dir\"] == str(tmp_path)\n    assert hparams[\"foo\"] is True\n    assert hparams[\"bar\"] == 1\n\n\ndef test_choose_logger(tmp_path):\n    assert isinstance(choose_logger(\"csv\", out_dir=tmp_path, name=\"csv\"), CSVLogger)\n    if _TENSORBOARD_AVAILABLE:\n        assert isinstance(choose_logger(\"tensorboard\", out_dir=tmp_path, name=\"tb\"), TensorBoardLogger)\n    if _WANDB_AVAILABLE:\n        assert isinstance(choose_logger(\"wandb\", out_dir=tmp_path, name=\"wandb\"), WandbLogger)\n    if _MLFLOW_AVAILABLE or _MLFLOW_SKINNY_AVAILABLE:\n        assert isinstance(choose_logger(\"mlflow\", out_dir=tmp_path, name=\"wandb\"), MLFlowLogger)\n    if _LITLOGGER_AVAILABLE:\n        assert isinstance(choose_logger(\"litlogger\", out_dir=tmp_path, name=\"litlogger\"), LitLogger)\n    with pytest.raises(ValueError, match=\"`--logger_name=foo` is not a valid option.\"):\n        choose_logger(\"foo\", out_dir=tmp_path, name=\"foo\")\n\n\n@pytest.mark.parametrize(\n    \"path_type, input_path, expected\",\n    [\n        (\"relative\", \"some/relative/path\", \"some/relative/path\"),\n        (\"absolute\", \"/usr/absolute/path\", \"/usr/absolute/path\"),\n        (\"env_relative\", \"some/relative/path\", \"prefix/some/relative/path\"),\n        (\"env_absolute\", \"/usr/absolute/path\", \"/usr/absolute/path\"),\n    ],\n)\ndef test_init_out_dir(path_type, input_path, expected):\n    if path_type.startswith(\"env_\"):\n        with mock.patch.dict(os.environ, {\"LIGHTNING_ARTIFACTS_DIR\": \"prefix\"}):\n            result = init_out_dir(input_path)\n            assert result == Path(expected), f\"Failed for {path_type} with input {input_path} (result {result})\"\n    else:\n        result = init_out_dir(input_path)\n        if \"LIGHTNING_ARTIFACTS_DIR\" not in os.environ:\n            assert result == Path(expected), f\"Failed for {path_type} with input {input_path} (result {result})\"\n        else:\n            assert result == Path(os.getenv(\"LIGHTNING_ARTIFACTS_DIR\")) / expected, (\n                f\"Failed for {path_type} with input {input_path} (result {result})\"\n            )\n\n\ndef test_find_resume_path(tmp_path):\n    assert find_resume_path(resume=None, out_dir=Path(\"does/not/exist\")) is None\n    assert find_resume_path(resume=Path(\"does/not/exist\"), out_dir=Path(\"does/not/matter\")) == Path(\"does/not/exist\")\n    assert find_resume_path(resume=(tmp_path / \"checkpoint.pt\"), out_dir=Path(\"does/not/matter\")) == (\n        tmp_path / \"checkpoint.pt\"\n    )\n\n    # `resume='auto'` does not enforce the checkpoint to exist\n    assert find_resume_path(resume=\"auto\", out_dir=Path(\"does/not/exist\")) is None\n\n    # `resume=True` requires a checkpoint to exist\n    with pytest.raises(FileNotFoundError, match=\"You passed `--resume=True`, but no checkpoint file was found\"):\n        find_resume_path(resume=True, out_dir=Path(\"does/not/exist\"))\n    with pytest.raises(FileNotFoundError, match=\"You passed `--resume=True`, but no checkpoint file was found\"):\n        find_resume_path(resume=True, out_dir=tmp_path)\n\n    (tmp_path / \"step-001\").mkdir()\n    (tmp_path / \"step-001\" / \"lit_model.pth\").touch()\n    (tmp_path / \"step-002\").mkdir()\n    (tmp_path / \"step-002\" / \"lit_model.pth\").touch()\n    (tmp_path / \"step-003\").mkdir()\n    (tmp_path / \"step-003\" / \"lit_model.pth\").touch()\n\n    assert find_resume_path(resume=True, out_dir=tmp_path) == (tmp_path / \"step-003\" / \"lit_model.pth\")\n    assert find_resume_path(resume=\"auto\", out_dir=tmp_path) == (tmp_path / \"step-003\" / \"lit_model.pth\")\n\n\n@pytest.fixture\ndef model_parameters():\n    return [torch.nn.Parameter(torch.randn(2, 2))]\n\n\ndef test_instantiate_bnb_optimizer_with_str(model_parameters):\n    import bitsandbytes as bnb\n\n    with mock.patch(\"litgpt.utils.get_argument_names\", return_value={\"lr\", \"eps\", \"weight_decay\"}):\n        optimizer = instantiate_bnb_optimizer(\"AdamW\", model_parameters)\n        assert isinstance(optimizer, bnb.optim.adamw.PagedAdamW)\n\n\ndef test_instantiate_bnb_optimizer_with_dict(model_parameters):\n    import bitsandbytes as bnb\n\n    optimizer_dict = {\"class_path\": \"AdamW\", \"init_args\": {\"lr\": 0.01}}\n    with mock.patch(\"litgpt.utils.get_argument_names\", return_value={\"lr\", \"eps\", \"weight_decay\"}):\n        optimizer = instantiate_bnb_optimizer(optimizer_dict, model_parameters)\n        assert isinstance(optimizer, bnb.optim.adamw.PagedAdamW)\n        assert optimizer.param_groups[0][\"lr\"] == 0.01\n\n\ndef test_instantiate_bnb_optimizer_with_invalid_str(model_parameters):\n    with pytest.raises(ValueError, match=\"only supports the AdamW\"):\n        instantiate_bnb_optimizer(\"SGD\", model_parameters)\n\n\ndef test_instantiate_torch_optimizer_with_str(model_parameters):\n    optimizer = instantiate_torch_optimizer(\"Adam\", model_parameters, lr=0.01)\n    assert isinstance(optimizer, torch.optim.Adam)\n    assert optimizer.param_groups[0][\"lr\"] == 0.01\n\n\ndef test_instantiate_torch_optimizer_with_class(model_parameters):\n    optimizer = instantiate_torch_optimizer(\n        {\"class_path\": \"torch.optim.Adam\", \"init_args\": {\"lr\": 123}}, model_parameters, lr=0.02\n    )\n    assert isinstance(optimizer, torch.optim.Adam)\n    # init args gets overridden\n    assert optimizer.param_groups[0][\"lr\"] == 0.02\n\n\n@pytest.mark.parametrize(\n    \"input_path, expected\",\n    [\n        (Path(\"checkpoints/my_model\"), Path(\"checkpoints/my_model\")),\n        (Path(\"checkpoints/my_model\"), Path(\"./checkpoints/my_model\")),\n    ],\n)\ndef test_extend_checkpoint_dir_is_prefixed(input_path, expected):\n    original_dir = Path.cwd()  # Save the current directory\n    with TemporaryDirectory() as tmp_dir:\n        os.chdir(tmp_dir)\n\n        try:\n            if not input_path.is_absolute():\n                input_path = Path(tmp_dir) / input_path\n            if not expected.is_absolute():\n                expected = Path(tmp_dir) / expected\n            input_path.parent.mkdir(parents=True, exist_ok=True)\n            input_path.touch(exist_ok=True)\n            assert extend_checkpoint_dir(input_path) == expected\n        finally:\n            os.chdir(original_dir)  # Reset the current directory\n\n\n@pytest.mark.parametrize(\n    \"input_path, expected\",\n    [\n        (Path(\"my_model\"), Path(\"checkpoints/my_model\")),\n        (Path(\"my_model\"), Path(\"./checkpoints/my_model\")),\n    ],\n)\ndef test_extend_checkpoint_dir(input_path, expected):\n    original_dir = Path.cwd()  # Save the current directory\n    with TemporaryDirectory() as tmp_dir:\n        os.chdir(tmp_dir)\n\n        try:\n            if not input_path.is_absolute():\n                input_path = Path(tmp_dir) / \"checkpoints\" / input_path\n            if not expected.is_absolute():\n                expected = Path(tmp_dir) / expected\n            input_path.parent.mkdir(parents=True, exist_ok=True)\n            input_path.touch(exist_ok=True)\n            assert extend_checkpoint_dir(input_path) == expected\n        finally:\n            os.chdir(original_dir)  # Reset the current directory\n\n\n@pytest.mark.parametrize(\n    \"input_path, expected\",\n    [\n        (Path(\"my_model\"), Path(\"my_model\")),\n        (Path(\"/my_model\"), Path(\"/my_model\")),\n    ],\n)\ndef test_extend_checkpoint_dir_dont_exist(input_path, expected):\n    assert extend_checkpoint_dir(input_path) == expected\n\n\ndef test_file_size_below_limit_on_cpu():\n    # Test file size below limit on CPU\n    with NamedTemporaryFile() as temp_file:\n        with mock.patch(\"os.path.getsize\", return_value=4_000_000_000):\n            size = check_file_size_on_cpu_and_warn(temp_file.name, \"cpu\")\n            assert size == 4_000_000_000\n\n\ndef test_file_size_above_limit_on_cpu():\n    # Test file size above limit on CPU\n    with NamedTemporaryFile() as temp_file:\n        with mock.patch(\"os.path.getsize\", return_value=4_600_000_000):\n            with pytest.warns(UserWarning) as record:\n                size = check_file_size_on_cpu_and_warn(temp_file.name, \"cpu\")\n            assert size == 4_600_000_000\n            assert \"over 4.2 GB\" in str(record[0].message)\n\n\ndef test_file_size_above_limit_on_gpu():\n    # Test file size above limit on GPU should not warn\n    with NamedTemporaryFile() as temp_file:\n        with mock.patch(\"os.path.getsize\", return_value=4_600_000_000):\n            size = check_file_size_on_cpu_and_warn(temp_file.name, \"gpu\")\n            assert size == 4_600_000_000\n\n\n@pytest.fixture\ndef mock_cuda_is_available_true(monkeypatch):\n    \"\"\"Fixture to mock torch.cuda.is_available() to return True.\"\"\"\n    monkeypatch.setattr(torch.cuda, \"is_available\", lambda: True)\n\n\n@pytest.fixture\ndef mock_nvidia_device_properties(monkeypatch):\n    \"\"\"Fixture to mock torch.cuda.get_device_properties() for NVIDIA GPUs.\"\"\"\n    mock_device_properties = mock.MagicMock(name=\"GPU Device\", spec=[\"name\"])\n    mock_device_properties.name = \"NVIDIA RTX A6000\"\n    monkeypatch.setattr(torch.cuda, \"get_device_properties\", lambda idx: mock_device_properties)\n\n\n@pytest.fixture\ndef mock_amd_device_properties(monkeypatch):\n    \"\"\"Fixture to mock torch.cuda.get_device_properties() for AMD GPUs.\"\"\"\n    mock_device_properties = mock.MagicMock(name=\"GPU Device\", spec=[\"name\"])\n    mock_device_properties.name = \"AMD Instinct MI250X\"\n    monkeypatch.setattr(torch.cuda, \"get_device_properties\", lambda idx: mock_device_properties)\n\n\n@pytest.fixture\ndef all_nvlink_connected_output():\n    return mock.MagicMock(\n        stdout=\"\"\"        GPU0\tGPU1\tGPU2\tGPU3\nGPU0\tX\tNV12\tNV12\tNV12\nGPU1\tNV12\tX\tNV12\tNV12\nGPU2\tNV12\tNV12\tX\tNV12\nGPU3\tNV12\tNV12\tNV12\tX\"\"\",\n        returncode=0,\n    )\n\n\n@mock.patch(\"subprocess.run\")\ndef test_all_nvlink_connected(\n    mock_run, all_nvlink_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties\n):\n    mock_run.return_value = all_nvlink_connected_output\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n        mock_print.assert_any_call(\"All GPUs are fully connected via NVLink.\")\n\n\n@pytest.fixture\ndef nvlink_partially_connected_output():\n    return mock.MagicMock(\n        stdout=\"\"\"        GPU0    GPU1    GPU2    GPU3    CPU Affinity\nGPU0     X      NV1     SYS     SYS     0-7\nGPU1    NV1      X      SYS     SYS     0-7\nGPU2    SYS     SYS      X      NV1     8-15\nGPU3    SYS     SYS     NV1      X      8-15\n\nLegend:\n  X   = Self\n  NV1 = Connected via NVLink with 1 hop\n  SYS = Connected via the PCIe or CPU subsystem\"\"\",\n        returncode=0,\n    )\n\n\n@mock.patch(\"subprocess.run\")\ndef test_nvlink_partially_connected_output(\n    mock_run, nvlink_partially_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties\n):\n    mock_run.return_value = nvlink_partially_connected_output\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n        mock_print.assert_any_call(\n            \"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. \"\n            \"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance.\"\n        )\n\n\n@pytest.fixture\ndef nvlink_not_connected_output():\n    return mock.MagicMock(\n        stdout=\"\"\"        GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity   GPU NUMA ID\nGPU0     X      PHB     PHB     PHB     0-47    0               N/A\nGPU1    PHB      X      PHB     PHB     0-47    0               N/A\nGPU2    PHB     PHB      X      PHB     0-47    0               N/A\nGPU3    PHB     PHB     PHB      X      0-47    0               N/A\n\nLegend:\n\n  X    = Self\n  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)\n  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node\n  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)\n  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)\n  PIX  = Connection traversing at most a single PCIe bridge\n  NV#  = Connection traversing a bonded set of # NVLinks\"\"\",\n        returncode=0,\n    )\n\n\n@mock.patch(\"subprocess.run\")\ndef test_nvlink_not_connected_output(\n    mock_run, nvlink_not_connected_output, mock_cuda_is_available_true, mock_nvidia_device_properties\n):\n    mock_run.return_value = nvlink_not_connected_output\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n        mock_print.assert_any_call(\n            \"Warning: Not all GPUs are fully connected via NVLink. Some GPUs are connected via slower interfaces. \"\n            \"It is recommended to switch to a different machine with faster GPU connections for optimal multi-GPU training performance.\"\n        )\n\n\n@pytest.fixture\ndef nvlink_all_gpu_connected_but_other_connected_output():\n    return mock.MagicMock(\n        stdout=\"\"\"\tGPU0\tGPU1\tGPU2\tGPU3\tGPU4\tGPU5\tGPU6\tGPU7\tNIC0\tNIC1\tNIC2\tNIC3\tNIC4\tNIC5\tNIC6\tNIC7\tNIC8\tNIC9\tCPU Affinity\tNUMA Affinity\tGPU NUMA ID\nGPU0\tX \tNV12\tNV12\tNV12\tNV12\tNV12\tNV12\tNV12\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\t0-63,128-191\t0\t\tN/A\nGPU1\tNV12\tX \tNV12\tNV12\tNV12\tNV12\tNV12\tNV12\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\t0-63,128-191\t0\t\tN/A\nGPU2\tNV12\tNV12\tX \tNV12\tNV12\tNV12\tNV12\tNV12\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\t0-63,128-191\t0\t\tN/A\nGPU3\tNV12\tNV12\tNV12\tX \tNV12\tNV12\tNV12\tNV12\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\t0-63,128-191\t0\t\tN/A\nGPU4\tNV12\tNV12\tNV12\tNV12\tX \tNV12\tNV12\tNV12\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\t64-127,192-254\t1\t\tN/A\nGPU5\tNV12\tNV12\tNV12\tNV12\tNV12\tX \tNV12\tNV12\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\t64-127,192-254\t1\t\tN/A\nGPU6\tNV12\tNV12\tNV12\tNV12\tNV12\tNV12\tX \tNV12\tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\t64-127,192-254\t1\t\tN/A\nGPU7\tNV12\tNV12\tNV12\tNV12\tNV12\tNV12\tNV12\tX \tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\t64-127,192-254\t1\t\tN/A\nNIC0\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tX \tPIX\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\nNIC1\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tPIX\tX \tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\nNIC2\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tX \tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\nNIC3\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tPXB\tX \tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\nNIC4\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tX \tPXB\tSYS\tSYS\tSYS\tSYS\nNIC5\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tPXB\tX \tSYS\tSYS\tSYS\tSYS\nNIC6\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tX \tPIX\tSYS\tSYS\nNIC7\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tPIX\tX \tSYS\tSYS\nNIC8\tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tX \tPXB\nNIC9\tSYS\tSYS\tSYS\tSYS\tPXB\tPXB\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tSYS\tPXB\tX\n\nLegend:\n\n  X    = Self\n  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)\n  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node\n  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)\n  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)\n  PIX  = Connection traversing at most a single PCIe bridge\n  NV#  = Connection traversing a bonded set of # NVLinks\n\nNIC Legend:\n\n  NIC0: mlx5_0\n  NIC1: mlx5_1\n  NIC2: mlx5_2\n  NIC3: mlx5_3\n  NIC4: mlx5_4\n  NIC5: mlx5_5\n  NIC6: mlx5_6\n  NIC7: mlx5_7\n  NIC8: mlx5_8\n  NIC9: mlx5_9\n\n\"\"\",\n        returncode=0,\n    )\n\n\n@mock.patch(\"subprocess.run\")\ndef test_nvlink_all_gpu_connected_but_other_connected_output(\n    mock_run,\n    nvlink_all_gpu_connected_but_other_connected_output,\n    mock_cuda_is_available_true,\n    mock_nvidia_device_properties,\n):\n    mock_run.return_value = nvlink_all_gpu_connected_but_other_connected_output\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n    mock_print.assert_any_call(\"All GPUs are fully connected via NVLink.\")\n\n\n@pytest.fixture\ndef nvidia_smi_nvlink_output_dual_gpu_no_numa():\n    return mock.MagicMock(\n        stdout=\"\"\"\n        GPU0    GPU1    CPU Affinity    NUMA Affinity   GPU NUMA ID\nGPU0     X      NV1     0-15    0               N/A\nGPU1    NV1      X      0-15    0               N/A\n\nLegend:\n\n  X    = Self\n  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)\n  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node\n  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)\n  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)\n  PIX  = Connection traversing at most a single PCIe bridge\n  NV#  = Connection traversing a bonded set of # NVLinks\n    \"\"\",\n        returncode=0,\n    )\n\n\n@mock.patch(\"subprocess.run\")\ndef test_check_nvlink_connectivity__returns_fully_connected_when_nvidia_all_nvlink_two_gpus(\n    mock_run, nvidia_smi_nvlink_output_dual_gpu_no_numa, mock_cuda_is_available_true, mock_nvidia_device_properties\n):\n    mock_run.return_value = nvidia_smi_nvlink_output_dual_gpu_no_numa\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n        mock_print.assert_any_call(\"All GPUs are fully connected via NVLink.\")\n\n\n@pytest.fixture\ndef rocm_smi_xgmi_output_multi_gpu():\n    \"\"\"\n    rocm-smi --showtopotype on ROCm 6.0.3+\n    \"\"\"\n    return mock.MagicMock(\n        stdout=\"\"\"\n=============================== ROCm System Management Interface ============================\n=============================== Link Type between two GPUs ===============================\n       GPU0         GPU1         GPU2         GPU3         GPU4         GPU5         GPU6         GPU7\nGPU0   0            XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         XGMI\nGPU1   XGMI         0            XGMI         XGMI         XGMI         XGMI         XGMI         XGMI\nGPU2   XGMI         XGMI         0            XGMI         XGMI         XGMI         XGMI         XGMI\nGPU3   XGMI         XGMI         XGMI         0            XGMI         XGMI         XGMI         XGMI\nGPU4   XGMI         XGMI         XGMI         XGMI         0            XGMI         XGMI         XGMI\nGPU5   XGMI         XGMI         XGMI         XGMI         XGMI         0            XGMI         XGMI\nGPU6   XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         0            XGMI\nGPU7   XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         XGMI         0\n================================== End of ROCm SMI Log ===================================\n    \"\"\",\n        returncode=0,\n    )\n\n\n@mock.patch(\"subprocess.run\")\ndef test_check_nvlink_connectivity__returns_fully_connected_when_amd_all_xgmi_8_gpus(\n    mock_run, rocm_smi_xgmi_output_multi_gpu, mock_cuda_is_available_true, mock_amd_device_properties\n):\n    mock_run.return_value = rocm_smi_xgmi_output_multi_gpu\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n        mock_print.assert_any_call(\"All GPUs are fully connected via XGMI.\")\n\n\n@mock.patch(\"subprocess.run\")\ndef test_check_nvlink_connectivity__returns_no_gpus_when_no_gpus(mock_run, monkeypatch):\n    monkeypatch.setattr(torch.cuda, \"is_available\", lambda: False)\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n        mock_print.assert_any_call(\"No GPUs available\")\n\n\n@mock.patch(\"subprocess.run\")\ndef test_check_nvlink_connectivity__returns_unrecognized_vendor_when_unrecognized_vendor(\n    mock_run, monkeypatch, mock_cuda_is_available_true\n):\n    mock_device_properties = mock.MagicMock(name=\"GPU Device\", spec=[\"name\"])\n    mock_device_properties.name = \"GARAGE DIY HYPERSCALER GPU\"\n    monkeypatch.setattr(torch.cuda, \"get_device_properties\", lambda idx: mock_device_properties)\n    with mock.patch(\"builtins.print\") as mock_print:\n        check_nvlink_connectivity()\n        mock_print.assert_any_call(\"Unrecognized GPU vendor: GARAGE DIY HYPERSCALER GPU\")\n\n\ndef test_fix_and_load_json():\n    # Test 1: Invalid JSON string with a trailing comma\n    invalid_json_trailing_comma = \"\"\"\n    {\n      \"_from_model_config\": true,\n      \"bos_token_id\": 128000,\n      \"eos_token_id\": 128001,\n      \"transformers_version\": \"4.45.0.dev0\",\n      \"do_sample\": true,\n      \"temperature\": 0.6,\n      \"top_p\": 0.9,\n    }\n    \"\"\"\n\n    expected_output_trailing_comma = {\n        \"_from_model_config\": True,\n        \"bos_token_id\": 128000,\n        \"eos_token_id\": 128001,\n        \"transformers_version\": \"4.45.0.dev0\",\n        \"do_sample\": True,\n        \"temperature\": 0.6,\n        \"top_p\": 0.9,\n    }\n\n    result_trailing_comma = fix_and_load_json(invalid_json_trailing_comma)\n    assert result_trailing_comma == expected_output_trailing_comma\n\n    # Test 2: Invalid JSON string with missing commas between properties\n    invalid_json_missing_commas = \"\"\"\n    {\n      \"_from_model_config\": true,\n      \"bos_token_id\": 128000,\n      \"eos_token_id\": 128001,\n      \"transformers_version\": \"4.45.0.dev0\"\n      \"do_sample\": true,\n      \"temperature\": 0.6,\n      \"top_p\": 0.9,\n    }\n    \"\"\"\n\n    expected_output_missing_commas = {\n        \"_from_model_config\": True,\n        \"bos_token_id\": 128000,\n        \"eos_token_id\": 128001,\n        \"transformers_version\": \"4.45.0.dev0\",\n        \"do_sample\": True,\n        \"temperature\": 0.6,\n        \"top_p\": 0.9,\n    }\n\n    result_missing_commas = fix_and_load_json(invalid_json_missing_commas)\n    assert result_missing_commas == expected_output_missing_commas\n\n\ndef test_select_sft_generate_example():\n    eval_mock = mock.MagicMock()\n    data_mock = mock.MagicMock()\n\n    test_dataset = {\"data\": [{\"instruction\": \"Test instruction 1\"}, {\"instruction\": \"Test instruction 2\"}]}\n    train_dataset = {\"data\": [{\"instruction\": \"Train instruction 1\"}, {\"instruction\": \"Train instruction 2\"}]}\n\n    data_mock.test_dataset.data = test_dataset[\"data\"]\n    data_mock.train_dataset.data = train_dataset[\"data\"]\n\n    # Test \"first\" instruction from test dataset\n    eval_mock.evaluate_example = \"first\"\n    instruction = select_sft_generate_example(eval_mock, data_mock)\n    assert instruction == \"Test instruction 1\"\n\n    # Test \"first\" instruction from train dataset when test dataset is empty\n    data_mock.test_dataset.data = []\n    instruction = select_sft_generate_example(eval_mock, data_mock)\n    assert instruction == \"Train instruction 1\"\n\n    # Test random selection from test dataset\n    eval_mock.evaluate_example = \"random\"\n    data_mock.test_dataset.data = [{\"instruction\": \"Test instruction 1\"}, {\"instruction\": \"Test instruction 2\"}]\n    with mock.patch(\"random.randint\", return_value=1):\n        instruction = select_sft_generate_example(eval_mock, data_mock)\n        assert instruction == \"Test instruction 2\"\n\n    # Test random selection from train dataset when test dataset is empty\n    data_mock.test_dataset.data = []\n    with mock.patch(\"random.randint\", return_value=1):\n        instruction = select_sft_generate_example(eval_mock, data_mock)\n        assert instruction == \"Train instruction 2\"\n\n    # Test specific index from test dataset\n    eval_mock.evaluate_example = 1\n    data_mock.test_dataset.data = [{\"instruction\": \"Test instruction 1\"}, {\"instruction\": \"Test instruction 2\"}]\n    instruction = select_sft_generate_example(eval_mock, data_mock)\n    assert instruction == \"Test instruction 2\"\n\n    # Test specific index from train dataset when test dataset has fewer elements\n    data_mock.test_dataset.data = [{\"instruction\": \"Test instruction 1\"}]\n    instruction = select_sft_generate_example(eval_mock, data_mock)\n    assert instruction == \"Train instruction 2\"\n\n    # Test out-of-range index\n    eval_mock.evaluate_example = 2\n    data_mock.test_dataset.data = [{\"instruction\": \"Test instruction 1\"}]\n    data_mock.train_dataset.data = [{\"instruction\": \"Train instruction 1\"}]\n    with pytest.raises(IndexError):\n        select_sft_generate_example(eval_mock, data_mock)\n\n    # Test unknown evaluation type\n    eval_mock.evaluate_example = \"unknown\"\n    with pytest.raises(ValueError):\n        select_sft_generate_example(eval_mock, data_mock)\n"
  },
  {
    "path": "tests/test_yarn.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport pytest\nimport torch\nfrom transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM\n\nfrom litgpt import Config\nfrom litgpt.model import Block\n\n\n@torch.inference_mode()\n@pytest.mark.parametrize(\"batch_size\", (1, 2))\n@pytest.mark.parametrize(\"seq_len\", (8, 16))\n@pytest.mark.parametrize(\"device\", [torch.device(\"cpu\")])\ndef test_deepseek_v3_block_with_yarn(batch_size, seq_len, device):\n    \"\"\"Test DeepSeek V3 block (attention + MLP + norms) with YaRN RoPE scaling - litgpt vs hf\"\"\"\n    # Use layer_idx=0 to test dense MLP instead of MoE\n    layer_idx = 0\n\n    # YaRN configuration\n    yarn_config = dict(\n        factor=8.0,\n        beta_fast=32.0,\n        beta_slow=1.0,\n        original_max_seq_len=4096,\n        mscale=1.0,\n        mscale_all_dim=0.8,\n    )\n\n    config_litgpt = Config(\n        n_embd=64,\n        n_head=4,\n        n_query_groups=4,\n        head_size=16,\n        norm_eps=1e-6,\n        norm_class_name=\"RMSNorm\",\n        bias=False,\n        parallel_residual=False,\n        mlp_class_name=\"LLaMAMoE\",\n        intermediate_size=128,\n        rope_interleave=True,\n        rope_adjustments=yarn_config,  # YaRN config\n        latent_attention={\n            \"q_lora_rank\": 32,\n            \"kv_lora_rank\": 16,\n            \"qk_rope_head_dim\": 8,\n            \"qk_nope_head_dim\": 8,\n            \"v_head_dim\": 16,\n        },\n        first_k_dense_replace=3,  # Use dense MLP for first 3 layers\n    )\n\n    # HF config with YaRN\n    rope_parameters = {\n        \"type\": \"yarn\",\n        \"rope_theta\": 10000.0,\n        \"factor\": yarn_config[\"factor\"],\n        \"beta_fast\": yarn_config[\"beta_fast\"],\n        \"beta_slow\": yarn_config[\"beta_slow\"],\n        \"original_max_position_embeddings\": yarn_config[\"original_max_seq_len\"],\n        \"mscale\": yarn_config[\"mscale\"],\n        \"mscale_all_dim\": yarn_config[\"mscale_all_dim\"],\n    }\n\n    config_hf = DeepseekV3Config(\n        padded_vocab_size=10000,\n        num_hidden_layers=1,\n        vocab_size=10000,\n        hidden_size=64,\n        intermediate_size=128,\n        num_attention_heads=4,\n        num_key_value_heads=4,\n        q_lora_rank=32,\n        kv_lora_rank=16,\n        qk_rope_head_dim=8,\n        qk_nope_head_dim=8,\n        v_head_dim=16,\n        rope_interleave=True,\n        first_k_dense_replace=3,\n        rms_norm_eps=1e-6,\n        rope_scaling=rope_parameters,  # YaRN config\n    )\n\n    # Debug: Check if HF config has rope_parameters\n    print(\"\\n=== HF Config Debug ===\")\n    print(f\"config_hf.rope_parameters: {config_hf.rope_scaling}\")\n\n    block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device)\n    model_hf = DeepseekV3ForCausalLM(config_hf).to(device)\n    block_hf = model_hf.model.layers[layer_idx]\n\n    block_litgpt.eval()\n    block_hf.eval()\n\n    sync_block_weights(block_litgpt, block_hf)\n\n    hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device)\n\n    # Prepare RoPE sin/cos tables using YaRN computation\n    from litgpt.model import build_rope_cache\n\n    rope_head_dim = config_litgpt.latent_attention[\"qk_rope_head_dim\"]\n\n    # Build YaRN RoPE cache for LitGPT\n    cos_litgpt, sin_litgpt = build_rope_cache(\n        seq_len=seq_len,\n        n_elem=rope_head_dim,\n        device=device,\n        base=config_litgpt.rope_base,\n        extra_config={\n            \"factor\": yarn_config[\"factor\"],\n            \"beta_fast\": yarn_config[\"beta_fast\"],\n            \"beta_slow\": yarn_config[\"beta_slow\"],\n            \"original_max_seq_len\": yarn_config[\"original_max_seq_len\"],\n            \"mscale\": yarn_config[\"mscale\"],\n            \"mscale_all_dim\": yarn_config[\"mscale_all_dim\"],\n        },\n    )\n\n    # Get YaRN RoPE embeddings from HF (rotary_emb is on model level, not layer level)\n    rotary_emb = model_hf.model.rotary_emb\n    position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)\n    cos_hf, sin_hf = rotary_emb(hidden_states, position_ids)\n\n    # Expand dimensions for batch and broadcast\n    cos_litgpt = cos_litgpt.unsqueeze(0).expand(batch_size, -1, -1)\n    sin_litgpt = sin_litgpt.unsqueeze(0).expand(batch_size, -1, -1)\n\n    # Compare RoPE embeddings first\n    print(\"\\n=== RoPE Embeddings Comparison ===\")\n    print(f\"LitGPT cos/sin shape: {cos_litgpt.shape}, {sin_litgpt.shape}\")\n    print(f\"HF cos/sin shape: {cos_hf.shape}, {sin_hf.shape}\")\n    print(f\"Cos max diff: {(cos_litgpt - cos_hf).abs().max()}\")\n    print(f\"Sin max diff: {(sin_litgpt - sin_hf).abs().max()}\")\n    print(f\"\\nLitGPT cos sample [0,0,:4]: {cos_litgpt[0, 0, :4]}\")\n    print(f\"HF cos sample [0,0,:4]: {cos_hf[0, 0, :4]}\")\n    print(f\"LitGPT cos min/max: {cos_litgpt.min():.4f} / {cos_litgpt.max():.4f}\")\n    print(f\"HF cos min/max: {cos_hf.min():.4f} / {cos_hf.max():.4f}\")\n\n    # Check inv_freq from both\n    print(\"\\n=== Checking inv_freq ===\")\n    print(f\"HF rotary_emb.inv_freq shape: {rotary_emb.inv_freq.shape}\")\n    print(f\"HF inv_freq: {rotary_emb.inv_freq}\")\n    print(f\"HF attention_scaling: {rotary_emb.attention_scaling}\")\n\n    # Use the same embeddings for both (LitGPT's)\n    cos = cos_litgpt\n    sin = sin_litgpt\n\n    causal_mask = torch.triu(\n        torch.full((seq_len, seq_len), float(\"-inf\"), device=device, dtype=hidden_states.dtype), diagonal=1\n    )\n    attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1)\n\n    # Run forward passes\n    output_litgpt = block_litgpt(hidden_states, cos, sin)\n    output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask)\n    if isinstance(output_hf, tuple):\n        output_hf = output_hf[0]\n\n    max_diff = (output_litgpt - output_hf).abs().max()\n    print(\"\\n=== DEBUG INFO ===\")\n    print(f\"Max diff: {max_diff}\")\n    print(f\"Output litgpt mean: {output_litgpt.mean()}, std: {output_litgpt.std()}\")\n    print(f\"Output hf mean: {output_hf.mean()}, std: {output_hf.std()}\")\n    print(f\"Cos/sin shape: {cos.shape}, {sin.shape}\")\n    print(f\"Hidden states shape: {hidden_states.shape}\")\n\n    # Check if the issue is in attention or MLP\n    if hasattr(output_litgpt, \"shape\") and hasattr(output_hf, \"shape\"):\n        if output_litgpt.shape != output_hf.shape:\n            print(f\"Shape mismatch! litgpt: {output_litgpt.shape}, hf: {output_hf.shape}\")\n\n    assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), f\"FAILED: Max diff: {max_diff}\"\n\n\ndef sync_weights(litgpt_model, hf_model):\n    \"\"\"Copies weights from lit-gpt model to HF model.\"\"\"\n    print(\"Synchronizing weights...\")\n    with torch.no_grad():\n        hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight)\n        hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight)\n        hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight)\n        hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight)\n        hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight)\n        hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight)\n        hf_model.o_proj.weight.copy_(litgpt_model.proj.weight)\n    print(\"Synchronization complete.\")\n\n\ndef sync_block_weights(block_litgpt, block_hf):\n    \"\"\"Synchronize all weights from LitGPT block to HF block.\"\"\"\n    print(\"Synchronizing block weights...\")\n    with torch.no_grad():\n        # Sync attention weights\n        sync_weights(block_litgpt.attn, block_hf.self_attn)\n\n        # Sync MLP weights (assumes dense MLP, not MoE)\n        block_hf.mlp.gate_proj.weight.copy_(block_litgpt.mlp.fc_1.weight)\n        block_hf.mlp.up_proj.weight.copy_(block_litgpt.mlp.fc_2.weight)\n        block_hf.mlp.down_proj.weight.copy_(block_litgpt.mlp.proj.weight)\n\n        # Sync normalization layers\n        block_hf.input_layernorm.weight.copy_(block_litgpt.norm_1.weight)\n        block_hf.post_attention_layernorm.weight.copy_(block_litgpt.norm_2.weight)\n\n    print(\"Block synchronization complete.\")\n"
  },
  {
    "path": "tutorials/0_to_litgpt.md",
    "content": "# Zero to LitGPT: Getting Started with Pretraining, Finetuning, and Using LLMs\n\n\n\nThis tutorial walks you through the main features and usage patterns for ⚡️LitGPT, a library for pretraining, finetuning, and using LLMs that focuses on an efficient user experience while being developer-friendly.\n\nThe topics, following the installation of LitGPT, are in chronological order, reflecting the steps in an LLM lifecycle: Pretraining → Finetuning → Inference.\n\n&nbsp;\n\n<img src=\"images/0_to_litgpt/usage.webp\" width=500>\n\n&nbsp;\n\n<img src=\"images/0_to_litgpt/commands.webp\" width=300>\n\n&nbsp;\n\nHowever, it is also possible, and even common, to use and deploy models with LitGPT without pretraining and finetuning. So, if you are not interested in pretraining and finetuning, please feel free to skip these sections.\n\n\n\n\n\n&nbsp;\n## Install LitGPT\n\nLitGPT is available as a Python library from the PyPI package repository, and we recommend installing it using Python's `pip` installer module, including all required package dependencies:\n\n```bash\npip install 'litgpt[all]'\n```\n\nAlternatively, if you are a researcher or developer planning to make changes to LitGPT, you can clone the GitHub repository and install it from a local folder as follows:\n\n```\ngit clone https://github.com/Lightning-AI/litgpt.git\ncd litgpt\npip install -e '.[all]'\n```\n\n\n&nbsp;\n## Pretrain LLMs\n\nPretraining LLMs requires substantial compute resources and time commitment. For that reason, most researchers and practitioners prefer to skip this step and continue with the *Download pretrained model weights* section instead.\n\nHowever, if you feel adventurous and want to pretrain your own LLM, here's how.\n\nFirst, we have to decide which type of model architecture we want to use. We list the available architectures by using the `pretrain` command without any additional arguments:\n\n```bash\nlitgpt pretrain list\n```\n\nThis prints a list of all available model architectures in alphabetical order:\n\n```\nCamel-Platypus2-13B\nCamel-Platypus2-70B\nCodeLlama-13b-Python-hf\n...\nEleutherAI/pythia-410m\n...\nvicuna-13b-v1.3\nvicuna-13b-v1.5\nvicuna-13b-v1.5-16k\nvicuna-33b-v1.3\nvicuna-7b-v1.3\nvicuna-7b-v1.5\nvicuna-7b-v1.5-16k\n```\n\nSuppose we want to pretraining the 1.1B parameter small `tiny-llama-1.1b` model. Before starting finetuning, we must also choose and download a tokenizer.\n\nWe can download a tokenizer via the `download` command. Note that running `litgpt download list` will also print a list of all available models and tokenizers to download.\n\nTo filter for specific models, e.g., TinyLlama, we can use the `grep` command in our terminal:\n\n```bash\nlitgpt download list | grep  TinyLlama\n```\n\nThis prints\n\n```\nTinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\nTinyLlama/TinyLlama-1.1B-Chat-v1.0\n```\n\nLet's now download the tokenizer corresponding to `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T` that we can then use to pretrain the TinyLlama model:\n\n```\nlitgpt download \\\n   TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \\\n   --tokenizer_only true\n```\n\n(when specified)\n\n&nbsp;\n\n<img src=\"images/0_to_litgpt/pretrain.webp\" width=400>\n\n&nbsp;\n\nNext, we can pretrain the model on the OpenWebText dataset with the default setting as follows:\n\n```bash\nlitgpt pretrain tiny-llama-1.1b \\\n  --data OpenWebText \\\n  --tokenizer_dir TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n```\n\nIf you are interested in additional settings, you can use the help command as follows:\n\n```\nlitgpt pretrain --help\n```\n\n&nbsp;\n\n> [!TIP]\n> Above, we only covered the most basic commands for pretraining a model using LitGPT. We highly recommend checking the resources below if you are interested in pretraining a model.\n\n&nbsp;\n\n**More information and additional resources**\n\n- [tutorials/pretrain](./pretrain.md): General information about pretraining in LitGPT\n- [tutorials/pretrain_tinyllama](./pretrain_tinyllama.md): A tutorial for finetuning a 1.1B TinyLlama model on 3 trillion tokens\n- [config_hub/pretrain](../config_hub/pretrain): Pre-made config files for pretraining that work well out of the box\n- Project templates in reproducible environments with multi-GPU and multi-node support:\n  - [Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)\n  - [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)\n  - [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)\n\n\n&nbsp;\n## Download pretrained model weights\n\nMost practical use cases, like LLM inference (/chat) or finetuning, involve using pretrained model weights. LitGPT supports a large number of model weights, which can be listed by executing the `download` with `list` as an argument:\n\n```bash\nlitgpt download list\n```\n\nThis will print a (long) list of all supported pretrained models (abbreviated for readability below):\n\n```\n..\ngoogle/gemma-2b\n...\nmeta-llama/Llama-2-7b-hf\n...\nmicrosoft/phi-2\n...\nmistralai/Mixtral-8x7B-Instruct-v0.1\n...\n```\n\nTo download the model weights, provide one of the model strings above as input argument:\n\n```bash\nlitgpt download microsoft/phi-2\n```\n\n```\nmodel-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s]\nmodel-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s]\ntokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s]\n...\nConverting checkpoint files to LitGPT format.\nProcessing checkpoints/microsoft/phi-2/model-00001-of-00002.bin\n...\nSaving converted checkpoint to checkpoints/microsoft/phi-2\n```\n\n\n&nbsp;\n\n> [!TIP]\n> Note that some models, such as Llama 2, require that you accept Meta AI's terms of service for this model, and you need to use a special access token via the `litgpt download ... --access_token ...` option. For more information, visit the respective Model Hub website, e.g., [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf). The access token can be created under your Model Hub in the `Profile > Access Tokens` menu.\n\n&nbsp;\n\n\nBy default, the weights are going to be stored in a `./checkpoints` subdirectory:\n\n```bash\nls -lh checkpoints/microsoft/phi-2/\n```\n\n```\ntotal 11G\n-rw-r--r-- 1 sebastian sebastian  863 Mar 19 21:14 config.json\n-rw-r--r-- 1 sebastian sebastian  124 Mar 19 21:14 generation_config.json\n-rw-r--r-- 1 sebastian sebastian 5.2G Mar 19 21:15 lit_model.pth\n-rw-r--r-- 1 sebastian sebastian 4.7G Mar 19 21:15 model-00001-of-00002.bin\n-rw-r--r-- 1 sebastian sebastian 538M Mar 19 21:15 model-00002-of-00002.bin\n-rw-r--r-- 1 sebastian sebastian  528 Mar 19 21:15 model_config.yaml\n-rw-r--r-- 1 sebastian sebastian 2.1M Mar 19 21:14 tokenizer.json\n-rw-r--r-- 1 sebastian sebastian 7.2K Mar 19 21:14 tokenizer_config.json\n```\n\nThe model is now ready for inference and chat, for example, using the `chat` command on the checkpoint directory:\n\n```bash\nlitgpt chat microsoft/phi-2\n```\n\n```\nNow chatting with phi-2.\nTo exit, press 'Enter' on an empty prompt.\n\nSeed set to 1234\n>> Prompt: Why are LLMs so useful?\n>> Reply:  When building applications or operating systems, you can use LLMs to know how a computer should respond to your commands. This can make your programs run faster and more efficiently.\n\nTime for inference: 1.26 sec total, 27.81 tokens/sec, 35 tokens\n\n>> Prompt:\n```\n&nbsp;\n\n> [!TIP]\n> Use `--multiline true` to support prompts that require multiple input lines.\n\n<br>\n\n&nbsp;\n**More information and additional resources**\n\n- [tutorials/download_model_weights](download_model_weights.md): A more comprehensive download tutorial, tips for GPU memory limitations, and more\n\n\n&nbsp;\n## Finetune LLMs\n\nLitGPT supports several methods of supervised instruction finetuning, which allows you to finetune models to follow instructions.\n\nDatasets for Instruction-finetuning are usually formatted in the following way:\n\n&nbsp;\n\n<img src=\"images/0_to_litgpt/instruction-1.webp\" width=400>\n\n&nbsp;\n\nAlternatively, datasets for instruction finetuning can also contain an `'input'` field:\n\nIn an instruction-finetuning context, \"full\" finetuning means updating all model parameters as opposed to only a subset. Adapter and LoRA (short for low-rank adaptation) are methods for parameter-efficient finetuning that only require updating a small fraction of the model weights.\n\n&nbsp;\n\n<img src=\"images/0_to_litgpt/finetune.webp\" width=400>\n\n&nbsp;\n\nParameter-efficient finetuning is much more resource-efficient and cheaper than full finetuning, and it often results in the same good performance on downstream tasks.\n\nIn the following example, we will use LoRA for finetuning, which is one of the most popular LLM finetuning methods. (For more information on how LoRA works, please see [Code LoRA from Scratch](https://lightning.ai/lightning-ai/studios/code-lora-from-scratch).)\n\nBefore we start, we have to download a model as explained in the previous \"Download pretrained model\" section above:\n\n```bash\nlitgpt download microsoft/phi-2\n```\n\nThe LitGPT interface can be used via command line arguments and configuration files. We recommend starting with the configuration files from the [config_hub](../config_hub) and either modifying them directly or overriding specific settings via the command line. For example, we can use the following setting to train the downloaded 2.7B parameter `microsoft/phi-2` model, where we set `--max_steps 5` for a quick test run.\n\nIf you have downloaded or cloned the LitGPT repository, you can provide the `config` file via a relative path:\n\n```bash\nlitgpt finetune_lora microsoft/phi-2\\\n  --config config_hub/finetune/phi-2/lora.yaml \\\n  --train.max_steps 5\n```\n\nAlternatively, you can provide a URL:\n\n```bash\nlitgpt finetune_lora microsoft/phi-2\\\n  --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/phi-2/lora.yaml \\\n  --train.max_steps 5\n```\n\n\n&nbsp;\n\n\n> [!TIP]\n> Note that the config file above will finetune the model on the `Alpaca2k` dataset on 1 GPU and save the resulting files in an `out/finetune/lora-phi-2` directory. All of these settings can be changed via a respective command line argument or by changing the config file.\n> To see more options, execute `litgpt finetune_lora --help`.\n\n&nbsp;\n\nRunning the previous finetuning command will initiate the finetuning process, which should only take about a minute on a GPU due to the `--train.max_steps 5` setting.\n\n```\n{'checkpoint_dir': PosixPath('checkpoints/microsoft/phi-2'),  # TODO\n 'data': Alpaca2k(mask_prompt=False,\n                  val_split_fraction=0.03847,\n                  prompt_style=<litgpt.prompts.Alpaca object at 0x7f5fa2867e80>,\n                  ignore_index=-100,\n                  seed=42,\n                  num_workers=4,\n                  download_dir=PosixPath('data/alpaca2k')),\n 'devices': 1,\n 'eval': EvalArgs(interval=100, max_new_tokens=100, max_iters=100),\n 'logger_name': 'csv',\n 'lora_alpha': 16,\n 'lora_dropout': 0.05,\n 'lora_head': True,\n 'lora_key': True,\n 'lora_mlp': True,\n 'lora_projection': True,\n 'lora_query': True,\n 'lora_r': 8,\n 'lora_value': True,\n 'num_nodes': 1,\n 'out_dir': PosixPath('out/finetune/lora-phi-2'),\n 'precision': 'bf16-true',\n 'quantize': None,\n 'seed': 1337,\n 'train': TrainArgs(save_interval=800,\n                    log_interval=1,\n                    global_batch_size=8,\n                    micro_batch_size=4,\n                    lr_warmup_steps=10,\n                    epochs=1,\n                    max_tokens=None,\n                    max_steps=5,\n                    max_seq_length=512,\n                    tie_embeddings=None,\n                    learning_rate=0.0002,\n                    weight_decay=0.0,\n                    beta1=0.9,\n                    beta2=0.95,\n                    max_norm=None,\n                    min_lr=6e-05)}\nSeed set to 1337\nNumber of trainable parameters: 12,226,560\nNumber of non-trainable parameters: 2,779,683,840\nThe longest sequence length in the train data is 512, the model's maximum sequence length is 512 and context length is 2048\nValidating ...\nRecommend a movie for me to watch during the weekend and explain the reason.\nBelow is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nRecommend a movie for me to watch during the weekend and explain the reason.\n\n### Response:\nI recommend you watch \"Parasite\" because it's a critically acclaimed movie that won multiple awards, including the Academy Award for Best Picture. It's a thought-provoking and suspenseful film that will keep you on the edge of your seat. The movie also tackles social and economic inequalities, making it a must-watch for anyone interested in meaningful storytelling.\n\n/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MeanMetric was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n  warnings.warn(*args, **kwargs)  # noqa: B028\nMissing logger folder: out/finetune/lora-phi-2/logs/csv\nEpoch 1 | iter 1 step 0 | loss train: 1.646, val: n/a | iter time: 820.31 ms\nEpoch 1 | iter 2 step 1 | loss train: 1.660, val: n/a | iter time: 548.72 ms (step)\nEpoch 1 | iter 3 step 1 | loss train: 1.687, val: n/a | iter time: 300.07 ms\nEpoch 1 | iter 4 step 2 | loss train: 1.597, val: n/a | iter time: 595.27 ms (step)\nEpoch 1 | iter 5 step 2 | loss train: 1.640, val: n/a | iter time: 260.75 ms\nEpoch 1 | iter 6 step 3 | loss train: 1.703, val: n/a | iter time: 568.22 ms (step)\nEpoch 1 | iter 7 step 3 | loss train: 1.678, val: n/a | iter time: 511.70 ms\nEpoch 1 | iter 8 step 4 | loss train: 1.741, val: n/a | iter time: 514.14 ms (step)\nEpoch 1 | iter 9 step 4 | loss train: 1.689, val: n/a | iter time: 423.59 ms\nEpoch 1 | iter 10 step 5 | loss train: 1.524, val: n/a | iter time: 603.03 ms (step)\nTraining time: 11.20s\nMemory used: 13.90 GB\nSaving LoRA weights to 'out/finetune/lora-phi-2/final/lit_model.pth.lora'\nSaved merged weights to 'out/finetune/lora-phi-2/final/lit_model.pth'\n```\n\nNotice that the LoRA script saves both the LoRA weights (`'out/finetune/lora-phi-2/final/lit_model.pth.lora'`) and the LoRA weight merged back into the original model (`'out/finetune/lora-phi-2/final/lit_model.pth'`) for convenience. This allows us to use the finetuned model via the `chat` function directly:\n\n```bash\nlitgpt chat out/finetune/lora-phi-2/final/\n```\n\n```\nNow chatting with phi-2.\nTo exit, press 'Enter' on an empty prompt.\n\nSeed set to 1234\n>> Prompt: Why are LLMs so useful?\n>> Reply: LLMs are useful because they can be trained to perform various natural language tasks, such as language translation, text generation, and question-answering. They are also able to understand the context of the input data, which makes them particularly useful for tasks such as sentiment analysis and text summarization. Additionally, because LLMs can learn from large amounts of data, they are able to generalize well and perform well on new data.\n\nTime for inference: 2.15 sec total, 39.57 tokens/sec, 85 tokens\n\n>> Prompt:\n```\n\n\n\n&nbsp;\n\n**More information and additional resources**\n\n- [tutorials/prepare_dataset](prepare_dataset.md): A summary of all out-of-the-box supported datasets in LitGPT and utilities for preparing custom datasets\n- [tutorials/finetune](finetune.md): An overview of the different finetuning methods supported in LitGPT\n- [tutorials/finetune_full](finetune_full.md): A tutorial on full-parameter finetuning\n- [tutorials/finetune_lora](finetune_lora.md): Options for parameter-efficient finetuning with LoRA and QLoRA\n- [tutorials/finetune_adapter](finetune_adapter.md): A description of the parameter-efficient Llama-Adapter methods supported in LitGPT\n- [tutorials/oom](oom.md): Tips for dealing with out-of-memory (OOM) errors\n- [config_hub/finetune](../config_hub/finetune): Pre-made config files for finetuning that work well out of the box\n\n&nbsp;\n## LLM inference\n\nTo use a downloaded or finetuned model for chat, you only need to provide the corresponding checkpoint directory containing the model and tokenizer files. For example, to chat with the phi-2 model from Microsoft, download it as follows, as described in the \"Download pretrained model\" section:\n\n```bash\nlitgpt download microsoft/phi-2\n```\n\n```\nmodel-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s]\nmodel-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s]\ntokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s]\n...\nConverting checkpoint files to LitGPT format.\nProcessing checkpoints/microsoft/phi-2/model-00001-of-00002.bin\n...\nSaving converted checkpoint to checkpoints/microsoft/phi-2\n```\n\n\n\nThen, chat with the model using the following command:\n\n```bash\nlitgpt chat microsoft/phi-2\n```\n\n```\nNow chatting with phi-2.\nTo exit, press 'Enter' on an empty prompt.\n\nSeed set to 1234\n>> Prompt: What is the main difference between a large language model and a traditional search engine?\n>> Reply:  A large language model uses deep learning algorithms to analyze and generate natural language, while a traditional search engine uses algorithms to retrieve information from web pages.\n\nTime for inference: 1.14 sec total, 26.26 tokens/sec, 30 tokens\n```\n\n> [!TIP]\n> Most model weights are already represented in an efficient bfloat16 format. However, if the model currently exceeds your GPU memory, you can try to pass the `--precision bf16-true` option. In addition, you can check the quantization documentation for further optimization, which is linked below.\n\n\n&nbsp;\n**More information and additional resources**\n\n- [tutorials/inference](inference.md): Chat and inference tutorial\n- [tutorials/quantize](quantize.md): Quantizing models to reduce GPU memory requirements\n\n\n&nbsp;\n## Using the LitGPT Python API for Inference\n\nThe previous section explained how to use the `litgpt chat` command line interface for inference. Alternatively, LitGPT also offers a Python API approach to generate text using an LLM:\n\n```python\nfrom litgpt import LLM\n\nllm = LLM.load(\"microsoft/phi-2\")\ntext = llm.generate(\"What do Llamas eat?\", top_k=1, max_new_tokens=30)\nprint(text)\n```\n\nNote that the if you pass a supported model name to `LLM.load()`, as shown above, it will download the model from the HF hub if it doesn't exist locally, yet (use `litgpt download list` on the command line to get a list of all currently supported models.)\n\nAlternatively, to load model from a local path, just provide the corresponding path as input to the `load` method:\n\n```python\nllm = LLM.load(\"path/to/my/local/checkpoint\")\n```\n\n&nbsp;\n**More information and additional resources**\n\n- [tutorials/python-api](python-api.md): The LitGPT Python API documentation\n\n\n\n&nbsp;\n## Evaluating models\n\nLitGPT comes with a handy `litgpt evaluate` command to evaluate models with [Eleuther AI's Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness). For example, to evaluate the previously downloaded `microsoft/phi-2` model on several tasks available from the Evaluation Harness, you can use the following command:\n\n\n```bash\nlitgpt evaluate microsoft/phi-2\n  --batch_size 16 \\\n  --tasks \"hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge\"\n```\n\n(A list of supported tasks can be found [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).)\n\n\n&nbsp;\n## Deploy LLMs\n\nYou can deploy LitGPT LLMs using your tool of choice. Below is an example using LitGPT built-in serving capabilities:\n\n\n```bash\n# 1) Download a pretrained model (alternatively, use your own finetuned model)\nlitgpt download microsoft/phi-2\n\n# 2) Start the server\nlitgpt serve microsoft/phi-2\n```\n\n```python\n# 3) Use the server (in a separate session)\nimport requests, json\n response = requests.post(\n     \"http://127.0.0.1:8000/predict\",\n     json={\"prompt\": \"Fix typos in the following sentence: Example input\"}\n)\nprint(response.json()[\"output\"])\n```\n\nThis prints:\n\n```\nInstruct: Fix typos in the following sentence: Example input\nOutput: Example input.\n```\n\n\n&nbsp;\n**More information and additional resources**\n\n- [tutorials/deploy](deploy.md): A full deployment tutorial and example\n\n\n&nbsp;\n## Converting LitGPT model weights to `safetensors` format\n\nSometimes, it can be useful to convert LitGPT model weights for third-party and external tools. For example, we can convert a LitGPT model to the Hugging Face format and save it via `.safetensors` files, which we can do as follows:\n\n```bash\nlitgpt convert_from_litgpt microsoft/phi-2 out/converted_model/\n```\n\nCertain tools like the `.from_pretrained` method in Hugging Face `transformers` also require the original `config.json` file that originally came with the downloaded model:\n\n```bash\ncp checkpoints/microsoft/phi-2/config.json out/converted_model/config.json\n```\n\nYou can now load the model into a Hugging Face transformers model and safe it in a `.safetensors` format as follows:\n\n```bash\nimport torch\nfrom transformers import AutoModel\n\n# Load model\nstate_dict = torch.load('out/converted_model/model.pth')\nmodel = AutoModel.from_pretrained(\n    \"microsoft/phi-2\", state_dict=state_dict\n)\n\n# Save .safetensors files\nmodel.save_pretrained(\"out/converted_model/\")\n```\n\n```\n⚡ ~/litgpt ls -lh out/converted_model\ntotal 16G\n-rwxr--r-- 1 sebastian sebastian  891 Mar 20 17:08 config.json\n-rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:08 model-00001-of-00003.safetensors\n-rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:09 model-00002-of-00003.safetensors\n-rw-r--r-- 1 sebastian sebastian 601M Mar 20 17:09 model-00003-of-00003.safetensors\n-rw-r--r-- 1 sebastian sebastian 5.2G Mar 20 16:30 model.pth\n-rw-r--r-- 1 sebastian sebastian  33K Mar 20 17:09 model.safetensors.index.json\n```\n\nYou can then use the model with external tools, for example, Eleuther AI's [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) (see the `lm_eval` installation instructions [here](https://github.com/EleutherAI/lm-evaluation-harness?tab=readme-ov-file#install)).\n\nThe LM Evaluation Harness requires a tokenizer to be present in the model checkpoint folder, which we can copy from the original download checkpoint:\n\n```bash\n# Copy the tokenizer needed by the Eval Harness\ncp checkpoints/microsoft/phi-2/tokenizer*\nout/converted_model\n```\n\nThen, we can run the Evaluation Harness as follows:\n\n```bash\nlm_eval --model hf \\\n    --model_args pretrained=\"out/converted_model\" \\\n    --tasks \"hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge\" \\\n    --device \"cuda:0\" \\\n    --batch_size 4\n```\n\n&nbsp;\n\n> [!TIP]\n> The Evaluation Harness tasks above are those used in Open LLM Leaderboard. You can find a list all supported tasks [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md).\n\n\n\n&nbsp;\n**More information and additional resources**\n\n- [tutorials/convert_lit_models](./convert_lit_models.md): Tutorial on converting LitGPT weights\n\n\n\n&nbsp;\n\n## Get involved!\n\nWe appreciate your feedback and contributions. If you have feature requests, questions, or want to contribute code or config files, please don't hesitate to use the [GitHub Issue](https://github.com/Lightning-AI/litgpt/issues) tracker.\n\nWe welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.\n\n&nbsp;\n\n> [!TIP]\n> Unsure about contributing? Check out our [How to Contribute to LitGPT](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/) guide.\n\n&nbsp;\n\nIf you have general questions about building with LitGPT, please [join our Discord](https://discord.gg/VptPCZkGNa).\n"
  },
  {
    "path": "tutorials/convert_hf_checkpoint.md",
    "content": "# Converting Hugging Face Transformers to LitGPT weights\n\nBy default, the `litgpt download` command converts the downloaded HF checkpoint files into a LitGPT compatible format after downloading. For example,\n\n```bash\nlitgpt download EleutherAI/pythia-14m\n```\n\ncreates the following files:\n\n```\ncheckpoints/\n└── EleutherAI/\n    └── pythia-14m/\n        ├── config.json\n        ├── generation_config.json\n        ├── model_config.yaml      # LitGPT specific file\n        ├── lit_model.pth          # LitGPT specific file\n        ├── pytorch_model.bin\n        ├── tokenizer.json\n        └── tokenizer_config.json\n```\n\n\n\nTo disable the automatic conversion, which is useful for development and debugging purposes, you can run the `litgpt download` with the `--convert_checkpoint false` flag. This will only download the checkpoint files but do not convert them for use in LitGPT:\n\n```bash\nrm -rf checkpoints/EleutherAI/pythia-14m\n\nlitgpt download EleutherAI/pythia-14m \\\n  --convert_checkpoint false\n\nls checkpoints/EleutherAI/pythia-14m\n```\n\n```\n checkpoints/\n└── EleutherAI/\n    └── pythia-14m/\n        ├── config.json\n        ├── generation_config.json\n        ├── pytorch_model.bin\n        ├── tokenizer.json\n        └── tokenizer_config.json\n```\n\nThe required files `model_config.yaml` and `lit_model.pth` files can then be manually generated via the `litgpt/scripts/convert_hf_checkpoint.py` script:\n\n```bash\nlitgpt convert_to_litgpt checkpoints/EleutherAI/pythia-14m\n```\n"
  },
  {
    "path": "tutorials/convert_lit_models.md",
    "content": "## Converting LitGPT weights to Hugging Face Transformers\n\nLitGPT weights need to be converted to a format that Hugging Face understands with a [conversion script](../litgpt/scripts/convert_lit_checkpoint.py) before our scripts can run.\n\nWe provide a helpful command to convert models LitGPT models back to their equivalent Hugging Face Transformers format:\n\n```bash\nlitgpt convert_from_litgpt checkpoint_dir converted_dir\n```\n\nThese paths are just placeholders, you will need to customize them based on which finetuning or pretraining command you ran and its configuration.\n\n### Loading converted LitGPT checkpoints into transformers\n\n\nFor example,\n\n```bash\ncp checkpoints/repo_id/config.json converted/config.json\n```\n\nThen, you can load the checkpoint file in a Python session as follows:\n\n```python\nimport torch\nfrom transformers import AutoModel\n\n\nstate_dict = torch.load(\"output_dir/model.pth\")\nmodel = AutoModel.from_pretrained(\n    \"output_dir/\", local_files_only=True, state_dict=state_dict\n)\n```\n\nAlternatively, you can also load the model without copying the `config.json` file as follows:\n\n```python\nmodel = AutoModel.from_pretrained(\"online_repo_id\", state_dict=state_dict)\n```\n\n\n\n### Merging LoRA weights\n\nPlease note that if you want to convert a model that has been finetuned using an adapter like LoRA, these weights should be [merged](../litgpt/scripts/merge_lora.py) to the checkpoint prior to converting.\n\n```sh\nlitgpt merge_lora path/to/lora/checkpoint_dir\n```\n\n<br>\n<br>\n\n# A finetuning and conversion tutorial\n\nThis section contains a reproducible example for finetuning a LitGPT model and converting it back into a HF `transformer` model.\n\n1. Download a model of interest:\n\nFor convenience, we first specify an environment variable (optional) to avoid copy and pasting the whole path:\n\n```bash\nexport repo_id=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n```\n\nInstead of using TinyLlama, you can replace the `repo_id` target with any other model repository\nspecifier that is currently supported by LitGPT. You can get a list of supported repository specifier\nby running `litgpt/scripts/download.py` without any additional arguments.\n\nThen, we download the model we specified via `$repo_id` above:\n\n```bash\nlitgpt download $repo_id\n```\n\n2. Finetune the model:\n\n\n```bash\nexport finetuned_dir=out/lit-finetuned-model\n\nlitgpt finetune_lora $repo_id \\\n   --out_dir $finetuned_dir \\\n   --train.epochs 1 \\\n   --data Alpaca\n```\n\n3. Merge LoRA weights:\n\nNote that this step only applies if the model was finetuned with `lora.py` above and not when `full.py` was used for finetuning.\n\n```bash\nlitgpt merge_lora $finetuned_dir/final\n```\n\n\n4. Convert the finetuning model back into a HF format:\n\n```bash\nlitgpt convert_from_litgpt $finetuned_dir/final/ out/hf-tinyllama/converted\n```\n\n\n5. Load the model into a `transformers` model:\n\n```python\nimport torch\nfrom transformers import AutoModel\n\nstate_dict = torch.load('out/hf-tinyllama/converted/model.pth')\nmodel = AutoModel.from_pretrained(\"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\", state_dict=state_dict)\n```\n\n&nbsp;\n## Using the LM Evaluation Harness\n\nTo evaluate LitGPT models, use the integrated evaluation utilities based on Eleuther AI's LM Evaluation Harness. For more information, please see the [evaluation](evaluation.md) documentation.\n\nAlternatively, if you wish to use converted LitGPT models with the LM Evaluation Harness from [Eleuther AI's GitHub repository](https://github.com/EleutherAI/lm-evaluation-harness), you can use the following steps.\n\n1. Follow the instructions above to load the model into a Hugging Face transformers model.\n\n2. Create a `model.safetensor` file:\n\n```python\nmodel.save_pretrained(\"out/hf-tinyllama/converted/\")\n```\n\n3. Copy the tokenizer files into the model-containing directory:\n\n```bash\ncp checkpoints/$repo_id/tokenizer* out/hf-tinyllama/converted\n```\n\n4. Run the evaluation harness, for example:\n\n```bash\nlm_eval --model hf \\\n    --model_args pretrained=out/hf-tinyllama/converted \\\n    --tasks \"hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge\" \\\n    --device \"cuda:0\" \\\n    --batch_size 4\n```\n"
  },
  {
    "path": "tutorials/deploy.md",
    "content": "# Serve and Deploy LLMs\n\nThis document shows how you can serve a LitGPT for deployment.\n\n\n&nbsp;\n## Serve an LLM with LitServe\n\nThis section illustrates how we can set up an inference server for a phi-2 LLM using `litgpt serve` that is minimal and highly scalable.\n\n\n&nbsp;\n### Step 1: Start the inference server\n\n\n```bash\n# 1) Download a pretrained model (alternatively, use your own finetuned model)\nlitgpt download microsoft/phi-2\n\n# 2) Start the server\nlitgpt serve microsoft/phi-2\n```\n\n> [!TIP]\n> Use `litgpt serve --help` to display additional options, including the port, devices, LLM temperature setting, and more.\n\n\n&nbsp;\n### Step 2: Query the inference server\n\nYou can now send requests to the inference server you started in step 2. For example, in a new Python session, we can send requests to the inference server as follows:\n\n\n```python\nimport requests, json\n\nresponse = requests.post(\n    \"http://127.0.0.1:8000/predict\",\n    json={\"prompt\": \"Fix typos in the following sentence: Example input\"}\n)\n\nprint(response.json()[\"output\"])\n```\n\nExecuting the code above prints the following output:\n\n```\nExample input.\n```\n\n&nbsp;\n### Optional: Use the streaming mode\n\nThe 2-step procedure described above returns the complete response all at once. If you want to stream the response on a token-by-token basis, start the server with the streaming option enabled:\n\n```bash\nlitgpt serve microsoft/phi-2 --stream true\n```\n\nThen, use the following updated code to query the inference server:\n\n```python\nimport requests, json\n\nresponse = requests.post(\n    \"http://127.0.0.1:8000/predict\",\n    json={\"prompt\": \"Fix typos in the following sentence: Example input\"},\n    stream=True\n)\n\n# stream the response\nfor line in response.iter_lines(decode_unicode=True):\n    if line:\n        print(json.loads(line)[\"output\"], end=\"\")\n```\n\n```\nSure, here is the corrected sentence:\n\nExample input\n```\n\n&nbsp;\n## Serve an LLM with OpenAI-compatible API\n\nLitGPT provides OpenAI-compatible endpoints that allow you to use the OpenAI SDK or any OpenAI-compatible client to interact with your models. This is useful for integrating LitGPT into existing applications that use the OpenAI API.\n\n&nbsp;\n### Step 1: Start the server with OpenAI specification\n\n```bash\n# 1) Download a pretrained model (alternatively, use your own finetuned model)\nlitgpt download HuggingFaceTB/SmolLM2-135M-Instruct\n\n# 2) Start the server with OpenAI-compatible endpoints\nlitgpt serve HuggingFaceTB/SmolLM2-135M-Instruct --openai_spec true\n```\n\n> [!TIP]\n> The `--openai_spec true` flag enables OpenAI-compatible endpoints at `/v1/chat/completions` instead of the default `/predict` endpoint.\n\n&nbsp;\n### Step 2: Query using OpenAI-compatible endpoints\n\nYou can now send requests to the OpenAI-compatible endpoint using curl:\n\n```bash\ncurl -X POST http://127.0.0.1:8000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"SmolLM2-135M-Instruct\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Hello! How are you?\"}]\n  }'\n```\n\nOr use the OpenAI Python SDK:\n\n```python\nfrom openai import OpenAI\n\n# Configure the client to use your local LitGPT server\nclient = OpenAI(\n    base_url=\"http://127.0.0.1:8000/v1\",\n    api_key=\"not-needed\"  # LitGPT doesn't require authentication by default\n)\n\nresponse = client.chat.completions.create(\n    model=\"SmolLM2-135M-Instruct\",\n    messages=[\n        {\"role\": \"user\", \"content\": \"Hello! How are you?\"}\n    ]\n)\n\nprint(response.choices[0].message.content)\n```\n\n&nbsp;\n## Serve an LLM UI with Chainlit\n\nIf you are interested in developing a simple ChatGPT-like UI prototype, see the Chainlit tutorial in the following Studio:\n\n<a target=\"_blank\" href=\"https://lightning.ai/lightning-ai/studios/chatgpt-like-llm-uis-via-chainlit\">\n  <img src=\"https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg\" alt=\"Open In Studio\"/>\n</a>\n"
  },
  {
    "path": "tutorials/developer-docs/README.md",
    "content": "LitGPT developer documentation files.\n"
  },
  {
    "path": "tutorials/developer-docs/adding-models.md",
    "content": "# Adding New Models\n\nThis document provides an overview and explanation of how new LLM architectures and model weights can be added to LitGPT.\n\n&nbsp;\n\n> [!NOTE]\n> One of the design focus areas of LitGPT is to provide efficient readable code. At the same time, LitGPT aims to support selected LLMs that are useful to the community. LitGPT aims to reuse and share as much code as possible between different LLMs to strike a balance between code readability and enabling support for various LLMs. In short, we try to minimize writing custom code for a given LLM and aim for code reuse.\n\n\n&nbsp;\n\n&nbsp;\n## 1. Discuss the LLM to be added\n\nAs an open-source project, we appreciate your contributions! However, before you begin putting valuable time and work into a contribution, ideally, open an issue to discuss whether support for a certain model is within the project's scope.\n\n&nbsp;\n## 2. Set up your development environment\n\nClone the repository:\n\n```bash\ngit clone https://github.com/Lightning-AI/litgpt.git\n```\n\nThen, install it with the \"editable\" mode for development:\n\n```bash\ncd litgpt\npip install litgpt -e \".[all]\"\n```\n\n&nbsp;\n## 3. Update the config file\n\nUpdate the [litgpt/config.py](../../litgpt/config.py) config file, adding the new model configuration there. It's easiest to start with the most similar model, copy the configuration, and then modify it according to the `config.json` file on the HF hub.\n\nFor example, suppose an entry for Llama 3 8B already exists and you want to add support for Llama 3 70B.\n\nCopy the Llama 3 8B entry:\n\n```python\n # https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json\n dict(\n     name=\"Llama-3-8B{}\",\n     hf_config=dict(org=\"meta-llama\", name=\"Meta-Llama-3-8B{}\"),\n     vocab_size=128256,\n     padding_multiple=64,\n     n_layer=32,\n     n_head=32,\n     n_query_groups=8,\n     rotary_percentage=1.0,\n     parallel_residual=False,\n     bias=False,\n     norm_class_name=\"RMSNorm\",\n     mlp_class_name=\"LLaMAMLP\",\n     intermediate_size=14336,\n     rope_base=500000,\n ),\n```\n\nThen create the entry for the 70B model. Here, make sure you update the values according to the `config.json` file available on the HF hub:\n\n```python\n# https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json\n dict(\n     name=\"Llama-3-70B{}\",\n     hf_config=dict(org=\"meta-llama\", name=\"Meta-Llama-3-70B{}\"),\n     vocab_size=128256,\n     padding_multiple=64,\n     n_layer=80,\n     n_head=64,\n     n_embd=8192,\n     n_query_groups=8,\n     rotary_percentage=1.0,\n     parallel_residual=False,\n     bias=False,\n     norm_class_name=\"RMSNorm\",\n     mlp_class_name=\"LLaMAMLP\",\n     intermediate_size=28672,\n     rope_base=500000,\n ),\n```\n\n&nbsp;\n\n> [!NOTE]\n> Some models may require you to implement a new MLP class analogous to `class LLaMAMLP`.\n> A more or less reliable indicator is the presence of a `modeling.py` file in the model's original repository.\n> If this file exists, it suggests that this model requires custom code.\n> This will then also require additional changes beyond simply updating\n> the configuration in LitGPT's `config.py`.\n\n&nbsp;\n## 4. Try downloading the model\n\nAfter making the modifications above, try downloading the model:\n\n```bash\nlitgpt download meta-llama/Meta-Llama-3-70B --access_token ...\n```\n\n&nbsp;\n\n> [!NOTE]\n> Not all models require an access token\n\n&nbsp;\n\nIf the conversion following the download fails, proceed with the next section.\n\n&nbsp;\n## 5. Update the checkpoint conversion script\n\nIf the `litgpt download ...` command from the previous section failed, you may have to adjust the checkpoint conversion script: [litgpt/scripts/convert_hf_checkpoint.py](../../litgpt/scripts/convert_hf_checkpoint.py).\n\nHere, you may have to adjust or implement a new `def copy_weights_hf_...` function.\n\nYou can test the updated conversion code without needing to redownload the weights as follows:\n\n```bash\npython litgpt/scripts/convert_hf_checkpoint.py meta-llama/Meta-Llama-3-70B\n```\n\n&nbsp;\n## 6. Add the Prompt Style\n\nIf you are adding a new model class, find out its prompt style. First, check [litgpt/prompts.py](../../litgpt/prompts.py) if a similar prompt style template already exists. For Llama 3, this is as follows:\n\n```python\nclass Llama3(PromptStyle):\n     def apply(self, prompt: str, **kwargs: str) -> str:\n         # https://github.com/meta-llama/llama3/blob/359887376f0aaf30e433f23e25df858d8c2a9833/llama/tokenizer.py#L202-L229\n         return (\n             \"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\n\"\n             \"You are a helpful assistant.<|eot_id|>\\n\"  # The system prompt is optional\n             \"<|start_header_id|>user<|end_header_id|>\\n\\n\"\n             f\"{prompt}<|eot_id|>\\n\"\n             \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n         )\n\n     def stop_tokens(self, tokenizer: \"Tokenizer\") -> Tuple[List[int], ...]:\n         return (\n             [tokenizer.eos_id],\n             [tokenizer.token_to_id(\"<|eot_id|>\")],\n         )\n```\n\nIf your model requires a different prompt template, create a new `PromptStyle` class.\n\nThen, in the same file, update the `prompt_styles` dictionary:\n\n```python\nprompt_styles: Dict[str, Type[PromptStyle]] = {\n    ...\n    \"llama3\": Llama3,\n}\n```\n\nFinally, also in the same file, update the `model_name_to_prompt_style` function:\n\n```python\ndef model_name_to_prompt_style(model_name: str) -> PromptStyle:\n    ...\n    if re.search(\"Llama-3.*-Instruct\", model_name):\n    return Llama3()\n```\n\n&nbsp;\n## 7. Try using the model for inference\n\nNext, use the model to see if inference works:\n\n```bash\nlitgpt generate meta-llama/Meta-Llama-3-70B\n```\n\n&nbsp;\n\n> [!NOTE]\n> If you notice that the model produces non-sensible language outputs, you need to double-check the config file and find out if there are incorrect values or other problems. The next section on adding unit tests may offer additional pointers.\n\n&nbsp;\n\n&nbsp;\n## 8. Add unit tests\n\n&nbsp;\n### 8.1 Add model unit tests\n\nOpen the [`tests/test_model.py`](../../tests/test_model.py) file and add a new `def test_against_hf_...` function using one of the existing functions as a template. For instance,\n\n```python\ndef test_against_hf_llama2(ours_kwargs, device, dtype):\n...\n    # test end to end\n    x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)\n    assert x.size(1) == T\n    ours_y = ours_model(x)\n    theirs_y = theirs_model(x)[\"logits\"].to(dtype)  # HF converts logits to float\n    torch.testing.assert_close(ours_y, theirs_y)\n```\n\nIf the\n\n```bash\nlitgpt generate meta-llama/Meta-Llama-3-70B\n```\n\ncommand from the previous section produces incoherent text, this function can be a helpful guide for debugging. For this, modify the implementation in `transformers` and `litgpt` packages (on your local installation), to inspect or print out the intermediate values at a layer. It's recommend starting with the embedding layers and then go through one layer at the time, to find out where the values differ to get pointers for debugging.\n\nTest the unit test via\n\n```python\npytest tests/test_model.py::test_against_hf_...\n```\n\n&nbsp;\n### 8.2 Add prompt style unit test\n\nOpen the [`tests/test_model.py`](../../tests/test_model.py) file and add a test for the respective prompts you added earlier, if applicable. For example,\n\n\n```python\ndef test_prompt_style_from_config():\n    model_names = [\n        ...\n        \"Llama-3-70B-Instruct\",\n        ...\n    ]\n```\n\nRun the unit test via\n\n```python\npytest tests/test_prompts.py\n```\n\n&nbsp;\n## 9. Try finetuning the model\n\nNow, try finetuning the model:\n\n```bash\nlitgpt finetune meta-llama/Meta-Llama-3-70B --train.max_steps 10\n```\n\n&nbsp;\n## 10. Update the documentation\n\nFinally, update the documentation files.\n\n&nbsp;\n### 10.1 Update the README file\n\nUpdate the \"All Models\" table in the [README.md](../../README.md) file.\n\n&nbsp;\n### 10.2 Update the download tutorials\n\nAdd the new model to the model table at the top as well as to the list under `litgpt download list`.\n"
  },
  {
    "path": "tutorials/developer-docs/python-api.md",
    "content": "# LitGPT High-level Python API\n\nThis is a work-in-progress draft for a high-level LitGPT Python API.\n\n&nbsp;\n## Model loading & saving\n\nThe `LLM.load` command loads an `llm` object, which contains both the model object (a PyTorch module) and a preprocessor.\n\n```python\nfrom litgpt import LLM\n\nllm = LLM.load(\n    model=\"url | local_path\",\n    # high-level user only needs to care about those:\n    memory_reduction=\"none | medium | strong\"\n    # advanced options for technical users:\n    source=\"hf | local | other\"\n    quantize=\"bnb.nf4\",\n    precision=\"bf16-true\",\n    device=\"\"auto | cuda | cpu\",\n)\n```\n\nHere,\n\n-  `llm.model` contains the PyTorch Module\n- and `llm.preprocessor.tokenizer`  contains the tokenizer\n\nThe `llm.save` command saves the model weights, tokenizer, and configuration information.\n\n\n```python\nllm.save(checkpoint_dir, format=\"lightning | ollama | hf\")\n```\n\n\n&nbsp;\n## Inference / Chat\n\n```\nresponse = llm.generate(\n    prompt=\"What do Llamas eat?\",\n    temperature=0.1,\n    top_p=0.8,\n    ...\n)\n```\n\n\n&nbsp;\n## Dataset\n\nThe `llm.prepare_dataset` command prepares a dataset for training.\n\n```\nllm.download_dataset(\n    URL,\n    ...\n)\n```\n\n```\ndataset = llm.prepare_dataset(\n    path,\n    task=\"pretrain | instruction_finetune\",\n    test_portion=0.1,\n    ...\n)\n```\n\n&nbsp;\n## Training\n\n\n```python\nllm.instruction_finetune(\n    config=None,\n    dataset=dataset,\n    max_iter=10,\n    method=\"full | lora | adapter | adapter_v2\"\n)\n```\n\n```python\nllm.pretrain(config=None, dataset=dataset, max_iter=10, ...)\n```\n\n&nbsp;\n## Serving\n\n\n```python\nllm.serve(port=8000)\n```\n\nThen in another Python session:\n\n```python\nimport requests, json\n\nresponse = requests.post(\n    \"http://127.0.0.1:8000/predict\",\n    json={\"prompt\": \"Fix typos in the following sentence: Example input\"}\n)\n\nprint(response.json()[\"output\"])\n```\n"
  },
  {
    "path": "tutorials/download_model_weights.md",
    "content": "# Download Model Weights with LitGPT\n\nLitGPT supports a variety of LLM architectures with publicly available weights. You can download model weights and access a list of supported models using the `litgpt download list` command.\n\n&nbsp;\n\n\n| Model | Model size | Author | Reference |\n|----|----|----|----|\n| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma)                                                                 |\n| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950)                                                                   |\n| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/)                                                                                             |\n| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm)      |\n| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae)                                                                                              |\n| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3)                                                                                              |\n| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models)                 |\n| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2)                                  |\n| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf)                                       |\n| Gemma 2 | 2B, 9B, 27B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf)                              |\n| Gemma 3 | 1B, 4B, 12B, 27B | Google | [Google Team, Google Deepmind](https://arxiv.org/pdf/2503.19786)\n| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288)                                                                           |\n| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3)                                                                                   |\n| Llama 3.1 | 8B, 70B, 405B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3)                                                                           |\n| Llama 3.2 | 1B, 3B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md)                                    |\n| Llama 3.3 | 70B | Meta AI | [Meta AI 2024](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct)                                                                                 |\n| Llama 3.1 Nemotron | 70B | NVIDIA | [NVIDIA AI 2024](https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct/modelcard) |\n| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/)                                                                       |\n| Mathstral | 7B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mathstral/)                                                                        |\n| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)\n| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/)                                                                     |\n| Mistral | 7B, 123B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/)                                                                        |\n| Mixtral MoE | 8x22B | Mistral AI | [Mistral AI 2024](https://mistral.ai/news/mixtral-8x22b/)                                                                         |\n| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch)                                                                          |\n| OLMo | 1B, 7B | Allen Institute for AI (AI2) | [Groeneveld et al. 2024](https://aclanthology.org/2024.acl-long.841/)     |\n| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama)                                                         |\n| Phi 1.5 & 2 | 1.3B, 2.7B | Microsoft Research  | [Li et al. 2023](https://arxiv.org/abs/2309.05463)                                                                          |\n| Phi 3 & 3.5 | 3.8B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2404.14219)\n| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905)                                                                            |\n| Phi 4 Mini Instruct | 3.8B | Microsoft Research | [Microsoft 2025](https://arxiv.org/abs/2503.01743)                                           |\n| Phi 4 Mini Reasoning | 3.8B | Microsoft Research | [Xu, Peng et al. 2025](https://arxiv.org/abs/2504.21233)                                           |\n| Phi 4 Reasoning | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318)                                           |\n| Phi 4 Reasoning Plus | 3.8B | Microsoft Research | [Abdin et al. 2025](https://arxiv.org/abs/2504.21318)                                           |\n| Platypus | 7B, 13B, 70B |  Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317)                                                               |\n| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373)                                            |\n| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/)                                               |\n| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186)                                          |\n| Qwen2.5 1M (Long Context) | 7B, 14B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwen2.5-1m/)                                          |\n| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122)                                          |\n| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/)                                                                         |\n| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/)                                                                         |\n| Qwen3 | 0.6B, 1.7B, 4B{Hybrid, Thinking-2507, Instruct-2507}, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/)                                                                         |\n| Qwen3 MoE | 30B{Hybrid, Thinking-2507, Instruct-2507}, 235B{Hybrid, Thinking-2507, Instruct-2507} | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/)                                                                         |\n| R1 Distll Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf)                                                                         |\n| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1)                                                                 |\n| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm)                                                               |\n| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding)                                                  |\n| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra)                                                                         |\n| StableLM  | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM)                                                                    |\n| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding)                                             |\n| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama)                                                                         |\n| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/)                                                                          |                                                            |\n\n&nbsp;\n\n## General Instructions\n\n### 1. List Available Models\n\nTo see all supported models, run the following command:\n\n```bash\nlitgpt download list\n```\n\nThe output is shown below:\n\n```\nallenai/OLMo-1B-hf\nallenai/OLMo-7B-hf\nallenai/OLMo-7B-Instruct-hf\nbsc-lt/salamandra-2b\nbsc-lt/salamandra-2b-instruct\nbsc-lt/salamandra-7b\nbsc-lt/salamandra-7b-instruct\ncodellama/CodeLlama-13b-hf\ncodellama/CodeLlama-13b-Instruct-hf\ncodellama/CodeLlama-13b-Python-hf\ncodellama/CodeLlama-34b-hf\ncodellama/CodeLlama-34b-Instruct-hf\ncodellama/CodeLlama-34b-Python-hf\ncodellama/CodeLlama-70b-hf\ncodellama/CodeLlama-70b-Instruct-hf\ncodellama/CodeLlama-70b-Python-hf\ncodellama/CodeLlama-7b-hf\ncodellama/CodeLlama-7b-Instruct-hf\ncodellama/CodeLlama-7b-Python-hf\ndatabricks/dolly-v2-12b\ndatabricks/dolly-v2-3b\ndatabricks/dolly-v2-7b\ndeepseek-ai/DeepSeek-R1-Distill-Llama-8B\ndeepseek-ai/DeepSeek-R1-Distill-Llama-70B\nEleutherAI/pythia-1.4b\nEleutherAI/pythia-1.4b-deduped\nEleutherAI/pythia-12b\nEleutherAI/pythia-12b-deduped\nEleutherAI/pythia-14m\nEleutherAI/pythia-160m\nEleutherAI/pythia-160m-deduped\nEleutherAI/pythia-1b\nEleutherAI/pythia-1b-deduped\nEleutherAI/pythia-2.8b\nEleutherAI/pythia-2.8b-deduped\nEleutherAI/pythia-31m\nEleutherAI/pythia-410m\nEleutherAI/pythia-410m-deduped\nEleutherAI/pythia-6.9b\nEleutherAI/pythia-6.9b-deduped\nEleutherAI/pythia-70m\nEleutherAI/pythia-70m-deduped\ngarage-bAInd/Camel-Platypus2-13B\ngarage-bAInd/Camel-Platypus2-70B\ngarage-bAInd/Platypus-30B\ngarage-bAInd/Platypus2-13B\ngarage-bAInd/Platypus2-70B\ngarage-bAInd/Platypus2-70B-instruct\ngarage-bAInd/Platypus2-7B\ngarage-bAInd/Stable-Platypus2-13B\ngoogle/codegemma-7b-it\ngoogle/gemma-3-27b-it\ngoogle/gemma-3-12b-it\ngoogle/gemma-3-4b-it\ngoogle/gemma-3-1b-it\ngoogle/gemma-2-27b\ngoogle/gemma-2-27b-it\ngoogle/gemma-2-2b\ngoogle/gemma-2-2b-it\ngoogle/gemma-2-9b\ngoogle/gemma-2-9b-it\ngoogle/gemma-2b\ngoogle/gemma-2b-it\ngoogle/gemma-7b\ngoogle/gemma-7b-it\nh2oai/h2o-danube2-1.8b-chat\nHuggingFaceTB/SmolLM2-135M\nHuggingFaceTB/SmolLM2-135M-Instruct\nHuggingFaceTB/SmolLM2-360M\nHuggingFaceTB/SmolLM2-360M-Instruct\nHuggingFaceTB/SmolLM2-1.7B\nHuggingFaceTB/SmolLM2-1.7B-Instruct\nlmsys/longchat-13b-16k\nlmsys/longchat-7b-16k\nlmsys/vicuna-13b-v1.3\nlmsys/vicuna-13b-v1.5\nlmsys/vicuna-13b-v1.5-16k\nlmsys/vicuna-33b-v1.3\nlmsys/vicuna-7b-v1.3\nlmsys/vicuna-7b-v1.5\nlmsys/vicuna-7b-v1.5-16k\nmeta-llama/Llama-2-13b-chat-hf\nmeta-llama/Llama-2-13b-hf\nmeta-llama/Llama-2-70b-chat-hf\nmeta-llama/Llama-2-70b-hf\nmeta-llama/Llama-2-7b-chat-hf\nmeta-llama/Llama-2-7b-hf\nmeta-llama/Llama-3.2-1B\nmeta-llama/Llama-3.2-1B-Instruct\nmeta-llama/Llama-3.2-3B\nmeta-llama/Llama-3.2-3B-Instruct\nmeta-llama/Llama-3.3-70B-Instruct\nmeta-llama/Meta-Llama-3-70B\nmeta-llama/Meta-Llama-3-70B-Instruct\nmeta-llama/Meta-Llama-3-8B\nmeta-llama/Meta-Llama-3-8B-Instruct\nmeta-llama/Meta-Llama-3.1-405B\nmeta-llama/Meta-Llama-3.1-405B-Instruct\nmeta-llama/Meta-Llama-3.1-70B\nmeta-llama/Meta-Llama-3.1-70B-Instruct\nmeta-llama/Meta-Llama-3.1-8B\nmeta-llama/Meta-Llama-3.1-8B-Instruct\nmicrosoft/phi-1_5\nmicrosoft/phi-2\nmicrosoft/Phi-3-mini-128k-instruct\nmicrosoft/Phi-3-mini-4k-instruct\nmicrosoft/Phi-3.5-mini-instruct\nmicrosoft/phi-4\nmicrosoft/Phi-4-mini-instruct\nmistralai/mathstral-7B-v0.1\nmistralai/Mistral-7B-Instruct-v0.1\nmistralai/Mistral-7B-Instruct-v0.2\nmistralai/Mistral-7B-Instruct-v0.3\nmistralai/Mistral-7B-v0.1\nmistralai/Mistral-7B-v0.3\nmistralai/Mistral-Large-Instruct-2407\nmistralai/Mistral-Large-Instruct-2411\nmistralai/Mixtral-8x7B-Instruct-v0.1\nmistralai/Mixtral-8x7B-v0.1\nmistralai/Mixtral-8x22B-Instruct-v0.1\nmistralai/Mixtral-8x22B-v0.1\nNousResearch/Nous-Hermes-13b\nNousResearch/Nous-Hermes-llama-2-7b\nNousResearch/Nous-Hermes-Llama2-13b\nnvidia/Llama-3.1-Nemotron-70B-Instruct-HF\nopenlm-research/open_llama_13b\nopenlm-research/open_llama_3b\nopenlm-research/open_llama_7b\nQwen/Qwen2.5-0.5B\nQwen/Qwen2.5-0.5B-Instruct\nQwen/Qwen2.5-1.5B\nQwen/Qwen2.5-1.5B-Instruct\nQwen/Qwen2.5-3B\nQwen/Qwen2.5-3B-Instruct\nQwen/Qwen2.5-7B\nQwen/Qwen2.5-7B-Instruct\nQwen/Qwen2.5-7B-Instruct-1M\nQwen/Qwen2.5-14B\nQwen/Qwen2.5-14B-Instruct\nQwen/Qwen2.5-14B-Instruct-1M\nQwen/Qwen2.5-32B\nQwen/Qwen2.5-32B-Instruct\nQwen/Qwen2.5-72B\nQwen/Qwen2.5-72B-Instruct\nQwen/Qwen2.5-Coder-0.5B\nQwen/Qwen2.5-Coder-0.5B-Instruct\nQwen/Qwen2.5-Coder-1.5B\nQwen/Qwen2.5-Coder-1.5B-Instruct\nQwen/Qwen2.5-Coder-3B\nQwen/Qwen2.5-Coder-3B-Instruct\nQwen/Qwen2.5-Coder-7B\nQwen/Qwen2.5-Coder-7B-Instruct\nQwen/Qwen2.5-Coder-14B\nQwen/Qwen2.5-Coder-14B-Instruct\nQwen/Qwen2.5-Coder-32B\nQwen/Qwen2.5-Coder-32B-Instruct\nQwen/Qwen2.5-Math-1.5B\nQwen/Qwen2.5-Math-1.5B-Instruct\nQwen/Qwen2.5-Math-7B\nQwen/Qwen2.5-Math-7B-Instruct\nQwen/Qwen2.5-Math-72B\nQwen/Qwen2.5-Math-72B-Instruct\nQwen/Qwen3-0.6B\nQwen/Qwen3-0.6B-Base\nQwen/Qwen3-1.7B\nQwen/Qwen3-1.7B-Base\nQwen/Qwen3-4B\nQwen/Qwen3-4B-Base\nQwen/Qwen3-8B\nQwen/Qwen3-8B-Base\nQwen/Qwen3-14B\nQwen/Qwen3-14B-Base\nQwen/Qwen3-32B\nQwen/Qwen3-30B-A3B\nQwen/Qwen3-30B-A3B-Base\nQwen/Qwen3-235B-A22B\nQwen/Qwen3-4B-Thinking-2507\nQwen/Qwen3-4B-Instruct-2507\nQwen/Qwen3-30B-A3B-Thinking-2507\nQwen/Qwen3-30B-A3B-Instruct-2507\nQwen/Qwen3-235B-A22B-Thinking-2507\nQwen/Qwen3-235B-A22B-Instruct-2507\nQwen/QwQ-32B\nQwen/QwQ-32B-Preview\nstabilityai/FreeWilly2\nstabilityai/stable-code-3b\nstabilityai/stablecode-completion-alpha-3b\nstabilityai/stablecode-completion-alpha-3b-4k\nstabilityai/stablecode-instruct-alpha-3b\nstabilityai/stablelm-3b-4e1t\nstabilityai/stablelm-base-alpha-3b\nstabilityai/stablelm-base-alpha-7b\nstabilityai/stablelm-tuned-alpha-3b\nstabilityai/stablelm-tuned-alpha-7b\nstabilityai/stablelm-zephyr-3b\ntiiuae/falcon-180B\ntiiuae/falcon-180B-chat\ntiiuae/falcon-40b\ntiiuae/falcon-40b-instruct\ntiiuae/falcon-7b\ntiiuae/falcon-7b-instruct\ntiiuae/Falcon3-1B-Base\ntiiuae/Falcon3-1B-Instruct\ntiiuae/Falcon3-3B-Base\ntiiuae/Falcon3-3B-Instruct\ntiiuae/Falcon3-7B-Base\ntiiuae/Falcon3-7B-Instruct\ntiiuae/Falcon3-10B-Base\ntiiuae/Falcon3-10B-Instruct\nTinyLlama/TinyLlama-1.1B-Chat-v1.0\nTinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\ntogethercomputer/LLaMA-2-7B-32K\ntogethercomputer/RedPajama-INCITE-7B-Base\ntogethercomputer/RedPajama-INCITE-7B-Chat\ntogethercomputer/RedPajama-INCITE-7B-Instruct\ntogethercomputer/RedPajama-INCITE-Base-3B-v1\ntogethercomputer/RedPajama-INCITE-Base-7B-v0.1\ntogethercomputer/RedPajama-INCITE-Chat-3B-v1\ntogethercomputer/RedPajama-INCITE-Chat-7B-v0.1\ntogethercomputer/RedPajama-INCITE-Instruct-3B-v1\ntogethercomputer/RedPajama-INCITE-Instruct-7B-v0.1\nTrelis/Llama-2-7b-chat-hf-function-calling-v2\nunsloth/Mistral-7B-v0.2\n```\n\n&nbsp;\n\n> [!TIP]\n> To sort the list above by model name after the `/`, use `litgpt download list | sort -f -t'/' -k2`.\n\n&nbsp;\n\n> [!NOTE]\n> If you want to adopt a model variant that is not listed in the table above but has a similar architecture as one of the supported models, you can use this model by by using the `--model_name` argument as shown below:\n>\n> ```bash\n> litgpt download NousResearch/Hermes-2-Pro-Mistral-7B \\\n>  --model_name Mistral-7B-v0.1\n> ```\n\n&nbsp;\n\n### 2. Download Model Weights\n\nTo download the weights for a specific model provide a `<repo_id>` with the model's repository ID. For example:\n\n```bash\nlitgpt download <repo_id>\n```\n\nThis command downloads the model checkpoint into the `checkpoints/` directory.\n\n&nbsp;\n\n### 3. Additional Help\n\nFor more options, add the `--help` flag when running the script:\n\n```bash\nlitgpt download --help\n```\n\n&nbsp;\n\n### 4. Run the Model\n\nAfter conversion, run the model with the given checkpoint path as input, adjusting `repo_id` accordingly:\n\n```bash\nlitgpt chat <repo_id>\n```\n\n&nbsp;\n\n## Tinyllama Example\n\nThis section shows a typical end-to-end example for downloading and using TinyLlama:\n\n1. List available TinyLlama checkpoints:\n\n```bash\nlitgpt download list | grep Tiny\n```\n\n```\nTinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\nTinyLlama/TinyLlama-1.1B-Chat-v1.0\n```\n\n2. Download a TinyLlama checkpoint:\n\n```bash\nexport repo_id=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\nlitgpt download $repo_id\n```\n\n3. Use the TinyLlama model:\n\n```bash\nlitgpt chat $repo_id\n```\n\n&nbsp;\n## Specific models and access tokens\n\nNote that certain models require that you've been granted access to the weights on the Hugging Face Hub.\n\nFor example, to get access to the Gemma 2B model, you can do so by following the steps at <https://huggingface.co/google/gemma-2b>. After access is granted, you can find your HF hub token in <https://huggingface.co/settings/tokens>.\n\nOnce you've been granted access and obtained the access token you need to pass the additional `--access_token`:\n\n```bash\nlitgpt download google/gemma-2b \\\n  --access_token your_hf_token\n```\n\n&nbsp;\n\n## Finetunes and Other Model Variants\n\nSometimes you want to download the weights of a finetune of one of the models listed above. To do this, you need to manually specify the `model_name` associated to the config to use. For example:\n\n```bash\nlitgpt download NousResearch/Hermes-2-Pro-Mistral-7B \\\n  --model_name Mistral-7B-v0.1\n```\n\n&nbsp;\n\n## Tips for GPU Memory Limitations\n\nThe `litgpt download` command will automatically convert the downloaded model checkpoint into a LitGPT-compatible format. In case this conversion fails due to GPU memory constraints, you can try to reduce the memory requirements by passing the  `--dtype bf16-true` flag to convert all parameters into this smaller precision (however, note that most model weights are already in a bfloat16 format, so it may not have any effect):\n\n```bash\nlitgpt download <repo_id>\n  --dtype bf16-true\n```\n\n(If your GPU does not support the bfloat16 format, you can also try a regular 16-bit float format via `--dtype 16-true`.)\n\n&nbsp;\n\n## Converting Checkpoints Manually\n\nFor development purposes, for example, when adding or experimenting with new model configurations, it may be beneficial to split the weight download and model conversion into two separate steps.\n\nYou can do this by passing the `--convert_checkpoint false` option to the download script:\n\n```bash\nlitgpt download <repo_id> \\\n  --convert_checkpoint false\n```\n\nand then calling the `convert_hf_checkpoint` command:\n\n```bash\nlitgpt convert_to_litgpt <repo_id>\n```\n\n&nbsp;\n\n## Downloading Tokenizers Only\n\nIn some cases we don't need the model weight, for example, when we are pretraining a model from scratch instead of finetuning it. For cases like this, you can use the `--tokenizer_only` flag to only download a model's tokenizer, which can then be used in the pretraining scripts:\n\n```bash\nlitgpt download TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \\\n  --tokenizer_only true\n```\n\nand\n\n```bash\nlitgpt pretrain tiny-llama-1.1b \\\n  --data ... \\\n  --tokenizer_dir TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T/\n```\n"
  },
  {
    "path": "tutorials/evaluation.md",
    "content": "# LLM Evaluation\n\n&nbsp;\n\n## Using lm-evaluation-harness\n\nYou can evaluate LitGPT using [EleutherAI's lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) framework with a large number of different evaluation tasks.\n\nYou need to install the `lm-eval` framework first:\n\n```bash\npip install lm_eval\n```\n\n&nbsp;\n\n### Evaluating LitGPT base models\n\nSuppose you downloaded a base model that we want to evaluate. Here, we use the `microsoft/phi-2` model:\n\n```bash\nlitgpt download microsoft/phi-2\n```\n\nThe download command above will save the model to the `checkpoints/microsoft/phi-2` directory, which we can\nspecify in the following evaluation command:\n\n\n```\nlitgpt evaluate microsoft/phi-2/ \\\n  --batch_size 4 \\\n  --tasks \"hellaswag,truthfulqa_mc2,mmlu\" \\\n  --out_dir evaluate_model/\n```\n\nThe resulting output is as follows:\n\n```\n...\n|---------------------------------------|-------|------|-----:|--------|-----:|---|-----:|\n...\n|truthfulqa_mc2                         |      2|none  |     0|acc     |0.4656|±  |0.0164|\n|hellaswag                              |      1|none  |     0|acc     |0.2569|±  |0.0044|\n|                                       |       |none  |     0|acc_norm|0.2632|±  |0.0044|\n\n|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|\n|------------------|-------|------|-----:|------|-----:|---|-----:|\n|mmlu              |N/A    |none  |     0|acc   |0.2434|±  |0.0036|\n| - humanities     |N/A    |none  |     0|acc   |0.2578|±  |0.0064|\n| - other          |N/A    |none  |     0|acc   |0.2401|±  |0.0077|\n| - social_sciences|N/A    |none  |     0|acc   |0.2301|±  |0.0076|\n| - stem           |N/A    |none  |     0|acc   |0.2382|±  |0.0076|\n```\n\n\nPlease note that the `litgpt evaluate` command run an internal model conversion.\nThis is only necessary the first time you want to evaluate a model, and it will skip the\nconversion steps if you run the `litgpt evaluate` on the same checkpoint directory again.\n\nIn some cases, for example, if you modified the model in the `checkpoint_dir` since the first `litgpt evaluate`\ncall, you need to use the `--force_conversion` flag to to update the files used by litgpt evaluate accordingly:\n\n```\nlitgpt evaluate microsoft/phi-2/ \\\n  --batch_size 4 \\\n  --out_dir evaluate_model/ \\\n  --tasks \"hellaswag,truthfulqa_mc2,mmlu\" \\\n  --force_conversion true\n```\n\n&nbsp;\n\n> [!TIP]\n> Run `litgpt evaluate list` to print a list\n> of the supported tasks. To filter for a specific subset of tasks, e.g., MMLU, use `litgpt evaluate list | grep mmlu`.\n\n> [!TIP]\n> The evaluation may take a long time, and for testing purpoes, you may want to reduce the number of tasks\n> or set a limit for the number of examples per task, for example, `--limit 10`.\n\n\n\n\n&nbsp;\n\n### Evaluating LoRA-finetuned LLMs\n\nNo further conversion is necessary when evaluating LoRA-finetuned models as the `finetune_lora` command already prepares the necessary merged model files:\n\n```bash\nlitgpt finetune_lora microsoft/phi-2 \\\n  --out_dir lora_model\n```\n\n&nbsp;\n\n```bash\nlitgpt evaluate lora_model/final \\\n  --batch_size 4 \\\n  --tasks \"hellaswag,truthfulqa_mc2,mmlu\" \\\n  --out_dir evaluate_model/ \\\n```\n\n\n&nbsp;\n\n### Evaluating on a custom test set\n\nThere is currently no built-in function to evaluate models on custom test sets. However, this section describes a general approach that users can take to evaluate the responses of a model using another LLM.\n\nSuppose you have a test dataset with the following structure:\n\n```python\ntest_data = [\n    {\n        \"instruction\": \"Name the author of 'Pride and Prejudice'.\",\n        \"input\": \"\",\n        \"output\": \"Jane Austen.\"\n    },\n    {\n        \"instruction\": \"Pick out the adjective from the following list.\",\n        \"input\": \"run, tall, quickly\",\n        \"output\": \"The correct adjective from the list is 'tall.'\"\n    },\n]\n```\n\nFor simplicity, the dictionary above only contains two entries. In practice, it is recommended to use test datasets that contain at least 100 entries (ideally 1000 or more).\n\nIf your dataset is stored in JSON format, use the following code to load it:\n\n```python\nwith open(\"test_data.json\", \"r\") as file:\n    test_data = json.load(file)\n```\n\nNext, it is recommended to format the dataset according to a prompt style. For example, to use the `Alpaca` prompt style, use the following code:\n\n```python\nfrom litgpt.prompts import Alpaca\n\nprompt_style = Alpaca()\nprompt_style.apply(prompt=test_data[0][\"instruction\"], **test_data[0])\n```\n\nwhich returns\n\n```\n\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\nName the author of 'Pride and Prejudice'.\\n\\n### Response:\\n\n```\n\nNext, load the LLM you want to evaluate. For this example, we use `phi-2`:\n\n```python\nfrom litgpt import LLM\n\nllm = LLM.load(\"microsoft/phi-2\")\n```\n\nThen, using the loaded model, we add the test set responses to the dataset:\n\n\n```python\nfrom tqdm import trange\n\n\nfor i in trange(len(test_data)):\n    response = llm.generate(prompt_style.apply(prompt=test_data[i][\"instruction\"], **test_data[i]))\n    test_data[i][\"response\"] = response\n```\n\nNext, we use a second LLM to calculate the response quality on a scale from 0 to 100. It is recommended to use the 70B Llama 3 instruction-fintuned model for this task, or the smaller 8B Llama 3 model, which is more resource-efficient:\n\n\n```python\ndel llm # delete previous `llm` to free up GPU memory\nscorer = LLM.load(\"meta-llama/Meta-Llama-3-8B-Instruct\", access_token=\"...\")\n```\n\nThen, based on this LLM, we calculate the response quality with the following function:\n\n```python\nfrom tqdm import tqdm\n\n\ndef generate_model_scores(data_dict, model, response_field=\"response\", target_field=\"output\"):\n    scores = []\n    for entry in tqdm(data_dict, desc=\"Scoring entries\"):\n        prompt = (\n            f\"Given the input `{format_input(entry)}` \"\n            f\"and correct output `{entry[target_field]}`, \"\n            f\"score the model response `{entry[response_field]}`\"\n            f\" on a scale from 0 to 100, where 100 is the best score. \"\n            f\"Respond with the integer number only.\"\n        )\n        score = model.generate(prompt, max_new_tokens=50)\n        try:\n            scores.append(int(score))\n        except ValueError:\n            continue\n\n    return scores\n```\n\n\n```python\nscores = generate_model_scores(test_data, model=scorer)\nprint(f\"\\n{llm}\")\nprint(f\"Number of scores: {len(scores)} of {len(test_data)}\")\nprint(f\"Average score: {sum(scores)/len(scores):.2f}\\n\")\n```\n\nThis will print out the average score on all test set entries:\n\n```\nScoring entries: 100%|██████████| 2/2 [00:00<00:00,  4.37it/s]\n\nNumber of scores: 2 of 2\nAverage score: 47.50\n```\n"
  },
  {
    "path": "tutorials/examples/ptl-trainer/README.md",
    "content": "## Minimal PyTorch Lightning Trainer Example\n\n\n\nThe script in this folder provides minimal examples showing how to train a LitGPT model using LitGPT's `GPT` class with the [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning) Trainer.\n\nYou can run the scripts as follows:\n\n&nbsp\n## Small 160M model:\n\n```bash\n# Download the Pythia model\nlitgpt download EleutherAI/pythia-160m\n\npython litgpt_ptl_small.py\n```\n\n&nbsp\n## Medium-sized 8B model:\n\n```bash\n# Download the Llama 3.1 model\nlitgpt download meta-llama/Meta-Llama-3.1-8B --access_token hf_...\n\npython litgpt_ptl_medium.py\n```\n"
  },
  {
    "path": "tutorials/examples/ptl-trainer/litgpt_ptl_medium.py",
    "content": "import lightning as L\nimport torch\n\nimport litgpt\nfrom litgpt.data import Alpaca2k\nfrom litgpt.lora import GPT, merge_lora_weights\n\n\nclass LitLLM(L.LightningModule):\n    def __init__(self):\n        super().__init__()\n        self.model = GPT.from_name(\n            name=\"Llama-3.1-8B\",\n            lora_r=32,\n            lora_alpha=16,\n            lora_dropout=0.05,\n            lora_key=False,\n            lora_value=True,\n        )\n        litgpt.lora.mark_only_lora_as_trainable(self.model)\n\n    def on_train_start(self):\n        state_dict = torch.load(\"checkpoints/meta-llama/Meta-Llama-3.1-8B/lit_model.pth\", mmap=True)\n        self.model.load_state_dict(state_dict, strict=False)\n\n    def training_step(self, batch):\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n        logits = self.model(input_ids)\n        loss = litgpt.utils.chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])\n        self.log(\"train_loss\", loss, prog_bar=True)\n        return loss\n\n    def configure_optimizers(self):\n        warmup_steps = 10\n        optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))\n        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n        return [optimizer], [scheduler]\n\n\nif __name__ == \"__main__\":\n    data = Alpaca2k()\n    tokenizer = litgpt.Tokenizer(\"checkpoints/meta-llama/Meta-Llama-3.1-8B\")\n    data.connect(tokenizer, batch_size=1, max_seq_length=512)\n\n    trainer = L.Trainer(\n        devices=1,\n        max_epochs=2,\n        accumulate_grad_batches=8,\n        precision=\"bf16-true\",\n    )\n    with trainer.init_module(empty_init=True):\n        model = LitLLM()\n\n    trainer.fit(model, data)\n\n    # Save final checkpoint\n    merge_lora_weights(model.model)\n    trainer.save_checkpoint(\"checkpoints/finetuned.ckpt\", weights_only=True)\n"
  },
  {
    "path": "tutorials/examples/ptl-trainer/litgpt_ptl_small.py",
    "content": "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n\nimport lightning as L\nimport torch\n\nfrom litgpt import LLM\nfrom litgpt.data import Alpaca2k\n\n\nclass LitLLM(L.LightningModule):\n    def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None):\n        super().__init__()\n\n        self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None)\n        self.trainer_ckpt_path = trainer_ckpt_path\n\n    def setup(self, stage):\n        self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path)\n\n    def training_step(self, batch):\n        logits, loss = self.llm(input_ids=batch[\"input_ids\"], target_ids=batch[\"labels\"])\n        self.log(\"train_loss\", loss, prog_bar=True)\n        return loss\n\n    def validation_step(self, batch):\n        logits, loss = self.llm(input_ids=batch[\"input_ids\"], target_ids=batch[\"labels\"])\n        self.log(\"validation_loss\", loss, prog_bar=True)\n        return loss\n\n    def configure_optimizers(self):\n        warmup_steps = 10\n        optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))\n        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n        return [optimizer], [scheduler]\n\n\nif __name__ == \"__main__\":\n    batch_size = 8\n    accumulate_grad_batches = 1\n\n    #########################################################\n    # Use case 1: Pretraining from random weights\n    #########################################################\n\n    llm = LLM.load(\"EleutherAI/pythia-160m\", tokenizer_dir=\"EleutherAI/pythia-160m\", init=\"random\")\n    llm.save(\"pythia-160m-random-weights\")\n    del llm\n\n    lit_model = LitLLM(checkpoint_dir=\"pythia-160m-random-weights\", tokenizer_dir=\"EleutherAI/pythia-160m\")\n    data = Alpaca2k()\n\n    data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\n    trainer = L.Trainer(\n        devices=1,\n        accelerator=\"cuda\",\n        max_epochs=1,\n        accumulate_grad_batches=accumulate_grad_batches,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    lit_model.llm.generate(\"hello world\")\n\n    del lit_model\n\n    #############################################################################\n    # Use case 2: Continued pretraining / finetuning from downloaded checkpoint\n    #############################################################################\n\n    lit_model = LitLLM(checkpoint_dir=\"EleutherAI/pythia-160m\")\n    data = Alpaca2k()\n\n    data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\n    trainer = L.Trainer(\n        devices=1,\n        accelerator=\"cuda\",\n        max_epochs=1,\n        accumulate_grad_batches=accumulate_grad_batches,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    lit_model.llm.generate(\"hello world\")\n\n    del lit_model\n\n    #########################################################\n    # Use case 3: Resume training from Trainer checkpoint\n    #########################################################\n\n    import os\n\n    def find_latest_checkpoint(directory):\n        latest_checkpoint = None\n        latest_time = 0\n\n        for root, _, files in os.walk(directory):\n            for file in files:\n                if file.endswith(\".ckpt\"):\n                    file_path = os.path.join(root, file)\n                    file_time = os.path.getmtime(file_path)\n                    if file_time > latest_time:\n                        latest_time = file_time\n                        latest_checkpoint = file_path\n\n        return latest_checkpoint\n\n    lit_model = LitLLM(\n        checkpoint_dir=\"EleutherAI/pythia-160m\", trainer_ckpt_path=find_latest_checkpoint(\"lightning_logs\")\n    )\n\n    data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\n    trainer = L.Trainer(\n        devices=1,\n        accelerator=\"cuda\",\n        max_epochs=1,\n        accumulate_grad_batches=accumulate_grad_batches,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    lit_model.llm.generate(\"hello world\")\n\n    #################################################################\n    # Use case 4: Resume training after saving a checkpoint manually\n    #################################################################\n\n    lit_model.llm.save(\"finetuned_checkpoint\")\n    del lit_model\n    lit_model = LitLLM(checkpoint_dir=\"finetuned_checkpoint\")\n\n    data.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\n    trainer = L.Trainer(\n        devices=1,\n        accelerator=\"cuda\",\n        max_epochs=1,\n        accumulate_grad_batches=accumulate_grad_batches,\n        precision=\"bf16-true\",\n    )\n    trainer.fit(lit_model, data)\n\n    lit_model.llm.model.to(lit_model.llm.preprocessor.device)\n    lit_model.llm.generate(\"hello world\")\n"
  },
  {
    "path": "tutorials/finetune.md",
    "content": "# Finetuning\n\nWe provide a simple finetuning commands (`litgpt finetune_*`) that instruction-finetune a pretrained model on datasets such as [Alpaca](https://github.com/tatsu-lab/stanford_alpaca), [Dolly](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm), and others. For more information on the supported instruction datasets and how to prepare your own custom datasets, please see the [tutorials/prepare_dataset](prepare_dataset.md) tutorials.\n\nLitGPT currently supports the following finetuning methods:\n\n```bash\nlitgpt finetune_full\nlitgpt finetune_lora\nlitgpt finetune_adapter\nlitgpt finetune_adapter_v2\n```\n\n&nbsp;\n> [!TIP]\n> To install all required dependencies before finetuning, first run `pip install \"litgpt[all]\"`.\n&nbsp;\n\n\nThe following section provides more details about these methods, including links for additional resources.\n\n\n&nbsp;\n## LitGPT finetuning commands\n\nThe section below provides additional information on the available and links to further resources.\n\n&nbsp;\n### Full finetuning\n\n```bash\nlitgpt finetune_full\n```\n\nThis method trains all model weight parameters and is the most memory-intensive finetuning technique in LitGPT.\n\n**More information and resources:**\n\n- the LitGPT [tutorials/finetune_full](finetune_full.md) tutorial\n\n\n&nbsp;\n### LoRA and QLoRA finetuning\n\n```bash\nlitgpt finetune_lora stabilityai/stablelm-base-alpha-3b\n```\n\nLoRA and QLoRA are parameter-efficient finetuning technique that only require updating a small number of parameters, which makes this a more memory-efficienty alternative to full finetuning.\n\n**More information and resources:**\n\n- the LitGPT [tutorials/finetune_lora](finetune_lora.md) tutorial\n- the LoRA paper by ([Hu et al. 2021](https://arxiv.org/abs/2106.09685))\n- the conceptual tutorial [Parameter-Efficient LLM Finetuning With Low-Rank Adaptation (LoRA)](https://lightning.ai/pages/community/tutorial/lora-llm/)\n\n\n&nbsp;\n### Adapter finetuning\n\n```bash\nlitgpt finetune_adapter stabilityai/stablelm-base-alpha-3b\n```\n\nor\n\n```bash\nlitgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b\n```\n\nSimilar to LoRA, adapter finetuning is a parameter-efficient finetuning technique that only requires training a small subset of weight parameters, making this finetuning method more memory-efficient than full-parameter finetuning.\n\n**More information and resources:**\n\n- the LitGPT [tutorials/finetune_adapter](finetune_adapter.md) tutorial\n- the Llama-Adapter ([Gao et al. 2023](https://arxiv.org/abs/2304.15010)) and Llama-Adapter v2  ([Zhang et al. 2023](https://arxiv.org/abs/2303.16199)) papers that originally introduces these methods\n- the conceptual tutorial [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/)\n"
  },
  {
    "path": "tutorials/finetune_adapter.md",
    "content": "# Finetuning with Adapter\n\nAdapter, first introduced for the LLaMA model as [LLaMA-Adapter](https://arxiv.org/abs/2303.16199), is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in an LLM. In total, there are only ~500k parameters to update during finetuning in StableLM 3B, which significantly reduces the memory footprint and speeds up training.\n\nWe are able to demonstrate instruction-finetuning LitGPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3060 GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.\n\nIf you are new to Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.\n\nLLaMA-Adapter v2 extends the original LLaMA-Adapter idea by adding trainable bias and scale parameters to each linear layer in the transformer. Furthermore, LLaMA-Adapter v2 makes the normalization layers trainable. Where the StableLM 3B model has 500k trainable parameters with GPT v1, GPT-Adapter v2 adds an additional 1.5 M trainable parameter for the bias and scale parameters and ~300k trainable parameters for the normalization layers. So, adapter v2 has ~2.3 M trainable parameters in total.\n\n## Preparation\n\nThe steps here only need to be done once:\n\n1. Follow the instructions in the [README](../README.md) to install the dependencies.\n2. Download and convert the weights following our [guide](download_model_weights.md).\n\nLitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more.\nYou can optionally [prepare your own dataset](#tune-on-your-dataset).\nFor more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.\n\n## Running the finetuning\n\n```bash\nlitgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \\\n  --data Alpaca \\\n```\n\nor for Adapter V2\n\n```bash\nlitgpt finetune adapter_v2 stabilityai/stablelm-base-alpha-3b \\\n  --data Alpaca \\\n```\n\nThe finetuning requires at least one GPU with ~12 GB memory.\nYou can speed up training by passing the `devices` argument to the script to utilize more GPUs if available.\nDepending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.\nTo fit Adapter V2 to 12GB memory set `--train.micro_batch_size 2`.\n\nFor example, the following settings will let you finetune the model in under 1 hour:\n\n```bash\n--devices 4 --train.micro_batch_size 4\n```\n\nThis script will save checkpoints periodically to the `out_dir` directory. If you are finetuning different models or on your own dataset, you can specify an output directory with your preferred name:\n\n```bash\nlitgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \\\n  --data Alpaca \\\n  --out_dir out/adapter/my-model-finetuned\n```\n\nor for Adapter V2\n\n```bash\nlitgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b \\\n  --data Alpaca \\\n  --out_dir out/adapter_v2/my-model-finetuned\n```\n\nIf your GPU does not support `bfloat16`, you can pass the `--precision 32-true` argument.\nFor instance, to fine-tune on MPS (the GPU on modern Macs), you can run\n\n```bash\nlitgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \\\n  --data Alpaca \\\n  --out_dir out/adapter/my-model-finetuned \\\n  --precision 32-true\n```\n\nNote that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac.\n\n### Quantization\n\nOptionally, finetuning using quantization can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:\n\n```bash\nlitgpt finetune_adapter stabilityai/stablelm-base-alpha-3b \\\n  --quantize \"bnb.nf4\"\n```\n\nor using `adapter_v2` with double-quantization:\n\n```bash\nlitgpt finetune_adapter_v2 stabilityai/stablelm-base-alpha-3b \\\n  --quantize \"bnb.nf4-dq\"\n```\n\nFor additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md).\n\n## Test the model\n\nYou can test the finetuned model with your own instructions by running:\n\n```bash\nlitgpt generate_adapter stabilityai/stablelm-base-alpha-3b \\\n    --prompt \"Recommend a movie to watch on the weekend.\"\n```\n\nor for Adapter V2\n\n```bash\nlitgpt generate_adapter_v2 stabilityai/stablelm-base-alpha-3b \\\n    --prompt \"Recommend a movie to watch on the weekend.\"\n\n```\n\nOutput:\n\n```text\nA good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...\n```\n\nIf your GPU supports `bfloat16`, the script will automatically use it.\n\n## Tune on your dataset\n\nYou can easily train on your own instruction dataset saved in JSON format.\n\n1. Create a JSON file in which each row holds one instruction-response pair.\n   A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context.\n   Below is an example json file:\n\n    ```text\n    [\n        {\n            \"instruction\": \"Arrange the given numbers in ascending order.\",\n            \"input\": \"2, 4, 0, 8, 3\", // Optional: only used in Alpaca chat template\n            \"output\": \"0, 2, 3, 4, 8\"\n        },\n        ...\n    ]\n    ```\n\n2. Run `litgpt adapter` or `litgpt adapter_v2` by passing in the location of your data (and optionally other parameters):\n\n    ```bash\n    litgpt finetune_adapter tiiuae/falcon-7b \\\n        --data JSON \\\n        --data.json_path data/mydata.json \\\n        --out_dir data/mydata-finetuned\n    ```\n"
  },
  {
    "path": "tutorials/finetune_full.md",
    "content": "# Finetuning the whole model\n\nIf you are interested in parameter-efficient finetuning, check out [finetune_adapter.md](finetune_adapter.md). In contrast to parameter-efficient finetuning, this \"full\" approach finetunes all model parameters, which is substantially more expensive. It may only be recommended as a baseline for comparison studies.\n\n## Preparation\n\nThe steps here only need to be done once:\n\n1. Follow the instructions in the [README](../README.md) to install the dependencies.\n2. Download and convert the weights following our [guide](download_model_weights.md).\n\nLitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more.\nYou can optionally [prepare your own dataset](#tune-on-your-dataset).\nFor more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.\n\n## Running the finetuning\n\n```bash\nlitgpt finetune_full tiiuae/falcon-7b \\\n  --data Alpaca \\\n```\n\nFinetuning the falcon-7b model requires at least 8 GPUs with ~40 GB memory each.\n\nYou can speed up training by passing the `devices` argument to the script to utilize more GPUs if available.\nDepending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.\n\nThis script will save checkpoints periodically to the `out_dir` directory. If you are finetuning different models or on your own dataset, you can specify an output directory with your preferred name:\n\n```bash\nlitgpt finetune_full tiiuae/falcon-7b \\\n  --data Alpaca \\\n  --out_dir out/full/my-model-finetuned\n```\n\nIf your GPU does not support `bfloat16`, you can pass the `--precision 32-true` argument.\nFor instance, to fine-tune on MPS (the GPU on modern Macs), you can run\n\n```bash\nlitgpt finetune_full tiiuae/falcon-7b \\\n  --data Alpaca \\\n  --out_dir out/full/my-model-finetuned \\\n  --precision 32-true\n```\n\nNote that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac.\n\n## Test the model\n\nYou can test the finetuned model with your own instructions by running:\n\n```bash\nlitgpt generate tiiuae/falcon-7b \\\n    --prompt \"Recommend a movie to watch on the weekend.\" \\\n    --finetuned_path out/full/my-model-finetuned/lit_model_finetuned.pth\n```\n\nOutput:\n\n```text\nA good movie to watch on the weekend would be The Lion King, since it's a classic family film that everyone can enjoy...\n```\n\nIf your GPU supports `bfloat16`, the script will automatically use it.\n\n## Tune on your dataset\n\nYou can easily train on your own instruction dataset saved in JSON format.\n\n1. Create a JSON file in which each row holds one instruction-response pair.\n   A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context.\n   Below is an example json file:\n\n    ```text\n    [\n        {\n            \"instruction\": \"Arrange the given numbers in ascending order.\",\n            \"input\": \"2, 4, 0, 8, 3\", // Optional: only used in Alpaca chat template\n            \"output\": \"0, 2, 3, 4, 8\"\n        },\n        ...\n    ]\n    ```\n\n2. Run `litgpt finetune` by passing in the location of your data (and optionally other parameters):\n\n    ```bash\n    litgpt finetune tiiuae/falcon-7b \\\n        --data JSON \\\n        --data.json_path data/mydata.json \\\n        --out_dir data/mydata-finetuned\n    ```\n"
  },
  {
    "path": "tutorials/finetune_lora.md",
    "content": "# Finetuning with LoRA / QLoRA\n\n[Low-rank adaption (LoRA)](https://arxiv.org/abs/2106.09685) is a technique to approximate the update to the linear layers in a LLM with a low-rank matrix factorization. This significantly reduces the number of trainable parameters and speeds up training with little impact on the final performance of the model.\nWe demonstrate this method by instruction-finetuning LitGPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3090 (24GB) GPU** with CUDA 11.8.\n\n&nbsp;\n\n## Preparation\n\nThe steps here only need to be done once:\n\n1. Follow the instructions in the [README](../README.md) to install the dependencies.\n2. Download and convert the weights and save them in the `./checkpoints` folder.\n   Weights can be downloaded following the instructions in the [download_model_weights](download_model_weights.md) documentation:\n\nLitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and more.\nYou can optionally [prepare your own dataset](#tune-on-your-dataset).\nFor more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.\n\n&nbsp;\n\n## Running the Finetuning\n\n```bash\nlitgpt finetune_lora stabilityai/stablelm-base-alpha-3b \\\n  --data Alpaca\n```\n\nThe finetuning requires at least one GPU with ~24 GB memory (RTX 3090).\n\nThis script will save checkpoints periodically to the folder `out/`.\n\n> [!NOTE]\n> LoRA can be applied to not only `query`, `key` or `value` matrices, but also to `projection`, `mlp` and classification `head`.\n> According to [QLoRA](https://arxiv.org/abs/2305.14314) paper (section 4): \"LoRA on all linear transformer block layers are required to match full finetuning performance\".\n> By default LoRA is applied only to the `query` and `value` matrices. In order to apply LoRA to other weight matrices - change the arguments to `litgpt/finetune/lora.py` accordingly.\n\nOptionally, finetuning using 4-bit quantization (as in QLoRA) can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:\n\n```bash\nlitgpt finetune_lora stabilityai/stablelm-base-alpha-3b \\\n  --quantize \"bnb.nf4\"\n```\n\nand optionally with double-quantization:\n\n```bash\nlitgpt finetune_lora stabilityai/stablelm-base-alpha-3b \\\n  --quantize \"bnb.nf4-dq\"\n```\n\nThe table below lists a comparison with different settings on a StableLM 3B model finetuned with LoRA on Alpaca for 1,000 iterations using a microbatch size of 1:\n\n| Settings                                    | Training Memory | Training Time |  Inference Memory |\n|---------------------------------------------|-----------------|---------------|-------------------|\n| Default (bf16-mixed)                        | 26.92 GB        | 1.34 min      | 21.43 GB          |\n| --precision bf16-true                       | 9.69 GB         | 1.24 min      | 7.30 GB           |\n| --precision bf16-true --quantize bnb.nf4    | 6.35 GB         | 1.82 min      | 3.20 GB           |\n| --precision bf16-true --quantize bnb.nf4-dq | 6.19 GB         | 1.87 min      | 3.04 GB           |\n\nThe advantages of QLoRA-style quantization are more pronounced in larger models, such as Llama 2 7B. The table below summarizes the results for Llama 2 7B on Alpaca for 1,000 iterations using a microbatch size of 1:\n\n| Settings                                    | Training Memory  | Training Time | Inference Memory |\n|---------------------------------------------|------------------|---------------|------------------|\n| Default (bf16-mixed)                        | OutOfMemoryError | N/A           | 40.21 GB         |\n| --precision bf16-true                       | 21.30 GB         | 2.36 min      | 13.52 GB         |\n| --precision bf16-true --quantize bnb.nf4    | 14.14 GB         | 3.68 min      | 4.57 GB          |\n| --precision bf16-true --quantize bnb.nf4-dq | 13.84 GB         | 3.83 min      | 4.26 GB          |\n\nFor additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md).\n\n&nbsp;\n\n## Test the Model\n\nYou can test the finetuned model with your own instructions by running:\n\n```bash\nlitgpt generate \"out/lora/final\" \\\n  --prompt \"Recommend a movie to watch on the weekend.\"\n```\n\nOutput:\n\n```text\nI would recommend the movie The Martian (2015). It is a sci-fi movie starring Matt Damon that follows the story of...\n```\n\nIf your GPU supports `bfloat16`, you can additionally pass `--precision \"bf16-true\"` to bring the memory consumption down to ~7.6 GB for StableLM-3B (versus ~15.2  GB for `--precision \"32-full\"`). In addition, you may use quantization methods, for example `--precision \"bf16-true\" --quantize \"bnb.nf4\"` brings the memory consumption further down to ~4.4 GB for StableLM-3B.\n\n&nbsp;\n\n## Tune on Your Dataset\n\nYou can easily train on your own instruction dataset saved in JSON format.\n\n1. Create a JSON file in which each row holds one instruction-response pair.\n   A row has an entry for 'instruction' and 'output', and optionally 'input'. Note that currently, the 'input' field is only used in the Alpaca chat template. If you are using the Alpaca template, 'input' can be the empty string if the instruction doesn't require a context.\n   Below is an example json file:\n\n    ```text\n    [\n        {\n            \"instruction\": \"Arrange the given numbers in ascending order.\",\n            \"input\": \"2, 4, 0, 8, 3\", // Optional: only used in Alpaca chat template\n            \"output\": \"0, 2, 3, 4, 8\"\n        },\n        ...\n    ]\n    ```\n\n2. Run `litgpt finetune_lora` by passing in the location of your data (and optionally other parameters):\n\n    ```bash\n    litgpt finetune_lora checkpoints/stabilityai/stablelm-base-alpha-3b \\\n        --data JSON \\\n        --data.json_path data/mydata.json \\\n        --out_dir out_dir/mydata-finetuned\n    ```\n\n3. Test and use the finetuned model:\n\n    ```bash\n    litgpt chat out_dir/mydata-finetuned/final\n    ```\n\nor\n\n    ```bash\n    litgpt serve out_dir/mydata-finetuned/final\n    ```\n\n\n\n&nbsp;\n\n## Merging LoRA Weights (Optional)\n\nFinetuning a model with LoRA generates a `lit_model.pth.lora` file.\nThis file exclusively contains the LoRA weights, which are much smaller than the original model checkpoint to conserve storage space.\n\n> [!NOTE]\n> LitGPT will automatically merge the checkpoint for you if you use it in any of the inference commands, such as `litgpt generate` or `litgpt chat`.\n> Manual merging is only necessary if you want to use the checkpoint outside LitGPT.\n\nIf desired, there is the option to merge these LoRA weights manually into the original model's checkpoint, which creates a full `lit_model.pth` checkpoint.\nThe advantage of this merging process is to streamline inference operations, as it eliminates the need to dynamically incorporate the LoRA weights during runtime, which can improve inference speed.\n\nFor example, after finetuning produced a checkpoint folder `out/lora/step-002000`, merge it as follows:\n\n```bash\nlitgpt merge_lora \"out/lora/step-002000\"\n```\nThe command above creates a full `lit_model.pth` checkpoint file.\n"
  },
  {
    "path": "tutorials/full_finetune_example.py",
    "content": "\"\"\"\nThis script is meant to be the simplest possible starting point for full finetuning a GPT model using lightning fabric with code (not CLI).\n\n- no checkpoints\n- no out dir\n- no precision\n- no resume\n- no train/eval args (or any args in general)\n- no logger (only to terminal)\n- no grad accumulation\nand no other fancy stuff.\n\nTo add all the above stuff, you can slowly add them in yourself by looking at the code in litgpt/finetune/full.py or the docs for litgpt/fabric.\n\"\"\"\n\nimport os\n\nimport lightning as L\nimport torch\nimport torch.nn as nn\n\nfrom litgpt.data import Alpaca\nfrom litgpt.model import GPT, Config\nfrom litgpt.tokenizer import Tokenizer\nfrom litgpt.utils import num_parameters\n\n# training params/args\nSEED = 1337\nMODEL_NAME = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"  # try also \"stabilityai/stablelm-base-alpha-3b\"!\nBATCH_SIZE = 4\nLR_WARMUP_STEPS = 100\nMAX_STEPS = 601\n\n\ndef validate(model, val_dataloader):\n    model.eval()\n    loss = 0\n    with torch.no_grad():\n        for batch in val_dataloader:\n            input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n            logits = model(input_ids)\n            logits = logits.reshape(-1, logits.size(-1))\n            targets = targets.reshape(-1)\n            loss += nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])\n    fabric.print(f\"Validation loss: {loss / len(val_dataloader)}\")\n\n\ndef train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader):\n    for iter_num, batch in enumerate(train_dataloader):\n        input_ids, targets = batch[\"input_ids\"], batch[\"labels\"]\n\n        # get model preds (logits)\n        logits = model(input_ids)\n        logits = logits.reshape(-1, logits.size(-1))\n\n        # get loss\n        targets = targets.reshape(-1)\n        loss = nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])\n\n        # update weights\n        fabric.backward(loss)\n        optimizer.step()\n        optimizer.zero_grad()\n        scheduler.step()\n\n        # print train loss every 100 steps\n        if iter_num % 100 == 0 or iter_num == 0:\n            fabric.print(f\"Train iter {iter_num} -  loss {loss}\")\n\n        # validate every 300 steps\n        if iter_num % 300 == 0 or iter_num == 0:\n            validate(model, val_dataloader)\n            model.train()\n        iter_num += 1\n\n        if iter_num >= MAX_STEPS:\n            break\n\n\ndef main(fabric):\n    fabric.seed_everything(SEED)\n\n    # setup data, make tokenizer and make dataloaders\n    data = Alpaca()\n    tokenizer = Tokenizer(checkpoint_dir=f\"checkpoints/{MODEL_NAME}\")\n    data.connect(tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_length=1024)\n    data.setup()\n    train_dataloader = data.train_dataloader()\n    val_dataloader = data.val_dataloader()\n    train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)\n\n    # print how many steps in an epoch\n    fabric.print(f\"Steps in an epoch: {len(train_dataloader)}\")\n\n    # setup model\n    config = Config.from_file(f\"checkpoints/{MODEL_NAME}/model_config.yaml\")\n    model = GPT(config)\n    fabric.print(f\"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}\")\n    model = fabric.setup(model)\n\n    # setup optimizer\n    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.02, betas=(0.9, 0.95))\n    optimizer = fabric.setup_optimizers(optimizer)\n\n    # setup lr scheduler\n    scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / LR_WARMUP_STEPS)\n    scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(MAX_STEPS - LR_WARMUP_STEPS))\n    scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[LR_WARMUP_STEPS])\n\n    # Start training!!!\n    train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader)\n\n\nif __name__ == \"__main__\":\n    # check that the model exists (downloaded to ./checkpoints/)\n    if not os.path.exists(f\"checkpoints/{MODEL_NAME}\"):\n        print(f\"Model {MODEL_NAME} not found. Please download it using `litgpt download --repo {MODEL_NAME}`\")\n        exit()\n\n    ### Setup and launch\n    fabric = L.Fabric(devices=\"auto\", strategy=\"auto\")\n    fabric.launch(main)\n"
  },
  {
    "path": "tutorials/inference.md",
    "content": "# Inference\n\nWe demonstrate how to run inference (next token prediction) with the GPT base model in the [`litgpt generate`](../litgpt/generate/base.py) command:\n\n```bash\nlitgpt generate stabilityai/stablelm-base-alpha-3b \\\n  --prompt \"Hello, my name is\"\n```\n\nOutput:\n\n```text\nHello, my name is Levi Durrer, I'm an Austrian journalist - Chairman of the Press Blair Party, with 37 years in the Press Blair International, and two years in the Spectre of Austerity for the other. I'm crossing my fingers that you will feel\n```\n\nThe script assumes you have downloaded and converted the weights as described [here](download_model_weights.md).\n\nThis will run the 3B pre-trained model and require ~7 GB of GPU memory using the `bfloat16` datatype.\n\n## Run interactively\n\nYou can also chat with the model interactively:\n\n```bash\nlitgpt chat stabilityai/stablelm-tuned-alpha-3b\n```\n\nThis script can work with any checkpoint. For the best chat-like experience, we recommend using it with a checkpoints\nfine-tuned for chatting such as `stabilityai/stablelm-tuned-alpha-3b` or `togethercomputer/RedPajama-INCITE-Chat-3B-v1`.\n\n> [!TIP]\n> Use `--multiline true` to work with inputs that span multiple lines.\n\n\n## Run a large model on one smaller device\n\nCheck out our [quantization tutorial](quantize.md).\n\n## Run a large model on multiple smaller devices\n\nWe offer two scripts to leverage multiple devices for inference.\n\n### [`litgpt generate_sequentially`](../litgpt/generate/sequentially.py)\n\nAllows you to run models that wouldn't fit in a single card by partitioning the transformer blocks across all your devices and running them sequentially.\n\nFor instance, `meta-llama/Llama-2-70b-chat-hf` would require ~140 GB of GPU memory to load on a single device, plus the memory for activations.\nWith 80 transformer layers, we could partition them across 8, 5, 4, or 2 devices.\n\n```shell\nlitgpt generate_sequentially meta-llama/Llama-2-70b-chat-hf \\\n  --max_new_tokens 256 \\\n  --num_samples 2\n```\n\nUsing A100 40GB GPUs, we need to use at least 4. You can control the number of devices by setting the `CUDA_VISIBLE_DEVICES=` environment variable.\n\n| Devices | Max GPU RAM | Token/sec |\n|---------|-------------|-----------|\n| 2       | OOM         | -         |\n| 4       | 35.64 GB    | 7.55      |\n| 5       | 28.72 GB    | 7.49      |\n| 8       | 18.35 GB    | 7.47      |\n\nNote that the memory usage will also depend on the `max_new_tokens` value used.\n\nThe script also supports quantization, using 4-bit precision, we can now use 2 GPUs\n\n```shell\nlitgpt generate_sequentially meta-llama/Llama-2-70b-chat-hf \\\n  --max_new_tokens 256 \\\n  --num_samples 2 \\\n  --quantize bnb.nf4-dq\n```\n\n| Devices | Max GPU RAM | Token/sec |\n|---------|-------------|-----------|\n| 2       | 20.00 GB    | 8.63      |\n| 4       | 10.80 GB    | 8.23      |\n| 5       | 8.96 GB     | 8.10      |\n| 8       | 6.23 GB     | 8.18      |\n\nSmaller devices can also be used to run inference with this technique.\n\n### [`litgpt generate_tp`](../litgpt/generate/tp.py)\n\nUses tensor parallelism (TP) to run models that wouldn't fit in a single card by sharding the MLP and Attention QKV linear layers across all your devices.\n\nFor instance, `meta-llama/Llama-2-70b-chat-hf` would require ~140 GB of GPU memory to load on a single device, plus the memory for activations.\nThe requirement is that the intermediate size (for the MLP) and the QKV size (for attention) is divisible by the number of devices.\nWith an intermediate size of 28672, we can use 2, 4, 7, or 8 devices. With a QKV size of 10240 we can use 2, 4, 5, or 8 devices.\nSince the script is configured to shard both, the intersection is used: we can only use 2, 4, or 8 devices.\n\n```shell\nlitgpt generate_tp meta-llama/Llama-2-70b-chat-hf \\\n  --max_new_tokens 256 \\\n  --num_samples 2\n```\n\nUsing A100 40GB GPUs, we need to use at least 4. You can control the number of devices by setting the `CUDA_VISIBLE_DEVICES=` environment variable.\n\n| Devices | Max GPU RAM | Token/sec |\n|---------|-------------|-----------|\n| 2       | OOM         | -         |\n| 4       | 35.46 GB    | 9.33      |\n| 8       | 18.19 GB    | 8.61      |\n\nNote that the memory usage will also depend on the `max_new_tokens` value used.\n\nThe script also supports quantization, using 4-bit precision, we can now use 2 GPUs\n\n```shell\nlitgpt generate_tp meta-llama/Llama-2-70b-chat-hf \\\n  --max_new_tokens 256 \\\n  --num_samples 2 \\\n  --quantize bnb.nf4-dq\n```\n\n| Devices | Max GPU RAM | Token/sec |\n|---------|-------------|-----------|\n| 2       | 19.79 GB    | 6.72      |\n| 4       | 10.73 GB    | 6.48      |\n| 8       | 6.15 GB     | 6.20      |\n\nSmaller devices can also be used to run inference with this technique.\n"
  },
  {
    "path": "tutorials/mkdocs.yml",
    "content": "site_name: LitGPT Tutorials\n\nplugins:\n  - pagetree\n\ntheme:\n  name: material\n"
  },
  {
    "path": "tutorials/oom.md",
    "content": "## Dealing with out-of-memory (OOM) errors\n\nIf you got this error while running a script\n\n```bash\nOutOfMemoryError: CUDA out of memory. Tried to allocate 2.22 GiB. GPU 0 has a total capacity of 79.15 GiB of which 228.38 MiB is free. Including non-PyTorch memory, this process\nhas 78.93 GiB memory in use. Of the allocated memory 76.28 GiB is allocated by PyTorch, and 2.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory\nis large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n```\n\nit means that your GPU memory size wasn't big enough for the model and script configuration.\n\nHere's a few things you can try:\n\n### Reduce the micro batch size\n\nAdjust the `--train.micro_batch_size` argument in the fine-tuning and pretraining scripts. This variable determines the number of samples loaded per iteration.\n\nA smaller value will simply load fewer samples simultaneously. The minimum value is 1.\n\nExperiment with different micro batch sizes to find a balance between memory consumption and computational efficiency. Smaller micro batch sizes consume less memory but may result in slower training convergence. Conversely, larger micro batch sizes require more memory but can accelerate training speed.\n\n### Reduce the model's context length\n\nThe context length (`block_size` in the code) plays a significant role in running models with attention.\n\n* The pretraining scripts are configured to use the full context length of the model to train.\n* The finetuning scripts are configured to use the longest sample length of the training data to avoid allocating unnecessary memory (`--train.max_seq_length` argument).\n  If that's longer than the model's context length, an error is raised. If you try to run a batch that is longer than this, an error is raised.\n\nHowever, your hardware may not support such large context lengths. Here's what you can do:\n\n* For the pretraining scripts, you can simply reduce the `Config(block_size=...)` value.\n* For the finetuning scripts, you can trim the length of the samples in your dataset.\n  All the finetuning scripts expose a `--data.max_seq_length=...` argument. This might also be useful in cases where\n  sample lengths are highly unbalanced, as the presence of a single very long sample would incur a larger memory usage for all other\n  shorter samples. For example, the median length of the samples in Alpaca is 110 tokens. Truncating the Alpaca dataset to 256 max tokens reduces the memory requirements of a Falcon 7B model from 23.52 GB to 15.73 GB. For more information about the dataset truncation, please see the *Truncating datasets* section in the [prepare_dataset.md](prepare_dataset.md) tutorial.\n\nKeep in mind that reducing the context length will affect the modelling performance on text sequences longer than the limit.\n\n### Use lower precision\n\nOur scripts expose the `--precision` argument, this directly impacts the memory usage.\n\nUsing true lower precision (`16-true`, `bf16-true`) reduces the memory usage by half compared to `32-true`, however,\nthe model might start producing NaNs due to the limited range of representable values.\n\nMixed precision training (`16-mixed`, `bf16-mixed`) provides better stability but offers limited memory reduction.\n\n### Do sharding across multiple GPUs\n\nFor exceptionally large models, the aforementioned techniques might still not suffice. If you have multiple GPUs available,\nyou can trade off memory for speed by changing the `--devices 1` argument in the scripts. Enabling this option enables a parallelism technique (FSDP), sharding the memory across different GPUs.\n\nThe default configuration already uses activation checkpointing, but you can enable CPU offloading by changing the `cpu_offload=False` argument in the scripts.\n\n### Try a different optimizer\n\nOur scripts use the [`AdamW` optimizer](https://pytorch.org/docs/main/generated/torch.optim.AdamW.html).\nIt maintains 2 states for each trainable parameter of the model, meaning that the optimizer memory is double compared to\nan optimizer like [`SGD`](https://pytorch.org/docs/main/generated/torch.optim.SGD.html).\n\nYou can try replacing it with your optimizer of choice that is lighter in memory requirements. Keep in mind that different optimizers have distinct optimization behaviors, so it's essential to assess their impact on the training process and model performance.\nAn example would be the recently published [Sophia](https://arxiv.org/abs/2305.14342) or [Lion](https://arxiv.org/abs/2302.06675) optimizers.\n\nThis suggestion is particularly relevant for pretraining, as the trainable parameters in the model represent a small\nsubset of the total in the fine-tuning scripts.\n"
  },
  {
    "path": "tutorials/prepare_dataset.md",
    "content": "# Preparing Datasets\n\nBelow is a table of all datasets that are currently supported in LitGPT:\n\n| Name         | Task        | Size                | Reference Repo                                                                       | Paper / Blog                                                                                                              | Data License                                                                                                                                                                                                     |\n|--------------|-------------|---------------------|--------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| Alpaca       | Finetuning  | 51,759 samples      | [URL](https://github.com/tatsu-lab/stanford_alpaca)                                  | [URL](https://crfm.stanford.edu/2023/03/13/alpaca.html)                                                                   | Attribution-NonCommercial 4.0 International, [URL](https://crfm.stanford.edu/2023/03/13/alpaca.html)                                                                                                             |\n| Alpaca-2k    | Finetuning  | 2000 samples        | [URL](https://huggingface.co/datasets/mhenrichsen/alpaca_2k_test)                    | See Alpaca above                                                                                                          | See Alpaca Above                                                                                                                                                                                                 |\n| Alpaca-GPT4  | Finetuning  | 52,002 samples      | [URL](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)                    | [URL](https://arxiv.org/abs/2304.03277)                                                                                   | Attribution-NonCommercial 4.0 International, [URL](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/blob/main/DATA_LICENSE)                                                                           |\n| Alpaca Libre | Finetuning  | 55,370 samples      | [URL](https://github.com/mobarski/alpaca-libre)                                      | -                                                                                                                         | CC0/MIT,  [URL](https://github.com/mobarski/alpaca-libre)                                                                                                                                                        |\n| Deita        | Finetuning  | 9,500 samples       | [URL](https://huggingface.co/datasets/HuggingFaceH4/deita-10k-v0-sft/tree/main/data) | [URL](https://arxiv.org/abs/2312.15685)                                                                                   | MIT [URL](https://huggingface.co/datasets/hkust-nlp/deita-10k-v0/blob/main/README.md)                                                                                                                            |\n| Dolly        | Finetuning  | 15,011 samples      | [URL](https://github.com/databrickslabs/dolly/tree/master/data)                      | [URL](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm)              | CC-BY-SA, [URL](https://github.com/databrickslabs/dolly#model-overview)                                                                                                                                          |\n| FLAN         | Finetuning  | 1,753,240 samples   | [UR](https://huggingface.co/datasets/Muennighoff/flan)                               | [URL](https://blog.research.google/2023/02/the-flan-collection-advancing-open.html)                                       | Subset dependent                                                                                                                                                                                                 |\n| LongForm     | Finetuning  | 23,652 samples      | [URL](https://github.com/akoksal/LongForm)                                           | [URL](https://arxiv.org/abs/2304.08460)                                                                                   | No information provided and subset-dependent, [URL](https://github.com/akoksal/LongForm)                                                                                                                         |\n| LIMA         | Finetuning  | 1,084 samples       | [URL](https://huggingface.co/datasets/GAIR/lima)                                     | [URL](https://arxiv.org/abs/2305.11206)                                                                                   | \"If the source data of LIMA has a stricter license than CC BY-NC-SA, the LIMA dataset follows the same. Otherwise, it follows the CC BY-NC-SA license\", [URL](https://huggingface.co/datasets/GAIR/lima#license) |\n| OpenWeb Text | Pretraining | 8,013,769 documents | [URL](https://github.com/jcpeterson/openwebtext)                                     | [URL](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) | Unspecified                                                                                                                                                                                                      |\n| TinyLlama    | Pretraining | 1 T tokens          | [URL](https://github.com/jzhang38/TinyLlama)                                         | [URL](https://arxiv.org/abs/2401.02385)                                                                                   |                                                                                                                                                                                                                  |\n| TinyStories  | Pretraining | 4,967,871 stories   | [URL](https://huggingface.co/datasets/roneneldan/TinyStories)                        | [URL](https://arxiv.org/abs/2305.07759)                                                                                   | CDLA-Sharing-1.0                                                                                                                                                                                                 |\n\n&nbsp;\n\n## Preparation\n\nThe steps here only need to be done once before preparing the finetuning datasets in the following subsections:\n\n1. Follow the instructions in the [README](../README.md) to install the dependencies.\n2. Download and convert the weights following our [guide](download_model_weights.md).\n\nFor the following examples, we will focus on finetuning with the `litgpt finetune_lora` command and use a Falcon 7B model.\nHowever, the same steps apply to all other models and finetuning scripts.\nPlease read the [tutorials/finetune_*.md](.) documents for more information about finetuning models.\n\n&nbsp;\n\n> [!IMPORTANT]\n> By default, the maximum sequence length is obtained from the model configuration file. In case you run into out-of-memory errors, especially in the cases of LIMA and Dolly,\n> you can try to lower the context length by setting the `--train.max_seq_length` parameter, for example, `litgpt finetune lora --train.max_seq_length 256`. For more information on truncating datasets, see the *Truncating datasets* section in the Alpaca section near the top of this article.\n\n&nbsp;\n\n### Alpaca\n\nThe Alpaca dataset consists of 52,000 instructions and demonstrations produced by OpenAI's text-davinci-003 engine. This data is used in instruction-tuning, helping improve the performance of language models to follow instructions.\n\nIn its development, the creators leveraged the data generation methodology from the [Self-Instruct framework](https://github.com/yizhongw/self-instruct).\n\nThe original [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) dataset can be used as follows:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Alpaca\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt finetune --data.help Alpaca` to list additional dataset-specific command line options.\n\n&nbsp;\n\n#### Truncating datasets\n\nBy default, the finetuning scripts will determine the size of the longest tokenized sample in the dataset to determine the block size. However, if you are willing to truncate a few examples in the training set, you can reduce the computational resource requirements significantly. For instance you can set a sequence length threshold via `--train.max_seq_length`. We can determine an appropriate maximum sequence length by considering the distribution of the data sample lengths shown in the histogram below.\n\n<img src=\"images/prepare_dataset/alpaca.jpg\" width=400px>\n\nIn this case, a cut-off of 256 may be a reasonable choice:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Alpaca \\\n  --train.max_seq_length 256\n```\n\nFor comparison, the Falcon 7B model requires 23.52 GB of memory for the original Alpaca dataset and 15.73 GB of memory for the truncated Alpaca dataset when finetuning with LoRA using a micro batchsize of 1 and bfloat-16 precision.\n\n&nbsp;\n\n### Alpaca-2k\n\n[Alpaca-2k](https://huggingface.co/datasets/mhenrichsen/alpaca_2k_test) is a smaller, 2000-sample subset of Alpaca described above.\n\n```bash\nlitgpt finetune_lora \"tiiuae/falcon-7b\" \\\n  --data Alpaca2k\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt_finetune --data.help Alpaca2k` to list additional dataset-specific command line options.\n\n&nbsp;\n\nThe Alpaca-2k dataset distribution is shown below.\n\n<img src=\"images/prepare_dataset/alpaca-2k.jpg\" width=400px>\n\n\n### Alpaca-GPT4\n\nThe Alpaca-GPT4 was built by using the prompts of the original Alpaca dataset and generate the responses via GPT 4. The\ndataset consists of 52,000 instructions and responses.\n\nThe original [Alpaca-GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset can be used as follows:\n\n```bash\nlitgpt finetune lora \"tiiuae/falcon-7b\" \\\n  --data AlpacaGPT4\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt_finetune --data.help AlpacaGPT4` to list additional dataset-specific command line options.\n\n&nbsp;\n\nThe Alpaca-GPT4 dataset distribution is shown below.\n\n<img src=\"images/prepare_dataset/alpacagpt4.jpg\" width=400px>\n\n&nbsp;\n\n### Alpaca Libre\n\n[Alpaca Libre](https://github.com/mobarski/alpaca-libre) is a reimplementation or alternative to Alpaca using the same formatting.\n\nTo use Alpaca Libre instead of the original Alpaca dataset, use the following command:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Alpaca \\\n  --data.file_url \"https://raw.githubusercontent.com/mobarski/alpaca-libre/main/data/output/alpaca_libre_ok_tasks_v4.json\" \\\n  --data.file_name \"alpaca_libre_data_cleaned_archive.json\"\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt finetune --data.help Alpaca` to list additional dataset-specific command line options.\n\n&nbsp;\n\nThe Alpaca Libre dataset distribution is shown below.\n\n<img src=\"images/prepare_dataset/alpaca_libre.jpg\" width=400px>\n\nYou may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 256 may be a good choice:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Alpaca \\\n  --data.file_url \"https://raw.githubusercontent.com/mobarski/alpaca-libre/main/data/output/alpaca_libre_ok_tasks_v4.json\" \\\n  --data.file_name \"alpaca_libre_data_cleaned_archive.json\" \\\n  --train.max_seq_length 256\n```\n\n&nbsp;\n\n### Deita\n\nThe Deita dataset (short for Data-Efficient Instruction Tuning for Alignment) is a collection of 9500 prompts and responses, as described in the [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) paper.\nUsing Falcon 7b as an example, we can use the dataset as follows:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Deita\n```\n\n&nbsp;\n\n\n> [!TIP]\n> Use `litgpt finetune --data.help Deita` to list additional dataset-specific command line options.\n\n&nbsp;\n\nDeita contains multiturn conversations. By default, only the first instruction-response pairs from\neach of these multiturn conversations are included. If you want to override this behavior and include the follow-up instructions\nand responses, set `--data.include_multiturn_conversations True`, which will include all multiturn conversations as regular\nprompt-response pairs. Considering the multiturn-answers, the dataset consists of 209,272 prompt-response pairs.\n\nThe Deita dataset distribution without including multit-turn conversations is shown below.\n\n<img src=\"images/prepare_dataset/deita.jpg\" width=400px>\n\nThe Deita dataset distribution including multit-turn conversations is depicted in the following histogram.\n\n<img src=\"images/prepare_dataset/deita-multiturn.jpg\" width=400px>\n\nYou may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Deita \\\n  --train.max_seq_length 512\n```\n\n&nbsp;\n\n### Dolly\n\nThe Dolly dataset is a publicly available collection of 15k instruction-following entries created by Databricks. It spans multiple behavioral domains, as described in the [InstructGPT paper](https://arxiv.org/abs/2203.02155) paper. These include areas like brainstorming, classification, closed QA, content creation, information retrieval, open QA, and summary generation.\n\nThe usage is similar to the Alpaca dataset described above. Using Falcon 7b as an example, we can use the dataset as follows:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Dolly\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt finetune --data.help Dolly` to list additional dataset-specific command line options.\n\n&nbsp;\n\nThe Dolly dataset distribution is shown below.\n\n<img src=\"images/prepare_dataset/dolly.jpg\" width=400px>\n\nYou may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data Dolly \\\n  --train.max_seq_length 256\n```\n\n&nbsp;\n\n### LongForm\n\nLongForm is a semi-synthetic dataset based on raw text corpora for which the instructions were generated via an LLM. For more details about the instruction-generation process, please refer to the [LongForm research paper](https://arxiv.org/abs/2304.08460) by Köksal et al. According to the research paper, a Llama 7B model trained on LongForm achieves substantially better performance than the same Llama model trained on the 2x larger Alpaca dataset.\n\nLongForm consists of 23,652 training samples, 2,042 validation samples, and 2,045 test samples. (In LitGPT, the validation samples are currently not used.)\n\nThe more detailed dataset composition is as follows based on a table taken from the [dataset repository](https://github.com/akoksal/LongForm):\n\n| **Type**               | **Source**     | **Number of Examples** |\n|------------------------|----------------|------------------------|\n| **Corpora**            | C4             | 10,000                 |\n|                        | Wikipedia      | 5,000                  |\n| **Structured Corpora** | Stack Exchange | 4,380                  |\n|                        | WikiHow        | 2,500                  |\n| **Tasks**              | NIv2           | 3,684                  |\n|                        | Big Bench      | 600                    |\n|                        | BEA-GEC        | 1,203                  |\n|                        | Enron          | 372                    |\n| **Total**              |                | 27,739                 |\n|                        |                |                        |\n| **Train**              |                | 23,652                 |\n| **Validation**         |                | 2,042                  |\n| **Test**               |                | 2,045                  |\n\nLicense information is not provided but would depend on the individual subsets listed above.\n\nThe LongForm dataset distribution is shown below.\n\n<img src=\"images/prepare_dataset/longform.jpg\" width=400px>\n\nYou may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 1500 may be a good choice:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data LongForm \\\n  --train.max_seq_length 1500\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt finetune --data.help LongForm` to list additional dataset-specific command line options.\n\n&nbsp;\n\n&nbsp;\n\n### LIMA\n\nThe LIMA dataset is a collection of 1,000 carefully curated prompts and responses, as described in the [LIMA: Less Is More for Alignment](https://arxiv.org/abs/2305.11206) paper. The dataset is sourced from three community Q&A websites: Stack Exchange, wikiHow, and the Pushshift Reddit Dataset. In addition, it also contains prompts and answers written and collected by the authors of the LIMA paper.\n\nThe usage is similar to the Dolly dataset described above except that it requires an Hugging Face access token that you need to copy & paste from your Hugging Face account. Using Falcon 7b as an example, we can use the dataset as follows:\n\n```bash\nexport HF_TOKEN=\"insert_your_huggingface_token_here\"\n\nlitgpt finetune lora \\\n  --data LIMA \\\n  --checkpoint_dir \"tiiuae/falcon-7b\"\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt finetune --data.help LIMA` to list additional dataset-specific command line options.\n\n&nbsp;\n\nLIMA contains a handful of multiturn conversations. By default, only the first instruction-response pairs from\neach of these multiturn conversations are included. If you want to override this behavior and include the follow-up instructions\nand responses, set `--data.include_multiturn_conversations True`.\n\nThe LIMA dataset distribution is shown below.\n\n<img src=\"images/prepare_dataset/lima.jpg\" width=400px>\n\nYou may want to consider truncating the dataset (see the *Truncating datasets* discussion in the Alpaca section for more information.) For this dataset, a cut-off of 512 may be a good choice:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data LIMA \\\n  --train.max_seq_length 512\n```\n\n&nbsp;\n\n### FLAN\n\nFLAN is a collection of several dataset subsets by Google. In particular, the provided script in LitGPT loads the subsets from\n[here](https://huggingface.co/datasets/Muennighoff/flan).\n\nBy default, all subsets (1,386,050 samples) and validations sets (367,190 subsets) are combined into a single dataset:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data FLAN\n```\n\nHowever, you can also select individual subsets via comma-separated strings as follows:\n\n```bash\nlitgpt finetune lora tiiuae/falcon-7b \\\n  --data FLAN \\\n  --data.subsets \"aeslc_10templates,ag_news_subset_10templates,anli_r1_10templates\"\n```\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt finetune --data.help FLAN` to list additional dataset-specific command line options.\n\n&nbsp;\n\nYou can find a list of all 66 supported subsets [here](https://huggingface.co/datasets/Muennighoff/flan).\n\n&nbsp;\n\n## Preparing Custom Datasets for Instruction Finetuning\n\nThe models in LitGPT expect datasets for instruction finetuning in the following format:\n\n```text\n[\n    {\n        \"instruction\": \"Write a limerick about a\n                        pelican.”,\n        \"input\": \"\",\n        \"output\": \"There once was a pelican so fine,\n                   \\nHis beak was as colorful as\n                   sunshine,\\nHe would fish all day,\\nIn\n                   a very unique way,\\nThis pelican was\n                   truly divine!\\n\\n\\n\"\n    },\n    {\n        \"instruction\": \"Identify the odd one out from\n                        the group.\",\n        \"input\": \"Carrot, Apple, Banana, Grape\",\n        \"output\": \"Carrot\\n\\n\"\n    },\n]\n```\n\n(Note that depending on the task, the `\"input\"` text can be an empty string, as shown above.)\n\nYou can use your own data in LitGPT by either reading in a JSON file in the format shown above or by implementing a custom `DataModule`.\n\n&nbsp;\n\n### Preparing Custom Datasets From a JSON File\n\nYou can prepare custom dataset using a JSON file where each row is a dictionary with these keys:\n\n- `instruction`: Column which will describe the task.\n- `input`: A string holding a special input value for the instruction. This applies to some samples, and in others, this is empty (empty string).\n- `output`: The expected response\n\n> If any of the fields are missing, then the script will fail to read the dataset.\n\nThen simply run any of the finetuning scripts with this input:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data JSON \\\n  --data.json_path path/to/your/data.json \\\n  --data.val_split_fraction 0.1\n```\n\nYou can also customize how the dataset is read by using these additional parameters\n\n- `val_split_fraction`: The fraction of the data to split. Defaults to `0.1`\n\n- `seed`: The seed value to reproduce the same random splits for train and test data.\n\n- `mask_inputs`: Whether to mask the prompt section from the label (with `ignore_index`).\n\n- `ignore_index`: The index to use for labels that should be ignored. Defaults to `-100` (used when `mask_inputs` is `True`).\n\nTo use the settings described above, you can add the respective command line arguments when calling the finetuning scripts as shown in the example below:\n\n```bash\nlitgpt finetune_lora tiiuae/falcon-7b \\\n  --data JSON \\\n  --data.json_path path/to/your/data.json \\\n  --data.val_split_fraction 0.1 \\\n  --data.seed 42 \\\n  --data.mask_inputs False \\\n  --data.ignore_index -100\n```\n\nYou can also pass a directory containing a `train.json` and `val.json` to `--data.json_path` to define a fixed train/val split.\n\n&nbsp;\n\n> [!TIP]\n> Use `litgpt finetune --data.help JSON` to list additional dataset-specific command line options.\n\n&nbsp;\n\n### Preparing Custom Datasets Using DataModule\n\nIf you don't have a JSON file following the format described in the previous section, the easiest way to prepare a new dataset is to copy and modify one of the existing data modules in LitGPT:\n\n- [`litgpt/data/alpaca.py`](https://github.com/Lightning-AI/litgpt/blob/main/litgpt/data/alpaca.py) (if you plan to load a dataset from a JSON file);\n- [`litgpt/data/lima.py`](https://github.com/Lightning-AI/litgpt/blob/main/litgpt/data/lima.py) (if you plan to load a dataset using the `datasets` Python library).\n\nNote that you only need to modify a small fraction of the code file, namely the portion that downloads and formats the training data (see the `prepare_data` and `setup()` methods).\n\n&nbsp;\n\n## Preparing Pretraining Datasets\n\nIn addition to the finetuning dataset described above, LitGPT also supports several datasets for pretraining. The pretraining datasets are described in more detail in the following separate tutorial documents:\n\n- [Pretrain TinyLlama on Slimpajama and Starcoder](./pretrain_tinyllama.md)\n"
  },
  {
    "path": "tutorials/pretrain.md",
    "content": "# Pretrain LLMs with LitGPT\n\n\nThis document explains how to pretrain LLMs using LitGPT.\n\n&nbsp;\n## Using the `litgpt pretrain` command\n\nYou can pretrain models in LitGPT using the `litgpt pretrain` API starting with any of the available architectures listed by calling `litgpt pretrain list` without any additional arguments:\n\n&nbsp;\n> [!TIP]\n> To install all required dependencies before pretraining, first run `pip install \"litgpt[all]\"`.\n&nbsp;\n\n```bash\nlitgpt pretrain list\n```\n\nShown below is an abbreviated list:\n\n```\nValueError: Please specify --model_name <model_name>. Available values:\nCamel-Platypus2-13B\n...\nGemma-2b\n...\nLlama-2-7b-hf\n...\nMixtral-8x7B-v0.1\n...\npythia-14m\n```\n\nFor demonstration purposes, we can pretrain a small 14 million-parameter Pythia model on the small TinyStories dataset using the [debug.yaml config file](https://github.com/Lightning-AI/litgpt/blob/main/config_hub/pretrain/debug.yaml) as follows:\n\n```bash\nlitgpt pretrain pythia-14m \\\n   --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/debug.yaml\n```\n\n\n&nbsp;\n## Pretrain on custom data\n\nThe simplest way to get started with pretraining on a small custom dataset is by using the `TextFiles` data module, which lets you pretrain a dataset from a folder containing plain text files.\n\n&nbsp;\n\n> [!NOTE]\n> This approach adds a beginning-of-sequence token at the beginning of each text file. However, it otherwise assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable in case a text file conists of multiple documents.\n\n&nbsp;\n\n> [!WARNING]\n> Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size.\n> For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData`.\n\n&nbsp;\n\nFor instance, assume you stored a number of text files in a `custom_pretraining_dataset` folder (we recommend avoiding small files and concatenating them to files of at least 50 Mb for efficiency):\n\n```bash\n~ ls -lh custom_pretraining_data\ntotal 3225M\n-rw-r--r-- 1 sebastian 50M Apr  2 18:31 combined_1.txt\n-rw-r--r-- 1 sebastian 50M Apr  2 18:31 combined_2.txt\n-rw-r--r-- 1 sebastian 50M Apr  2 18:31 combined_3.txt\n-rw-r--r-- 1 sebastian 50M Apr  2 18:31 combined_4.txt\n-rw-r--r-- 1 sebastian 50M Apr  2 18:31 combined_5.txt\n...\n```\n\nYou can then use the `TextFiles` API to pretrain a model (here a small `pythia-14m` model for illustration purposes) from scratch as follows:\n\n```bash\nlitgpt download EleutherAI/pythia-14m \\\n  --tokenizer_only true\n\nlitgpt pretrain pythia-14m \\\n   --tokenizer_dir EleutherAI/pythia-14m \\\n   --data TextFiles \\\n   --data.train_data_path custom_pretraining_data \\\n   --train.lr_warmup_steps=200 \\\n   --optimizer AdamW \\\n   --optimizer.lr 0.005\n```\n\n&nbsp;\n> [!TIP]\n> Use the `litgpt pretrain --data.help TextFiles` command to list additional dataset options.\n&nbsp;\n\n\n&nbsp;\n## Continued pretraining on custom data\n\nOften, it makes sense to adopt an existing pretrained model and further pretrain it on our own custom data. The existing pretrained model can be either our own pretrained model or a model downloaded from a model hub.\n\nThe following subsections illustrate three typical scenarioes:\n\n1. Starting from a downloaded base model\n2. Continuing the pretraining after interruption\n3. Further pretraining on a different dataset\n\n&nbsp;\n\n> [!NOTE]\n> This approach assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable.\n\n&nbsp;\n\n> [!WARNING]\n> Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size.\n> For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData --data.path path/to/your/data`.\n\n\n&nbsp;\n### 1) Continued pretraining when starting from a downloaded base model\n\n\nFor instance, let's assume we download a Pythia model:\n\n```bash\nlitgpt download EleutherAI/pythia-160m\n```\n\nNext, assume we have a custom dataset stored in text files similar to the *Pretrain on custom data* above. We can further pretrain the Pythia model via the `--initial_checkpoint_dir` setting as follows:\n\n```bash\nlitgpt pretrain pythia-160m \\\n   --initial_checkpoint_dir EleutherAI/pythia-160m \\\n   --tokenizer_dir EleutherAI/pythia-160m \\\n   --out_dir ./new_pretrained_checkpoint \\\n   --data TextFiles \\\n   --data.train_data_path custom_pretraining_data \\\n   --train.max_tokens 1_000_000\n```\n\n&nbsp;\n> [!TIP]\n> Use the `litgpt pretrain --data.help TextFiles` command to list additional dataset options.\n\n\n&nbsp;\n### 2) Continued pretraining after interruption\n\nIn case a you interrupted a training run, you can continue it with the `--resume` option, for example:\n\n```bash\nlitgpt pretrain pythia-160m \\\n   --resume \"auto\" \\\n   --tokenizer_dir EleutherAI/pythia-160m \\\n   --out_dir ./new_pretrained_checkpoint \\\n   --data TextFiles \\\n   --data.train_data_path custom_pretraining_data \\\n   --train.max_tokens 1_000_000\n```\n\n&nbsp;\n### 3) Continued pretraining on a new dataset\n\nSuppose you pretrained a model using the examples above. To further pretrain the model on a new dataset, you first need to convert the pretrained checkpoint via the following command:\n\n```bash\nlitgpt convert_pretrained_checkpoint ./new_pretrained_checkpoint/final ./new_pretrained_checkpoint_converted\n```\n\nThen, you can pretrain the converted model on the new dataset as follows:\n\n```bash\nlitgpt pretrain pythia-160m \\\n   --initial_checkpoint_dir ./new_pretrained_checkpoint_converted \\\n   --tokenizer_dir EleutherAI/pythia-160m \\\n   --out_dir ./new_pretrained_checkpoint_2 \\\n   --data TextFiles \\\n   --data.train_data_path custom_pretraining_data_2 \\\n   --train.max_tokens 1_000_000\n```\n\n\n&nbsp;\n## Pretrain a 1.1B TinyLlama model\n\nYou can find an end-to-end LitGPT tutorial for pretraining a TinyLlama model using LitGPT [here](pretrain_tinyllama.md).\n\n\n&nbsp;\n## Optimize LitGPT pretraining with Lightning Thunder\n\n[Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) is a source-to-source compiler for PyTorch, which is fully compatible with LitGPT. In experiments, Thunder resulted in a 40% speed-up compared to using regular PyTorch when finetuning a 7B Llama 2 model.\n\nFor more information, see the [Lightning Thunder extension README](https://github.com/Lightning-AI/lightning-thunder).\n\n\n&nbsp;\n## Project templates\n\nThe following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support:\n&nbsp;\n\n|                                                                                                                                                                                                                                                                                                                                             |                                                                                                                                                                                                                                                                                                                                                |\n|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| <p align=\"left\">[Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) <br> [<img src=\"https://pl-public-data.s3.amazonaws.com/assets_litgpt/readme/3.webp\" width=\"300\"></p>](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)         | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) <br> <p align=\"left\">[<img src=\"https://pl-public-data.s3.amazonaws.com/assets_litgpt/readme/4.webp\" width=\"300\"></p>](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)                                        |\n| [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) <br> <p align=\"left\">[<img src=\"https://pl-public-data.s3.amazonaws.com/assets_litgpt/readme/1.webp\" width=\"300\"></p>](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | |\n|                                                                                                                                                                                                                                                                                                                                             |\n"
  },
  {
    "path": "tutorials/pretrain_tinyllama.md",
    "content": "# Pretrain TinyLlama\n\nThis tutorial will walk you through pretraining [TinyLlama](https://github.com/jzhang38/TinyLlama/).\n\n> [!TIP]\n> To get started with zero setup, clone the [TinyLlama studio on Lightning AI](https://lightning.ai/lightning-ai/studios/llm-pretrain-tinyllama-1-1b).\n\n&nbsp;\n## What's TinyLlama?\n\n[TinyLlama](https://github.com/jzhang38/TinyLlama/) is architecturally the same as Meta AI's LLama 2, but only has 1.1B parameters and is instead trained on multiple epochs on a mix of [SlimPajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) and [Starcoder](https://huggingface.co/datasets/bigcode/starcoderdata) datasets.\n\nHere is a quick fact sheet:\n\n| Name                          | Description                                                                                                                                                  |\n|-------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| Parameters                    | 1.1B                                                                                                                                                         |\n| Model Size                    | Layers: 22, Heads: 32, Query Groups: 4, Embedding Size: 2048, Intermediate Size: 5632                                                                        |\n| Sequence Length               | 2048                                                                                                                                                         |\n| Learning Rate                 | 4e-4                                                                                                                                                         |\n| Learning Rate Schedule        | Cosine with 2000 warmup steps                                                                                                                                |\n| Training Data                 | [SlimPajama](https://huggingface.co/datasets/cerebras/slimpajama-627b) (893 GB), [Starcoder](https://huggingface.co/datasets/bigcode/starcoderdata) (290 GB) |\n| Combined Dataset Size         | Around 950B tokens                                                                                                                                           |\n| Total Tokens During Training  | 3 trillion (3 epochs)                                                                                                                                        |\n| Time to complete training     | ~ 4 weeks with 64 A100 GPUs                                                                                                                                  |\n| Model FLOPs Utilization (MFU) | 52%                                                                                                                                                          |\n\n(this table was sourced from the author's [README](https://github.com/jzhang38/TinyLlama/))\n\n&nbsp;\n## Download datasets\n\nYou can download the data using git lfs:\n\n```bash\n# Make sure you have git-lfs installed (https://git-lfs.com):\nsudo apt install git-lfs\n```\n\n```bash\ngit clone https://huggingface.co/datasets/cerebras/slimpajama-627b data/slimpajama-raw\ngit clone https://huggingface.co/datasets/bigcode/starcoderdata data/starcoderdata-raw\n```\n\nAround 1.2 TB of disk space is required to store both datasets.\n\n&nbsp;\n## Prepare the datasets for training\n\nIn order to start pretraining litgpt on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `litdata` optimization pipeline and streaming dataset.\n\nFirst, install additional dependencies for preprocessing:\n\n```bash\npip install '.[all]'\n```\n\nYou will need to have the tokenizer config available:\n\n```bash\nlitgpt download meta-llama/Llama-2-7b-hf \\\n   --access_token your_hf_token \\\n   --tokenizer_only true\n```\n\nThen, run the preprocessing script for each dataset and split.\nYou will require **1.1 TB** of disk space for Starcoder and **2.5** TB of space for the SlimPajama dataset.\n\n**Starcoder:**\n\n```bash\npython litgpt/data/prepare_starcoder.py \\\n  --input_dir data/starcoderdata-raw \\\n  --output_dir data/starcoder \\\n  --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf\n```\n\n**SlimPajama:**\n\n```bash\npython litgpt/data/prepare_slimpajama.py \\\n  --input_dir data/slimpajama-raw/validation \\\n  --output_dir data/slimpajama/val \\\n  --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf\n\npython litgpt/data/prepare_slimpajama.py \\\n  --input_dir data/slimpajama-raw/test \\\n  --output_dir data/slimpajama/test \\\n  --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf\n\npython litgpt/data/prepare_slimpajama.py \\\n  --input_dir data/slimpajama-raw/train \\\n  --output_dir data/slimpajama/train \\\n  --tokenizer_path checkpoints/meta-llama/Llama-2-7b-hf\n```\n\nIf you want to run on a small slice of the datasets first, pass the flag `--fast_dev_run=true` to the commands above.\nIn the above we are assuming that you will be using the same tokenizer as used in LlaMA/TinyLlama, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here.\n\n&nbsp;\n## Pretraining\n\nRunning the pretraining script with its default settings requires at least 8 A100 GPUs.\n\n```bash\nlitgpt pretrain --config config_hub/pretrain/tinyllama.yaml\n```\n\n&nbsp;\n> [!TIP]\n> Use the `litgpt pretrain --data.help TinyLlama` command to list additional dataset options.\n&nbsp;\n\n\nThe script will save checkpoints periodically to the folder `out/`.\nBy default, the `pretrain` script will pretrain the model with FSDP in\n`bfloat16` mixed precision and gradient accumulation.\n\nNote that `pretrain` is not actually a model-specific training script, so feel free [try other configurations](../config_hub)\nor change the model type and size by passing a different string to the model name argument, for example:\n\n```shell\nlitgpt pretrain Gemma-2b\n```\n\nThe currently supported model names can be listed by executing `litgpt pretrain` without any additional arguments.\n\nKeep in mind that training with a single machine will take weeks. To speed up the process, you'll need access to a cluster.\nOnce you're in a cluster, you can follow [these instructions](https://lightning.ai/docs/fabric/stable/fundamentals/launch.html#launch-on-a-cluster)\nto launch the script across machines:\n\n- [Lightning AI](https://lightning.ai/docs/fabric/stable/guide/multi_node/cloud.html)\n- [SLURM cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html)\n- [Barebones cluster](https://lightning.ai/docs/fabric/stable/guide/multi_node/barebones.html)\n- [MPI](https://lightning.ai/docs/fabric/stable/guide/multi_node/other.html)\n\nThe script exposes several hyperparameters you can tweak through the command line.\n\nFor instance, `--train.micro_batch_size` should be adjusted so the process will use the available\nGPU memory. For more tips to avoid out-of-memory issues, please also see the more detailed\n[Dealing with out-of-memory (OOM) errors](oom.md) guide.\n\nLast, logging is kept minimal in the script, but for long-running experiments we recommend switching to a proper experiment tracker.\nLitGPT supports multiple experiment trackers including:\n\n- **TensorBoard** (default): Local visualization with TensorBoard\n- **CSV Logger**: Simple local logging to CSV files\n- **WandB**: Cloud-based experiment tracking with Weights & Biases\n- **MLflow**: MLflow experiment tracking\n- **[LitLogger](https://github.com/Lightning-AI/LitLogger)**: Lightning.ai's native experiment tracking (set `--logger_name=litlogger`)\n\nAs an example, we included WandB (set `--logger_name=wandb`) to show how you can integrate any experiment tracking framework.\nFor reference, [here are the loss curves for our reproduction](https://api.wandb.ai/links/awaelchli/y7pzdpwy).\n\n&nbsp;\n## Resume training\n\nThe checkpoints saved during pretraining contain all the information to resume if needed.\nSimply rerun the script with the `--resume` argument added:\n\n```bash\nlitgpt pretrain tiny-llama\\\n  --config config_hub/pretrain/tinyllama.yaml \\\n  --resume out/pretrain/tiny-llama/step-00060500\n```\n**Important:** Each checkpoint is a directory. Point to the directory, not the 'lit_model.pth' file inside of it.\n\n&nbsp;\n> [!TIP]\n> Use the `litgpt pretrain --data.help TinyLlama` command to list additional dataset options.\n&nbsp;\n\n\n&nbsp;\n## Export checkpoints\n\nAfter training is completed, you can convert the checkpoint to a format that can be loaded for evaluation, inference, finetuning etc.\n\n```bash\nlitgpt convert_pretrained_checkpoint out/pretrain/tiny-llama/step-00060500 \\\n  --output_dir checkpoints/tiny-llama/final\n```\n\nAfter conversion, the output folder will contain these files:\n```\ncheckpoints/tiny-llama/final\n├── model_config.yaml\n├── lit_model.pth\n├── tokenizer_config.json\n├── tokenizer.json\n└── tokenizer.model\n```\n\nYou can then use this checkpoint folder to run [evaluation](evaluation.md), [inference](inference.md), [finetuning](finetune_lora.md) or [process the checkpoint further](convert_lit_models.md).\n\n\n&nbsp;\n## Project templates\n\nThe following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support:\n&nbsp;\n\n|                                                                                                                                                                                                                                                                                                                                             |                                                                                                                                                                                                                                                                                                                                                |\n|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| <p align=\"left\">[Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) <br> [<img src=\"https://pl-public-data.s3.amazonaws.com/assets_litgpt/readme/3.webp\" width=\"300\"></p>](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)         | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) <br> <p align=\"left\">[<img src=\"https://pl-public-data.s3.amazonaws.com/assets_litgpt/readme/4.webp\" width=\"300\"></p>](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)                                        |\n| [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) <br> <p align=\"left\">[<img src=\"https://pl-public-data.s3.amazonaws.com/assets_litgpt/readme/1.webp\" width=\"300\"></p>](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | |\n|                                                                                                                                                                                                                                                                                                                                             |\n"
  },
  {
    "path": "tutorials/python-api.md",
    "content": "# LitGPT Python API\n\nThis is a work-in-progress draft describing the current LitGPT Python API (experimental and subject to change).\n\n\n## Model loading\n\nUse the `LLM.load` method to load a model from a LitGPT model checkpoint folder. For example, consider loading a Phi-2 model. If a given checkpoint directory `\"microsoft/phi-2\"` does not exist as a local checkpoint directory, the model will be downloaded automatically from the HF Hub (assuming that `\"microsoft/phi-2\"` is a valid repository name):\n\n```python\nfrom litgpt import LLM\n\nllm_1 = LLM.load(\"microsoft/phi-2\")\n```\n\n```\nconfig.json: 100%|████████████████████████████████████████████████| 735/735 [00:00<00:00, 7.75MB/s]\ngeneration_config.json: 100%|█████████████████████████████████████| 124/124 [00:00<00:00, 2.06MB/s]\nmodel-00001-of-00002.safetensors: 100%|███████████████████████████| 5.00G/5.00G [00:12<00:00, 397MB/s]\nmodel-00002-of-00002.safetensors: 100%|███████████████████████████| 564M/564M [00:01<00:00, 421MB/s]\nmodel.safetensors.index.json: 100%|███████████████████████████████| 35.7k/35.7k [00:00<00:00, 115MB/s]\ntokenizer.json: 100%|█████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 21.5MB/s]\ntokenizer_config.json: 100%|██████████████████████████████████████| 7.34k/7.34k [00:00<00:00, 80.6MB/s]\n```\n\n&nbsp;\n> [!NOTE]\n> To get a list of all supported models, execute `litgpt download list` in the command line terminal.\n&nbsp;\n<br>\n\n\nIf you attempt to load the model again, LitGPT will load this model from a local directory since it's already been downloaded:\n\n```python\nllm_2 = LLM.load(\"microsoft/phi-2\")\n```\n\n\nIf you created a pretrained or finetuned model checkpoint via LitGPT, you can load it in a similar fashion:\n\n```python\nmy_llm = LLM.load(\"path/to/my/local/checkpoint\")\n```\n\n\n\n\n&nbsp;\n## Generate/Chat\n\nGenerate output using the `.generate` method:\n\n```python\nfrom litgpt import LLM\n\nllm = LLM.load(\"microsoft/phi-2\")\n\ntext = llm.generate(\"What do Llamas eat?\", top_k=1, max_new_tokens=30)\nprint(text)\n```\n\n```\nLlamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized digestive system that allows them to efficiently extract\n```\n\nAlternatively, stream the response one token at a time:\n\n```python\nresult = llm.generate(\"hi\", stream=True)\nfor e in result:\n    print(e, end=\"\", flush=True)\n```\n\n```\nLlamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized digestive system that allows them to efficiently extract\n```\n\n\n&nbsp;\n## Saving models\n\nAfter finetuning or modifying a model, you can save it to disk using the `.save()` method:\n\n```python\nfrom litgpt import LLM\n\nllm = LLM.load(\"microsoft/phi-2\")\n# ... perform finetuning or modifications ...\nllm.save(\"path/to/save/directory\")\n```\n\nThe saved checkpoint can then be loaded later:\n\n```python\nllm = LLM.load(\"path/to/save/directory\")\n```\n\n\n&nbsp;\n## Random weights\n\nTo start with random weights, for example, if you plan a pretraining script, initialize the model with `init=\"random\"`. Note that this requires passing a `tokenizer_dir` that contains a valid tokenizer file.\n\n```python\nfrom litgpt.api import LLM\nllm = LLM.load(\"pythia-160m\", init=\"random\", tokenizer_dir=\"EleutherAI/pythia-160m\")\n```\n\n\n\n&nbsp;\n## Multi-GPU strategies\n\nBy default, the model is loaded onto a single GPU. Optionally, you can use the `.distribute()` method with the \"sequential\" or \"tensor_parallel\" `generate_strategy` settings.\n\n### Sequential strategy\n\nThe `generate_strategy=\"sequential\"` setting loads different parts of the models onto different GPUs. The goal behind this strategy is to support models that cannot fit into single-GPU memory. (Note that if you have a model that can fit onto a single GPU, this sequential strategy will be slower.)\n\n```python\nfrom litgpt.api import LLM\n\nllm = LLM.load(\n    \"microsoft/phi-2\",\n    distribute=None\n)\n\nllm.distribute(\n    generate_strategy=\"sequential\",\n    devices=4,  # Optional setting, otherwise uses all available GPUs\n    fixed_kv_cache_size=256  # Optionally use a small kv-cache to further reduce memory usage\n)\n```\n\n```\nUsing 4 devices\nMoving '_forward_module.transformer.h.31' to cuda:3: 100%|██████████| 32/32 [00:00<00:00, 32.71it/s]\n```\n\nAfter initializing the model, the model can be used via the `generate` method similar to the default `generate_strategy` setting:\n\n```python\ntext = llm.generate(\"What do llamas eat?\", max_new_tokens=100)\nprint(text)\n```\n\n```\n Llamas are herbivores and their diet consists mainly of grasses, plants, and leaves.\n```\n\n&nbsp;\n### Tensor parallel strategy\n\nThe sequential strategy explained in the previous subsection distributes the model sequentially across GPUs, which allows users to load models that would not fit onto a single GPU. However, due to this method's sequential nature, processing is naturally slower than parallel processing.\n\nTo take advantage of parallel processing via tensor parallelism, you can use the `generate_strategy=\"tensor_parallel\" setting. However, this method has downsides: the initial setup may be slower for large models, and it cannot run in interactive processes such as Jupyter notebooks.\n\n```python\nfrom litgpt.api import LLM\n\n\nif __name__ == \"__main__\":\n\n    llm = LLM.load(\n        model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        distribute=None\n    )\n\n    llm.distribute(generate_strategy=\"tensor_parallel\", devices=4)\n\n    print(llm.generate(prompt=\"What do llamas eat?\"))\n    print(llm.generate(prompt=\"What is 1+2?\", top_k=1))\n```\n\n\n&nbsp;\n## Speed and resource estimates\n\nUse the `.benchmark()` method to compare the computational performance of different settings. The `.benchmark()` method takes the same arguments as the `.generate()` method. For example, we can estimate the speed and GPU memory consumption as follows (the resulting numbers were obtained on an A10G GPU):\n\n```python\nfrom litgpt.api import LLM\nfrom pprint import pprint\n\nllm = LLM.load(\n    model=\"microsoft/phi-2\",\n    distribute=None\n)\n\nllm.distribute(fixed_kv_cache_size=500)\n\ntext, bench_d = llm.benchmark(prompt=\"What do llamas eat?\", top_k=1, stream=True)\nprint(text)\npprint(bench_d)\n\n\n# Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a specialized\n# digestive system that allows them to efficiently extract nutrients from plant material.\n\n# Using 1 device(s)\n#  Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a unique digestive system that allows them to efficiently extract nutrients from tough plant material.\n\n# {'Inference speed in tokens/sec': [17.617540650112936],\n#  'Seconds to first token': [0.6533610639999097],\n#  'Seconds total': [1.4758019020000575],\n#  'Tokens generated': [26],\n#  'Total GPU memory allocated in GB': [5.923729408]}\n```\n\nTo get more reliably estimates, it's recommended to repeat the benchmark for multiple iterations via `num_iterations=10`:\n\n```python\ntext, bench_d = llm.benchmark(num_iterations=10, prompt=\"What do llamas eat?\", top_k=1, stream=True)\nprint(text)\npprint(bench_d)\n\n# Using 1 device(s)\n#  Llamas are herbivores and primarily eat grass, leaves, and shrubs. They have a unique digestive system that allows them to efficiently extract nutrients from tough plant material.\n\n# {'Inference speed in tokens/sec': [17.08638672485105,\n#                                    31.79908547222976,\n#                                    32.83646959864293,\n#                                    32.95994240022436,\n#                                    33.01563039816964,\n#                                    32.85263413816648,\n#                                    32.82712094713627,\n#                                    32.69216141907453,\n#                                    31.52431714347663,\n#                                    32.56752130561681],\n#  'Seconds to first token': [0.7278506560005553,\n#                             0.022963577999689733,\n#                             0.02399449199947412,\n#                             0.022921959999621322,\n# ...\n```\n\nAs one can see, the first iteration may take longer due to warmup times. So, it's recommended to discard the first iteration:\n\n```python\nfor key in bench_d:\n    bench_d[key] = bench_d[key][1:]\n```\n\nFor better visualization, you can use the `benchmark_dict_to_markdown_table` function\n\n```python\nfrom litgpt.api import benchmark_dict_to_markdown_table\n\nprint(benchmark_dict_to_markdown_table(bench_d_list))\n```\n\n| Metric                              | Mean                        | Std Dev                     |\n|-------------------------------------|-----------------------------|-----------------------------|\n| Seconds total                       | 0.80                        | 0.01                        |\n| Seconds to first token              | 0.02                        | 0.00                        |\n| Tokens generated                    | 26.00                       | 0.00                        |\n| Inference speed in tokens/sec       | 32.56                       | 0.50                        |\n| Total GPU memory allocated in GB    | 5.92                        | 0.00                        |\n\n\n&nbsp;\n# PyTorch Lightning Trainer support\n\nYou can use the LitGPT `LLM` class with the [PyTorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to pretrain and finetune models.\n\nThe examples below show the usage via a simple 160 million parameter model for demonstration purposes to be able to quickly try it out. However, you can replace the `EleutherAI/pythia-160m` model with any model supported by LitGPT (you can find a list of supported models by executing `litgpt download list` or visiting the [model weight docs](download_model_weights.md)).\n\n&nbsp;\n## Step 1: Define a `LightningModule`\n\nFirst, we define a `LightningModule` similar to what we would do when working with other types of neural networks in PyTorch Lightning:\n\n\n```python\nimport torch\nimport litgpt\nfrom litgpt import LLM\nfrom litgpt.data import Alpaca2k\nimport lightning as L\n\n\nclass LitLLM(L.LightningModule):\n    def __init__(self, checkpoint_dir, tokenizer_dir=None, trainer_ckpt_path=None):\n        super().__init__()\n\n        self.llm = LLM.load(checkpoint_dir, tokenizer_dir=tokenizer_dir, distribute=None)\n        self.trainer_ckpt_path = trainer_ckpt_path\n\n    def setup(self, stage):\n        self.llm.trainer_setup(trainer_ckpt=self.trainer_ckpt_path)\n\n    def training_step(self, batch):\n        logits, loss = self.llm(input_ids=batch[\"input_ids\"], target_ids=batch[\"labels\"])\n        self.log(\"train_loss\", loss, prog_bar=True)\n        return loss\n\n    def validation_step(self, batch):\n        logits, loss = self.llm(input_ids=batch[\"input_ids\"], target_ids=batch[\"labels\"])\n        self.log(\"validation_loss\", loss, prog_bar=True)\n        return loss\n\n    def configure_optimizers(self):\n        warmup_steps = 10\n        optimizer = torch.optim.AdamW(self.llm.model.parameters(), lr=0.0002, weight_decay=0.0, betas=(0.9, 0.95))\n        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / warmup_steps)\n        return [optimizer], [scheduler]\n```\n\nIn the code example above, note how we set `distribute=None` in `llm.load()` in the `__init__` method. This step is necessary because we want to let the PyTorch Lightning Trainer handle the GPU devices. We then call `self.llm.trainer_setup` in the `setup()` method, which adjusts the LitGPT settings to be compatible with the Trainer. Other than that, everything else looks like a standard `LightningModule`.\n\nNext, we have a selection of different use cases, but first, let's set some general settings to specify the batch size and gradient accumulation steps:\n\n```python\nbatch_size = 8\naccumulate_grad_batches = 1\n```\n\nFor larger models, you may want to decrease the batch size and increase the number of accumulation steps. (Setting `accumulate_grad_batches = 1` effectively disables gradient accumulation, and it is only shown here for reference in case you wish to change this setting.)\n\n## Step 2: Using the Trainer\n\n&nbsp;\n### Use case 1: Pretraining from random weights\n\nIn case you plan to train a model from scratch (not recommended over finetuning because training a model from scratch in general requires substantial time and resources), you can do it as follows:\n\n```python\n# Create model with random as opposed to pretrained weights\nllm = LLM.load(\"EleutherAI/pythia-160m\", tokenizer_dir=\"EleutherAI/pythia-160m\", init=\"random\")\nllm.save(\"pythia-160m-random-weights\")\ndel llm\n\nlit_model = LitLLM(checkpoint_dir=\"pythia-160m-random-weights\", tokenizer_dir=\"EleutherAI/pythia-160m\")\ndata = Alpaca2k()\n\ndata.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\ntrainer = L.Trainer(\n    devices=1,\n    accelerator=\"cuda\",\n    max_epochs=1,\n    accumulate_grad_batches=accumulate_grad_batches,\n    precision=\"bf16-true\",\n)\ntrainer.fit(lit_model, data)\n\nlit_model.llm.model.to(lit_model.llm.preprocessor.device)\nlit_model.llm.generate(\"hello world\")\n```\n\n&nbsp;\n### Use case 2: Continued pretraining or finetuning a downloaded model\n\nThe continued pretraining or finetuning from a downloaded model checkpoint is similar to the example above, except that we can skip the initial steps of instantiating a model with random weights.\n\n```python\n\nlit_model = LitLLM(checkpoint_dir=\"EleutherAI/pythia-160m\")\ndata = Alpaca2k()\n\ndata.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\ntrainer = L.Trainer(\n    devices=1,\n    accelerator=\"cuda\",\n    max_epochs=1,\n    accumulate_grad_batches=accumulate_grad_batches,\n    precision=\"bf16-true\",\n)\ntrainer.fit(lit_model, data)\n\nlit_model.llm.model.to(lit_model.llm.preprocessor.device)\nlit_model.llm.generate(\"hello world\")\n```\n\n&nbsp;\n### Use case 3: Resume training from Trainer checkpoint\n\nSuppose you trained a model and decide to follow up with a few additional training rounds. This can be achieved as follows by loading an existing Trainer checkpoint:\n\n```python\n\nimport os\n\ndef find_latest_checkpoint(directory):\n    latest_checkpoint = None\n    latest_time = 0\n\n    for root, _, files in os.walk(directory):\n        for file in files:\n            if file.endswith('.ckpt'):\n                file_path = os.path.join(root, file)\n                file_time = os.path.getmtime(file_path)\n                if file_time > latest_time:\n                    latest_time = file_time\n                    latest_checkpoint = file_path\n\n    return latest_checkpoint\n\nlit_model = LitLLM(checkpoint_dir=\"EleutherAI/pythia-160m\", trainer_ckpt_path=find_latest_checkpoint(\"lightning_logs\"))\n\ndata.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\ntrainer = L.Trainer(\n    devices=1,\n    accelerator=\"cuda\",\n    max_epochs=1,\n    accumulate_grad_batches=accumulate_grad_batches,\n    precision=\"bf16-true\",\n)\ntrainer.fit(lit_model, data)\n\nlit_model.llm.model.to(lit_model.llm.preprocessor.device)\nlit_model.llm.generate(\"hello world\")\n```\n\n&nbsp;\n### Use case 4: Resume training after saving a checkpoint manually\n\nThis example illustrates how we can save a LitGPT checkpoint from a previous training run that we can load and use later. Note that compared to using the Trainer checkpoint in the previous section, the model saved via this approach also contains the tokenizer and other relevant files. Hence, this approach does not require the original `\"EleutherAI/pythia-160m\"` model checkpoint directory.\n\n```python\nlit_model.llm.save(\"finetuned_checkpoint\")\ndel lit_model\nlit_model = LitLLM(checkpoint_dir=\"finetuned_checkpoint\")\n\ndata.connect(lit_model.llm.tokenizer, batch_size=batch_size, max_seq_length=512)\n\ntrainer = L.Trainer(\n    devices=1,\n    accelerator=\"cuda\",\n    max_epochs=1,\n    accumulate_grad_batches=accumulate_grad_batches,\n    precision=\"bf16-true\",\n)\ntrainer.fit(lit_model, data)\n\nlit_model.llm.model.to(lit_model.llm.preprocessor.device)\nlit_model.llm.generate(\"hello world\")\n```\n"
  },
  {
    "path": "tutorials/quantize.md",
    "content": "# Quantize the model\n\nThis document provides different strategies for quantizing the various models available in LitGPT to reduce GPU memory usage, which is useful for running larger models on certain GPU hardware.\n\n**All the examples below were run on an A100 40GB GPU with CUDA 12.1.**\n\n> [!NOTE]\n> Quantization also supports finetuning via [QLoRA](finetune_lora.md)\n\n## Baseline\n\nIt's useful to start with a baseline to have a reference point for memory savings via the various quantization methods.\n\n```bash\nlitgpt generate tiiuae/falcon-7b \\\n  --precision 32-true \\\n  --max_new_tokens 256\n...\nTime for inference 1: 6.93 sec total, 36.96 tokens/sec.\nMemory used: 28.95 GB\n```\n\nFirst, using a lower precision compared to 32-bit float can result in two times reduced memory consumption. You can either try setting `--precision 16-true` for regular 16-bit precision or  `--precision bf16-true` if your GPU supports brain-float 16-bit precision. ([This brief video](https://lightning.ai/courses/deep-learning-fundamentals/9.0-overview-techniques-for-speeding-up-model-training/unit-9.1-accelerated-model-training-via-mixed-precision-training/) explains the difference between regular 16-bit and bf16-bit precision.)\n\nIn short, when `--precision bf16-true` or `--precision 16-true` is used, the model weights will automatically be converted and consume less memory.\nHowever, this might not be enough for large models or when using GPUs with limited memory.\n\n```bash\nlitgpt generate tiiuae/falcon-7b \\\n  --precision bf16-true \\\n  --max_new_tokens 256\n...\nTime for inference 1: 5.37 sec total, 47.66 tokens/sec.\nMemory used: 14.50 GB\n```\n\nTo reduce the memory requirements further, LitGPT supports several quantization techniques, which are shown below.\n\n> [!TIP]\n> Most quantization examples below also use the `--precision bf16-true` setting explained above. If your GPU does not support `bfloat16`, you can change it to `--precision 16-true`.\n\n## `bnb.nf4`\n\nEnabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.\n\n> [!IMPORTANT]\n> `bitsandbytes` only supports `CUDA` devices and the `Linux` operating system.\n> Windows users should use [WSL2](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl).\n\nUses the normalized float 4 (nf4) data type. This is recommended over \"fp4\" based on the paper's experimental results and theoretical analysis.\n\n```bash\npip install bitsandbytes\n\nlitgpt generate tiiuae/falcon-7b \\\n  --quantize bnb.nf4 \\\n  --precision bf16-true \\\n  --max_new_tokens 256\n...\nTime for inference 1: 6.80 sec total, 37.62 tokens/sec\nMemory used: 5.72 GB\n```\n\n## `bnb.nf4-dq`\n\nEnabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.\n\n\"dq\" stands for \"Double Quantization\" which reduces the average memory footprint by quantizing the quantization constants.\nIn average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).\n\n```bash\npip install bitsandbytes\n\nlitgpt generate tiiuae/falcon-7b \\\n  --quantize bnb.nf4-dq \\\n  --precision bf16-true \\\n  --max_new_tokens 256\n\n...\nTime for inference 1: 8.09 sec total, 30.87 tokens/sec\nMemory used: 5.38 GB\n```\n\n## `bnb.fp4`\n\nEnabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.\n\nUses pure FP4 quantization.\n\n```bash\npip install bitsandbytes\n\nlitgpt generate tiiuae/falcon-7b \\\n  --quantize bnb.fp4 \\\n  --precision bf16-true \\\n  --max_new_tokens 256\n...\nTime for inference 1: 6.92 sec total, 36.98 tokens/sec\nMemory used: 5.72 GB\n```\n\n## `bnb.fp4-dq`\n\nEnabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2305.14314v1) to learn more about how it works.\n\n\"dq\" stands for \"Double Quantization\" which reduces the average memory footprint by quantizing the quantization constants.\nIn average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).\n\n```bash\npip install bitsandbytes\n\nlitgpt generate tiiuae/falcon-7b \\\n  --quantize bnb.fp4-dq \\\n  --precision bf16-true \\\n  --max_new_tokens 256\n...\nTime for inference 1: 10.02 sec total, 25.54 tokens/sec\nMemory used: 5.38 GB\n```\n\n## `bnb.int8`\n\nEnabled with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). Check out the [paper](https://arxiv.org/abs/2110.02861) to learn more about how it works.\n\n```bash\npip install bitsandbytes\n\nlitgpt generate tiiuae/falcon-7b \\\n  --quantize bnb.int8 \\\n  --precision 16-true \\\n  --max_new_tokens 256\n...\nTime for inference 1: 20.22 sec total, 12.66 tokens/sec\nMemory used: 8.70 GB\n```\n"
  },
  {
    "path": "tutorials/resource-tables.md",
    "content": "# Resource Tables\n\n- Last updated: 10/20/2023\n- LitGPT version: commit 8641822\n- Hardware: NVIDIA A100-SXM4-40GB\n- OS: Ubuntu 22.04.3 LTS (x86_64)\n- Nvidia driver version: 525.125.06\n- Relevant libraries\n  - PyTorch 2.1.0+cu121\n  - Bitsandbytes 0.41.1\n\nThis document provides an overview and examples of hardware requirements when running models in LitGPT.\n\nFor additional tips on lowering the GPU memory footprint, please also see the [Dealing with out-of-memory (OOM) errors](oom.md) document.\n\nAll experiments were run using 16-bit brain floating point precision (`--precision bf16-true`). If your GPU does not support brain floating point precision, you can use regular 16-bit floating point precision (`--precision 16-true`).\n\nAll experiments were conducted using the Alpaca dataset with its default length. Note that due to different tokenizers being used by the different models, the number of tokens in the longest training example differs based on the model:\n\n- phi1.5: 1044 tokens\n- StableLM Alpha: 1034 tokens\n- Llama 2: 1304 tokens\n- Falcon 1079 tokens\n\nNote that the number of tokens in the training set does not affect the supported context width (block size) of the models, which is as follows:\n\n- phi1.5: 2048 tokens\n- StableLM 3B Alpha: 4096 tokens\n- Llama 2: 4048 tokens\n- Falcon: 2048 tokens\n- CodeLlama 13B: 16384 tokens\n\n&nbsp;\n\n## Finetuning with LoRA on 1 GPU\n\nThe following experiments were conducted on 1xA100 with a minibatch size of 128 using the `litgpt finetune_lora` command.\n\n| Size  | Model          | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations |\n|-------|----------------|--------------|-----------------|----------------------|-------------|--------------------|\n| 1.3 B | phi-1.5        | None         | 1               | 1,572,864            | 4.82 GB     | 1.62 min           |\n| 1.3 B | phi-1.5        | bnb.nf4      | 1               | 1,572,864            | 3.78 GB     | 1.77 min           |\n| 1.3 B | phi-1.5        | bnb.nf4-dq   | 1               | 1,572,864            | 3.72 GB     | 1.87 min           |\n| 1.3 B | phi-1.5        | None         | 2               | 1,572,864            | 6.76 GB     | 1.65 min           |\n| 1.3 B | phi-1.5        | None         | 4               | 1,572,864            | 10.68 GB    | 1.70 min           |\n|       |                |              |                 |                      |             |                    |\n| 3 B   | StableLM Alpha | None         | 1               | 2,097,152            | 9.69 GB     | 1.24 min           |\n| 3 B   | StableLM Alpha | bnb.nf4      | 1               | 2,097,152            | 6.35 GB     | 1.82 min           |\n| 3 B   | StableLM Alpha | bnb.nf4-dq   | 1               | 2,097,152            | 6.19 GB     | 1.87 min           |\n| 3 B   | StableLM Alpha | None         | 2               | 2,097,152            | 12.10 GB    | 1.33 min           |\n| 3 B   | StableLM Alpha | None         | 4               | 2,097,152            | 16.92 GB    | 1.50 min           |\n|       |                |              |                 |                      |             |                    |\n| 7 B   | Llama 2        | None         | 1               | 4,194,304            | 21.30 GB    | 2.36 min           |\n| 7 B   | Llama 2        | bnb.nf4      | 1               | 4,194,304            | 14.14 GB    | 3.68 min           |\n| 7 B   | Llama 2        | bnb.nf4-dq   | 1               | 4,194,304            | 13.84 GB    | 3.83 min           |\n| 7 B   | Llama 2        | None         | 2               | 4,194,304            | 29.07 GB    | 2.52 min           |\n| 7 B   | Llama 2        | None         | 4               | 4,194,304            | OOM         | -                  |\n|       |                |              |                 |                      |             |                    |\n| 13 B  | Llama 2        | None         | 1               | 6,553,600            | 38.12 GB    | 3.19 min           |\n| 13 B  | Llama 2        | bnb.nf4      | 1               | 6,553,600            | 23.14 GB    | 6.38 min           |\n| 13 B  | Llama 2        | bnb.nf4-dq   | 1               | 6,553,600            | 22.55 GB    | 6.55 min           |\n| 13 B  | Llama 2        | None         | 2               | 6,553,600            | OOM         | -                  |\n| 13 B  | Llama 2        | None         | 4               | 6,553,600            | OOM         | -                  |\n|       |                |              |                 |                      |             |                    |\n| 40 B  | Falcon         | None         | 1               | 12,042,240           | OOM         | -                  |\n| 40 B  | Falcon         | bnb.nf4      | 1               | 12,042,240           | OOM         | -                  |\n| 40 B  | Falcon         | bnb.nf4-dq   | 1               | 12,042,240           | OOM         | -                  |\n\n&nbsp;\n\n## Finetuning with Adapter on 1 GPU\n\nThe following experiments were conducted on 1xA100 with a minibatch size of 128 using the `litgpt finetune_adapter` command.\n\n| Size | Model          | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations |\n|------|----------------|--------------|-----------------|----------------------|-------------|--------------------|\n| 3 B  | StableLM Alpha | None         | 1               | 573,888              | 9.10 GB     | 0.74 min           |\n| 3 B  | StableLM Alpha | bnb.nf4      | 1               | 573,888              | 5.65 GB     | 1.38 min           |\n| 3 B  | StableLM Alpha | bnb.nf4-dq   | 1               | 573,888              | 5.48 GB     | 1.46 min           |\n|      |                |              |                 |                      |             |                    |\n| 7 B  | Llama 2        | None         | 1               | 1,229,760            | 19.98 GB    | 1.50 min           |\n| 7 B  | Llama 2        | bnb.nf4      | 1               | 1,229,760            | 12.68 GB    | 2.93 min           |\n| 7 B  | Llama 2        | bnb.nf4-dq   | 1               | 1,229,760            | 12.38 GB    | 3.00 min           |\n\nThe same config, but using the `litgpt finetune_adapter_v2` command.\n\n| Size | Model          | Quantization | Microbatch size | Trainable parameters | Max GPU RAM | Time 1k iterations |\n|------|----------------|--------------|-----------------|----------------------|-------------|--------------------|\n| 3 B  | StableLM Alpha | None         | 1               | 2,125,248            | 10.71 GB    | 0.87 min           |\n| 3 B  | StableLM Alpha | bnb.nf4      | 1               | 2,125,248            | 7.41 GB     | 1.59 min           |\n| 3 B  | StableLM Alpha | bnb.nf4-dq   | 1               | 2,125,248            | 7.25 GB     | 1.62 min           |\n|      |                |              |                 |                      |             |                    |\n| 7 B  | Llama 2        | None         | 1               | 4,279,744            | 25.51 GB    | 1.81 min           |\n| 7 B  | Llama 2        | bnb.nf4      | 1               | 4,279,744            | 18.30 GB    | 3.23 min           |\n| 7 B  | Llama 2        | bnb.nf4-dq   | 1               | 4,279,744            | 17.98 GB    | 3.32 min           |\n\n&nbsp;\n\n## Finetuning with LoRA on Multiple GPUs\n\nThe following experiments were conducted on multiple A100 GPUs with a minibatch size of 128 using the `litgpt finetune_lora` command.\n\n| Size  | Model          | Quantization | Microbatch size | Trainable parameters | GPU      | Max GPU RAM | Time 1k iterations |\n|-------|----------------|--------------|-----------------|----------------------|----------|-------------|--------------------|\n| 1.3 B | phi-1.5        | None         | 1               | 1,572,864            | 2 x A100 | 4.86 GB     | 3.81 min           |\n| 1.3 B | phi-1.5        | bnb.nf4      | 1               | 1,572,864            | 2 x A100 | N/A         | -                  |\n| 1.3 B | phi-1.5        | bnb.nf4-dq   | 1               | 1,572,864            | 2 x A100 | N/A         | -                  |\n| 1.3 B | phi-1.5        | None         | 2               | 1,572,864            | 2 x A100 | 5.05 GB     | 3.63 min           |\n| 1.3 B | phi-1.5        | None         | 4               | 1,572,864            | 2 x A100 | 5.88 GB     | 3.64 min           |\n|       |                |              |                 |                      |          |             |                    |\n| 3 B   | StableLM Alpha | None         | 1               | 2,097,152            | 2 x A100 | 12.75 GB    | 2.92 min           |\n| 3 B   | StableLM Alpha | None         | 2               | 2,097,152            | 2 x A100 | 12.94 GB    | 3.06 min           |\n| 3 B   | StableLM Alpha | None         | 4               | 2,097,152            | 2 x A100 | 13.45 GB    | 3.86 min           |\n|       |                |              |                 |                      |          |             | -                  |\n| 7 B   | Llama 2        | None         | 1               | 4,194,304            | 2 x A100 | 22.18 GB    | 5.93 min           |\n| 7 B   | Llama 2        | None         | 2               | 4,194,304            | 2 x A100 | 22.47 GB    | 6.48 min           |\n| 7 B   | Llama 2        | None         | 4               | 4,194,304            | 2 x A100 | 23.39 GB    | 8.66 min           |\n|       |                |              |                 |                      |          |             |                    |\n| 13 B  | Llama 2        | None         | 1               | 6,553,600            | 2 x A100 | OOM         | -                  |\n| 13 B  | Llama 2        | bnb.nf4      | 1               | 6,553,600            | 2 x A100 | N/A         | -                  |\n| 13 B  | Llama 2        | bnb.nf4-dq   | 1               | 6,553,600            | 2 x A100 | N/A         | -                  |\n|       |                |              |                 |                      |          |             |                    |\n| 13 B  | Llama 2        | None         | 1               | 6,553,600            | 4 x A100 | 35.57 GB    | 10.25 min          |\n| 40 B  | Falcon         | None         | 1               | 12,042,240           | 4 x A100 | OOM         | -                  |\n\n&nbsp;\n\n## Single-GPU Inference\n\n| Size  | Model          | Quantization | GPU      | Max GPU RAM                               | Token/sec |\n|-------|----------------|--------------|----------|-------------------------------------------|-----------|\n| 1.3 B | phi-1.5        | None         | 1 x A100 | 2.86 GB                                   | 42.56     |\n| 1.3 B | phi-1.5        | bnb.nf4      | 1 x A100 | 1.39 GB                                   | 22.89     |\n| 1.3 B | phi-1.5        | bnb.nf4-dq   | 1 x A100 | 1.33 GB                                   | 22.75     |\n|       |                |              |          |                                           |           |\n| 3 B   | StableLM Alpha | None         | 1 x A100 | 7.30 GB                                   | 49.01     |\n| 3 B   | StableLM Alpha | bnb.nf4      | 1 x A100 | 3.20 GB                                   | 29.04     |\n| 3 B   | StableLM Alpha | bnb.nf4-dq   | 1 x A100 | 3.04 GB                                   | 27.15     |\n|       |                |              |          |                                           |           |\n| 7 B   | Llama 2        | None         | 1 x A100 | 13.52 GB                                  | 30.97     |\n| 7 B   | Llama 2        | bnb.nf4      | 1 x A100 | 4.57 GB                                   | 19.98     |\n| 7 B   | Llama 2        | bnb.nf4-dq   | 1 x A100 | 4.26 GB                                   | 17.3      |\n|       |                |              |          |                                           |           |\n| 13 B  | Llama 2        | None         | 1 x A100 | 26.21 GB                                  | 24.82     |\n| 13 B  | Llama 2        | bnb.nf4      | 1 x A100 | 8.32 GB                                   | 16.73     |\n| 13 B  | Llama 2        | bnb.nf4-dq   | 1 x A100 | 7.72 GB                                   | 14.43     |\n|       |                |              |          |                                           |           |\n| 34 B  | CodeLlama      | None         | 1 x A100 | OOM                                       | -         |\n| 34 B  | CodeLlama      | bnb.nf4      | 1 x A100 | 20.52 GB                                  | 14.32     |\n| 34 B  | CodeLlama      | bnb.nf4-dq   | 1 x A100 | 18.95 GB                                  | 12.37     |\n|       |                |              |          |                                           |           |\n| 40 B  | Falcon         | None         | 1 x A100 | OOM                                       | -         |\n| 40 B  | Falcon         | bnb.nf4      | 1 x A100 | 26.55 GB                                  | 13.25     |\n| 40 B  | Falcon         | bnb.nf4-dq   | 1 x A100 | 24.63 GB                                  | 11.64     |\n|       |                |              |          |                                           |           |\n| 70 B  | Llama 2        | None         | 1 x A100 | OOM                                       | -         |\n| 70 B  | Llama 2        | bnb.nf4      | 1 x A100 | CUDA error: CUBLAS_STATUS_NOT_INITIALIZED | -         |\n| 70 B  | Llama 2        | bnb.nf4-dq   | 1 x A100 | 37.21 GB                                  | 7.97      |\n"
  }
]